diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index 070d94df8..d00218dd5 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -162,7 +162,7 @@ def _process_vllm_chat_completion_end_of_stream( finish_reason: str | None, last_chunk_content: str | None, current_event_type: ChatCompletionResponseEventType, - tool_call_buf: UnparseableToolCall, + tool_call_bufs: dict[str, UnparseableToolCall] | None = None, ) -> list[OpenAIChatCompletionChunk]: chunks = [] @@ -171,9 +171,8 @@ def _process_vllm_chat_completion_end_of_stream( else: stop_reason = StopReason.end_of_message - if tool_call_buf.tool_name: - # at least one tool call request is received - + tool_call_bufs = tool_call_bufs or {} + for _index, tool_call_buf in sorted(tool_call_bufs.items()): args_str = tool_call_buf.arguments or "{}" try: args = json.loads(args_str) @@ -225,8 +224,14 @@ def _process_vllm_chat_completion_end_of_stream( async def _process_vllm_chat_completion_stream_response( stream: AsyncGenerator[OpenAIChatCompletionChunk, None], ) -> AsyncGenerator: - event_type = ChatCompletionResponseEventType.start - tool_call_buf = UnparseableToolCall() + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.start, + delta=TextDelta(text=""), + ) + ) + event_type = ChatCompletionResponseEventType.progress + tool_call_bufs: dict[str, UnparseableToolCall] = {} end_of_stream_processed = False async for chunk in stream: @@ -235,17 +240,22 @@ async def _process_vllm_chat_completion_stream_response( return choice = chunk.choices[0] if choice.delta.tool_calls: - tool_call = convert_tool_call(choice.delta.tool_calls[0]) - tool_call_buf.tool_name += str(tool_call.tool_name) - tool_call_buf.call_id += tool_call.call_id - # TODO: remove str() when dict type for 'arguments' is no longer allowed - tool_call_buf.arguments += str(tool_call.arguments) + for delta_tool_call in choice.delta.tool_calls: + tool_call = convert_tool_call(delta_tool_call) + if delta_tool_call.index not in tool_call_bufs: + tool_call_bufs[delta_tool_call.index] = UnparseableToolCall() + tool_call_buf = tool_call_bufs[delta_tool_call.index] + tool_call_buf.tool_name += str(tool_call.tool_name) + tool_call_buf.call_id += tool_call.call_id + tool_call_buf.arguments += ( + tool_call.arguments if isinstance(tool_call.arguments, str) else json.dumps(tool_call.arguments) + ) if choice.finish_reason: chunks = _process_vllm_chat_completion_end_of_stream( finish_reason=choice.finish_reason, last_chunk_content=choice.delta.content, current_event_type=event_type, - tool_call_buf=tool_call_buf, + tool_call_bufs=tool_call_bufs, ) for c in chunks: yield c @@ -266,7 +276,7 @@ async def _process_vllm_chat_completion_stream_response( # the stream ended without a chunk containing finish_reason - we have to generate the # respective completion chunks manually chunks = _process_vllm_chat_completion_end_of_stream( - finish_reason=None, last_chunk_content=None, current_event_type=event_type, tool_call_buf=tool_call_buf + finish_reason=None, last_chunk_content=None, current_event_type=event_type, tool_call_bufs=tool_call_bufs ) for c in chunks: yield c diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index e2314d44f..cc0000528 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -531,13 +531,19 @@ async def convert_message_to_openai_dict(message: Message, download: bool = Fals tool_name = tc.tool_name if isinstance(tool_name, BuiltinTool): tool_name = tool_name.value + + # arguments_json can be None, so attempt it first and fall back to arguments + if hasattr(tc, "arguments_json") and tc.arguments_json: + arguments = tc.arguments_json + else: + arguments = json.dumps(tc.arguments) result["tool_calls"].append( { "id": tc.call_id, "type": "function", "function": { "name": tool_name, - "arguments": tc.arguments_json if hasattr(tc, "arguments_json") else json.dumps(tc.arguments), + "arguments": arguments, }, } ) diff --git a/tests/unit/providers/inference/test_remote_vllm.py b/tests/unit/providers/inference/test_remote_vllm.py index 6e1623131..f452d9fd9 100644 --- a/tests/unit/providers/inference/test_remote_vllm.py +++ b/tests/unit/providers/inference/test_remote_vllm.py @@ -24,6 +24,12 @@ from openai.types.chat.chat_completion_chunk import ( from openai.types.chat.chat_completion_chunk import ( ChoiceDelta as OpenAIChoiceDelta, ) +from openai.types.chat.chat_completion_chunk import ( + ChoiceDeltaToolCall as OpenAIChoiceDeltaToolCall, +) +from openai.types.chat.chat_completion_chunk import ( + ChoiceDeltaToolCallFunction as OpenAIChoiceDeltaToolCallFunction, +) from openai.types.model import Model as OpenAIModel from llama_stack.apis.inference import ( @@ -206,8 +212,164 @@ async def test_tool_call_delta_empty_tool_call_buf(): yield chunk chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())] - assert len(chunks) == 1 - assert chunks[0].event.stop_reason == StopReason.end_of_turn + assert len(chunks) == 2 + assert chunks[0].event.event_type.value == "start" + assert chunks[1].event.event_type.value == "complete" + assert chunks[1].event.stop_reason == StopReason.end_of_turn + + +@pytest.mark.asyncio +async def test_tool_call_delta_streaming_arguments_dict(): + async def mock_stream(): + mock_chunk_1 = OpenAIChatCompletionChunk( + id="chunk-1", + created=1, + model="foo", + object="chat.completion.chunk", + choices=[ + OpenAIChoice( + delta=OpenAIChoiceDelta( + content="", + tool_calls=[ + OpenAIChoiceDeltaToolCall( + id="tc_1", + index=1, + function=OpenAIChoiceDeltaToolCallFunction( + name="power", + arguments="", + ), + ) + ], + ), + finish_reason=None, + index=0, + ) + ], + ) + mock_chunk_2 = OpenAIChatCompletionChunk( + id="chunk-2", + created=1, + model="foo", + object="chat.completion.chunk", + choices=[ + OpenAIChoice( + delta=OpenAIChoiceDelta( + content="", + tool_calls=[ + OpenAIChoiceDeltaToolCall( + id="tc_1", + index=1, + function=OpenAIChoiceDeltaToolCallFunction( + name="power", + arguments='{"number": 28, "power": 3}', + ), + ) + ], + ), + finish_reason=None, + index=0, + ) + ], + ) + mock_chunk_3 = OpenAIChatCompletionChunk( + id="chunk-3", + created=1, + model="foo", + object="chat.completion.chunk", + choices=[ + OpenAIChoice(delta=OpenAIChoiceDelta(content="", tool_calls=None), finish_reason="tool_calls", index=0) + ], + ) + for chunk in [mock_chunk_1, mock_chunk_2, mock_chunk_3]: + yield chunk + + chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())] + assert len(chunks) == 3 + assert chunks[0].event.event_type.value == "start" + assert chunks[1].event.event_type.value == "progress" + assert chunks[1].event.delta.type == "tool_call" + assert chunks[1].event.delta.parse_status.value == "succeeded" + assert chunks[1].event.delta.tool_call.arguments_json == '{"number": 28, "power": 3}' + assert chunks[2].event.event_type.value == "complete" + + +@pytest.mark.asyncio +async def test_multiple_tool_calls(): + async def mock_stream(): + mock_chunk_1 = OpenAIChatCompletionChunk( + id="chunk-1", + created=1, + model="foo", + object="chat.completion.chunk", + choices=[ + OpenAIChoice( + delta=OpenAIChoiceDelta( + content="", + tool_calls=[ + OpenAIChoiceDeltaToolCall( + id="", + index=1, + function=OpenAIChoiceDeltaToolCallFunction( + name="power", + arguments='{"number": 28, "power": 3}', + ), + ), + ], + ), + finish_reason=None, + index=0, + ) + ], + ) + mock_chunk_2 = OpenAIChatCompletionChunk( + id="chunk-2", + created=1, + model="foo", + object="chat.completion.chunk", + choices=[ + OpenAIChoice( + delta=OpenAIChoiceDelta( + content="", + tool_calls=[ + OpenAIChoiceDeltaToolCall( + id="", + index=2, + function=OpenAIChoiceDeltaToolCallFunction( + name="multiple", + arguments='{"first_number": 4, "second_number": 7}', + ), + ), + ], + ), + finish_reason=None, + index=0, + ) + ], + ) + mock_chunk_3 = OpenAIChatCompletionChunk( + id="chunk-3", + created=1, + model="foo", + object="chat.completion.chunk", + choices=[ + OpenAIChoice(delta=OpenAIChoiceDelta(content="", tool_calls=None), finish_reason="tool_calls", index=0) + ], + ) + for chunk in [mock_chunk_1, mock_chunk_2, mock_chunk_3]: + yield chunk + + chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())] + assert len(chunks) == 4 + assert chunks[0].event.event_type.value == "start" + assert chunks[1].event.event_type.value == "progress" + assert chunks[1].event.delta.type == "tool_call" + assert chunks[1].event.delta.parse_status.value == "succeeded" + assert chunks[1].event.delta.tool_call.arguments_json == '{"number": 28, "power": 3}' + assert chunks[2].event.event_type.value == "progress" + assert chunks[2].event.delta.type == "tool_call" + assert chunks[2].event.delta.parse_status.value == "succeeded" + assert chunks[2].event.delta.tool_call.arguments_json == '{"first_number": 4, "second_number": 7}' + assert chunks[3].event.event_type.value == "complete" @pytest.mark.asyncio @@ -231,7 +393,8 @@ async def test_process_vllm_chat_completion_stream_response_no_choices(): yield chunk chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())] - assert len(chunks) == 0 + assert len(chunks) == 1 + assert chunks[0].event.event_type.value == "start" def test_chat_completion_doesnt_block_event_loop(caplog):