mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
fix: Including tool call in chat (#1931)
Include the tool call details with the chat when doing Rag with Remote vllm Fixes: #1929 With this PR the tool call is included in the chat returned to vllm, the model (meta-llama/Llama-3.1-8B-Instruct) the returns the answer as expected. Signed-off-by: Derek Higgins <derekh@redhat.com>
This commit is contained in:
parent
7ed137e963
commit
c8797f1125
3 changed files with 106 additions and 2 deletions
|
@ -524,11 +524,26 @@ async def convert_message_to_openai_dict(message: Message, download: bool = Fals
|
||||||
else:
|
else:
|
||||||
content = [await _convert_content(message.content)]
|
content = [await _convert_content(message.content)]
|
||||||
|
|
||||||
return {
|
result = {
|
||||||
"role": message.role,
|
"role": message.role,
|
||||||
"content": content,
|
"content": content,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if hasattr(message, "tool_calls") and message.tool_calls:
|
||||||
|
result["tool_calls"] = []
|
||||||
|
for tc in message.tool_calls:
|
||||||
|
result["tool_calls"].append(
|
||||||
|
{
|
||||||
|
"id": tc.call_id,
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": tc.tool_name,
|
||||||
|
"arguments": tc.arguments_json if hasattr(tc, "arguments_json") else json.dumps(tc.arguments),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
class UnparseableToolCall(BaseModel):
|
class UnparseableToolCall(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -28,12 +28,15 @@ from openai.types.model import Model as OpenAIModel
|
||||||
|
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
|
CompletionMessage,
|
||||||
|
SystemMessage,
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
ToolConfig,
|
ToolConfig,
|
||||||
|
ToolResponseMessage,
|
||||||
UserMessage,
|
UserMessage,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.models import Model
|
from llama_stack.apis.models import Model
|
||||||
from llama_stack.models.llama.datatypes import StopReason
|
from llama_stack.models.llama.datatypes import StopReason, ToolCall
|
||||||
from llama_stack.providers.remote.inference.vllm.config import VLLMInferenceAdapterConfig
|
from llama_stack.providers.remote.inference.vllm.config import VLLMInferenceAdapterConfig
|
||||||
from llama_stack.providers.remote.inference.vllm.vllm import (
|
from llama_stack.providers.remote.inference.vllm.vllm import (
|
||||||
VLLMInferenceAdapter,
|
VLLMInferenceAdapter,
|
||||||
|
@ -135,6 +138,49 @@ async def test_old_vllm_tool_choice(vllm_inference_adapter):
|
||||||
assert request.tool_config.tool_choice == ToolChoice.none
|
assert request.tool_config.tool_choice == ToolChoice.none
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tool_call_response(vllm_inference_adapter):
|
||||||
|
"""Verify that tool call arguments from a CompletionMessage are correctly converted
|
||||||
|
into the expected JSON format."""
|
||||||
|
|
||||||
|
# Patch the call to vllm so we can inspect the arguments sent were correct
|
||||||
|
with patch.object(
|
||||||
|
vllm_inference_adapter.client.chat.completions, "create", new_callable=AsyncMock
|
||||||
|
) as mock_nonstream_completion:
|
||||||
|
messages = [
|
||||||
|
SystemMessage(content="You are a helpful assistant"),
|
||||||
|
UserMessage(content="How many?"),
|
||||||
|
CompletionMessage(
|
||||||
|
content="",
|
||||||
|
stop_reason=StopReason.end_of_turn,
|
||||||
|
tool_calls=[
|
||||||
|
ToolCall(
|
||||||
|
call_id="foo",
|
||||||
|
tool_name="knowledge_search",
|
||||||
|
arguments={"query": "How many?"},
|
||||||
|
arguments_json='{"query": "How many?"}',
|
||||||
|
)
|
||||||
|
],
|
||||||
|
),
|
||||||
|
ToolResponseMessage(call_id="foo", content="knowledge_search found 5...."),
|
||||||
|
]
|
||||||
|
await vllm_inference_adapter.chat_completion(
|
||||||
|
"mock-model",
|
||||||
|
messages,
|
||||||
|
stream=False,
|
||||||
|
tools=[],
|
||||||
|
tool_config=ToolConfig(tool_choice=ToolChoice.auto),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert mock_nonstream_completion.call_args.kwargs["messages"][2]["tool_calls"] == [
|
||||||
|
{
|
||||||
|
"id": "foo",
|
||||||
|
"type": "function",
|
||||||
|
"function": {"name": "knowledge_search", "arguments": '{"query": "How many?"}'},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_tool_call_delta_empty_tool_call_buf():
|
async def test_tool_call_delta_empty_tool_call_buf():
|
||||||
"""
|
"""
|
||||||
|
|
43
tests/unit/providers/utils/inference/test_openai_compat.py
Normal file
43
tests/unit/providers/utils/inference/test_openai_compat.py
Normal file
|
@ -0,0 +1,43 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from llama_stack.apis.common.content_types import TextContentItem
|
||||||
|
from llama_stack.apis.inference.inference import CompletionMessage, UserMessage
|
||||||
|
from llama_stack.models.llama.datatypes import StopReason, ToolCall
|
||||||
|
from llama_stack.providers.utils.inference.openai_compat import convert_message_to_openai_dict
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_convert_message_to_openai_dict():
|
||||||
|
message = UserMessage(content=[TextContentItem(text="Hello, world!")], role="user")
|
||||||
|
assert await convert_message_to_openai_dict(message) == {
|
||||||
|
"role": "user",
|
||||||
|
"content": [{"type": "text", "text": "Hello, world!"}],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Test convert_message_to_openai_dict with a tool call
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_convert_message_to_openai_dict_with_tool_call():
|
||||||
|
message = CompletionMessage(
|
||||||
|
content="",
|
||||||
|
tool_calls=[
|
||||||
|
ToolCall(call_id="123", tool_name="test_tool", arguments_json='{"foo": "bar"}', arguments={"foo": "bar"})
|
||||||
|
],
|
||||||
|
stop_reason=StopReason.end_of_turn,
|
||||||
|
)
|
||||||
|
|
||||||
|
openai_dict = await convert_message_to_openai_dict(message)
|
||||||
|
|
||||||
|
assert openai_dict == {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [{"type": "text", "text": ""}],
|
||||||
|
"tool_calls": [
|
||||||
|
{"id": "123", "type": "function", "function": {"name": "test_tool", "arguments": '{"foo": "bar"}'}}
|
||||||
|
],
|
||||||
|
}
|
Loading…
Add table
Add a link
Reference in a new issue