diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 5a5046a35..350c3c997 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -4,7 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import copy from typing import Any, AsyncGenerator, Dict, List, Optional from llama_stack import logcat @@ -54,7 +53,6 @@ from llama_stack.apis.tools import ( ) from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO from llama_stack.providers.datatypes import RoutingTable -from llama_stack.providers.utils.inference.prompt_adapter import get_default_tool_prompt_format class VectorIORouter(VectorIO): @@ -181,9 +179,6 @@ class InferenceRouter(Inference): params["tool_prompt_format"] = tool_prompt_format tool_config = ToolConfig(**params) - tool_config = copy.copy(tool_config) - tool_config.tool_prompt_format = tool_config.tool_prompt_format or get_default_tool_prompt_format(model_id) - tools = tools or [] if tool_config.tool_choice == ToolChoice.none: tools = [] diff --git a/llama_stack/providers/utils/inference/prompt_adapter.py b/llama_stack/providers/utils/inference/prompt_adapter.py index aa98b2170..37b1a8160 100644 --- a/llama_stack/providers/utils/inference/prompt_adapter.py +++ b/llama_stack/providers/utils/inference/prompt_adapter.py @@ -15,6 +15,7 @@ from typing import List, Optional, Tuple, Union import httpx from PIL import Image as PIL_Image +from llama_stack import logcat from llama_stack.apis.common.content_types import ( ImageContentItem, InterleavedContent, @@ -253,7 +254,8 @@ async def chat_completion_request_to_prompt(request: ChatCompletionRequest, llam formatter = ChatFormat(tokenizer=Tokenizer.get_instance()) model_input = formatter.encode_dialog_prompt( - request.messages, tool_prompt_format=request.tool_config.tool_prompt_format + request.messages, + tool_prompt_format=request.tool_config.tool_prompt_format or get_default_tool_prompt_format(llama_model), ) return formatter.tokenizer.decode(model_input.tokens) @@ -267,7 +269,8 @@ async def chat_completion_request_to_model_input_info( formatter = ChatFormat(tokenizer=Tokenizer.get_instance()) model_input = formatter.encode_dialog_prompt( - request.messages, tool_prompt_format=request.tool_config.tool_prompt_format + request.messages, + tool_prompt_format=request.tool_config.tool_prompt_format or get_default_tool_prompt_format(llama_model), ) return ( formatter.tokenizer.decode(model_input.tokens), @@ -461,6 +464,7 @@ def _get_tool_choice_prompt(tool_choice: ToolChoice | str, tools: List[ToolDefin def get_default_tool_prompt_format(model: str) -> ToolPromptFormat: llama_model = resolve_model(model) if llama_model is None: + logcat.warning("inference", f"Could not resolve model {model}, defaulting to json tool prompt format") return ToolPromptFormat.json if llama_model.model_family == ModelFamily.llama3_1 or (