diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index 070d94df8..d00218dd5 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -162,7 +162,7 @@ def _process_vllm_chat_completion_end_of_stream( finish_reason: str | None, last_chunk_content: str | None, current_event_type: ChatCompletionResponseEventType, - tool_call_buf: UnparseableToolCall, + tool_call_bufs: dict[str, UnparseableToolCall] | None = None, ) -> list[OpenAIChatCompletionChunk]: chunks = [] @@ -171,9 +171,8 @@ def _process_vllm_chat_completion_end_of_stream( else: stop_reason = StopReason.end_of_message - if tool_call_buf.tool_name: - # at least one tool call request is received - + tool_call_bufs = tool_call_bufs or {} + for _index, tool_call_buf in sorted(tool_call_bufs.items()): args_str = tool_call_buf.arguments or "{}" try: args = json.loads(args_str) @@ -225,8 +224,14 @@ def _process_vllm_chat_completion_end_of_stream( async def _process_vllm_chat_completion_stream_response( stream: AsyncGenerator[OpenAIChatCompletionChunk, None], ) -> AsyncGenerator: - event_type = ChatCompletionResponseEventType.start - tool_call_buf = UnparseableToolCall() + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.start, + delta=TextDelta(text=""), + ) + ) + event_type = ChatCompletionResponseEventType.progress + tool_call_bufs: dict[str, UnparseableToolCall] = {} end_of_stream_processed = False async for chunk in stream: @@ -235,17 +240,22 @@ async def _process_vllm_chat_completion_stream_response( return choice = chunk.choices[0] if choice.delta.tool_calls: - tool_call = convert_tool_call(choice.delta.tool_calls[0]) - tool_call_buf.tool_name += str(tool_call.tool_name) - tool_call_buf.call_id += tool_call.call_id - # TODO: remove str() when dict type for 'arguments' is no longer allowed - tool_call_buf.arguments += str(tool_call.arguments) + for delta_tool_call in choice.delta.tool_calls: + tool_call = convert_tool_call(delta_tool_call) + if delta_tool_call.index not in tool_call_bufs: + tool_call_bufs[delta_tool_call.index] = UnparseableToolCall() + tool_call_buf = tool_call_bufs[delta_tool_call.index] + tool_call_buf.tool_name += str(tool_call.tool_name) + tool_call_buf.call_id += tool_call.call_id + tool_call_buf.arguments += ( + tool_call.arguments if isinstance(tool_call.arguments, str) else json.dumps(tool_call.arguments) + ) if choice.finish_reason: chunks = _process_vllm_chat_completion_end_of_stream( finish_reason=choice.finish_reason, last_chunk_content=choice.delta.content, current_event_type=event_type, - tool_call_buf=tool_call_buf, + tool_call_bufs=tool_call_bufs, ) for c in chunks: yield c @@ -266,7 +276,7 @@ async def _process_vllm_chat_completion_stream_response( # the stream ended without a chunk containing finish_reason - we have to generate the # respective completion chunks manually chunks = _process_vllm_chat_completion_end_of_stream( - finish_reason=None, last_chunk_content=None, current_event_type=event_type, tool_call_buf=tool_call_buf + finish_reason=None, last_chunk_content=None, current_event_type=event_type, tool_call_bufs=tool_call_bufs ) for c in chunks: yield c diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index e2314d44f..cc0000528 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -531,13 +531,19 @@ async def convert_message_to_openai_dict(message: Message, download: bool = Fals tool_name = tc.tool_name if isinstance(tool_name, BuiltinTool): tool_name = tool_name.value + + # arguments_json can be None, so attempt it first and fall back to arguments + if hasattr(tc, "arguments_json") and tc.arguments_json: + arguments = tc.arguments_json + else: + arguments = json.dumps(tc.arguments) result["tool_calls"].append( { "id": tc.call_id, "type": "function", "function": { "name": tool_name, - "arguments": tc.arguments_json if hasattr(tc, "arguments_json") else json.dumps(tc.arguments), + "arguments": arguments, }, } ) diff --git a/tests/integration/agents/test_agents.py b/tests/integration/agents/test_agents.py index 63fd74f53..66c9ab829 100644 --- a/tests/integration/agents/test_agents.py +++ b/tests/integration/agents/test_agents.py @@ -266,6 +266,7 @@ def test_builtin_tool_web_search(llama_stack_client, agent_config): assert found_tool_execution +@pytest.mark.skip(reason="Code interpreter is currently disabled in the Stack") def test_builtin_tool_code_execution(llama_stack_client, agent_config): agent_config = { **agent_config, @@ -346,7 +347,7 @@ def test_custom_tool(llama_stack_client, agent_config): messages=[ { "role": "user", - "content": "What is the boiling point of polyjuice?", + "content": "What is the boiling point of the liquid polyjuice in celsius?", }, ], session_id=session_id, @@ -420,7 +421,7 @@ def run_agent_with_tool_choice(client, agent_config, tool_choice): messages=[ { "role": "user", - "content": "What is the boiling point of polyjuice?", + "content": "What is the boiling point of the liquid polyjuice in celsius?", }, ], session_id=session_id, @@ -674,8 +675,8 @@ def test_create_turn_response(llama_stack_client, agent_config, client_tools): def test_multi_tool_calls(llama_stack_client, agent_config): - if "gpt" not in agent_config["model"]: - pytest.xfail("Only tested on GPT models") + if "gpt" not in agent_config["model"] and "llama-4" not in agent_config["model"].lower(): + pytest.xfail("Only tested on GPT and Llama 4 models") agent_config = { **agent_config, @@ -689,23 +690,34 @@ def test_multi_tool_calls(llama_stack_client, agent_config): messages=[ { "role": "user", - "content": "Call get_boiling_point twice to answer: What is the boiling point of polyjuice in both celsius and fahrenheit?", + "content": "Call get_boiling_point twice to answer: What is the boiling point of polyjuice in both celsius and fahrenheit?.\nUse the tool responses to answer the question.", }, ], session_id=session_id, stream=False, ) steps = response.steps - assert len(steps) == 7 - assert steps[0].step_type == "shield_call" - assert steps[1].step_type == "inference" - assert steps[2].step_type == "shield_call" - assert steps[3].step_type == "tool_execution" - assert steps[4].step_type == "shield_call" - assert steps[5].step_type == "inference" - assert steps[6].step_type == "shield_call" - tool_execution_step = steps[3] + has_input_shield = agent_config.get("input_shields") + has_output_shield = agent_config.get("output_shields") + assert len(steps) == 3 + (2 if has_input_shield else 0) + (2 if has_output_shield else 0) + if has_input_shield: + assert steps[0].step_type == "shield_call" + steps.pop(0) + assert steps[0].step_type == "inference" + if has_output_shield: + assert steps[1].step_type == "shield_call" + steps.pop(1) + assert steps[1].step_type == "tool_execution" + tool_execution_step = steps[1] + if has_input_shield: + assert steps[2].step_type == "shield_call" + steps.pop(2) + assert steps[2].step_type == "inference" + if has_output_shield: + assert steps[3].step_type == "shield_call" + steps.pop(3) + assert len(tool_execution_step.tool_calls) == 2 assert tool_execution_step.tool_calls[0].tool_name.startswith("get_boiling_point") assert tool_execution_step.tool_calls[1].tool_name.startswith("get_boiling_point") diff --git a/tests/unit/providers/inference/test_remote_vllm.py b/tests/unit/providers/inference/test_remote_vllm.py index 6e1623131..f9eaee7d6 100644 --- a/tests/unit/providers/inference/test_remote_vllm.py +++ b/tests/unit/providers/inference/test_remote_vllm.py @@ -24,6 +24,12 @@ from openai.types.chat.chat_completion_chunk import ( from openai.types.chat.chat_completion_chunk import ( ChoiceDelta as OpenAIChoiceDelta, ) +from openai.types.chat.chat_completion_chunk import ( + ChoiceDeltaToolCall as OpenAIChoiceDeltaToolCall, +) +from openai.types.chat.chat_completion_chunk import ( + ChoiceDeltaToolCallFunction as OpenAIChoiceDeltaToolCallFunction, +) from openai.types.model import Model as OpenAIModel from llama_stack.apis.inference import ( @@ -206,8 +212,164 @@ async def test_tool_call_delta_empty_tool_call_buf(): yield chunk chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())] - assert len(chunks) == 1 - assert chunks[0].event.stop_reason == StopReason.end_of_turn + assert len(chunks) == 2 + assert chunks[0].event.event_type.value == "start" + assert chunks[1].event.event_type.value == "complete" + assert chunks[1].event.stop_reason == StopReason.end_of_turn + + +@pytest.mark.asyncio +async def test_tool_call_delta_streaming_arguments_dict(): + async def mock_stream(): + mock_chunk_1 = OpenAIChatCompletionChunk( + id="chunk-1", + created=1, + model="foo", + object="chat.completion.chunk", + choices=[ + OpenAIChoice( + delta=OpenAIChoiceDelta( + content="", + tool_calls=[ + OpenAIChoiceDeltaToolCall( + id="tc_1", + index=1, + function=OpenAIChoiceDeltaToolCallFunction( + name="power", + arguments="", + ), + ) + ], + ), + finish_reason=None, + index=0, + ) + ], + ) + mock_chunk_2 = OpenAIChatCompletionChunk( + id="chunk-2", + created=1, + model="foo", + object="chat.completion.chunk", + choices=[ + OpenAIChoice( + delta=OpenAIChoiceDelta( + content="", + tool_calls=[ + OpenAIChoiceDeltaToolCall( + id="tc_1", + index=1, + function=OpenAIChoiceDeltaToolCallFunction( + name="power", + arguments='{"number": 28, "power": 3}', + ), + ) + ], + ), + finish_reason=None, + index=0, + ) + ], + ) + mock_chunk_3 = OpenAIChatCompletionChunk( + id="chunk-3", + created=1, + model="foo", + object="chat.completion.chunk", + choices=[ + OpenAIChoice(delta=OpenAIChoiceDelta(content="", tool_calls=None), finish_reason="tool_calls", index=0) + ], + ) + for chunk in [mock_chunk_1, mock_chunk_2, mock_chunk_3]: + yield chunk + + chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())] + assert len(chunks) == 3 + assert chunks[0].event.event_type.value == "start" + assert chunks[1].event.event_type.value == "progress" + assert chunks[1].event.delta.type == "tool_call" + assert chunks[1].event.delta.parse_status.value == "succeeded" + assert chunks[1].event.delta.tool_call.arguments_json == '{"number": 28, "power": 3}' + assert chunks[2].event.event_type.value == "complete" + + +@pytest.mark.asyncio +async def test_multiple_tool_calls(): + async def mock_stream(): + mock_chunk_1 = OpenAIChatCompletionChunk( + id="chunk-1", + created=1, + model="foo", + object="chat.completion.chunk", + choices=[ + OpenAIChoice( + delta=OpenAIChoiceDelta( + content="", + tool_calls=[ + OpenAIChoiceDeltaToolCall( + id="", + index=1, + function=OpenAIChoiceDeltaToolCallFunction( + name="power", + arguments='{"number": 28, "power": 3}', + ), + ), + ], + ), + finish_reason=None, + index=0, + ) + ], + ) + mock_chunk_2 = OpenAIChatCompletionChunk( + id="chunk-2", + created=1, + model="foo", + object="chat.completion.chunk", + choices=[ + OpenAIChoice( + delta=OpenAIChoiceDelta( + content="", + tool_calls=[ + OpenAIChoiceDeltaToolCall( + id="", + index=2, + function=OpenAIChoiceDeltaToolCallFunction( + name="multiple", + arguments='{"first_number": 4, "second_number": 7}', + ), + ), + ], + ), + finish_reason=None, + index=0, + ) + ], + ) + mock_chunk_3 = OpenAIChatCompletionChunk( + id="chunk-3", + created=1, + model="foo", + object="chat.completion.chunk", + choices=[ + OpenAIChoice(delta=OpenAIChoiceDelta(content="", tool_calls=None), finish_reason="tool_calls", index=0) + ], + ) + for chunk in [mock_chunk_1, mock_chunk_2, mock_chunk_3]: + yield chunk + + chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())] + assert len(chunks) == 4 + assert chunks[0].event.event_type.value == "start" + assert chunks[1].event.event_type.value == "progress" + assert chunks[1].event.delta.type == "tool_call" + assert chunks[1].event.delta.parse_status.value == "succeeded" + assert chunks[1].event.delta.tool_call.arguments_json == '{"number": 28, "power": 3}' + assert chunks[2].event.event_type.value == "progress" + assert chunks[2].event.delta.type == "tool_call" + assert chunks[2].event.delta.parse_status.value == "succeeded" + assert chunks[2].event.delta.tool_call.arguments_json == '{"first_number": 4, "second_number": 7}' + assert chunks[3].event.event_type.value == "complete" @pytest.mark.asyncio @@ -231,7 +393,8 @@ async def test_process_vllm_chat_completion_stream_response_no_choices(): yield chunk chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())] - assert len(chunks) == 0 + assert len(chunks) == 1 + assert chunks[0].event.event_type.value == "start" def test_chat_completion_doesnt_block_event_loop(caplog): @@ -369,7 +532,7 @@ async def test_process_vllm_chat_completion_stream_response_tool_call_args_last_ yield chunk chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())] - assert len(chunks) == 2 + assert len(chunks) == 3 assert chunks[-1].event.event_type == ChatCompletionResponseEventType.complete assert chunks[-2].event.delta.type == "tool_call" assert chunks[-2].event.delta.tool_call.tool_name == mock_tool_name @@ -422,7 +585,7 @@ async def test_process_vllm_chat_completion_stream_response_no_finish_reason(): yield chunk chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())] - assert len(chunks) == 2 + assert len(chunks) == 3 assert chunks[-1].event.event_type == ChatCompletionResponseEventType.complete assert chunks[-2].event.delta.type == "tool_call" assert chunks[-2].event.delta.tool_call.tool_name == mock_tool_name @@ -471,7 +634,7 @@ async def test_process_vllm_chat_completion_stream_response_tool_without_args(): yield chunk chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())] - assert len(chunks) == 2 + assert len(chunks) == 3 assert chunks[-1].event.event_type == ChatCompletionResponseEventType.complete assert chunks[-2].event.delta.type == "tool_call" assert chunks[-2].event.delta.tool_call.tool_name == mock_tool_name