fix: set default tool_prompt_format in inference api (#1214)

Summary:
Currently we don't set the best tool_prompt_format according to model as
promisd.

Test Plan:
Added print around raw model input and inspected manually
---
[//]: # (BEGIN SAPLING FOOTER)
Stack created with [Sapling](https://sapling-scm.com). Best reviewed
with
[ReviewStack](https://reviewstack.dev/meta-llama/llama-stack/pull/1214).
* #1234
* __->__ #1214
This commit is contained in:
ehhuang 2025-02-24 12:38:37 -08:00 committed by GitHub
parent c4987bc349
commit 14c38acf97
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 20 additions and 0 deletions

View file

@ -52,6 +52,7 @@ from llama_stack.apis.tools import (
)
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.providers.datatypes import RoutingTable
from llama_stack.providers.utils.inference.prompt_adapter import get_default_tool_prompt_format
class VectorIORouter(VectorIO):
@ -158,6 +159,8 @@ class InferenceRouter(Inference):
params["tool_prompt_format"] = tool_prompt_format
tool_config = ToolConfig(**params)
tool_config.tool_prompt_format = tool_config.tool_prompt_format or get_default_tool_prompt_format(model_id)
tools = tools or []
if tool_config.tool_choice == ToolChoice.none:
tools = []

View file

@ -456,3 +456,20 @@ def _get_tool_choice_prompt(tool_choice: ToolChoice | str, tools: List[ToolDefin
else:
# specific tool
return f"You MUST use the tool `{tool_choice}` to answer the user query."
def get_default_tool_prompt_format(model: str) -> ToolPromptFormat:
llama_model = resolve_model(model)
if llama_model is None:
return ToolPromptFormat.json
if llama_model.model_family == ModelFamily.llama3_1 or (
llama_model.model_family == ModelFamily.llama3_2 and is_multimodal(llama_model.core_model_id)
):
# llama3.1 and llama3.2 multimodal models follow the same tool prompt format
return ToolPromptFormat.json
elif llama_model.model_family in (ModelFamily.llama3_2, ModelFamily.llama3_3):
# llama3.2 and llama3.3 models follow the same tool prompt format
return ToolPromptFormat.python_list
else:
return ToolPromptFormat.json