mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
feat: better using get_default_tool_prompt_format (#1360)
Summary: https://github.com/meta-llama/llama-stack/pull/1214 introduced `get_default_tool_prompt_format` but tried to use it on the raw identifier. Here we move calling this func later in the stack and rely on the inference provider to resolve the raw identifier into llama model, then call get_default_tool_prompt_format. Test Plan: ``` LLAMA_STACK_CONFIG=ollama pytest -s -v tests/client-sdk/inference/test_text_inference.py::test_text_chat_completion_with_tool_calling_and_non_streaming --inference-model=llama3.2:3b-instruct-fp16 --vision-inference-model="" ``` Before: <img width="1288" alt="image" src="https://github.com/user-attachments/assets/918c7839-1f45-4540-864e-4b842cc367df" /> After: <img width="1522" alt="image" src="https://github.com/user-attachments/assets/447d78af-b3b9-4837-8cb7-6ac549005efe" />
This commit is contained in:
parent
386c806c70
commit
ee5e9b935a
2 changed files with 6 additions and 7 deletions
|
@ -15,6 +15,7 @@ from typing import List, Optional, Tuple, Union
|
|||
import httpx
|
||||
from PIL import Image as PIL_Image
|
||||
|
||||
from llama_stack import logcat
|
||||
from llama_stack.apis.common.content_types import (
|
||||
ImageContentItem,
|
||||
InterleavedContent,
|
||||
|
@ -253,7 +254,8 @@ async def chat_completion_request_to_prompt(request: ChatCompletionRequest, llam
|
|||
|
||||
formatter = ChatFormat(tokenizer=Tokenizer.get_instance())
|
||||
model_input = formatter.encode_dialog_prompt(
|
||||
request.messages, tool_prompt_format=request.tool_config.tool_prompt_format
|
||||
request.messages,
|
||||
tool_prompt_format=request.tool_config.tool_prompt_format or get_default_tool_prompt_format(llama_model),
|
||||
)
|
||||
return formatter.tokenizer.decode(model_input.tokens)
|
||||
|
||||
|
@ -267,7 +269,8 @@ async def chat_completion_request_to_model_input_info(
|
|||
|
||||
formatter = ChatFormat(tokenizer=Tokenizer.get_instance())
|
||||
model_input = formatter.encode_dialog_prompt(
|
||||
request.messages, tool_prompt_format=request.tool_config.tool_prompt_format
|
||||
request.messages,
|
||||
tool_prompt_format=request.tool_config.tool_prompt_format or get_default_tool_prompt_format(llama_model),
|
||||
)
|
||||
return (
|
||||
formatter.tokenizer.decode(model_input.tokens),
|
||||
|
@ -461,6 +464,7 @@ def _get_tool_choice_prompt(tool_choice: ToolChoice | str, tools: List[ToolDefin
|
|||
def get_default_tool_prompt_format(model: str) -> ToolPromptFormat:
|
||||
llama_model = resolve_model(model)
|
||||
if llama_model is None:
|
||||
logcat.warning("inference", f"Could not resolve model {model}, defaulting to json tool prompt format")
|
||||
return ToolPromptFormat.json
|
||||
|
||||
if llama_model.model_family == ModelFamily.llama3_1 or (
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue