fix: pass tool_prompt_format to chat_formatter (#1198)

Summary:

Need this to format the completion message with tool_calls correctly.
See added unittest.

Test Plan:

python -m unittest
llama_stack.providers.tests.inference.test_prompt_adapter
This commit is contained in:
ehhuang 2025-02-20 21:38:35 -08:00 committed by GitHub
parent 33a64eb5ec
commit cfa752fc92
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 50 additions and 2 deletions

View file

@ -8,7 +8,10 @@ import unittest
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionRequest, ChatCompletionRequest,
CompletionMessage,
StopReason,
SystemMessage, SystemMessage,
ToolCall,
ToolConfig, ToolConfig,
UserMessage, UserMessage,
) )
@ -20,6 +23,7 @@ from llama_stack.models.llama.datatypes import (
) )
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_messages, chat_completion_request_to_messages,
chat_completion_request_to_prompt,
) )
MODEL = "Llama3.1-8B-Instruct" MODEL = "Llama3.1-8B-Instruct"
@ -119,6 +123,46 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
self.assertTrue("Return function calls in JSON format" in messages[1].content) self.assertTrue("Return function calls in JSON format" in messages[1].content)
self.assertEqual(messages[-1].content, content) self.assertEqual(messages[-1].content, content)
async def test_completion_message_encoding(self):
request = ChatCompletionRequest(
model=MODEL3_2,
messages=[
UserMessage(content="hello"),
CompletionMessage(
content="",
stop_reason=StopReason.end_of_turn,
tool_calls=[
ToolCall(
tool_name="custom1",
arguments={"param1": "value1"},
call_id="123",
)
],
),
],
tools=[
ToolDefinition(
tool_name="custom1",
description="custom1 tool",
parameters={
"param1": ToolParamDefinition(
param_type="str",
description="param1 description",
required=True,
),
},
),
],
tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.python_list),
)
prompt = await chat_completion_request_to_prompt(request, request.model)
self.assertIn('[custom1(param1="value1")]', prompt)
request.model = MODEL
request.tool_config.tool_prompt_format = ToolPromptFormat.json
prompt = await chat_completion_request_to_prompt(request, request.model)
self.assertIn('{"type": "function", "name": "custom1", "parameters": {"param1": "value1"}}', prompt)
async def test_user_provided_system_message(self): async def test_user_provided_system_message(self):
content = "Hello !" content = "Hello !"
system_prompt = "You are a pirate" system_prompt = "You are a pirate"

View file

@ -252,7 +252,9 @@ async def chat_completion_request_to_prompt(request: ChatCompletionRequest, llam
request = await convert_request_to_raw(request) request = await convert_request_to_raw(request)
formatter = ChatFormat(tokenizer=Tokenizer.get_instance()) formatter = ChatFormat(tokenizer=Tokenizer.get_instance())
model_input = formatter.encode_dialog_prompt(request.messages) model_input = formatter.encode_dialog_prompt(
request.messages, tool_prompt_format=request.tool_config.tool_prompt_format
)
return formatter.tokenizer.decode(model_input.tokens) return formatter.tokenizer.decode(model_input.tokens)
@ -264,7 +266,9 @@ async def chat_completion_request_to_model_input_info(
request = await convert_request_to_raw(request) request = await convert_request_to_raw(request)
formatter = ChatFormat(tokenizer=Tokenizer.get_instance()) formatter = ChatFormat(tokenizer=Tokenizer.get_instance())
model_input = formatter.encode_dialog_prompt(request.messages) model_input = formatter.encode_dialog_prompt(
request.messages, tool_prompt_format=request.tool_config.tool_prompt_format
)
return ( return (
formatter.tokenizer.decode(model_input.tokens), formatter.tokenizer.decode(model_input.tokens),
len(model_input.tokens), len(model_input.tokens),