diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index df4ed03d3..a7c0d63e5 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -52,6 +52,7 @@ from llama_stack.apis.tools import ( ) from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO from llama_stack.providers.datatypes import RoutingTable +from llama_stack.providers.utils.inference.prompt_adapter import get_default_tool_prompt_format class VectorIORouter(VectorIO): @@ -158,6 +159,8 @@ class InferenceRouter(Inference): params["tool_prompt_format"] = tool_prompt_format tool_config = ToolConfig(**params) + tool_config.tool_prompt_format = tool_config.tool_prompt_format or get_default_tool_prompt_format(model_id) + tools = tools or [] if tool_config.tool_choice == ToolChoice.none: tools = [] diff --git a/llama_stack/providers/utils/inference/prompt_adapter.py b/llama_stack/providers/utils/inference/prompt_adapter.py index ca6fe04fd..7e7ab3a1d 100644 --- a/llama_stack/providers/utils/inference/prompt_adapter.py +++ b/llama_stack/providers/utils/inference/prompt_adapter.py @@ -456,3 +456,20 @@ def _get_tool_choice_prompt(tool_choice: ToolChoice | str, tools: List[ToolDefin else: # specific tool return f"You MUST use the tool `{tool_choice}` to answer the user query." + + +def get_default_tool_prompt_format(model: str) -> ToolPromptFormat: + llama_model = resolve_model(model) + if llama_model is None: + return ToolPromptFormat.json + + if llama_model.model_family == ModelFamily.llama3_1 or ( + llama_model.model_family == ModelFamily.llama3_2 and is_multimodal(llama_model.core_model_id) + ): + # llama3.1 and llama3.2 multimodal models follow the same tool prompt format + return ToolPromptFormat.json + elif llama_model.model_family in (ModelFamily.llama3_2, ModelFamily.llama3_3): + # llama3.2 and llama3.3 models follow the same tool prompt format + return ToolPromptFormat.python_list + else: + return ToolPromptFormat.json