diff --git a/litellm/llms/anthropic.py b/litellm/llms/anthropic.py index becbcc3282..e078a1ddf2 100644 --- a/litellm/llms/anthropic.py +++ b/litellm/llms/anthropic.py @@ -4,7 +4,7 @@ from enum import Enum import requests, copy import time, uuid from typing import Callable, Optional -from litellm.utils import ModelResponse, Usage, map_finish_reason +from litellm.utils import ModelResponse, Usage, map_finish_reason, CustomStreamWrapper import litellm from .prompt_templates.factory import ( prompt_factory, @@ -118,6 +118,7 @@ def completion( headers = validate_environment(api_key, headers) _is_function_call = False messages = copy.deepcopy(messages) + optional_params = copy.deepcopy(optional_params) if model in custom_prompt_dict: # check if the model has a registered custom prompt model_prompt_details = custom_prompt_dict[model] @@ -161,6 +162,8 @@ def completion( ) # add the anthropic tool calling prompt to the system prompt optional_params.pop("tools") + stream = optional_params.pop("stream", None) + data = { "model": model, "messages": messages, @@ -177,14 +180,18 @@ def completion( "headers": headers, }, ) - + print_verbose(f"_is_function_call: {_is_function_call}") ## COMPLETION CALL - if "stream" in optional_params and optional_params["stream"] == True: + if ( + stream is not None and stream == True and _is_function_call == False + ): # if function call - fake the streaming (need complete blocks for output parsing in openai format) + print_verbose(f"makes anthropic streaming POST request") + data["stream"] = stream response = requests.post( api_base, headers=headers, data=json.dumps(data), - stream=optional_params["stream"], + stream=stream, ) if response.status_code != 200: @@ -255,6 +262,51 @@ def completion( completion_response["stop_reason"] ) + print_verbose(f"_is_function_call: {_is_function_call}; stream: {stream}") + if _is_function_call == True and stream is not None and stream == True: + print_verbose(f"INSIDE ANTHROPIC STREAMING TOOL CALLING CONDITION BLOCK") + # return an iterator + streaming_model_response = ModelResponse(stream=True) + streaming_model_response.choices[0].finish_reason = model_response.choices[ + 0 + ].finish_reason + # streaming_model_response.choices = [litellm.utils.StreamingChoices()] + streaming_choice = litellm.utils.StreamingChoices() + streaming_choice.index = model_response.choices[0].index + _tool_calls = [] + print_verbose( + f"type of model_response.choices[0]: {type(model_response.choices[0])}" + ) + print_verbose(f"type of streaming_choice: {type(streaming_choice)}") + if isinstance(model_response.choices[0], litellm.Choices): + if getattr( + model_response.choices[0].message, "tool_calls", None + ) is not None and isinstance( + model_response.choices[0].message.tool_calls, list + ): + for tool_call in model_response.choices[0].message.tool_calls: + _tool_call = {**tool_call.dict(), "index": 0} + _tool_calls.append(_tool_call) + delta_obj = litellm.utils.Delta( + content=getattr(model_response.choices[0].message, "content", None), + role=model_response.choices[0].message.role, + tool_calls=_tool_calls, + ) + streaming_choice.delta = delta_obj + streaming_model_response.choices = [streaming_choice] + completion_stream = model_response_iterator( + model_response=streaming_model_response + ) + print_verbose( + f"Returns anthropic CustomStreamWrapper with 'cached_response' streaming object" + ) + return CustomStreamWrapper( + completion_stream=completion_stream, + model=model, + custom_llm_provider="cached_response", + logging_obj=logging_obj, + ) + ## CALCULATING USAGE prompt_tokens = completion_response["usage"]["input_tokens"] completion_tokens = completion_response["usage"]["output_tokens"] @@ -271,6 +323,10 @@ def completion( return model_response +def model_response_iterator(model_response): + yield model_response + + def embedding(): # logic for parsing in - calling - parsing out model embedding calls pass diff --git a/litellm/main.py b/litellm/main.py index 0387f82e0f..3a6dde159c 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -1144,7 +1144,11 @@ def completion( logging_obj=logging, headers=headers, ) - if "stream" in optional_params and optional_params["stream"] == True: + if ( + "stream" in optional_params + and optional_params["stream"] == True + and not isinstance(response, CustomStreamWrapper) + ): # don't try to access stream object, response = CustomStreamWrapper( response, diff --git a/litellm/tests/log.txt b/litellm/tests/log.txt index 03b5c605ec..74a7259bf9 100644 --- a/litellm/tests/log.txt +++ b/litellm/tests/log.txt @@ -36,32 +36,32 @@ test_completion.py . [100%] /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:180: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.5/migration/ @root_validator(pre=True) -../proxy/_types.py:235 - /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:235: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.5/migration/ +../proxy/_types.py:241 + /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:241: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.5/migration/ @root_validator(pre=True) -../proxy/_types.py:247 - /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:247: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.5/migration/ +../proxy/_types.py:253 + /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:253: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.5/migration/ @root_validator(pre=True) -../proxy/_types.py:282 - /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:282: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.5/migration/ +../proxy/_types.py:292 + /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:292: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.5/migration/ @root_validator(pre=True) -../proxy/_types.py:308 - /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:308: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.5/migration/ +../proxy/_types.py:319 + /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:319: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.5/migration/ @root_validator(pre=True) -../proxy/_types.py:557 - /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:557: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.5/migration/ +../proxy/_types.py:570 + /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:570: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.5/migration/ @root_validator(pre=True) -../proxy/_types.py:578 - /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:578: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.5/migration/ +../proxy/_types.py:591 + /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:591: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.5/migration/ @root_validator(pre=True) -../utils.py:36 - /Users/krrishdholakia/Documents/litellm/litellm/utils.py:36: DeprecationWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html +../utils.py:35 + /Users/krrishdholakia/Documents/litellm/litellm/utils.py:35: DeprecationWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html import pkg_resources ../../../../../../opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2871: 10 warnings @@ -109,5 +109,11 @@ test_completion.py . [100%] /Users/krrishdholakia/Documents/litellm/litellm/llms/prompt_templates/factory.py:6: DeprecationWarning: 'imghdr' is deprecated and slated for removal in Python 3.13 import imghdr, base64 +test_completion.py::test_completion_claude_3_stream +../utils.py:3249 +../utils.py:3249 + /Users/krrishdholakia/Documents/litellm/litellm/utils.py:3249: DeprecationWarning: open_text is deprecated. Use files() instead. Refer to https://importlib-resources.readthedocs.io/en/latest/using.html#migrating-from-legacy for migration advice. + with resources.open_text( + -- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html -======================== 1 passed, 43 warnings in 4.47s ======================== +======================== 1 passed, 46 warnings in 3.14s ======================== diff --git a/litellm/tests/test_streaming.py b/litellm/tests/test_streaming.py index 6f8d73c317..2834b8319f 100644 --- a/litellm/tests/test_streaming.py +++ b/litellm/tests/test_streaming.py @@ -1749,7 +1749,7 @@ class Chunk(BaseModel): object: str created: int model: str - system_fingerprint: str + # system_fingerprint: str choices: List[Choices] @@ -1869,7 +1869,7 @@ class Chunk3(BaseModel): object: str created: int model: str - system_fingerprint: str + # system_fingerprint: str choices: List[Choices3] @@ -2032,3 +2032,56 @@ async def test_azure_astreaming_and_function_calling(): except Exception as e: pytest.fail(f"Error occurred: {e}") raise e + + +def test_completion_claude_3_function_call_with_streaming(): + litellm.set_verbose = True + tools = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location"], + }, + }, + } + ] + messages = [{"role": "user", "content": "What's the weather like in Boston today?"}] + try: + # test without max tokens + response = completion( + model="claude-3-opus-20240229", + messages=messages, + tools=tools, + tool_choice="auto", + stream=True, + ) + idx = 0 + for chunk in response: + # print(f"chunk: {chunk}") + if idx == 0: + assert ( + chunk.choices[0].delta.tool_calls[0].function.arguments is not None + ) + assert isinstance( + chunk.choices[0].delta.tool_calls[0].function.arguments, str + ) + validate_first_streaming_function_calling_chunk(chunk=chunk) + elif idx == 1: + validate_second_streaming_function_calling_chunk(chunk=chunk) + elif chunk.choices[0].finish_reason is not None: # last chunk + validate_final_streaming_function_calling_chunk(chunk=chunk) + idx += 1 + # raise Exception("it worked!") + except Exception as e: + pytest.fail(f"Error occurred: {e}") diff --git a/litellm/utils.py b/litellm/utils.py index a4168fd7e3..8c62a2222a 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -480,12 +480,12 @@ class ModelResponse(OpenAIObject): object=None, system_fingerprint=None, usage=None, - stream=False, + stream=None, response_ms=None, hidden_params=None, **params, ): - if stream: + if stream is not None and stream == True: object = "chat.completion.chunk" choices = [StreamingChoices()] else: @@ -9471,14 +9471,18 @@ class CustomStreamWrapper: def __next__(self): try: while True: - if isinstance(self.completion_stream, str) or isinstance( - self.completion_stream, bytes + if ( + isinstance(self.completion_stream, str) + or isinstance(self.completion_stream, bytes) + or isinstance(self.completion_stream, ModelResponse) ): chunk = self.completion_stream else: chunk = next(self.completion_stream) if chunk is not None and chunk != b"": - print_verbose(f"PROCESSED CHUNK PRE CHUNK CREATOR: {chunk}") + print_verbose( + f"PROCESSED CHUNK PRE CHUNK CREATOR: {chunk}; custom_llm_provider: {self.custom_llm_provider}" + ) response: Optional[ModelResponse] = self.chunk_creator(chunk=chunk) print_verbose(f"PROCESSED CHUNK POST CHUNK CREATOR: {response}")