From 9f2a7e6a743c4164c38f8cec01907f9c2f85f7f5 Mon Sep 17 00:00:00 2001 From: Ben Browning Date: Wed, 14 May 2025 07:00:53 -0400 Subject: [PATCH] fix: multiple tool calls in remote-vllm chat_completion This fixes an issue in how we used the tool_call_buf from streaming tool calls in the remote-vllm provider where it would end up concatenating parameters from multiple different tool call results instead of aggregating the results from each tool call separately. It also fixes an issue found while digging into that where we were accidentally mixing the json string form of tool call parameters with the string representation of the python form, which mean we'd end up with single quotes in what should be double-quoted json strings. The following tests are now passing 100% for the remote-vllm provider, where some of the test_text_inference were failing before this change: ``` VLLM_URL="http://localhost:8000/v1" INFERENCE_MODEL="RedHatAI/Llama-4-Scout-17B-16E-Instruct-FP8-dynamic" LLAMA_STACK_CONFIG=remote-vllm python -m pytest -v tests/integration/inference/test_text_inference.py --text-model "RedHatAI/Llama-4-Scout-17B-16E-Instruct-FP8-dynamic" VLLM_URL="http://localhost:8000/v1" INFERENCE_MODEL="RedHatAI/Llama-4-Scout-17B-16E-Instruct-FP8-dynamic" LLAMA_STACK_CONFIG=remote-vllm python -m pytest -v tests/integration/inference/test_vision_inference.py --vision-model "RedHatAI/Llama-4-Scout-17B-16E-Instruct-FP8-dynamic" ``` Many of the agent tests are passing, although some are failing due to bugs in vLLM's pythonic tool parser for Llama models. See the PR at https://github.com/vllm-project/vllm/pull/17917 and a gist at https://gist.github.com/bbrowning/b5007709015cb2aabd85e0bd08e6d60f for changes needed there, which will have to get made upstream in vLLM. Agent tests: ``` VLLM_URL="http://localhost:8000/v1" INFERENCE_MODEL="RedHatAI/Llama-4-Scout-17B-16E-Instruct-FP8-dynamic" LLAMA_STACK_CONFIG=remote-vllm python -m pytest -v tests/integration/agents/test_agents.py --text-model "RedHatAI/Llama-4-Scout-17B-16E-Instruct-FP8-dynamic" ```` Signed-off-by: Ben Browning --- .../providers/remote/inference/vllm/vllm.py | 36 ++-- .../utils/inference/openai_compat.py | 8 +- .../providers/inference/test_remote_vllm.py | 169 +++++++++++++++++- 3 files changed, 196 insertions(+), 17 deletions(-) diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index 070d94df8..d00218dd5 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -162,7 +162,7 @@ def _process_vllm_chat_completion_end_of_stream( finish_reason: str | None, last_chunk_content: str | None, current_event_type: ChatCompletionResponseEventType, - tool_call_buf: UnparseableToolCall, + tool_call_bufs: dict[str, UnparseableToolCall] | None = None, ) -> list[OpenAIChatCompletionChunk]: chunks = [] @@ -171,9 +171,8 @@ def _process_vllm_chat_completion_end_of_stream( else: stop_reason = StopReason.end_of_message - if tool_call_buf.tool_name: - # at least one tool call request is received - + tool_call_bufs = tool_call_bufs or {} + for _index, tool_call_buf in sorted(tool_call_bufs.items()): args_str = tool_call_buf.arguments or "{}" try: args = json.loads(args_str) @@ -225,8 +224,14 @@ def _process_vllm_chat_completion_end_of_stream( async def _process_vllm_chat_completion_stream_response( stream: AsyncGenerator[OpenAIChatCompletionChunk, None], ) -> AsyncGenerator: - event_type = ChatCompletionResponseEventType.start - tool_call_buf = UnparseableToolCall() + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.start, + delta=TextDelta(text=""), + ) + ) + event_type = ChatCompletionResponseEventType.progress + tool_call_bufs: dict[str, UnparseableToolCall] = {} end_of_stream_processed = False async for chunk in stream: @@ -235,17 +240,22 @@ async def _process_vllm_chat_completion_stream_response( return choice = chunk.choices[0] if choice.delta.tool_calls: - tool_call = convert_tool_call(choice.delta.tool_calls[0]) - tool_call_buf.tool_name += str(tool_call.tool_name) - tool_call_buf.call_id += tool_call.call_id - # TODO: remove str() when dict type for 'arguments' is no longer allowed - tool_call_buf.arguments += str(tool_call.arguments) + for delta_tool_call in choice.delta.tool_calls: + tool_call = convert_tool_call(delta_tool_call) + if delta_tool_call.index not in tool_call_bufs: + tool_call_bufs[delta_tool_call.index] = UnparseableToolCall() + tool_call_buf = tool_call_bufs[delta_tool_call.index] + tool_call_buf.tool_name += str(tool_call.tool_name) + tool_call_buf.call_id += tool_call.call_id + tool_call_buf.arguments += ( + tool_call.arguments if isinstance(tool_call.arguments, str) else json.dumps(tool_call.arguments) + ) if choice.finish_reason: chunks = _process_vllm_chat_completion_end_of_stream( finish_reason=choice.finish_reason, last_chunk_content=choice.delta.content, current_event_type=event_type, - tool_call_buf=tool_call_buf, + tool_call_bufs=tool_call_bufs, ) for c in chunks: yield c @@ -266,7 +276,7 @@ async def _process_vllm_chat_completion_stream_response( # the stream ended without a chunk containing finish_reason - we have to generate the # respective completion chunks manually chunks = _process_vllm_chat_completion_end_of_stream( - finish_reason=None, last_chunk_content=None, current_event_type=event_type, tool_call_buf=tool_call_buf + finish_reason=None, last_chunk_content=None, current_event_type=event_type, tool_call_bufs=tool_call_bufs ) for c in chunks: yield c diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index e2314d44f..cc0000528 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -531,13 +531,19 @@ async def convert_message_to_openai_dict(message: Message, download: bool = Fals tool_name = tc.tool_name if isinstance(tool_name, BuiltinTool): tool_name = tool_name.value + + # arguments_json can be None, so attempt it first and fall back to arguments + if hasattr(tc, "arguments_json") and tc.arguments_json: + arguments = tc.arguments_json + else: + arguments = json.dumps(tc.arguments) result["tool_calls"].append( { "id": tc.call_id, "type": "function", "function": { "name": tool_name, - "arguments": tc.arguments_json if hasattr(tc, "arguments_json") else json.dumps(tc.arguments), + "arguments": arguments, }, } ) diff --git a/tests/unit/providers/inference/test_remote_vllm.py b/tests/unit/providers/inference/test_remote_vllm.py index 6e1623131..f452d9fd9 100644 --- a/tests/unit/providers/inference/test_remote_vllm.py +++ b/tests/unit/providers/inference/test_remote_vllm.py @@ -24,6 +24,12 @@ from openai.types.chat.chat_completion_chunk import ( from openai.types.chat.chat_completion_chunk import ( ChoiceDelta as OpenAIChoiceDelta, ) +from openai.types.chat.chat_completion_chunk import ( + ChoiceDeltaToolCall as OpenAIChoiceDeltaToolCall, +) +from openai.types.chat.chat_completion_chunk import ( + ChoiceDeltaToolCallFunction as OpenAIChoiceDeltaToolCallFunction, +) from openai.types.model import Model as OpenAIModel from llama_stack.apis.inference import ( @@ -206,8 +212,164 @@ async def test_tool_call_delta_empty_tool_call_buf(): yield chunk chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())] - assert len(chunks) == 1 - assert chunks[0].event.stop_reason == StopReason.end_of_turn + assert len(chunks) == 2 + assert chunks[0].event.event_type.value == "start" + assert chunks[1].event.event_type.value == "complete" + assert chunks[1].event.stop_reason == StopReason.end_of_turn + + +@pytest.mark.asyncio +async def test_tool_call_delta_streaming_arguments_dict(): + async def mock_stream(): + mock_chunk_1 = OpenAIChatCompletionChunk( + id="chunk-1", + created=1, + model="foo", + object="chat.completion.chunk", + choices=[ + OpenAIChoice( + delta=OpenAIChoiceDelta( + content="", + tool_calls=[ + OpenAIChoiceDeltaToolCall( + id="tc_1", + index=1, + function=OpenAIChoiceDeltaToolCallFunction( + name="power", + arguments="", + ), + ) + ], + ), + finish_reason=None, + index=0, + ) + ], + ) + mock_chunk_2 = OpenAIChatCompletionChunk( + id="chunk-2", + created=1, + model="foo", + object="chat.completion.chunk", + choices=[ + OpenAIChoice( + delta=OpenAIChoiceDelta( + content="", + tool_calls=[ + OpenAIChoiceDeltaToolCall( + id="tc_1", + index=1, + function=OpenAIChoiceDeltaToolCallFunction( + name="power", + arguments='{"number": 28, "power": 3}', + ), + ) + ], + ), + finish_reason=None, + index=0, + ) + ], + ) + mock_chunk_3 = OpenAIChatCompletionChunk( + id="chunk-3", + created=1, + model="foo", + object="chat.completion.chunk", + choices=[ + OpenAIChoice(delta=OpenAIChoiceDelta(content="", tool_calls=None), finish_reason="tool_calls", index=0) + ], + ) + for chunk in [mock_chunk_1, mock_chunk_2, mock_chunk_3]: + yield chunk + + chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())] + assert len(chunks) == 3 + assert chunks[0].event.event_type.value == "start" + assert chunks[1].event.event_type.value == "progress" + assert chunks[1].event.delta.type == "tool_call" + assert chunks[1].event.delta.parse_status.value == "succeeded" + assert chunks[1].event.delta.tool_call.arguments_json == '{"number": 28, "power": 3}' + assert chunks[2].event.event_type.value == "complete" + + +@pytest.mark.asyncio +async def test_multiple_tool_calls(): + async def mock_stream(): + mock_chunk_1 = OpenAIChatCompletionChunk( + id="chunk-1", + created=1, + model="foo", + object="chat.completion.chunk", + choices=[ + OpenAIChoice( + delta=OpenAIChoiceDelta( + content="", + tool_calls=[ + OpenAIChoiceDeltaToolCall( + id="", + index=1, + function=OpenAIChoiceDeltaToolCallFunction( + name="power", + arguments='{"number": 28, "power": 3}', + ), + ), + ], + ), + finish_reason=None, + index=0, + ) + ], + ) + mock_chunk_2 = OpenAIChatCompletionChunk( + id="chunk-2", + created=1, + model="foo", + object="chat.completion.chunk", + choices=[ + OpenAIChoice( + delta=OpenAIChoiceDelta( + content="", + tool_calls=[ + OpenAIChoiceDeltaToolCall( + id="", + index=2, + function=OpenAIChoiceDeltaToolCallFunction( + name="multiple", + arguments='{"first_number": 4, "second_number": 7}', + ), + ), + ], + ), + finish_reason=None, + index=0, + ) + ], + ) + mock_chunk_3 = OpenAIChatCompletionChunk( + id="chunk-3", + created=1, + model="foo", + object="chat.completion.chunk", + choices=[ + OpenAIChoice(delta=OpenAIChoiceDelta(content="", tool_calls=None), finish_reason="tool_calls", index=0) + ], + ) + for chunk in [mock_chunk_1, mock_chunk_2, mock_chunk_3]: + yield chunk + + chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())] + assert len(chunks) == 4 + assert chunks[0].event.event_type.value == "start" + assert chunks[1].event.event_type.value == "progress" + assert chunks[1].event.delta.type == "tool_call" + assert chunks[1].event.delta.parse_status.value == "succeeded" + assert chunks[1].event.delta.tool_call.arguments_json == '{"number": 28, "power": 3}' + assert chunks[2].event.event_type.value == "progress" + assert chunks[2].event.delta.type == "tool_call" + assert chunks[2].event.delta.parse_status.value == "succeeded" + assert chunks[2].event.delta.tool_call.arguments_json == '{"first_number": 4, "second_number": 7}' + assert chunks[3].event.event_type.value == "complete" @pytest.mark.asyncio @@ -231,7 +393,8 @@ async def test_process_vllm_chat_completion_stream_response_no_choices(): yield chunk chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())] - assert len(chunks) == 0 + assert len(chunks) == 1 + assert chunks[0].event.event_type.value == "start" def test_chat_completion_doesnt_block_event_loop(caplog):