feat: better using get_default_tool_prompt_format

Summary:

Test Plan:
This commit is contained in:
Eric Huang 2025-03-03 12:09:07 -08:00
parent 7d111c7510
commit 5f3ec93a35
2 changed files with 6 additions and 7 deletions

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import copy
from typing import Any, AsyncGenerator, Dict, List, Optional from typing import Any, AsyncGenerator, Dict, List, Optional
from llama_stack import logcat from llama_stack import logcat
@ -54,7 +53,6 @@ from llama_stack.apis.tools import (
) )
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.providers.datatypes import RoutingTable from llama_stack.providers.datatypes import RoutingTable
from llama_stack.providers.utils.inference.prompt_adapter import get_default_tool_prompt_format
class VectorIORouter(VectorIO): class VectorIORouter(VectorIO):
@ -181,9 +179,6 @@ class InferenceRouter(Inference):
params["tool_prompt_format"] = tool_prompt_format params["tool_prompt_format"] = tool_prompt_format
tool_config = ToolConfig(**params) tool_config = ToolConfig(**params)
tool_config = copy.copy(tool_config)
tool_config.tool_prompt_format = tool_config.tool_prompt_format or get_default_tool_prompt_format(model_id)
tools = tools or [] tools = tools or []
if tool_config.tool_choice == ToolChoice.none: if tool_config.tool_choice == ToolChoice.none:
tools = [] tools = []

View file

@ -15,6 +15,7 @@ from typing import List, Optional, Tuple, Union
import httpx import httpx
from PIL import Image as PIL_Image from PIL import Image as PIL_Image
from llama_stack import logcat
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import (
ImageContentItem, ImageContentItem,
InterleavedContent, InterleavedContent,
@ -253,7 +254,8 @@ async def chat_completion_request_to_prompt(request: ChatCompletionRequest, llam
formatter = ChatFormat(tokenizer=Tokenizer.get_instance()) formatter = ChatFormat(tokenizer=Tokenizer.get_instance())
model_input = formatter.encode_dialog_prompt( model_input = formatter.encode_dialog_prompt(
request.messages, tool_prompt_format=request.tool_config.tool_prompt_format request.messages,
tool_prompt_format=request.tool_config.tool_prompt_format or get_default_tool_prompt_format(llama_model),
) )
return formatter.tokenizer.decode(model_input.tokens) return formatter.tokenizer.decode(model_input.tokens)
@ -267,7 +269,8 @@ async def chat_completion_request_to_model_input_info(
formatter = ChatFormat(tokenizer=Tokenizer.get_instance()) formatter = ChatFormat(tokenizer=Tokenizer.get_instance())
model_input = formatter.encode_dialog_prompt( model_input = formatter.encode_dialog_prompt(
request.messages, tool_prompt_format=request.tool_config.tool_prompt_format request.messages,
tool_prompt_format=request.tool_config.tool_prompt_format or get_default_tool_prompt_format(llama_model),
) )
return ( return (
formatter.tokenizer.decode(model_input.tokens), formatter.tokenizer.decode(model_input.tokens),
@ -461,6 +464,7 @@ def _get_tool_choice_prompt(tool_choice: ToolChoice | str, tools: List[ToolDefin
def get_default_tool_prompt_format(model: str) -> ToolPromptFormat: def get_default_tool_prompt_format(model: str) -> ToolPromptFormat:
llama_model = resolve_model(model) llama_model = resolve_model(model)
if llama_model is None: if llama_model is None:
logcat.warning("inference", f"Could not resolve model {model}, defaulting to json tool prompt format")
return ToolPromptFormat.json return ToolPromptFormat.json
if llama_model.model_family == ModelFamily.llama3_1 or ( if llama_model.model_family == ModelFamily.llama3_1 or (