feat: better using get_default_tool_prompt_format (#1360)

Summary:
https://github.com/meta-llama/llama-stack/pull/1214 introduced
`get_default_tool_prompt_format` but tried to use it on the raw
identifier.

Here we move calling this func later in the stack and rely on the
inference provider to resolve the raw identifier into llama model, then
call get_default_tool_prompt_format.

Test Plan:
```
LLAMA_STACK_CONFIG=ollama pytest -s -v tests/client-sdk/inference/test_text_inference.py::test_text_chat_completion_with_tool_calling_and_non_streaming --inference-model=llama3.2:3b-instruct-fp16 --vision-inference-model=""
```

Before:

<img width="1288" alt="image"
src="https://github.com/user-attachments/assets/918c7839-1f45-4540-864e-4b842cc367df"
/>

After:
<img width="1522" alt="image"
src="https://github.com/user-attachments/assets/447d78af-b3b9-4837-8cb7-6ac549005efe"
/>
This commit is contained in:
ehhuang 2025-03-03 14:50:06 -08:00 committed by GitHub
parent 386c806c70
commit ee5e9b935a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 6 additions and 7 deletions

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import copy
from typing import Any, AsyncGenerator, Dict, List, Optional from typing import Any, AsyncGenerator, Dict, List, Optional
from llama_stack import logcat from llama_stack import logcat
@ -54,7 +53,6 @@ from llama_stack.apis.tools import (
) )
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.providers.datatypes import RoutingTable from llama_stack.providers.datatypes import RoutingTable
from llama_stack.providers.utils.inference.prompt_adapter import get_default_tool_prompt_format
class VectorIORouter(VectorIO): class VectorIORouter(VectorIO):
@ -181,9 +179,6 @@ class InferenceRouter(Inference):
params["tool_prompt_format"] = tool_prompt_format params["tool_prompt_format"] = tool_prompt_format
tool_config = ToolConfig(**params) tool_config = ToolConfig(**params)
tool_config = copy.copy(tool_config)
tool_config.tool_prompt_format = tool_config.tool_prompt_format or get_default_tool_prompt_format(model_id)
tools = tools or [] tools = tools or []
if tool_config.tool_choice == ToolChoice.none: if tool_config.tool_choice == ToolChoice.none:
tools = [] tools = []

View file

@ -15,6 +15,7 @@ from typing import List, Optional, Tuple, Union
import httpx import httpx
from PIL import Image as PIL_Image from PIL import Image as PIL_Image
from llama_stack import logcat
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import (
ImageContentItem, ImageContentItem,
InterleavedContent, InterleavedContent,
@ -253,7 +254,8 @@ async def chat_completion_request_to_prompt(request: ChatCompletionRequest, llam
formatter = ChatFormat(tokenizer=Tokenizer.get_instance()) formatter = ChatFormat(tokenizer=Tokenizer.get_instance())
model_input = formatter.encode_dialog_prompt( model_input = formatter.encode_dialog_prompt(
request.messages, tool_prompt_format=request.tool_config.tool_prompt_format request.messages,
tool_prompt_format=request.tool_config.tool_prompt_format or get_default_tool_prompt_format(llama_model),
) )
return formatter.tokenizer.decode(model_input.tokens) return formatter.tokenizer.decode(model_input.tokens)
@ -267,7 +269,8 @@ async def chat_completion_request_to_model_input_info(
formatter = ChatFormat(tokenizer=Tokenizer.get_instance()) formatter = ChatFormat(tokenizer=Tokenizer.get_instance())
model_input = formatter.encode_dialog_prompt( model_input = formatter.encode_dialog_prompt(
request.messages, tool_prompt_format=request.tool_config.tool_prompt_format request.messages,
tool_prompt_format=request.tool_config.tool_prompt_format or get_default_tool_prompt_format(llama_model),
) )
return ( return (
formatter.tokenizer.decode(model_input.tokens), formatter.tokenizer.decode(model_input.tokens),
@ -461,6 +464,7 @@ def _get_tool_choice_prompt(tool_choice: ToolChoice | str, tools: List[ToolDefin
def get_default_tool_prompt_format(model: str) -> ToolPromptFormat: def get_default_tool_prompt_format(model: str) -> ToolPromptFormat:
llama_model = resolve_model(model) llama_model = resolve_model(model)
if llama_model is None: if llama_model is None:
logcat.warning("inference", f"Could not resolve model {model}, defaulting to json tool prompt format")
return ToolPromptFormat.json return ToolPromptFormat.json
if llama_model.model_family == ModelFamily.llama3_1 or ( if llama_model.model_family == ModelFamily.llama3_1 or (