From cfa752fc922cdf479699f7c69c66ba778eeec963 Mon Sep 17 00:00:00 2001 From: ehhuang Date: Thu, 20 Feb 2025 21:38:35 -0800 Subject: [PATCH] fix: pass tool_prompt_format to chat_formatter (#1198) Summary: Need this to format the completion message with tool_calls correctly. See added unittest. Test Plan: python -m unittest llama_stack.providers.tests.inference.test_prompt_adapter --- .../tests/inference/test_prompt_adapter.py | 44 +++++++++++++++++++ .../utils/inference/prompt_adapter.py | 8 +++- 2 files changed, 50 insertions(+), 2 deletions(-) diff --git a/llama_stack/providers/tests/inference/test_prompt_adapter.py b/llama_stack/providers/tests/inference/test_prompt_adapter.py index 323c6cb6a..2a6dbb561 100644 --- a/llama_stack/providers/tests/inference/test_prompt_adapter.py +++ b/llama_stack/providers/tests/inference/test_prompt_adapter.py @@ -8,7 +8,10 @@ import unittest from llama_stack.apis.inference import ( ChatCompletionRequest, + CompletionMessage, + StopReason, SystemMessage, + ToolCall, ToolConfig, UserMessage, ) @@ -20,6 +23,7 @@ from llama_stack.models.llama.datatypes import ( ) from llama_stack.providers.utils.inference.prompt_adapter import ( chat_completion_request_to_messages, + chat_completion_request_to_prompt, ) MODEL = "Llama3.1-8B-Instruct" @@ -119,6 +123,46 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase): self.assertTrue("Return function calls in JSON format" in messages[1].content) self.assertEqual(messages[-1].content, content) + async def test_completion_message_encoding(self): + request = ChatCompletionRequest( + model=MODEL3_2, + messages=[ + UserMessage(content="hello"), + CompletionMessage( + content="", + stop_reason=StopReason.end_of_turn, + tool_calls=[ + ToolCall( + tool_name="custom1", + arguments={"param1": "value1"}, + call_id="123", + ) + ], + ), + ], + tools=[ + ToolDefinition( + tool_name="custom1", + description="custom1 tool", + parameters={ + "param1": ToolParamDefinition( + param_type="str", + description="param1 description", + required=True, + ), + }, + ), + ], + tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.python_list), + ) + prompt = await chat_completion_request_to_prompt(request, request.model) + self.assertIn('[custom1(param1="value1")]', prompt) + + request.model = MODEL + request.tool_config.tool_prompt_format = ToolPromptFormat.json + prompt = await chat_completion_request_to_prompt(request, request.model) + self.assertIn('{"type": "function", "name": "custom1", "parameters": {"param1": "value1"}}', prompt) + async def test_user_provided_system_message(self): content = "Hello !" system_prompt = "You are a pirate" diff --git a/llama_stack/providers/utils/inference/prompt_adapter.py b/llama_stack/providers/utils/inference/prompt_adapter.py index 10fe442e8..ca6fe04fd 100644 --- a/llama_stack/providers/utils/inference/prompt_adapter.py +++ b/llama_stack/providers/utils/inference/prompt_adapter.py @@ -252,7 +252,9 @@ async def chat_completion_request_to_prompt(request: ChatCompletionRequest, llam request = await convert_request_to_raw(request) formatter = ChatFormat(tokenizer=Tokenizer.get_instance()) - model_input = formatter.encode_dialog_prompt(request.messages) + model_input = formatter.encode_dialog_prompt( + request.messages, tool_prompt_format=request.tool_config.tool_prompt_format + ) return formatter.tokenizer.decode(model_input.tokens) @@ -264,7 +266,9 @@ async def chat_completion_request_to_model_input_info( request = await convert_request_to_raw(request) formatter = ChatFormat(tokenizer=Tokenizer.get_instance()) - model_input = formatter.encode_dialog_prompt(request.messages) + model_input = formatter.encode_dialog_prompt( + request.messages, tool_prompt_format=request.tool_config.tool_prompt_format + ) return ( formatter.tokenizer.decode(model_input.tokens), len(model_input.tokens),