Including tool call in chat

Include the tool call details with the chat when doing
Rag with Remote vllm

Fixes: #1929

Signed-off-by: Derek Higgins <derekh@redhat.com>
This commit is contained in:
Derek Higgins 2025-04-10 13:20:09 +00:00
parent 45e08ff417
commit 6fe64ee169
3 changed files with 106 additions and 2 deletions

View file

@ -28,12 +28,15 @@ from openai.types.model import Model as OpenAIModel
from llama_stack.apis.inference import (
ChatCompletionRequest,
CompletionMessage,
SystemMessage,
ToolChoice,
ToolConfig,
ToolResponseMessage,
UserMessage,
)
from llama_stack.apis.models import Model
from llama_stack.models.llama.datatypes import StopReason
from llama_stack.models.llama.datatypes import StopReason, ToolCall
from llama_stack.providers.remote.inference.vllm.config import VLLMInferenceAdapterConfig
from llama_stack.providers.remote.inference.vllm.vllm import (
VLLMInferenceAdapter,
@ -135,6 +138,49 @@ async def test_old_vllm_tool_choice(vllm_inference_adapter):
assert request.tool_config.tool_choice == ToolChoice.none
@pytest.mark.asyncio
async def test_tool_call_response(vllm_inference_adapter):
"""Verify that tool call arguments from a CompletionMessage are correctly converted
into the expected JSON format."""
# Patch the call to vllm so we can inspect the arguments sent were correct
with patch.object(
vllm_inference_adapter.client.chat.completions, "create", new_callable=AsyncMock
) as mock_nonstream_completion:
messages = [
SystemMessage(content="You are a helpful assistant"),
UserMessage(content="How many?"),
CompletionMessage(
content="",
stop_reason=StopReason.end_of_turn,
tool_calls=[
ToolCall(
call_id="foo",
tool_name="knowledge_search",
arguments={"query": "How many?"},
arguments_json='{"query": "How many?"}',
)
],
),
ToolResponseMessage(call_id="foo", content="knowledge_search found 5...."),
]
await vllm_inference_adapter.chat_completion(
"mock-model",
messages,
stream=False,
tools=[],
tool_config=ToolConfig(tool_choice=ToolChoice.auto),
)
assert mock_nonstream_completion.call_args.kwargs["messages"][2]["tool_calls"] == [
{
"id": "foo",
"type": "function",
"function": {"name": "knowledge_search", "arguments": '{"query": "How many?"}'},
}
]
@pytest.mark.asyncio
async def test_tool_call_delta_empty_tool_call_buf():
"""