From 14c38acf97f4a8521c46a20de9f540ec888d5d50 Mon Sep 17 00:00:00 2001 From: ehhuang Date: Mon, 24 Feb 2025 12:38:37 -0800 Subject: [PATCH] fix: set default tool_prompt_format in inference api (#1214) Summary: Currently we don't set the best tool_prompt_format according to model as promisd. Test Plan: Added print around raw model input and inspected manually --- [//]: # (BEGIN SAPLING FOOTER) Stack created with [Sapling](https://sapling-scm.com). Best reviewed with [ReviewStack](https://reviewstack.dev/meta-llama/llama-stack/pull/1214). * #1234 * __->__ #1214 --- llama_stack/distribution/routers/routers.py | 3 +++ .../providers/utils/inference/prompt_adapter.py | 17 +++++++++++++++++ 2 files changed, 20 insertions(+) diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index df4ed03d3..a7c0d63e5 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -52,6 +52,7 @@ from llama_stack.apis.tools import ( ) from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO from llama_stack.providers.datatypes import RoutingTable +from llama_stack.providers.utils.inference.prompt_adapter import get_default_tool_prompt_format class VectorIORouter(VectorIO): @@ -158,6 +159,8 @@ class InferenceRouter(Inference): params["tool_prompt_format"] = tool_prompt_format tool_config = ToolConfig(**params) + tool_config.tool_prompt_format = tool_config.tool_prompt_format or get_default_tool_prompt_format(model_id) + tools = tools or [] if tool_config.tool_choice == ToolChoice.none: tools = [] diff --git a/llama_stack/providers/utils/inference/prompt_adapter.py b/llama_stack/providers/utils/inference/prompt_adapter.py index ca6fe04fd..7e7ab3a1d 100644 --- a/llama_stack/providers/utils/inference/prompt_adapter.py +++ b/llama_stack/providers/utils/inference/prompt_adapter.py @@ -456,3 +456,20 @@ def _get_tool_choice_prompt(tool_choice: ToolChoice | str, tools: List[ToolDefin else: # specific tool return f"You MUST use the tool `{tool_choice}` to answer the user query." + + +def get_default_tool_prompt_format(model: str) -> ToolPromptFormat: + llama_model = resolve_model(model) + if llama_model is None: + return ToolPromptFormat.json + + if llama_model.model_family == ModelFamily.llama3_1 or ( + llama_model.model_family == ModelFamily.llama3_2 and is_multimodal(llama_model.core_model_id) + ): + # llama3.1 and llama3.2 multimodal models follow the same tool prompt format + return ToolPromptFormat.json + elif llama_model.model_family in (ModelFamily.llama3_2, ModelFamily.llama3_3): + # llama3.2 and llama3.3 models follow the same tool prompt format + return ToolPromptFormat.python_list + else: + return ToolPromptFormat.json