From 86ed0aaba8a63f4613240ecc7d841fa1c5e45f45 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Tue, 12 Mar 2024 09:52:11 -0700 Subject: [PATCH] fix(anthropic.py): support streaming with function calling --- litellm/llms/anthropic.py | 47 +++++++++++++++++++++++++-- litellm/main.py | 6 +++- litellm/tests/test_streaming.py | 57 +++++++++++++++++++++++++++++++-- litellm/utils.py | 6 ++-- 4 files changed, 109 insertions(+), 7 deletions(-) diff --git a/litellm/llms/anthropic.py b/litellm/llms/anthropic.py index becbcc328..b9d1903df 100644 --- a/litellm/llms/anthropic.py +++ b/litellm/llms/anthropic.py @@ -4,7 +4,7 @@ from enum import Enum import requests, copy import time, uuid from typing import Callable, Optional -from litellm.utils import ModelResponse, Usage, map_finish_reason +from litellm.utils import ModelResponse, Usage, map_finish_reason, CustomStreamWrapper import litellm from .prompt_templates.factory import ( prompt_factory, @@ -118,6 +118,7 @@ def completion( headers = validate_environment(api_key, headers) _is_function_call = False messages = copy.deepcopy(messages) + optional_params = copy.deepcopy(optional_params) if model in custom_prompt_dict: # check if the model has a registered custom prompt model_prompt_details = custom_prompt_dict[model] @@ -161,6 +162,8 @@ def completion( ) # add the anthropic tool calling prompt to the system prompt optional_params.pop("tools") + stream = optional_params.pop("stream", None) + data = { "model": model, "messages": messages, @@ -179,7 +182,10 @@ def completion( ) ## COMPLETION CALL - if "stream" in optional_params and optional_params["stream"] == True: + if ( + stream is not None and stream == True and _is_function_call == False + ): # if function call - fake the streaming (need complete blocks for output parsing in openai format) + data["stream"] = stream response = requests.post( api_base, headers=headers, @@ -255,6 +261,39 @@ def completion( completion_response["stop_reason"] ) + if _is_function_call == True and stream is not None and stream == True: + # return an iterator + streaming_model_response = ModelResponse(stream=True) + streaming_model_response.choices[0].finish_reason = model_response.choices[ + 0 + ].finish_reason + streaming_model_response.choices[0].index = model_response.choices[0].index + _tool_calls = [] + if isinstance(model_response.choices[0], litellm.Choices): + if getattr( + model_response.choices[0].message, "tool_calls", None + ) is not None and isinstance( + model_response.choices[0].message.tool_calls, list + ): + for tool_call in model_response.choices[0].message.tool_calls: + _tool_call = {**tool_call.dict(), "index": 0} + _tool_calls.append(_tool_call) + delta_obj = litellm.utils.Delta( + content=getattr(model_response.choices[0].message, "content", None), + role=model_response.choices[0].message.role, + tool_calls=_tool_calls, + ) + streaming_model_response.choices[0].delta = delta_obj + completion_stream = model_response_iterator( + model_response=streaming_model_response + ) + return CustomStreamWrapper( + completion_stream=completion_stream, + model=model, + custom_llm_provider="cached_response", + logging_obj=logging_obj, + ) + ## CALCULATING USAGE prompt_tokens = completion_response["usage"]["input_tokens"] completion_tokens = completion_response["usage"]["output_tokens"] @@ -271,6 +310,10 @@ def completion( return model_response +def model_response_iterator(model_response): + yield model_response + + def embedding(): # logic for parsing in - calling - parsing out model embedding calls pass diff --git a/litellm/main.py b/litellm/main.py index 114b46948..15e175493 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -1073,7 +1073,11 @@ def completion( logging_obj=logging, headers=headers, ) - if "stream" in optional_params and optional_params["stream"] == True: + if ( + "stream" in optional_params + and optional_params["stream"] == True + and not isinstance(response, CustomStreamWrapper) + ): # don't try to access stream object, response = CustomStreamWrapper( response, diff --git a/litellm/tests/test_streaming.py b/litellm/tests/test_streaming.py index 6f8d73c31..4f647486b 100644 --- a/litellm/tests/test_streaming.py +++ b/litellm/tests/test_streaming.py @@ -1749,7 +1749,7 @@ class Chunk(BaseModel): object: str created: int model: str - system_fingerprint: str + # system_fingerprint: str choices: List[Choices] @@ -1869,7 +1869,7 @@ class Chunk3(BaseModel): object: str created: int model: str - system_fingerprint: str + # system_fingerprint: str choices: List[Choices3] @@ -2032,3 +2032,56 @@ async def test_azure_astreaming_and_function_calling(): except Exception as e: pytest.fail(f"Error occurred: {e}") raise e + + +def test_completion_claude_3_function_call_with_streaming(): + # litellm.set_verbose = True + tools = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location"], + }, + }, + } + ] + messages = [{"role": "user", "content": "What's the weather like in Boston today?"}] + try: + # test without max tokens + response = completion( + model="claude-3-opus-20240229", + messages=messages, + tools=tools, + tool_choice="auto", + stream=True, + ) + idx = 0 + for chunk in response: + # print(f"chunk: {chunk}") + if idx == 0: + assert ( + chunk.choices[0].delta.tool_calls[0].function.arguments is not None + ) + assert isinstance( + chunk.choices[0].delta.tool_calls[0].function.arguments, str + ) + validate_first_streaming_function_calling_chunk(chunk=chunk) + elif idx == 1: + validate_second_streaming_function_calling_chunk(chunk=chunk) + elif chunk.choices[0].finish_reason is not None: # last chunk + validate_final_streaming_function_calling_chunk(chunk=chunk) + idx += 1 + # raise Exception("it worked!") + except Exception as e: + pytest.fail(f"Error occurred: {e}") diff --git a/litellm/utils.py b/litellm/utils.py index 3b6169770..0369e8b57 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -9364,8 +9364,10 @@ class CustomStreamWrapper: def __next__(self): try: while True: - if isinstance(self.completion_stream, str) or isinstance( - self.completion_stream, bytes + if ( + isinstance(self.completion_stream, str) + or isinstance(self.completion_stream, bytes) + or isinstance(self.completion_stream, ModelResponse) ): chunk = self.completion_stream else: