From 7784307a5f4a008417cd364e7e58167927bbc328 Mon Sep 17 00:00:00 2001 From: ilya-kolchinsky Date: Thu, 8 May 2025 10:42:26 +0200 Subject: [PATCH] Fixed an "out of token budget" tool execution bug in the remote vLLM provider. --- .../providers/remote/inference/vllm/vllm.py | 125 ++++++++++++------ .../providers/inference/test_remote_vllm.py | 55 ++++++++ 2 files changed, 141 insertions(+), 39 deletions(-) diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index 8bc733fd3..96a4573fc 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -158,56 +158,92 @@ def _convert_to_vllm_finish_reason(finish_reason: str) -> StopReason: }.get(finish_reason, StopReason.end_of_turn) +def _process_vllm_chat_completion_end_of_stream( + finish_reason: str | None, + last_chunk_content: str | None, + current_event_type: ChatCompletionResponseEventType, + tool_call_buf: UnparseableToolCall, +) -> list[OpenAIChatCompletionChunk]: + chunks = [] + + args_str = tool_call_buf.arguments + args = None + try: + args = {} if not args_str else json.loads(args_str) + except Exception as e: + log.warning(f"Failed to parse tool call buffer arguments: {args_str} \nError: {e}") + + if finish_reason is not None: + actual_finish_reason = _convert_to_vllm_finish_reason(finish_reason) + else: + actual_finish_reason = StopReason.end_of_message + + if args: + chunks.append( + ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=current_event_type, + delta=ToolCallDelta( + tool_call=ToolCall( + call_id=tool_call_buf.call_id, + tool_name=tool_call_buf.tool_name, + arguments=args, + arguments_json=args_str, + ), + parse_status=ToolCallParseStatus.succeeded, + ), + ) + ) + ) + elif args_str: + chunks.append( + ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.progress, + delta=ToolCallDelta( + tool_call=str(tool_call_buf), + parse_status=ToolCallParseStatus.failed, + ), + ) + ) + ) + + chunks.append( + ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.complete, + delta=TextDelta(text=last_chunk_content or ""), + logprobs=None, + stop_reason=actual_finish_reason, + ) + ) + ) + + return chunks + + async def _process_vllm_chat_completion_stream_response( stream: AsyncGenerator[OpenAIChatCompletionChunk, None], ) -> AsyncGenerator: event_type = ChatCompletionResponseEventType.start tool_call_buf = UnparseableToolCall() + end_of_stream_processed = False + async for chunk in stream: if not chunk.choices: log.warning("vLLM failed to generation any completions - check the vLLM server logs for an error.") - continue + return choice = chunk.choices[0] if choice.finish_reason: - args_str = tool_call_buf.arguments - args = None - try: - args = {} if not args_str else json.loads(args_str) - except Exception as e: - log.warning(f"Failed to parse tool call buffer arguments: {args_str} \nError: {e}") - if args: - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=event_type, - delta=ToolCallDelta( - tool_call=ToolCall( - call_id=tool_call_buf.call_id, - tool_name=tool_call_buf.tool_name, - arguments=args, - arguments_json=args_str, - ), - parse_status=ToolCallParseStatus.succeeded, - ), - ) - ) - elif args_str: - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=ToolCallDelta( - tool_call=str(tool_call_buf), - parse_status=ToolCallParseStatus.failed, - ), - ) - ) - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.complete, - delta=TextDelta(text=choice.delta.content or ""), - logprobs=None, - stop_reason=_convert_to_vllm_finish_reason(choice.finish_reason), - ) + chunks = _process_vllm_chat_completion_end_of_stream( + finish_reason=choice.finish_reason, + last_chunk_content=choice.delta.content, + current_event_type=event_type, + tool_call_buf=tool_call_buf, ) + for c in chunks: + yield c + end_of_stream_processed = True elif choice.delta.tool_calls: tool_call = convert_tool_call(choice.delta.tool_calls[0]) tool_call_buf.tool_name += str(tool_call.tool_name) @@ -224,6 +260,17 @@ async def _process_vllm_chat_completion_stream_response( ) event_type = ChatCompletionResponseEventType.progress + if end_of_stream_processed: + return + + # the stream ended without a chunk containing finish_reason - we have to generate the + # respective completion chunks manually + chunks = _process_vllm_chat_completion_end_of_stream( + finish_reason=None, last_chunk_content=None, current_event_type=event_type, tool_call_buf=tool_call_buf + ) + for c in chunks: + yield c + class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): def __init__(self, config: VLLMInferenceAdapterConfig) -> None: diff --git a/tests/unit/providers/inference/test_remote_vllm.py b/tests/unit/providers/inference/test_remote_vllm.py index a2e3b64c2..ecf2f1b6e 100644 --- a/tests/unit/providers/inference/test_remote_vllm.py +++ b/tests/unit/providers/inference/test_remote_vllm.py @@ -28,6 +28,7 @@ from openai.types.model import Model as OpenAIModel from llama_stack.apis.inference import ( ChatCompletionRequest, + ChatCompletionResponseEventType, CompletionMessage, SystemMessage, ToolChoice, @@ -294,3 +295,57 @@ async def test_get_params_empty_tools(vllm_inference_adapter): ) params = await vllm_inference_adapter._get_params(request) assert "tools" not in params + + +@pytest.mark.asyncio +async def test_process_vllm_chat_completion_stream_response_no_finish_reason(): + """ + Tests the edge case where the model requests a tool call and stays idle without explicitly providing the + finish reason. + We want to make sure that this case is recognized and handled correctly, i.e., as a valid end of message. + """ + + mock_tool_name = "mock_tool" + mock_tool_arguments = {"arg1": 0, "arg2": 100} + mock_tool_arguments_str = '"{\\"arg1\\": 0, \\"arg2\\": 100}"' + + async def mock_stream(): + mock_chunks = [ + OpenAIChatCompletionChunk( + id="chunk-1", + created=1, + model="foo", + object="chat.completion.chunk", + choices=[ + { + "delta": { + "content": None, + "tool_calls": [ + { + "index": 0, + "id": "mock_id", + "type": "function", + "function": { + "name": mock_tool_name, + "arguments": mock_tool_arguments_str, + }, + } + ], + }, + "finish_reason": None, + "logprobs": None, + "index": 0, + } + ], + ), + ] + for chunk in mock_chunks: + print(f"Test chunk:\n{chunk}") + yield chunk + + chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())] + assert len(chunks) == 2 + assert chunks[-1].event.event_type == ChatCompletionResponseEventType.complete + assert chunks[-2].event.delta.type == "tool_call" + assert chunks[-2].event.delta.tool_call.tool_name == mock_tool_name + assert chunks[-2].event.delta.tool_call.arguments == mock_tool_arguments