From 3be85c717fbd42f96f63955367c605fc60f175d8 Mon Sep 17 00:00:00 2001 From: ilya-kolchinsky Date: Fri, 9 May 2025 12:38:59 +0200 Subject: [PATCH] Added proper support for calling tools without parameters. --- .../providers/remote/inference/vllm/vllm.py | 67 ++++++++++--------- .../providers/inference/test_remote_vllm.py | 49 ++++++++++++++ 2 files changed, 85 insertions(+), 31 deletions(-) diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index 96a4573fc..30fcf1674 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -166,47 +166,52 @@ def _process_vllm_chat_completion_end_of_stream( ) -> list[OpenAIChatCompletionChunk]: chunks = [] - args_str = tool_call_buf.arguments - args = None - try: - args = {} if not args_str else json.loads(args_str) - except Exception as e: - log.warning(f"Failed to parse tool call buffer arguments: {args_str} \nError: {e}") - if finish_reason is not None: actual_finish_reason = _convert_to_vllm_finish_reason(finish_reason) else: actual_finish_reason = StopReason.end_of_message - if args: - chunks.append( - ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=current_event_type, - delta=ToolCallDelta( - tool_call=ToolCall( - call_id=tool_call_buf.call_id, - tool_name=tool_call_buf.tool_name, - arguments=args, - arguments_json=args_str, + if tool_call_buf.tool_name: + # at least one tool call request is received + + args_str = tool_call_buf.arguments or "{}" + args = {} + args_parsed_successfully = True + try: + args = json.loads(args_str) + except Exception as e: + args_parsed_successfully = False + log.warning(f"Failed to parse tool call buffer arguments: {args_str} \nError: {e}") + + if args_parsed_successfully: + chunks.append( + ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=current_event_type, + delta=ToolCallDelta( + tool_call=ToolCall( + call_id=tool_call_buf.call_id, + tool_name=tool_call_buf.tool_name, + arguments=args, + arguments_json=args_str, + ), + parse_status=ToolCallParseStatus.succeeded, ), - parse_status=ToolCallParseStatus.succeeded, - ), + ) ) ) - ) - elif args_str: - chunks.append( - ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=ToolCallDelta( - tool_call=str(tool_call_buf), - parse_status=ToolCallParseStatus.failed, - ), + else: + chunks.append( + ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.progress, + delta=ToolCallDelta( + tool_call=str(tool_call_buf), + parse_status=ToolCallParseStatus.failed, + ), + ) ) ) - ) chunks.append( ChatCompletionResponseStreamChunk( diff --git a/tests/unit/providers/inference/test_remote_vllm.py b/tests/unit/providers/inference/test_remote_vllm.py index 8eeccfd56..5c9bda74a 100644 --- a/tests/unit/providers/inference/test_remote_vllm.py +++ b/tests/unit/providers/inference/test_remote_vllm.py @@ -348,3 +348,52 @@ async def test_process_vllm_chat_completion_stream_response_no_finish_reason(): assert chunks[-2].event.delta.type == "tool_call" assert chunks[-2].event.delta.tool_call.tool_name == mock_tool_name assert chunks[-2].event.delta.tool_call.arguments == mock_tool_arguments + + +@pytest.mark.asyncio +async def test_process_vllm_chat_completion_stream_response_tool_without_args(): + """ + Tests the edge case where no arguments are provided for the tool call. + Tool calls with no arguments should be treated as regular tool calls, which was not the case until now. + """ + mock_tool_name = "mock_tool" + + async def mock_stream(): + mock_chunks = [ + OpenAIChatCompletionChunk( + id="chunk-1", + created=1, + model="foo", + object="chat.completion.chunk", + choices=[ + { + "delta": { + "content": None, + "tool_calls": [ + { + "index": 0, + "id": "mock_id", + "type": "function", + "function": { + "name": mock_tool_name, + "arguments": "", + }, + } + ], + }, + "finish_reason": None, + "logprobs": None, + "index": 0, + } + ], + ), + ] + for chunk in mock_chunks: + yield chunk + + chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())] + assert len(chunks) == 2 + assert chunks[-1].event.event_type == ChatCompletionResponseEventType.complete + assert chunks[-2].event.delta.type == "tool_call" + assert chunks[-2].event.delta.tool_call.tool_name == mock_tool_name + assert chunks[-2].event.delta.tool_call.arguments == {}