Fixed an error where a tool call fails when the arguments are passed in the chunk containing finish_reason.

This commit is contained in:
ilya-kolchinsky 2025-05-08 09:46:35 +02:00
parent fe5f5e530c
commit 55da406471
2 changed files with 87 additions and 7 deletions

View file

@ -168,6 +168,12 @@ async def _process_vllm_chat_completion_stream_response(
log.warning("vLLM failed to generation any completions - check the vLLM server logs for an error.")
continue
choice = chunk.choices[0]
if choice.delta.tool_calls:
tool_call = convert_tool_call(choice.delta.tool_calls[0])
tool_call_buf.tool_name += str(tool_call.tool_name)
tool_call_buf.call_id += tool_call.call_id
# TODO: remove str() when dict type for 'arguments' is no longer allowed
tool_call_buf.arguments += str(tool_call.arguments)
if choice.finish_reason:
args_str = tool_call_buf.arguments
args = None
@ -208,13 +214,7 @@ async def _process_vllm_chat_completion_stream_response(
stop_reason=_convert_to_vllm_finish_reason(choice.finish_reason),
)
)
elif choice.delta.tool_calls:
tool_call = convert_tool_call(choice.delta.tool_calls[0])
tool_call_buf.tool_name += str(tool_call.tool_name)
tool_call_buf.call_id += tool_call.call_id
# TODO: remove str() when dict type for 'arguments' is no longer allowed
tool_call_buf.arguments += str(tool_call.arguments)
else:
elif not choice.delta.tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=event_type,

View file

@ -28,6 +28,7 @@ from openai.types.model import Model as OpenAIModel
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponseEventType,
CompletionMessage,
SystemMessage,
ToolChoice,
@ -294,3 +295,82 @@ async def test_get_params_empty_tools(vllm_inference_adapter):
)
params = await vllm_inference_adapter._get_params(request)
assert "tools" not in params
@pytest.mark.asyncio
async def test_process_vllm_chat_completion_stream_response_tool_call_args_last_chunk():
"""
Tests the edge case where the model returns the arguments for the tool call in the same chunk that
contains the finish reason (i.e., the last one).
We want to make sure the tool call is executed in this case, and the parameters are passed correctly.
"""
mock_tool_name = "mock_tool"
mock_tool_arguments = {"arg1": 0, "arg2": 100}
mock_tool_arguments_str = json.dumps(mock_tool_arguments)
async def mock_stream():
mock_chunks = [
OpenAIChatCompletionChunk(
id="chunk-1",
created=1,
model="foo",
object="chat.completion.chunk",
choices=[
{
"delta": {
"content": None,
"tool_calls": [
{
"index": 0,
"id": "mock_id",
"type": "function",
"function": {
"name": mock_tool_name,
"arguments": None,
},
}
],
},
"finish_reason": None,
"logprobs": None,
"index": 0,
}
],
),
OpenAIChatCompletionChunk(
id="chunk-1",
created=1,
model="foo",
object="chat.completion.chunk",
choices=[
{
"delta": {
"content": None,
"tool_calls": [
{
"index": 0,
"id": None,
"function": {
"name": None,
"arguments": mock_tool_arguments_str,
},
}
],
},
"finish_reason": "tool_calls",
"logprobs": None,
"index": 0,
}
],
),
]
for chunk in mock_chunks:
yield chunk
chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())]
assert len(chunks) == 2
assert chunks[-1].event.event_type == ChatCompletionResponseEventType.complete
assert chunks[-2].event.delta.type == "tool_call"
assert chunks[-2].event.delta.tool_call.tool_name == mock_tool_name
assert chunks[-2].event.delta.tool_call.arguments == mock_tool_arguments