From ac6c39c2837f80477455562a89d08a4f92a00b5d Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 3 Aug 2024 20:16:19 -0700 Subject: [PATCH] feat(anthropic_adapter.py): support streaming requests for `/v1/messages` endpoint Fixes https://github.com/BerriAI/litellm/issues/5011 --- litellm/adapters/anthropic_adapter.py | 152 ++++++++++++++++++++- litellm/integrations/custom_logger.py | 6 +- litellm/llms/anthropic.py | 77 ++++++++++- litellm/main.py | 59 +++++--- litellm/proxy/_new_secret_config.yaml | 4 +- litellm/proxy/proxy_server.py | 75 +++++++++- litellm/tests/test_anthropic_completion.py | 50 ++++++- litellm/types/llms/anthropic.py | 2 +- litellm/types/utils.py | 35 ++++- 9 files changed, 425 insertions(+), 35 deletions(-) diff --git a/litellm/adapters/anthropic_adapter.py b/litellm/adapters/anthropic_adapter.py index 7d9d799b6..dead0642d 100644 --- a/litellm/adapters/anthropic_adapter.py +++ b/litellm/adapters/anthropic_adapter.py @@ -4,7 +4,7 @@ import json import os import traceback import uuid -from typing import Literal, Optional +from typing import Any, Literal, Optional import dotenv import httpx @@ -13,7 +13,12 @@ from pydantic import BaseModel import litellm from litellm import ChatCompletionRequest, verbose_logger from litellm.integrations.custom_logger import CustomLogger -from litellm.types.llms.anthropic import AnthropicMessagesRequest, AnthropicResponse +from litellm.types.llms.anthropic import ( + AnthropicMessagesRequest, + AnthropicResponse, + ContentBlockDelta, +) +from litellm.types.utils import AdapterCompletionStreamWrapper class AnthropicAdapter(CustomLogger): @@ -43,8 +48,147 @@ class AnthropicAdapter(CustomLogger): response=response ) - def translate_completion_output_params_streaming(self) -> Optional[BaseModel]: - return super().translate_completion_output_params_streaming() + def translate_completion_output_params_streaming( + self, completion_stream: Any + ) -> AdapterCompletionStreamWrapper | None: + return AnthropicStreamWrapper(completion_stream=completion_stream) anthropic_adapter = AnthropicAdapter() + + +class AnthropicStreamWrapper(AdapterCompletionStreamWrapper): + """ + - first chunk return 'message_start' + - content block must be started and stopped + - finish_reason must map exactly to anthropic reason, else anthropic client won't be able to parse it. + """ + + sent_first_chunk: bool = False + sent_content_block_start: bool = False + sent_content_block_finish: bool = False + sent_last_message: bool = False + holding_chunk: Optional[Any] = None + + def __next__(self): + try: + if self.sent_first_chunk is False: + self.sent_first_chunk = True + return { + "type": "message_start", + "message": { + "id": "msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY", + "type": "message", + "role": "assistant", + "content": [], + "model": "claude-3-5-sonnet-20240620", + "stop_reason": None, + "stop_sequence": None, + "usage": {"input_tokens": 25, "output_tokens": 1}, + }, + } + if self.sent_content_block_start is False: + self.sent_content_block_start = True + return { + "type": "content_block_start", + "index": 0, + "content_block": {"type": "text", "text": ""}, + } + + for chunk in self.completion_stream: + if chunk == "None" or chunk is None: + raise Exception + + processed_chunk = litellm.AnthropicConfig().translate_streaming_openai_response_to_anthropic( + response=chunk + ) + if ( + processed_chunk["type"] == "message_delta" + and self.sent_content_block_finish is False + ): + self.holding_chunk = processed_chunk + self.sent_content_block_finish = True + return { + "type": "content_block_stop", + "index": 0, + } + elif self.holding_chunk is not None: + return_chunk = self.holding_chunk + self.holding_chunk = processed_chunk + return return_chunk + else: + return processed_chunk + + if self.sent_last_message is False: + self.sent_last_message = True + return {"type": "message_stop"} + raise StopIteration + except StopIteration: + if self.sent_last_message is False: + self.sent_last_message = True + return {"type": "message_stop"} + raise StopIteration + except Exception as e: + verbose_logger.error( + "Anthropic Adapter - {}\n{}".format(e, traceback.format_exc()) + ) + + async def __anext__(self): + try: + if self.sent_first_chunk is False: + self.sent_first_chunk = True + return { + "type": "message_start", + "message": { + "id": "msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY", + "type": "message", + "role": "assistant", + "content": [], + "model": "claude-3-5-sonnet-20240620", + "stop_reason": None, + "stop_sequence": None, + "usage": {"input_tokens": 25, "output_tokens": 1}, + }, + } + if self.sent_content_block_start is False: + self.sent_content_block_start = True + return { + "type": "content_block_start", + "index": 0, + "content_block": {"type": "text", "text": ""}, + } + async for chunk in self.completion_stream: + if chunk == "None" or chunk is None: + raise Exception + processed_chunk = litellm.AnthropicConfig().translate_streaming_openai_response_to_anthropic( + response=chunk + ) + if ( + processed_chunk["type"] == "message_delta" + and self.sent_content_block_finish is False + ): + self.holding_chunk = processed_chunk + self.sent_content_block_finish = True + return { + "type": "content_block_stop", + "index": 0, + } + elif self.holding_chunk is not None: + return_chunk = self.holding_chunk + self.holding_chunk = processed_chunk + return return_chunk + else: + return processed_chunk + if self.holding_chunk is not None: + return_chunk = self.holding_chunk + self.holding_chunk = None + return return_chunk + if self.sent_last_message is False: + self.sent_last_message = True + return {"type": "message_stop"} + raise StopIteration + except StopIteration: + if self.sent_last_message is False: + self.sent_last_message = True + return {"type": "message_stop"} + raise StopAsyncIteration diff --git a/litellm/integrations/custom_logger.py b/litellm/integrations/custom_logger.py index 5139723ca..bf089c364 100644 --- a/litellm/integrations/custom_logger.py +++ b/litellm/integrations/custom_logger.py @@ -10,7 +10,7 @@ from pydantic import BaseModel from litellm.caching import DualCache from litellm.proxy._types import UserAPIKeyAuth from litellm.types.llms.openai import ChatCompletionRequest -from litellm.types.utils import ModelResponse +from litellm.types.utils import AdapterCompletionStreamWrapper, ModelResponse class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callback#callback-class @@ -76,7 +76,9 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac """ pass - def translate_completion_output_params_streaming(self) -> Optional[BaseModel]: + def translate_completion_output_params_streaming( + self, completion_stream: Any + ) -> Optional[AdapterCompletionStreamWrapper]: """ Translates the streaming chunk, from the OpenAI format to the custom format. """ diff --git a/litellm/llms/anthropic.py b/litellm/llms/anthropic.py index ab47c9537..929375ef0 100644 --- a/litellm/llms/anthropic.py +++ b/litellm/llms/anthropic.py @@ -5,13 +5,16 @@ import time import types from enum import Enum from functools import partial -from typing import Callable, List, Optional, Union +from typing import Callable, List, Literal, Optional, Tuple, Union import httpx # type: ignore import requests # type: ignore +from openai.types.chat.chat_completion_chunk import Choice as OpenAIStreamingChoice import litellm import litellm.litellm_core_utils +import litellm.types +import litellm.types.utils from litellm import verbose_logger from litellm.litellm_core_utils.core_helpers import map_finish_reason from litellm.llms.custom_httpx.http_handler import ( @@ -33,8 +36,12 @@ from litellm.types.llms.anthropic import ( AnthropicResponseUsageBlock, ContentBlockDelta, ContentBlockStart, + ContentJsonBlockDelta, + ContentTextBlockDelta, MessageBlockDelta, + MessageDelta, MessageStartBlock, + UsageDelta, ) from litellm.types.llms.openai import ( AllMessageValues, @@ -480,6 +487,74 @@ class AnthropicConfig: return translated_obj + def _translate_streaming_openai_chunk_to_anthropic( + self, choices: List[OpenAIStreamingChoice] + ) -> Tuple[ + Literal["text_delta", "input_json_delta"], + Union[ContentTextBlockDelta, ContentJsonBlockDelta], + ]: + text: str = "" + partial_json: Optional[str] = None + for choice in choices: + if choice.delta.content is not None: + text += choice.delta.content + elif choice.delta.tool_calls is not None: + partial_json = "" + for tool in choice.delta.tool_calls: + if ( + tool.function is not None + and tool.function.arguments is not None + ): + partial_json += tool.function.arguments + + if partial_json is not None: + return "input_json_delta", ContentJsonBlockDelta( + type="input_json_delta", partial_json=partial_json + ) + else: + return "text_delta", ContentTextBlockDelta(type="text_delta", text=text) + + def translate_streaming_openai_response_to_anthropic( + self, response: litellm.ModelResponse + ) -> Union[ContentBlockDelta, MessageBlockDelta]: + ## base case - final chunk w/ finish reason + if response.choices[0].finish_reason is not None: + delta = MessageDelta( + stop_reason=self._translate_openai_finish_reason_to_anthropic( + response.choices[0].finish_reason + ), + ) + if getattr(response, "usage", None) is not None: + litellm_usage_chunk: Optional[litellm.Usage] = response.usage # type: ignore + elif ( + hasattr(response, "_hidden_params") + and "usage" in response._hidden_params + ): + litellm_usage_chunk = response._hidden_params["usage"] + else: + litellm_usage_chunk = None + if litellm_usage_chunk is not None: + usage_delta = UsageDelta( + input_tokens=litellm_usage_chunk.prompt_tokens or 0, + output_tokens=litellm_usage_chunk.completion_tokens or 0, + ) + else: + usage_delta = UsageDelta(input_tokens=0, output_tokens=0) + return MessageBlockDelta( + type="message_delta", delta=delta, usage=usage_delta + ) + ( + type_of_content, + content_block_delta, + ) = self._translate_streaming_openai_chunk_to_anthropic( + choices=response.choices # type: ignore + ) + return ContentBlockDelta( + type="content_block_delta", + index=response.choices[0].index, + delta=content_block_delta, + ) + # makes headers for API call def validate_environment(api_key, user_headers, model): diff --git a/litellm/main.py b/litellm/main.py index 67b935a55..f3e006feb 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -125,7 +125,7 @@ from .llms.vertex_ai_partner import VertexAIPartnerModels from .llms.vertex_httpx import VertexLLM from .llms.watsonx import IBMWatsonXAI from .types.llms.openai import HttpxBinaryResponseContent -from .types.utils import ChatCompletionMessageToolCall +from .types.utils import AdapterCompletionStreamWrapper, ChatCompletionMessageToolCall encoding = tiktoken.get_encoding("cl100k_base") from litellm.utils import ( @@ -515,7 +515,7 @@ def mock_completion( model_response = ModelResponse(stream=stream) if stream is True: # don't try to access stream object, - if kwargs.get("acompletion", False) == True: + if kwargs.get("acompletion", False) is True: return CustomStreamWrapper( completion_stream=async_mock_completion_streaming_obj( model_response, mock_response=mock_response, model=model, n=n @@ -524,13 +524,14 @@ def mock_completion( custom_llm_provider="openai", logging_obj=logging, ) - response = mock_completion_streaming_obj( - model_response, - mock_response=mock_response, + return CustomStreamWrapper( + completion_stream=mock_completion_streaming_obj( + model_response, mock_response=mock_response, model=model, n=n + ), model=model, - n=n, + custom_llm_provider="openai", + logging_obj=logging, ) - return response if n is None: model_response.choices[0].message.content = mock_response # type: ignore else: @@ -4037,7 +4038,9 @@ def text_completion( ###### Adapter Completion ################ -async def aadapter_completion(*, adapter_id: str, **kwargs) -> Optional[BaseModel]: +async def aadapter_completion( + *, adapter_id: str, **kwargs +) -> Optional[Union[BaseModel, AdapterCompletionStreamWrapper]]: """ Implemented to handle async calls for adapter_completion() """ @@ -4056,18 +4059,29 @@ async def aadapter_completion(*, adapter_id: str, **kwargs) -> Optional[BaseMode new_kwargs = translation_obj.translate_completion_input_params(kwargs=kwargs) - response: ModelResponse = await acompletion(**new_kwargs) # type: ignore - - translated_response = translation_obj.translate_completion_output_params( - response=response - ) + response: Union[ModelResponse, CustomStreamWrapper] = await acompletion(**new_kwargs) # type: ignore + translated_response: Optional[ + Union[BaseModel, AdapterCompletionStreamWrapper] + ] = None + if isinstance(response, ModelResponse): + translated_response = translation_obj.translate_completion_output_params( + response=response + ) + if isinstance(response, CustomStreamWrapper): + translated_response = ( + translation_obj.translate_completion_output_params_streaming( + completion_stream=response + ) + ) return translated_response except Exception as e: raise e -def adapter_completion(*, adapter_id: str, **kwargs) -> Optional[BaseModel]: +def adapter_completion( + *, adapter_id: str, **kwargs +) -> Optional[Union[BaseModel, AdapterCompletionStreamWrapper]]: translation_obj: Optional[CustomLogger] = None for item in litellm.adapters: if item["id"] == adapter_id: @@ -4082,11 +4096,20 @@ def adapter_completion(*, adapter_id: str, **kwargs) -> Optional[BaseModel]: new_kwargs = translation_obj.translate_completion_input_params(kwargs=kwargs) - response: ModelResponse = completion(**new_kwargs) # type: ignore - - translated_response = translation_obj.translate_completion_output_params( - response=response + response: Union[ModelResponse, CustomStreamWrapper] = completion(**new_kwargs) # type: ignore + translated_response: Optional[Union[BaseModel, AdapterCompletionStreamWrapper]] = ( + None ) + if isinstance(response, ModelResponse): + translated_response = translation_obj.translate_completion_output_params( + response=response + ) + elif isinstance(response, CustomStreamWrapper) or inspect.isgenerator(response): + translated_response = ( + translation_obj.translate_completion_output_params_streaming( + completion_stream=response + ) + ) return translated_response diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index 238fe7136..5f7d933a9 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -1,7 +1,7 @@ model_list: - - model_name: "*" + - model_name: "claude-3-5-sonnet-20240620" litellm_params: - model: "*" + model: "claude-3-5-sonnet-20240620" # litellm_settings: # failure_callback: ["langfuse"] diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 0f57a5fd1..538feac49 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -2396,7 +2396,9 @@ async def async_data_generator( user_api_key_dict=user_api_key_dict, response=chunk ) - chunk = chunk.model_dump_json(exclude_none=True, exclude_unset=True) + if isinstance(chunk, BaseModel): + chunk = chunk.model_dump_json(exclude_none=True, exclude_unset=True) + try: yield f"data: {chunk}\n\n" except Exception as e: @@ -2437,6 +2439,59 @@ async def async_data_generator( yield f"data: {error_returned}\n\n" +async def async_data_generator_anthropic( + response, user_api_key_dict: UserAPIKeyAuth, request_data: dict +): + verbose_proxy_logger.debug("inside generator") + try: + start_time = time.time() + async for chunk in response: + verbose_proxy_logger.debug( + "async_data_generator: received streaming chunk - {}".format(chunk) + ) + ### CALL HOOKS ### - modify outgoing data + chunk = await proxy_logging_obj.async_post_call_streaming_hook( + user_api_key_dict=user_api_key_dict, response=chunk + ) + + event_type = chunk.get("type") + + try: + yield f"event: {event_type}\ndata:{json.dumps(chunk)}\n\n" + except Exception as e: + yield f"event: {event_type}\ndata:{str(e)}\n\n" + except Exception as e: + verbose_proxy_logger.error( + "litellm.proxy.proxy_server.async_data_generator(): Exception occured - {}\n{}".format( + str(e), traceback.format_exc() + ) + ) + await proxy_logging_obj.post_call_failure_hook( + user_api_key_dict=user_api_key_dict, + original_exception=e, + request_data=request_data, + ) + verbose_proxy_logger.debug( + f"\033[1;31mAn error occurred: {e}\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`" + ) + router_model_names = llm_router.model_names if llm_router is not None else [] + + if isinstance(e, HTTPException): + raise e + else: + error_traceback = traceback.format_exc() + error_msg = f"{str(e)}\n\n{error_traceback}" + + proxy_exception = ProxyException( + message=getattr(e, "message", error_msg), + type=getattr(e, "type", "None"), + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", 500), + ) + error_returned = json.dumps({"error": proxy_exception.to_dict()}) + yield f"data: {error_returned}\n\n" + + def select_data_generator( response, user_api_key_dict: UserAPIKeyAuth, request_data: dict ): @@ -5379,6 +5434,19 @@ async def anthropic_response( ) ) + if ( + "stream" in data and data["stream"] is True + ): # use generate_responses to stream responses + selected_data_generator = async_data_generator_anthropic( + response=response, + user_api_key_dict=user_api_key_dict, + request_data=data, + ) + return StreamingResponse( + selected_data_generator, + media_type="text/event-stream", + ) + verbose_proxy_logger.info("\nResponse from Litellm:\n{}".format(response)) return response except RejectedRequestError as e: @@ -5425,11 +5493,10 @@ async def anthropic_response( user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data ) verbose_proxy_logger.error( - "litellm.proxy.proxy_server.completion(): Exception occured - {}".format( - str(e) + "litellm.proxy.proxy_server.anthropic_response(): Exception occured - {}\n{}".format( + str(e), traceback.format_exc() ) ) - verbose_proxy_logger.debug(traceback.format_exc()) error_msg = f"{str(e)}" raise ProxyException( message=getattr(e, "message", error_msg), diff --git a/litellm/tests/test_anthropic_completion.py b/litellm/tests/test_anthropic_completion.py index 15d150a56..3611a44ca 100644 --- a/litellm/tests/test_anthropic_completion.py +++ b/litellm/tests/test_anthropic_completion.py @@ -8,6 +8,9 @@ import traceback from dotenv import load_dotenv +import litellm.types +import litellm.types.utils + load_dotenv() import io import os @@ -15,6 +18,7 @@ import os sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path +from typing import Optional from unittest.mock import MagicMock, patch import pytest @@ -84,7 +88,22 @@ def test_anthropic_completion_input_translation_with_metadata(): assert translated_input["metadata"] == data["litellm_metadata"] -def test_anthropic_completion_e2e(): +def streaming_format_tests(chunk: dict, idx: int): + """ + 1st chunk - chunk.get("type") == "message_start" + 2nd chunk - chunk.get("type") == "content_block_start" + 3rd chunk - chunk.get("type") == "content_block_delta" + """ + if idx == 0: + assert chunk.get("type") == "message_start" + elif idx == 1: + assert chunk.get("type") == "content_block_start" + elif idx == 2: + assert chunk.get("type") == "content_block_delta" + + +@pytest.mark.parametrize("stream", [True]) # False +def test_anthropic_completion_e2e(stream): litellm.set_verbose = True litellm.adapters = [{"id": "anthropic", "adapter": anthropic_adapter}] @@ -95,13 +114,40 @@ def test_anthropic_completion_e2e(): messages=messages, adapter_id="anthropic", mock_response="This is a fake call", + stream=stream, ) print("Response: {}".format(response)) assert response is not None - assert isinstance(response, AnthropicResponse) + if stream is False: + assert isinstance(response, AnthropicResponse) + else: + """ + - ensure finish reason is returned + - assert content block is started and stopped + - ensure last chunk is 'message_stop' + """ + assert isinstance(response, litellm.types.utils.AdapterCompletionStreamWrapper) + finish_reason: Optional[str] = None + message_stop_received = False + content_block_started = False + content_block_finished = False + for idx, chunk in enumerate(response): + print(chunk) + streaming_format_tests(chunk=chunk, idx=idx) + if chunk.get("delta", {}).get("stop_reason") is not None: + finish_reason = chunk.get("delta", {}).get("stop_reason") + if chunk.get("type") == "message_stop": + message_stop_received = True + if chunk.get("type") == "content_block_stop": + content_block_finished = True + if chunk.get("type") == "content_block_start": + content_block_started = True + assert content_block_started and content_block_finished + assert finish_reason is not None + assert message_stop_received is True @pytest.mark.asyncio diff --git a/litellm/types/llms/anthropic.py b/litellm/types/llms/anthropic.py index b41980afd..60784e913 100644 --- a/litellm/types/llms/anthropic.py +++ b/litellm/types/llms/anthropic.py @@ -136,7 +136,7 @@ class ContentJsonBlockDelta(TypedDict): class ContentBlockDelta(TypedDict): - type: str + type: Literal["content_block_delta"] index: int delta: Union[ContentTextBlockDelta, ContentJsonBlockDelta] diff --git a/litellm/types/utils.py b/litellm/types/utils.py index be77821ad..481f762ee 100644 --- a/litellm/types/utils.py +++ b/litellm/types/utils.py @@ -6,7 +6,7 @@ from typing import Dict, List, Literal, Optional, Tuple, Union from openai._models import BaseModel as OpenAIObject from pydantic import ConfigDict, Field, PrivateAttr -from typing_extensions import Dict, Required, TypedDict, override +from typing_extensions import Callable, Dict, Required, TypedDict, override from ..litellm_core_utils.core_helpers import map_finish_reason from .llms.openai import ChatCompletionToolCallChunk, ChatCompletionUsageBlock @@ -1069,3 +1069,36 @@ class LoggedLiteLLMParams(TypedDict, total=False): output_cost_per_token: Optional[float] output_cost_per_second: Optional[float] cooldown_time: Optional[float] + + +class AdapterCompletionStreamWrapper: + def __init__(self, completion_stream): + self.completion_stream = completion_stream + + def __iter__(self): + return self + + def __aiter__(self): + return self + + def __next__(self): + try: + for chunk in self.completion_stream: + if chunk == "None" or chunk is None: + raise Exception + return chunk + raise StopIteration + except StopIteration: + raise StopIteration + except Exception as e: + print(f"AdapterCompletionStreamWrapper - {e}") # noqa + + async def __anext__(self): + try: + async for chunk in self.completion_stream: + if chunk == "None" or chunk is None: + raise Exception + return chunk + raise StopIteration + except StopIteration: + raise StopAsyncIteration