fix: remote vLLM tool execution now works when the last chunk contains the call arguments (#2112)

# What does this PR do?
Closes #2111.
Fixes an error causing Llama Stack to just return `<tool_call>` and
complete the turn without actually executing the tool. See the issue
description for more detail.

## Test Plan
1) Ran existing unit tests
2) Added a dedicated test verifying correct behavior in this edge case
3) Ran the code snapshot from #2111
This commit is contained in:
Ilya Kolchinsky 2025-05-14 11:38:00 +02:00 committed by GitHub
parent 1de0dfaab5
commit 43d4447ff0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 87 additions and 7 deletions

View file

@ -168,6 +168,12 @@ async def _process_vllm_chat_completion_stream_response(
log.warning("vLLM failed to generation any completions - check the vLLM server logs for an error.")
continue
choice = chunk.choices[0]
if choice.delta.tool_calls:
tool_call = convert_tool_call(choice.delta.tool_calls[0])
tool_call_buf.tool_name += str(tool_call.tool_name)
tool_call_buf.call_id += tool_call.call_id
# TODO: remove str() when dict type for 'arguments' is no longer allowed
tool_call_buf.arguments += str(tool_call.arguments)
if choice.finish_reason:
args_str = tool_call_buf.arguments
args = None
@ -208,13 +214,7 @@ async def _process_vllm_chat_completion_stream_response(
stop_reason=_convert_to_vllm_finish_reason(choice.finish_reason),
)
)
elif choice.delta.tool_calls:
tool_call = convert_tool_call(choice.delta.tool_calls[0])
tool_call_buf.tool_name += str(tool_call.tool_name)
tool_call_buf.call_id += tool_call.call_id
# TODO: remove str() when dict type for 'arguments' is no longer allowed
tool_call_buf.arguments += str(tool_call.arguments)
else:
elif not choice.delta.tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=event_type,

View file

@ -28,6 +28,7 @@ from openai.types.model import Model as OpenAIModel
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponseEventType,
CompletionMessage,
SystemMessage,
ToolChoice,
@ -294,3 +295,82 @@ async def test_get_params_empty_tools(vllm_inference_adapter):
)
params = await vllm_inference_adapter._get_params(request)
assert "tools" not in params
@pytest.mark.asyncio
async def test_process_vllm_chat_completion_stream_response_tool_call_args_last_chunk():
"""
Tests the edge case where the model returns the arguments for the tool call in the same chunk that
contains the finish reason (i.e., the last one).
We want to make sure the tool call is executed in this case, and the parameters are passed correctly.
"""
mock_tool_name = "mock_tool"
mock_tool_arguments = {"arg1": 0, "arg2": 100}
mock_tool_arguments_str = json.dumps(mock_tool_arguments)
async def mock_stream():
mock_chunks = [
OpenAIChatCompletionChunk(
id="chunk-1",
created=1,
model="foo",
object="chat.completion.chunk",
choices=[
{
"delta": {
"content": None,
"tool_calls": [
{
"index": 0,
"id": "mock_id",
"type": "function",
"function": {
"name": mock_tool_name,
"arguments": None,
},
}
],
},
"finish_reason": None,
"logprobs": None,
"index": 0,
}
],
),
OpenAIChatCompletionChunk(
id="chunk-1",
created=1,
model="foo",
object="chat.completion.chunk",
choices=[
{
"delta": {
"content": None,
"tool_calls": [
{
"index": 0,
"id": None,
"function": {
"name": None,
"arguments": mock_tool_arguments_str,
},
}
],
},
"finish_reason": "tool_calls",
"logprobs": None,
"index": 0,
}
],
),
]
for chunk in mock_chunks:
yield chunk
chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())]
assert len(chunks) == 2
assert chunks[-1].event.event_type == ChatCompletionResponseEventType.complete
assert chunks[-2].event.delta.type == "tool_call"
assert chunks[-2].event.delta.tool_call.tool_name == mock_tool_name
assert chunks[-2].event.delta.tool_call.arguments == mock_tool_arguments