From 076d2f349da3d7b227f4fae4e32c0e69d58f4e1d Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 27 Feb 2025 18:00:27 -0800 Subject: [PATCH] fix: litellm tool call parsing event type to in_progress (#1312) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # What does this PR do? - Test with script: https://gist.github.com/yanxi0830/64699f3604766ac2319421b750c5bf9c - Agent with tool calls does not get correctly parsed with LiteLLM provider b/c we skip processing `ChatCompletionResponseEventType.complete`. - However, LiteLLM spits out event_type="complete" with ToolCallDelta https://github.com/meta-llama/llama-stack/blob/2f7683bc5fc33192fe34533d47d47328ff522fee/llama_stack/providers/inline/agents/meta_reference/agent_instance.py#L570-L577 - Llama Model ``` ChatCompletionResponseStreamChunk( │ event=Event( │ │ delta=ToolCallDelta( │ │ │ parse_status='succeeded', │ │ │ tool_call=ToolCall( │ │ │ │ arguments={'kind': 'pod', 'namespace': 'openshift-lightspeed'}, │ │ │ │ call_id='call_tIjWTUdsQXhQ2XHC5ke4EQY5', │ │ │ │ tool_name='get_object_namespace_list' │ │ │ ), │ │ │ type='tool_call' │ │ ), │ │ event_type='progress', │ │ logprobs=None, │ │ stop_reason='end_of_turn' │ ), │ metrics=None ) ChatCompletionResponseStreamChunk( │ event=Event( │ │ delta=TextDelta(text='', type='text'), │ │ event_type='complete', │ │ logprobs=None, │ │ stop_reason='end_of_turn' │ ), │ metrics=None ) ``` - LiteLLM model ``` ChatCompletionResponseStreamChunk( │ event=Event( │ │ delta=ToolCallDelta( │ │ │ parse_status='succeeded', │ │ │ tool_call=ToolCall( │ │ │ │ arguments={'kind': 'pod', 'namespace': 'openshift-lightspeed'}, │ │ │ │ call_id='call_tIjWTUdsQXhQ2XHC5ke4EQY5', │ │ │ │ tool_name='get_object_namespace_list' │ │ │ ), │ │ │ type='tool_call' │ │ ), │ │ event_type='complete', │ │ logprobs=None, │ │ stop_reason='end_of_turn' │ ), │ metrics=None ) ChatCompletionResponseStreamChunk( │ event=Event( │ │ delta=TextDelta(text='', type='text'), │ │ event_type='complete', │ │ logprobs=None, │ │ stop_reason='end_of_turn' │ ), │ metrics=None ) ``` [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) ## Test Plan - Test with script: https://gist.github.com/yanxi0830/64699f3604766ac2319421b750c5bf9c [//]: # (## Documentation) --- .../providers/utils/inference/openai_compat.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index 1309e72a6..eaf5ad2e1 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -27,7 +27,9 @@ from openai.types.chat import ( from openai.types.chat import ( ChatCompletionMessageParam as OpenAIChatCompletionMessage, ) -from openai.types.chat import ChatCompletionMessageToolCall +from openai.types.chat import ( + ChatCompletionMessageToolCall, +) from openai.types.chat import ( ChatCompletionMessageToolCallParam as OpenAIChatCompletionMessageToolCall, ) @@ -199,7 +201,9 @@ def convert_openai_completion_logprobs_stream(text: str, logprobs: Optional[Unio return None -def process_completion_response(response: OpenAICompatCompletionResponse) -> CompletionResponse: +def process_completion_response( + response: OpenAICompatCompletionResponse, +) -> CompletionResponse: choice = response.choices[0] # drop suffix if present and return stop reason as end of turn if choice.text.endswith("<|eot_id|>"): @@ -492,7 +496,9 @@ class UnparseableToolCall(BaseModel): arguments: str = "" -async def convert_message_to_openai_dict_new(message: Message | Dict) -> OpenAIChatCompletionMessage: +async def convert_message_to_openai_dict_new( + message: Message | Dict, +) -> OpenAIChatCompletionMessage: """ Convert a Message to an OpenAI API-compatible dictionary. """ @@ -942,7 +948,7 @@ async def convert_openai_chat_completion_stream( ) yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.complete, + event_type=ChatCompletionResponseEventType.progress, delta=ToolCallDelta( tool_call=tool_call, parse_status=ToolCallParseStatus.succeeded,