From 5966079770f889d3d3c9ebba1029c939b70d8ec0 Mon Sep 17 00:00:00 2001 From: Yuan Tang Date: Thu, 20 Feb 2025 01:27:02 -0500 Subject: [PATCH] fix: More robust handling of the arguments in tool call response in remote::vllm (#1169) # What does this PR do? This fixes the following issue on the server side when the tool call response contains empty args. This happens when running `examples.agents.e2e_loop_with_client_tools` but `get_ticker_data` returns `[]`: ``` Traceback (most recent call last): File "/home/yutang/repos/llama-stack/llama_stack/distribution/server/server.py", line 208, in sse_generator async for item in event_gen: File "/home/yutang/repos/llama-stack/llama_stack/providers/inline/agents/meta_reference/agents.py", line 169, in _create_agent_turn_streaming async for event in agent.create_and_execute_turn(request): File "/home/yutang/repos/llama-stack/llama_stack/providers/inline/agents/meta_reference/agent_instance.py", line 189, in create_and_execute_turn async for chunk in self.run( File "/home/yutang/repos/llama-stack/llama_stack/providers/inline/agents/meta_reference/agent_instance.py", line 258, in run async for res in self._run( File "/home/yutang/repos/llama-stack/llama_stack/providers/inline/agents/meta_reference/agent_instance.py", line 499, in _run async for chunk in await self.inference_api.chat_completion( File "/home/yutang/repos/llama-stack/llama_stack/distribution/routers/routers.py", line 182, in return (chunk async for chunk in await provider.chat_completion(**params)) File "/home/yutang/repos/llama-stack/llama_stack/providers/remote/inference/vllm/vllm.py", line 296, in _stream_chat_completion async for chunk in res: File "/home/yutang/repos/llama-stack/llama_stack/providers/remote/inference/vllm/vllm.py", line 162, in _process_vllm_chat_completion_stream_response arguments=json.loads(tool_call_buf.arguments), File "/home/yutang/.conda/envs/distribution-myenv/lib/python3.10/json/__init__.py", line 346, in loads return _default_decoder.decode(s) File "/home/yutang/.conda/envs/distribution-myenv/lib/python3.10/json/decoder.py", line 337, in decode obj, end = self.raw_decode(s, idx=_w(s, 0).end()) File "/home/yutang/.conda/envs/distribution-myenv/lib/python3.10/json/decoder.py", line 355, in raw_decode raise JSONDecodeError("Expecting value", s, err.value) from None json.decoder.JSONDecodeError: Expecting value: line 1 column 1 (char 0) ``` ## Test Plan All existing tests in `tests/client-sdk/inference/test_text_inference.py` passed. [//]: # (## Documentation) --------- Signed-off-by: Yuan Tang --- .../providers/remote/inference/vllm/vllm.py | 39 +++++++++++++------ 1 file changed, 28 insertions(+), 11 deletions(-) diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index e073b98c6..220bf4bde 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -148,19 +148,36 @@ async def _process_vllm_chat_completion_stream_response( async for chunk in stream: choice = chunk.choices[0] if choice.finish_reason: - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=event_type, - delta=ToolCallDelta( - tool_call=ToolCall( - call_id=tool_call_buf.call_id, - tool_name=tool_call_buf.tool_name, - arguments=json.loads(tool_call_buf.arguments), + args_str = tool_call_buf.arguments + args = None + try: + args = {} if not args_str else json.loads(args_str) + except Exception as e: + log.warning(f"Failed to parse tool call buffer arguments: {args_str} \nError: {e}") + if args is not None: + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=event_type, + delta=ToolCallDelta( + tool_call=ToolCall( + call_id=tool_call_buf.call_id, + tool_name=tool_call_buf.tool_name, + arguments=args, + ), + parse_status=ToolCallParseStatus.succeeded, ), - parse_status=ToolCallParseStatus.succeeded, - ), + ) + ) + else: + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.progress, + delta=ToolCallDelta( + tool_call=str(tool_call_buf), + parse_status=ToolCallParseStatus.failed, + ), + ) ) - ) yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( event_type=ChatCompletionResponseEventType.complete,