From ee5e9b935ae2e61187878c2236c263aa87037d24 Mon Sep 17 00:00:00 2001 From: ehhuang Date: Mon, 3 Mar 2025 14:50:06 -0800 Subject: [PATCH] feat: better using get_default_tool_prompt_format (#1360) Summary: https://github.com/meta-llama/llama-stack/pull/1214 introduced `get_default_tool_prompt_format` but tried to use it on the raw identifier. Here we move calling this func later in the stack and rely on the inference provider to resolve the raw identifier into llama model, then call get_default_tool_prompt_format. Test Plan: ``` LLAMA_STACK_CONFIG=ollama pytest -s -v tests/client-sdk/inference/test_text_inference.py::test_text_chat_completion_with_tool_calling_and_non_streaming --inference-model=llama3.2:3b-instruct-fp16 --vision-inference-model="" ``` Before: image After: image --- llama_stack/distribution/routers/routers.py | 5 ----- llama_stack/providers/utils/inference/prompt_adapter.py | 8 ++++++-- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 5a5046a35..350c3c997 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -4,7 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import copy from typing import Any, AsyncGenerator, Dict, List, Optional from llama_stack import logcat @@ -54,7 +53,6 @@ from llama_stack.apis.tools import ( ) from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO from llama_stack.providers.datatypes import RoutingTable -from llama_stack.providers.utils.inference.prompt_adapter import get_default_tool_prompt_format class VectorIORouter(VectorIO): @@ -181,9 +179,6 @@ class InferenceRouter(Inference): params["tool_prompt_format"] = tool_prompt_format tool_config = ToolConfig(**params) - tool_config = copy.copy(tool_config) - tool_config.tool_prompt_format = tool_config.tool_prompt_format or get_default_tool_prompt_format(model_id) - tools = tools or [] if tool_config.tool_choice == ToolChoice.none: tools = [] diff --git a/llama_stack/providers/utils/inference/prompt_adapter.py b/llama_stack/providers/utils/inference/prompt_adapter.py index aa98b2170..37b1a8160 100644 --- a/llama_stack/providers/utils/inference/prompt_adapter.py +++ b/llama_stack/providers/utils/inference/prompt_adapter.py @@ -15,6 +15,7 @@ from typing import List, Optional, Tuple, Union import httpx from PIL import Image as PIL_Image +from llama_stack import logcat from llama_stack.apis.common.content_types import ( ImageContentItem, InterleavedContent, @@ -253,7 +254,8 @@ async def chat_completion_request_to_prompt(request: ChatCompletionRequest, llam formatter = ChatFormat(tokenizer=Tokenizer.get_instance()) model_input = formatter.encode_dialog_prompt( - request.messages, tool_prompt_format=request.tool_config.tool_prompt_format + request.messages, + tool_prompt_format=request.tool_config.tool_prompt_format or get_default_tool_prompt_format(llama_model), ) return formatter.tokenizer.decode(model_input.tokens) @@ -267,7 +269,8 @@ async def chat_completion_request_to_model_input_info( formatter = ChatFormat(tokenizer=Tokenizer.get_instance()) model_input = formatter.encode_dialog_prompt( - request.messages, tool_prompt_format=request.tool_config.tool_prompt_format + request.messages, + tool_prompt_format=request.tool_config.tool_prompt_format or get_default_tool_prompt_format(llama_model), ) return ( formatter.tokenizer.decode(model_input.tokens), @@ -461,6 +464,7 @@ def _get_tool_choice_prompt(tool_choice: ToolChoice | str, tools: List[ToolDefin def get_default_tool_prompt_format(model: str) -> ToolPromptFormat: llama_model = resolve_model(model) if llama_model is None: + logcat.warning("inference", f"Could not resolve model {model}, defaulting to json tool prompt format") return ToolPromptFormat.json if llama_model.model_family == ModelFamily.llama3_1 or (