forked from phoenix-oss/llama-stack-mirror
fix: set default tool_prompt_format in inference api (#1214)
Summary: Currently we don't set the best tool_prompt_format according to model as promisd. Test Plan: Added print around raw model input and inspected manually --- [//]: # (BEGIN SAPLING FOOTER) Stack created with [Sapling](https://sapling-scm.com). Best reviewed with [ReviewStack](https://reviewstack.dev/meta-llama/llama-stack/pull/1214). * #1234 * __->__ #1214
This commit is contained in:
parent
c4987bc349
commit
14c38acf97
2 changed files with 20 additions and 0 deletions
|
@ -52,6 +52,7 @@ from llama_stack.apis.tools import (
|
|||
)
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||
from llama_stack.providers.datatypes import RoutingTable
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import get_default_tool_prompt_format
|
||||
|
||||
|
||||
class VectorIORouter(VectorIO):
|
||||
|
@ -158,6 +159,8 @@ class InferenceRouter(Inference):
|
|||
params["tool_prompt_format"] = tool_prompt_format
|
||||
tool_config = ToolConfig(**params)
|
||||
|
||||
tool_config.tool_prompt_format = tool_config.tool_prompt_format or get_default_tool_prompt_format(model_id)
|
||||
|
||||
tools = tools or []
|
||||
if tool_config.tool_choice == ToolChoice.none:
|
||||
tools = []
|
||||
|
|
|
@ -456,3 +456,20 @@ def _get_tool_choice_prompt(tool_choice: ToolChoice | str, tools: List[ToolDefin
|
|||
else:
|
||||
# specific tool
|
||||
return f"You MUST use the tool `{tool_choice}` to answer the user query."
|
||||
|
||||
|
||||
def get_default_tool_prompt_format(model: str) -> ToolPromptFormat:
|
||||
llama_model = resolve_model(model)
|
||||
if llama_model is None:
|
||||
return ToolPromptFormat.json
|
||||
|
||||
if llama_model.model_family == ModelFamily.llama3_1 or (
|
||||
llama_model.model_family == ModelFamily.llama3_2 and is_multimodal(llama_model.core_model_id)
|
||||
):
|
||||
# llama3.1 and llama3.2 multimodal models follow the same tool prompt format
|
||||
return ToolPromptFormat.json
|
||||
elif llama_model.model_family in (ModelFamily.llama3_2, ModelFamily.llama3_3):
|
||||
# llama3.2 and llama3.3 models follow the same tool prompt format
|
||||
return ToolPromptFormat.python_list
|
||||
else:
|
||||
return ToolPromptFormat.json
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue