From 42a7588b049ba18c0432c15dee9f30e5ee3c40aa Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Fri, 22 Mar 2024 19:56:47 -0700 Subject: [PATCH] fix(anthropic.py): support async claude 3 tool calling + streaming https://github.com/BerriAI/litellm/issues/2644 --- litellm/llms/anthropic.py | 28 +++++++++++++++-- litellm/tests/test_streaming.py | 55 +++++++++++++++++++++++++++++++++ 2 files changed, 80 insertions(+), 3 deletions(-) diff --git a/litellm/llms/anthropic.py b/litellm/llms/anthropic.py index 5c8c85997..b6200a1a4 100644 --- a/litellm/llms/anthropic.py +++ b/litellm/llms/anthropic.py @@ -301,7 +301,7 @@ def completion( ) streaming_choice.delta = delta_obj streaming_model_response.choices = [streaming_choice] - completion_stream = model_response_iterator( + completion_stream = ModelResponseIterator( model_response=streaming_model_response ) print_verbose( @@ -330,8 +330,30 @@ def completion( return model_response -def model_response_iterator(model_response): - yield model_response +class ModelResponseIterator: + def __init__(self, model_response): + self.model_response = model_response + self.is_done = False + + # Sync iterator + def __iter__(self): + return self + + def __next__(self): + if self.is_done: + raise StopIteration + self.is_done = True + return self.model_response + + # Async iterator + def __aiter__(self): + return self + + async def __anext__(self): + if self.is_done: + raise StopAsyncIteration + self.is_done = True + return self.model_response def embedding(): diff --git a/litellm/tests/test_streaming.py b/litellm/tests/test_streaming.py index 983d50533..d854177aa 100644 --- a/litellm/tests/test_streaming.py +++ b/litellm/tests/test_streaming.py @@ -2089,3 +2089,58 @@ def test_completion_claude_3_function_call_with_streaming(): # raise Exception("it worked!") except Exception as e: pytest.fail(f"Error occurred: {e}") + + +@pytest.mark.asyncio +async def test_acompletion_claude_3_function_call_with_streaming(): + litellm.set_verbose = True + tools = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location"], + }, + }, + } + ] + messages = [{"role": "user", "content": "What's the weather like in Boston today?"}] + try: + # test without max tokens + response = await acompletion( + model="claude-3-opus-20240229", + messages=messages, + tools=tools, + tool_choice="auto", + stream=True, + ) + idx = 0 + print(f"response: {response}") + async for chunk in response: + # print(f"chunk: {chunk}") + if idx == 0: + assert ( + chunk.choices[0].delta.tool_calls[0].function.arguments is not None + ) + assert isinstance( + chunk.choices[0].delta.tool_calls[0].function.arguments, str + ) + validate_first_streaming_function_calling_chunk(chunk=chunk) + elif idx == 1: + validate_second_streaming_function_calling_chunk(chunk=chunk) + elif chunk.choices[0].finish_reason is not None: # last chunk + validate_final_streaming_function_calling_chunk(chunk=chunk) + idx += 1 + # raise Exception("it worked!") + except Exception as e: + pytest.fail(f"Error occurred: {e}")