From f121b8f63035dbf8c0fcd042e29633be54d24bc0 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Wed, 20 Nov 2024 12:02:15 -0800 Subject: [PATCH 1/7] move _process_response in transformation --- litellm/llms/anthropic/chat/handler.py | 167 +--------------- litellm/llms/anthropic/chat/transformation.py | 183 +++++++++++++++++- 2 files changed, 184 insertions(+), 166 deletions(-) diff --git a/litellm/llms/anthropic/chat/handler.py b/litellm/llms/anthropic/chat/handler.py index d565a16a0..86b1117ab 100644 --- a/litellm/llms/anthropic/chat/handler.py +++ b/litellm/llms/anthropic/chat/handler.py @@ -45,9 +45,7 @@ from litellm.types.llms.openai import ( ChatCompletionUsageBlock, ) from litellm.types.utils import GenericStreamingChunk -from litellm.types.utils import Message as LitellmMessage -from litellm.types.utils import PromptTokensDetailsWrapper -from litellm.utils import CustomStreamWrapper, ModelResponse, Usage +from litellm.utils import CustomStreamWrapper, ModelResponse from ...base import BaseLLM from ..common_utils import AnthropicError, process_anthropic_headers @@ -201,163 +199,6 @@ class AnthropicChatCompletion(BaseLLM): def __init__(self) -> None: super().__init__() - def _process_response( - self, - model: str, - response: Union[requests.Response, httpx.Response], - model_response: ModelResponse, - stream: bool, - logging_obj: litellm.litellm_core_utils.litellm_logging.Logging, # type: ignore - optional_params: dict, - api_key: str, - data: Union[dict, str], - messages: List, - print_verbose, - encoding, - json_mode: bool, - ) -> ModelResponse: - _hidden_params: Dict = {} - _hidden_params["additional_headers"] = process_anthropic_headers( - dict(response.headers) - ) - ## LOGGING - logging_obj.post_call( - input=messages, - api_key=api_key, - original_response=response.text, - additional_args={"complete_input_dict": data}, - ) - print_verbose(f"raw model_response: {response.text}") - ## RESPONSE OBJECT - try: - completion_response = response.json() - except Exception as e: - response_headers = getattr(response, "headers", None) - raise AnthropicError( - message="Unable to get json response - {}, Original Response: {}".format( - str(e), response.text - ), - status_code=response.status_code, - headers=response_headers, - ) - if "error" in completion_response: - response_headers = getattr(response, "headers", None) - raise AnthropicError( - message=str(completion_response["error"]), - status_code=response.status_code, - headers=response_headers, - ) - else: - text_content = "" - tool_calls: List[ChatCompletionToolCallChunk] = [] - for idx, content in enumerate(completion_response["content"]): - if content["type"] == "text": - text_content += content["text"] - ## TOOL CALLING - elif content["type"] == "tool_use": - tool_calls.append( - ChatCompletionToolCallChunk( - id=content["id"], - type="function", - function=ChatCompletionToolCallFunctionChunk( - name=content["name"], - arguments=json.dumps(content["input"]), - ), - index=idx, - ) - ) - - _message = litellm.Message( - tool_calls=tool_calls, - content=text_content or None, - ) - - ## HANDLE JSON MODE - anthropic returns single function call - if json_mode and len(tool_calls) == 1: - json_mode_content_str: Optional[str] = tool_calls[0]["function"].get( - "arguments" - ) - if json_mode_content_str is not None: - _converted_message = self._convert_tool_response_to_message( - tool_calls=tool_calls, - ) - if _converted_message is not None: - completion_response["stop_reason"] = "stop" - _message = _converted_message - model_response.choices[0].message = _message # type: ignore - model_response._hidden_params["original_response"] = completion_response[ - "content" - ] # allow user to access raw anthropic tool calling response - - model_response.choices[0].finish_reason = map_finish_reason( - completion_response["stop_reason"] - ) - - ## CALCULATING USAGE - prompt_tokens = completion_response["usage"]["input_tokens"] - completion_tokens = completion_response["usage"]["output_tokens"] - _usage = completion_response["usage"] - cache_creation_input_tokens: int = 0 - cache_read_input_tokens: int = 0 - - model_response.created = int(time.time()) - model_response.model = model - if "cache_creation_input_tokens" in _usage: - cache_creation_input_tokens = _usage["cache_creation_input_tokens"] - prompt_tokens += cache_creation_input_tokens - if "cache_read_input_tokens" in _usage: - cache_read_input_tokens = _usage["cache_read_input_tokens"] - prompt_tokens += cache_read_input_tokens - - prompt_tokens_details = PromptTokensDetailsWrapper( - cached_tokens=cache_read_input_tokens - ) - total_tokens = prompt_tokens + completion_tokens - usage = Usage( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=total_tokens, - prompt_tokens_details=prompt_tokens_details, - cache_creation_input_tokens=cache_creation_input_tokens, - cache_read_input_tokens=cache_read_input_tokens, - ) - - setattr(model_response, "usage", usage) # type: ignore - - model_response._hidden_params = _hidden_params - return model_response - - @staticmethod - def _convert_tool_response_to_message( - tool_calls: List[ChatCompletionToolCallChunk], - ) -> Optional[LitellmMessage]: - """ - In JSON mode, Anthropic API returns JSON schema as a tool call, we need to convert it to a message to follow the OpenAI format - - """ - ## HANDLE JSON MODE - anthropic returns single function call - json_mode_content_str: Optional[str] = tool_calls[0]["function"].get( - "arguments" - ) - try: - if json_mode_content_str is not None: - args = json.loads(json_mode_content_str) - if ( - isinstance(args, dict) - and (values := args.get("values")) is not None - ): - _message = litellm.Message(content=json.dumps(values)) - return _message - else: - # a lot of the times the `values` key is not present in the tool response - # relevant issue: https://github.com/BerriAI/litellm/issues/6741 - _message = litellm.Message(content=json.dumps(args)) - return _message - except json.JSONDecodeError: - # json decode error does occur, return the original tool response str - return litellm.Message(content=json_mode_content_str) - return None - async def acompletion_stream_function( self, model: str, @@ -454,7 +295,7 @@ class AnthropicChatCompletion(BaseLLM): headers=error_headers, ) - return self._process_response( + return AnthropicConfig._process_response( model=model, response=response, model_response=model_response, @@ -630,7 +471,7 @@ class AnthropicChatCompletion(BaseLLM): headers=error_headers, ) - return self._process_response( + return AnthropicConfig._process_response( model=model, response=response, model_response=model_response, @@ -855,7 +696,7 @@ class ModelResponseIterator: tool_use: The ChatCompletionToolCallChunk to use in the chunk response """ if self.json_mode is True and tool_use is not None: - message = AnthropicChatCompletion._convert_tool_response_to_message( + message = AnthropicConfig._convert_tool_response_to_message( tool_calls=[tool_use] ) if message is not None: diff --git a/litellm/llms/anthropic/chat/transformation.py b/litellm/llms/anthropic/chat/transformation.py index 1419d7ef2..860aacb44 100644 --- a/litellm/llms/anthropic/chat/transformation.py +++ b/litellm/llms/anthropic/chat/transformation.py @@ -1,7 +1,14 @@ +import json +import time import types -from typing import List, Literal, Optional, Tuple, Union +from re import A +from typing import Dict, List, Literal, Optional, Tuple, Union + +import httpx +import requests import litellm +from litellm.litellm_core_utils.core_helpers import map_finish_reason from litellm.llms.prompt_templates.factory import anthropic_messages_pt from litellm.types.llms.anthropic import ( AllAnthropicToolsValues, @@ -18,12 +25,23 @@ from litellm.types.llms.openai import ( AllMessageValues, ChatCompletionCachedContent, ChatCompletionSystemMessage, + ChatCompletionToolCallChunk, + ChatCompletionToolCallFunctionChunk, ChatCompletionToolParam, ChatCompletionToolParamFunctionChunk, + ChatCompletionUsageBlock, +) +from litellm.types.utils import Message as LitellmMessage +from litellm.types.utils import PromptTokensDetailsWrapper +from litellm.utils import ( + CustomStreamWrapper, + ModelResponse, + Usage, + add_dummy_tool, + has_tool_call_blocks, ) -from litellm.utils import add_dummy_tool, has_tool_call_blocks -from ..common_utils import AnthropicError +from ..common_utils import AnthropicError, process_anthropic_headers class AnthropicConfig: @@ -534,3 +552,162 @@ class AnthropicConfig: if not is_vertex_request: data["model"] = model return data + + @staticmethod + def _process_response( + model: str, + response: Union[requests.Response, httpx.Response], + model_response: ModelResponse, + stream: bool, + logging_obj: litellm.litellm_core_utils.litellm_logging.Logging, # type: ignore + optional_params: dict, + api_key: str, + data: Union[dict, str], + messages: List, + print_verbose, + encoding, + json_mode: bool, + ) -> ModelResponse: + _hidden_params: Dict = {} + _hidden_params["additional_headers"] = process_anthropic_headers( + dict(response.headers) + ) + ## LOGGING + logging_obj.post_call( + input=messages, + api_key=api_key, + original_response=response.text, + additional_args={"complete_input_dict": data}, + ) + print_verbose(f"raw model_response: {response.text}") + ## RESPONSE OBJECT + try: + completion_response = response.json() + except Exception as e: + response_headers = getattr(response, "headers", None) + raise AnthropicError( + message="Unable to get json response - {}, Original Response: {}".format( + str(e), response.text + ), + status_code=response.status_code, + headers=response_headers, + ) + if "error" in completion_response: + response_headers = getattr(response, "headers", None) + raise AnthropicError( + message=str(completion_response["error"]), + status_code=response.status_code, + headers=response_headers, + ) + else: + text_content = "" + tool_calls: List[ChatCompletionToolCallChunk] = [] + for idx, content in enumerate(completion_response["content"]): + if content["type"] == "text": + text_content += content["text"] + ## TOOL CALLING + elif content["type"] == "tool_use": + tool_calls.append( + ChatCompletionToolCallChunk( + id=content["id"], + type="function", + function=ChatCompletionToolCallFunctionChunk( + name=content["name"], + arguments=json.dumps(content["input"]), + ), + index=idx, + ) + ) + + _message = litellm.Message( + tool_calls=tool_calls, + content=text_content or None, + ) + + ## HANDLE JSON MODE - anthropic returns single function call + if json_mode and len(tool_calls) == 1: + json_mode_content_str: Optional[str] = tool_calls[0]["function"].get( + "arguments" + ) + if json_mode_content_str is not None: + _converted_message = ( + AnthropicConfig._convert_tool_response_to_message( + tool_calls=tool_calls, + ) + ) + if _converted_message is not None: + completion_response["stop_reason"] = "stop" + _message = _converted_message + model_response.choices[0].message = _message # type: ignore + model_response._hidden_params["original_response"] = completion_response[ + "content" + ] # allow user to access raw anthropic tool calling response + + model_response.choices[0].finish_reason = map_finish_reason( + completion_response["stop_reason"] + ) + + ## CALCULATING USAGE + prompt_tokens = completion_response["usage"]["input_tokens"] + completion_tokens = completion_response["usage"]["output_tokens"] + _usage = completion_response["usage"] + cache_creation_input_tokens: int = 0 + cache_read_input_tokens: int = 0 + + model_response.created = int(time.time()) + model_response.model = model + if "cache_creation_input_tokens" in _usage: + cache_creation_input_tokens = _usage["cache_creation_input_tokens"] + prompt_tokens += cache_creation_input_tokens + if "cache_read_input_tokens" in _usage: + cache_read_input_tokens = _usage["cache_read_input_tokens"] + prompt_tokens += cache_read_input_tokens + + prompt_tokens_details = PromptTokensDetailsWrapper( + cached_tokens=cache_read_input_tokens + ) + total_tokens = prompt_tokens + completion_tokens + usage = Usage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=total_tokens, + prompt_tokens_details=prompt_tokens_details, + cache_creation_input_tokens=cache_creation_input_tokens, + cache_read_input_tokens=cache_read_input_tokens, + ) + + setattr(model_response, "usage", usage) # type: ignore + + model_response._hidden_params = _hidden_params + return model_response + + @staticmethod + def _convert_tool_response_to_message( + tool_calls: List[ChatCompletionToolCallChunk], + ) -> Optional[LitellmMessage]: + """ + In JSON mode, Anthropic API returns JSON schema as a tool call, we need to convert it to a message to follow the OpenAI format + + """ + ## HANDLE JSON MODE - anthropic returns single function call + json_mode_content_str: Optional[str] = tool_calls[0]["function"].get( + "arguments" + ) + try: + if json_mode_content_str is not None: + args = json.loads(json_mode_content_str) + if ( + isinstance(args, dict) + and (values := args.get("values")) is not None + ): + _message = litellm.Message(content=json.dumps(values)) + return _message + else: + # a lot of the times the `values` key is not present in the tool response + # relevant issue: https://github.com/BerriAI/litellm/issues/6741 + _message = litellm.Message(content=json.dumps(args)) + return _message + except json.JSONDecodeError: + # json decode error does occur, return the original tool response str + return litellm.Message(content=json_mode_content_str) + return None From b3b1ff6882e81fb811f5bb8ea0f9912d06a01d32 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Wed, 20 Nov 2024 12:07:28 -0800 Subject: [PATCH 2/7] fix AnthropicConfig test --- .../llm_translation/test_anthropic_completion.py | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/tests/llm_translation/test_anthropic_completion.py b/tests/llm_translation/test_anthropic_completion.py index 8a788e0fb..9781297fe 100644 --- a/tests/llm_translation/test_anthropic_completion.py +++ b/tests/llm_translation/test_anthropic_completion.py @@ -712,9 +712,7 @@ def test_convert_tool_response_to_message_with_values(): ) ] - message = AnthropicChatCompletion._convert_tool_response_to_message( - tool_calls=tool_calls - ) + message = AnthropicConfig._convert_tool_response_to_message(tool_calls=tool_calls) assert message is not None assert message.content == '{"name": "John", "age": 30}' @@ -739,9 +737,7 @@ def test_convert_tool_response_to_message_without_values(): ) ] - message = AnthropicChatCompletion._convert_tool_response_to_message( - tool_calls=tool_calls - ) + message = AnthropicConfig._convert_tool_response_to_message(tool_calls=tool_calls) assert message is not None assert message.content == '{"name": "John", "age": 30}' @@ -760,9 +756,7 @@ def test_convert_tool_response_to_message_invalid_json(): ) ] - message = AnthropicChatCompletion._convert_tool_response_to_message( - tool_calls=tool_calls - ) + message = AnthropicConfig._convert_tool_response_to_message(tool_calls=tool_calls) assert message is not None assert message.content == "invalid json" @@ -779,8 +773,6 @@ def test_convert_tool_response_to_message_no_arguments(): ) ] - message = AnthropicChatCompletion._convert_tool_response_to_message( - tool_calls=tool_calls - ) + message = AnthropicConfig._convert_tool_response_to_message(tool_calls=tool_calls) assert message is None From 83a722a34be6e5a4674559b80ac55a19f19695b9 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Wed, 20 Nov 2024 12:09:32 -0800 Subject: [PATCH 3/7] add AnthropicConfig --- .../pass_through_endpoints/success_handler.py | 37 +++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/litellm/proxy/pass_through_endpoints/success_handler.py b/litellm/proxy/pass_through_endpoints/success_handler.py index 0a7ae541d..15a4b3f82 100644 --- a/litellm/proxy/pass_through_endpoints/success_handler.py +++ b/litellm/proxy/pass_through_endpoints/success_handler.py @@ -8,6 +8,7 @@ import httpx import litellm from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj +from litellm.llms.anthropic.chat.transformation import AnthropicConfig from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import ( VertexLLM, ) @@ -23,6 +24,9 @@ class PassThroughEndpointLogging: "predict", ] + # Anthropic + self.TRACKED_ANTHROPIC_ROUTES = ["/messages"] + async def pass_through_async_success_handler( self, httpx_response: httpx.Response, @@ -45,6 +49,17 @@ class PassThroughEndpointLogging: cache_hit=cache_hit, **kwargs, ) + elif self.is_anthropic_route(url_route): + await self.anthropic_passthrough_handler( + httpx_response=httpx_response, + logging_obj=logging_obj, + url_route=url_route, + result=result, + start_time=start_time, + end_time=end_time, + cache_hit=cache_hit, + **kwargs, + ) else: standard_logging_response_object = StandardPassThroughResponseObject( response=httpx_response.text @@ -76,6 +91,12 @@ class PassThroughEndpointLogging: return True return False + def is_anthropic_route(self, url_route: str): + for route in self.TRACKED_ANTHROPIC_ROUTES: + if route in url_route: + return True + return False + def extract_model_from_url(self, url: str) -> str: pattern = r"/models/([^:]+)" match = re.search(pattern, url) @@ -83,6 +104,22 @@ class PassThroughEndpointLogging: return match.group(1) return "unknown" + async def anthropic_passthrough_handler( + self, + httpx_response: httpx.Response, + logging_obj: LiteLLMLoggingObj, + url_route: str, + result: str, + start_time: datetime, + end_time: datetime, + cache_hit: bool, + **kwargs, + ): + """ + Transforms Anthropic response to OpenAI response, generates a standard logging object so downstream logging can be handled + """ + pass + async def vertex_passthrough_handler( self, httpx_response: httpx.Response, From bb7fe53bc5981d08c4ddaa2b11e69712102dd8ca Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Wed, 20 Nov 2024 13:18:35 -0800 Subject: [PATCH 4/7] fix anthropic_passthrough_handler --- .../pass_through_endpoints/success_handler.py | 58 ++++++++++++++++++- 1 file changed, 57 insertions(+), 1 deletion(-) diff --git a/litellm/proxy/pass_through_endpoints/success_handler.py b/litellm/proxy/pass_through_endpoints/success_handler.py index 15a4b3f82..05ba53fa0 100644 --- a/litellm/proxy/pass_through_endpoints/success_handler.py +++ b/litellm/proxy/pass_through_endpoints/success_handler.py @@ -2,12 +2,16 @@ import json import re import threading from datetime import datetime -from typing import Union +from typing import Optional, Union import httpx import litellm +from litellm._logging import verbose_proxy_logger from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj +from litellm.litellm_core_utils.litellm_logging import ( + get_standard_logging_object_payload, +) from litellm.llms.anthropic.chat.transformation import AnthropicConfig from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import ( VertexLLM, @@ -30,6 +34,7 @@ class PassThroughEndpointLogging: async def pass_through_async_success_handler( self, httpx_response: httpx.Response, + response_body: Optional[dict], logging_obj: LiteLLMLoggingObj, url_route: str, result: str, @@ -52,6 +57,7 @@ class PassThroughEndpointLogging: elif self.is_anthropic_route(url_route): await self.anthropic_passthrough_handler( httpx_response=httpx_response, + response_body=response_body or {}, logging_obj=logging_obj, url_route=url_route, result=result, @@ -107,6 +113,7 @@ class PassThroughEndpointLogging: async def anthropic_passthrough_handler( self, httpx_response: httpx.Response, + response_body: dict, logging_obj: LiteLLMLoggingObj, url_route: str, result: str, @@ -118,6 +125,55 @@ class PassThroughEndpointLogging: """ Transforms Anthropic response to OpenAI response, generates a standard logging object so downstream logging can be handled """ + model = response_body.get("model", "") + litellm_model_response: litellm.ModelResponse = ( + AnthropicConfig._process_response( + response=httpx_response, + model_response=litellm.ModelResponse(), + model=model, + stream=False, + messages=[], + logging_obj=logging_obj, + optional_params={}, + api_key="", + data={}, + print_verbose=litellm.print_verbose, + encoding=None, + json_mode=False, + ) + ) + + response_cost = litellm.completion_cost( + completion_response=litellm_model_response, + model=model, + ) + kwargs["response_cost"] = response_cost + kwargs["model"] = model + + # Make standard logging object for Vertex AI + standard_logging_object = get_standard_logging_object_payload( + kwargs=kwargs, + init_response_obj=litellm_model_response, + start_time=start_time, + end_time=end_time, + logging_obj=logging_obj, + status="success", + ) + + # pretty print standard logging object + verbose_proxy_logger.debug( + "standard_logging_object= %s", json.dumps(standard_logging_object, indent=4) + ) + kwargs["standard_logging_object"] = standard_logging_object + + await logging_obj.async_success_handler( + result=litellm_model_response, + start_time=start_time, + end_time=end_time, + cache_hit=cache_hit, + **kwargs, + ) + pass async def vertex_passthrough_handler( From 9f916636e156ab31e0707ce95285b0ca7f0dc4a3 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Wed, 20 Nov 2024 13:18:51 -0800 Subject: [PATCH 5/7] fix get_response_body --- .../pass_through_endpoints/pass_through_endpoints.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py index 548d07689..7b9e17842 100644 --- a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py +++ b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py @@ -45,11 +45,11 @@ router = APIRouter() pass_through_endpoint_logging = PassThroughEndpointLogging() -def get_response_body(response: httpx.Response): +def get_response_body(response: httpx.Response) -> Optional[dict]: try: return response.json() except Exception: - return response.text + return None async def set_env_variables_in_header(custom_headers: Optional[dict]) -> Optional[dict]: @@ -519,10 +519,12 @@ async def pass_through_request( # noqa: PLR0915 content = await response.aread() ## LOG SUCCESS - passthrough_logging_payload["response_body"] = get_response_body(response) + response_body: Optional[dict] = get_response_body(response) + passthrough_logging_payload["response_body"] = response_body end_time = datetime.now() await pass_through_endpoint_logging.pass_through_async_success_handler( httpx_response=response, + response_body=response_body, url_route=str(url), result="", start_time=start_time, From acf350a2fb39ad9593028081d73f2dd54edbaba3 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Wed, 20 Nov 2024 15:15:21 -0800 Subject: [PATCH 6/7] fix check for streaming response --- .../pass_through_endpoints.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py index 7b9e17842..8be241458 100644 --- a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py +++ b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py @@ -478,10 +478,9 @@ async def pass_through_request( # noqa: PLR0915 json=_parsed_body, ) - if ( - response.headers.get("content-type") is not None - and response.headers["content-type"] == "text/event-stream" - ): + verbose_proxy_logger.debug("response.headers= %s", response.headers) + + if _is_streaming_response(response) is True: try: response.raise_for_status() except httpx.HTTPStatusError as e: @@ -621,6 +620,13 @@ def create_pass_through_route( return endpoint_func +def _is_streaming_response(response: httpx.Response) -> bool: + _content_type = response.headers.get("content-type") + if _content_type is not None and "text/event-stream" in _content_type: + return True + return False + + async def initialize_pass_through_endpoints(pass_through_endpoints: list): verbose_proxy_logger.debug("initializing pass through endpoints") From 97ecedf997f03fbd6187959679ade4daeac760a9 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Wed, 20 Nov 2024 15:49:33 -0800 Subject: [PATCH 7/7] use 1 helper to return stream_response on passthrough --- .../pass_through_endpoints.py | 54 +++++++++++-------- litellm/proxy/pass_through_endpoints/types.py | 1 + 2 files changed, 33 insertions(+), 22 deletions(-) diff --git a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py index 8be241458..6c9a93849 100644 --- a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py +++ b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py @@ -303,9 +303,29 @@ def get_response_headers(headers: httpx.Headers) -> dict: def get_endpoint_type(url: str) -> EndpointType: if ("generateContent") in url or ("streamGenerateContent") in url: return EndpointType.VERTEX_AI + elif ("api.anthropic.com") in url: + return EndpointType.ANTHROPIC return EndpointType.GENERIC +async def stream_response( + response: httpx.Response, + logging_obj: LiteLLMLoggingObj, + endpoint_type: EndpointType, + start_time: datetime, + url: str, +) -> AsyncIterable[bytes]: + async for chunk in chunk_processor( + response.aiter_bytes(), + litellm_logging_obj=logging_obj, + endpoint_type=endpoint_type, + start_time=start_time, + passthrough_success_handler_obj=pass_through_endpoint_logging, + url_route=str(url), + ): + yield chunk + + async def pass_through_request( # noqa: PLR0915 request: Request, target: str, @@ -445,19 +465,14 @@ async def pass_through_request( # noqa: PLR0915 status_code=e.response.status_code, detail=await e.response.aread() ) - async def stream_response() -> AsyncIterable[bytes]: - async for chunk in chunk_processor( - response.aiter_bytes(), - litellm_logging_obj=logging_obj, + return StreamingResponse( + stream_response( + response=response, + logging_obj=logging_obj, endpoint_type=endpoint_type, start_time=start_time, - passthrough_success_handler_obj=pass_through_endpoint_logging, - url_route=str(url), - ): - yield chunk - - return StreamingResponse( - stream_response(), + url=str(url), + ), headers=get_response_headers(response.headers), status_code=response.status_code, ) @@ -488,19 +503,14 @@ async def pass_through_request( # noqa: PLR0915 status_code=e.response.status_code, detail=await e.response.aread() ) - async def stream_response() -> AsyncIterable[bytes]: - async for chunk in chunk_processor( - response.aiter_bytes(), - litellm_logging_obj=logging_obj, + return StreamingResponse( + stream_response( + response=response, + logging_obj=logging_obj, endpoint_type=endpoint_type, start_time=start_time, - passthrough_success_handler_obj=pass_through_endpoint_logging, - url_route=str(url), - ): - yield chunk - - return StreamingResponse( - stream_response(), + url=str(url), + ), headers=get_response_headers(response.headers), status_code=response.status_code, ) diff --git a/litellm/proxy/pass_through_endpoints/types.py b/litellm/proxy/pass_through_endpoints/types.py index b3aa4418d..59047a630 100644 --- a/litellm/proxy/pass_through_endpoints/types.py +++ b/litellm/proxy/pass_through_endpoints/types.py @@ -4,6 +4,7 @@ from typing import Optional, TypedDict class EndpointType(str, Enum): VERTEX_AI = "vertex-ai" + ANTHROPIC = "anthropic" GENERIC = "generic"