From c8797f1125cfded745f0688944d783355b4cfc07 Mon Sep 17 00:00:00 2001 From: Derek Higgins Date: Fri, 25 Apr 2025 00:59:10 +0100 Subject: [PATCH] fix: Including tool call in chat (#1931) Include the tool call details with the chat when doing Rag with Remote vllm Fixes: #1929 With this PR the tool call is included in the chat returned to vllm, the model (meta-llama/Llama-3.1-8B-Instruct) the returns the answer as expected. Signed-off-by: Derek Higgins --- .../utils/inference/openai_compat.py | 17 ++++++- .../providers/inference/test_remote_vllm.py | 48 ++++++++++++++++++- .../utils/inference/test_openai_compat.py | 43 +++++++++++++++++ 3 files changed, 106 insertions(+), 2 deletions(-) create mode 100644 tests/unit/providers/utils/inference/test_openai_compat.py diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index f91e7d7dc..4d690287b 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -524,11 +524,26 @@ async def convert_message_to_openai_dict(message: Message, download: bool = Fals else: content = [await _convert_content(message.content)] - return { + result = { "role": message.role, "content": content, } + if hasattr(message, "tool_calls") and message.tool_calls: + result["tool_calls"] = [] + for tc in message.tool_calls: + result["tool_calls"].append( + { + "id": tc.call_id, + "type": "function", + "function": { + "name": tc.tool_name, + "arguments": tc.arguments_json if hasattr(tc, "arguments_json") else json.dumps(tc.arguments), + }, + } + ) + return result + class UnparseableToolCall(BaseModel): """ diff --git a/tests/unit/providers/inference/test_remote_vllm.py b/tests/unit/providers/inference/test_remote_vllm.py index 88399198d..b3172cad4 100644 --- a/tests/unit/providers/inference/test_remote_vllm.py +++ b/tests/unit/providers/inference/test_remote_vllm.py @@ -28,12 +28,15 @@ from openai.types.model import Model as OpenAIModel from llama_stack.apis.inference import ( ChatCompletionRequest, + CompletionMessage, + SystemMessage, ToolChoice, ToolConfig, + ToolResponseMessage, UserMessage, ) from llama_stack.apis.models import Model -from llama_stack.models.llama.datatypes import StopReason +from llama_stack.models.llama.datatypes import StopReason, ToolCall from llama_stack.providers.remote.inference.vllm.config import VLLMInferenceAdapterConfig from llama_stack.providers.remote.inference.vllm.vllm import ( VLLMInferenceAdapter, @@ -135,6 +138,49 @@ async def test_old_vllm_tool_choice(vllm_inference_adapter): assert request.tool_config.tool_choice == ToolChoice.none +@pytest.mark.asyncio +async def test_tool_call_response(vllm_inference_adapter): + """Verify that tool call arguments from a CompletionMessage are correctly converted + into the expected JSON format.""" + + # Patch the call to vllm so we can inspect the arguments sent were correct + with patch.object( + vllm_inference_adapter.client.chat.completions, "create", new_callable=AsyncMock + ) as mock_nonstream_completion: + messages = [ + SystemMessage(content="You are a helpful assistant"), + UserMessage(content="How many?"), + CompletionMessage( + content="", + stop_reason=StopReason.end_of_turn, + tool_calls=[ + ToolCall( + call_id="foo", + tool_name="knowledge_search", + arguments={"query": "How many?"}, + arguments_json='{"query": "How many?"}', + ) + ], + ), + ToolResponseMessage(call_id="foo", content="knowledge_search found 5...."), + ] + await vllm_inference_adapter.chat_completion( + "mock-model", + messages, + stream=False, + tools=[], + tool_config=ToolConfig(tool_choice=ToolChoice.auto), + ) + + assert mock_nonstream_completion.call_args.kwargs["messages"][2]["tool_calls"] == [ + { + "id": "foo", + "type": "function", + "function": {"name": "knowledge_search", "arguments": '{"query": "How many?"}'}, + } + ] + + @pytest.mark.asyncio async def test_tool_call_delta_empty_tool_call_buf(): """ diff --git a/tests/unit/providers/utils/inference/test_openai_compat.py b/tests/unit/providers/utils/inference/test_openai_compat.py new file mode 100644 index 000000000..eb02f8203 --- /dev/null +++ b/tests/unit/providers/utils/inference/test_openai_compat.py @@ -0,0 +1,43 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import pytest + +from llama_stack.apis.common.content_types import TextContentItem +from llama_stack.apis.inference.inference import CompletionMessage, UserMessage +from llama_stack.models.llama.datatypes import StopReason, ToolCall +from llama_stack.providers.utils.inference.openai_compat import convert_message_to_openai_dict + + +@pytest.mark.asyncio +async def test_convert_message_to_openai_dict(): + message = UserMessage(content=[TextContentItem(text="Hello, world!")], role="user") + assert await convert_message_to_openai_dict(message) == { + "role": "user", + "content": [{"type": "text", "text": "Hello, world!"}], + } + + +# Test convert_message_to_openai_dict with a tool call +@pytest.mark.asyncio +async def test_convert_message_to_openai_dict_with_tool_call(): + message = CompletionMessage( + content="", + tool_calls=[ + ToolCall(call_id="123", tool_name="test_tool", arguments_json='{"foo": "bar"}', arguments={"foo": "bar"}) + ], + stop_reason=StopReason.end_of_turn, + ) + + openai_dict = await convert_message_to_openai_dict(message) + + assert openai_dict == { + "role": "assistant", + "content": [{"type": "text", "text": ""}], + "tool_calls": [ + {"id": "123", "type": "function", "function": {"name": "test_tool", "arguments": '{"foo": "bar"}'}} + ], + }