From 5052c3cbf33bacf1625a988201ae1b7b7e601d9a Mon Sep 17 00:00:00 2001 From: Ilya Kolchinsky <58424190+ilya-kolchinsky@users.noreply.github.com> Date: Wed, 14 May 2025 22:11:02 +0200 Subject: [PATCH] fix: Fixed an "out of token budget" error when attempting a tool call via remote vLLM provider (#2114) # What does this PR do? Closes #2113. Closes #1783. Fixes a bug in handling the end of tool execution request stream where no `finish_reason` is provided by the model. ## Test Plan 1. Ran existing unit tests 2. Added a dedicated test verifying correct behavior in this edge case 3. Ran the code snapshot from #2113 [//]: # (## Documentation) --- .../providers/remote/inference/vllm/vllm.py | 117 ++++++++++++------ .../providers/inference/test_remote_vllm.py | 102 +++++++++++++++ 2 files changed, 184 insertions(+), 35 deletions(-) diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index 3fb28ee08..070d94df8 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -158,33 +158,29 @@ def _convert_to_vllm_finish_reason(finish_reason: str) -> StopReason: }.get(finish_reason, StopReason.end_of_turn) -async def _process_vllm_chat_completion_stream_response( - stream: AsyncGenerator[OpenAIChatCompletionChunk, None], -) -> AsyncGenerator: - event_type = ChatCompletionResponseEventType.start - tool_call_buf = UnparseableToolCall() - async for chunk in stream: - if not chunk.choices: - log.warning("vLLM failed to generation any completions - check the vLLM server logs for an error.") - continue - choice = chunk.choices[0] - if choice.delta.tool_calls: - tool_call = convert_tool_call(choice.delta.tool_calls[0]) - tool_call_buf.tool_name += str(tool_call.tool_name) - tool_call_buf.call_id += tool_call.call_id - # TODO: remove str() when dict type for 'arguments' is no longer allowed - tool_call_buf.arguments += str(tool_call.arguments) - if choice.finish_reason: - args_str = tool_call_buf.arguments - args = None - try: - args = {} if not args_str else json.loads(args_str) - except Exception as e: - log.warning(f"Failed to parse tool call buffer arguments: {args_str} \nError: {e}") - if args: - yield ChatCompletionResponseStreamChunk( +def _process_vllm_chat_completion_end_of_stream( + finish_reason: str | None, + last_chunk_content: str | None, + current_event_type: ChatCompletionResponseEventType, + tool_call_buf: UnparseableToolCall, +) -> list[OpenAIChatCompletionChunk]: + chunks = [] + + if finish_reason is not None: + stop_reason = _convert_to_vllm_finish_reason(finish_reason) + else: + stop_reason = StopReason.end_of_message + + if tool_call_buf.tool_name: + # at least one tool call request is received + + args_str = tool_call_buf.arguments or "{}" + try: + args = json.loads(args_str) + chunks.append( + ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( - event_type=event_type, + event_type=current_event_type, delta=ToolCallDelta( tool_call=ToolCall( call_id=tool_call_buf.call_id, @@ -196,8 +192,12 @@ async def _process_vllm_chat_completion_stream_response( ), ) ) - elif args_str: - yield ChatCompletionResponseStreamChunk( + ) + except Exception as e: + log.warning(f"Failed to parse tool call buffer arguments: {args_str} \nError: {e}") + + chunks.append( + ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( event_type=ChatCompletionResponseEventType.progress, delta=ToolCallDelta( @@ -206,14 +206,50 @@ async def _process_vllm_chat_completion_stream_response( ), ) ) - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.complete, - delta=TextDelta(text=choice.delta.content or ""), - logprobs=None, - stop_reason=_convert_to_vllm_finish_reason(choice.finish_reason), - ) ) + + chunks.append( + ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.complete, + delta=TextDelta(text=last_chunk_content or ""), + logprobs=None, + stop_reason=stop_reason, + ) + ) + ) + + return chunks + + +async def _process_vllm_chat_completion_stream_response( + stream: AsyncGenerator[OpenAIChatCompletionChunk, None], +) -> AsyncGenerator: + event_type = ChatCompletionResponseEventType.start + tool_call_buf = UnparseableToolCall() + end_of_stream_processed = False + + async for chunk in stream: + if not chunk.choices: + log.warning("vLLM failed to generation any completions - check the vLLM server logs for an error.") + return + choice = chunk.choices[0] + if choice.delta.tool_calls: + tool_call = convert_tool_call(choice.delta.tool_calls[0]) + tool_call_buf.tool_name += str(tool_call.tool_name) + tool_call_buf.call_id += tool_call.call_id + # TODO: remove str() when dict type for 'arguments' is no longer allowed + tool_call_buf.arguments += str(tool_call.arguments) + if choice.finish_reason: + chunks = _process_vllm_chat_completion_end_of_stream( + finish_reason=choice.finish_reason, + last_chunk_content=choice.delta.content, + current_event_type=event_type, + tool_call_buf=tool_call_buf, + ) + for c in chunks: + yield c + end_of_stream_processed = True elif not choice.delta.tool_calls: yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( @@ -224,6 +260,17 @@ async def _process_vllm_chat_completion_stream_response( ) event_type = ChatCompletionResponseEventType.progress + if end_of_stream_processed: + return + + # the stream ended without a chunk containing finish_reason - we have to generate the + # respective completion chunks manually + chunks = _process_vllm_chat_completion_end_of_stream( + finish_reason=None, last_chunk_content=None, current_event_type=event_type, tool_call_buf=tool_call_buf + ) + for c in chunks: + yield c + class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): def __init__(self, config: VLLMInferenceAdapterConfig) -> None: diff --git a/tests/unit/providers/inference/test_remote_vllm.py b/tests/unit/providers/inference/test_remote_vllm.py index a8c4e07a0..6e1623131 100644 --- a/tests/unit/providers/inference/test_remote_vllm.py +++ b/tests/unit/providers/inference/test_remote_vllm.py @@ -374,3 +374,105 @@ async def test_process_vllm_chat_completion_stream_response_tool_call_args_last_ assert chunks[-2].event.delta.type == "tool_call" assert chunks[-2].event.delta.tool_call.tool_name == mock_tool_name assert chunks[-2].event.delta.tool_call.arguments == mock_tool_arguments + + +@pytest.mark.asyncio +async def test_process_vllm_chat_completion_stream_response_no_finish_reason(): + """ + Tests the edge case where the model requests a tool call and stays idle without explicitly providing the + finish reason. + We want to make sure that this case is recognized and handled correctly, i.e., as a valid end of message. + """ + + mock_tool_name = "mock_tool" + mock_tool_arguments = {"arg1": 0, "arg2": 100} + mock_tool_arguments_str = '"{\\"arg1\\": 0, \\"arg2\\": 100}"' + + async def mock_stream(): + mock_chunks = [ + OpenAIChatCompletionChunk( + id="chunk-1", + created=1, + model="foo", + object="chat.completion.chunk", + choices=[ + { + "delta": { + "content": None, + "tool_calls": [ + { + "index": 0, + "id": "mock_id", + "type": "function", + "function": { + "name": mock_tool_name, + "arguments": mock_tool_arguments_str, + }, + } + ], + }, + "finish_reason": None, + "logprobs": None, + "index": 0, + } + ], + ), + ] + for chunk in mock_chunks: + yield chunk + + chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())] + assert len(chunks) == 2 + assert chunks[-1].event.event_type == ChatCompletionResponseEventType.complete + assert chunks[-2].event.delta.type == "tool_call" + assert chunks[-2].event.delta.tool_call.tool_name == mock_tool_name + assert chunks[-2].event.delta.tool_call.arguments == mock_tool_arguments + + +@pytest.mark.asyncio +async def test_process_vllm_chat_completion_stream_response_tool_without_args(): + """ + Tests the edge case where no arguments are provided for the tool call. + Tool calls with no arguments should be treated as regular tool calls, which was not the case until now. + """ + mock_tool_name = "mock_tool" + + async def mock_stream(): + mock_chunks = [ + OpenAIChatCompletionChunk( + id="chunk-1", + created=1, + model="foo", + object="chat.completion.chunk", + choices=[ + { + "delta": { + "content": None, + "tool_calls": [ + { + "index": 0, + "id": "mock_id", + "type": "function", + "function": { + "name": mock_tool_name, + "arguments": "", + }, + } + ], + }, + "finish_reason": None, + "logprobs": None, + "index": 0, + } + ], + ), + ] + for chunk in mock_chunks: + yield chunk + + chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())] + assert len(chunks) == 2 + assert chunks[-1].event.event_type == ChatCompletionResponseEventType.complete + assert chunks[-2].event.delta.type == "tool_call" + assert chunks[-2].event.delta.tool_call.tool_name == mock_tool_name + assert chunks[-2].event.delta.tool_call.arguments == {}