diff --git a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py index 548d07689..6c9a93849 100644 --- a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py +++ b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py @@ -45,11 +45,11 @@ router = APIRouter() pass_through_endpoint_logging = PassThroughEndpointLogging() -def get_response_body(response: httpx.Response): +def get_response_body(response: httpx.Response) -> Optional[dict]: try: return response.json() except Exception: - return response.text + return None async def set_env_variables_in_header(custom_headers: Optional[dict]) -> Optional[dict]: @@ -303,9 +303,29 @@ def get_response_headers(headers: httpx.Headers) -> dict: def get_endpoint_type(url: str) -> EndpointType: if ("generateContent") in url or ("streamGenerateContent") in url: return EndpointType.VERTEX_AI + elif ("api.anthropic.com") in url: + return EndpointType.ANTHROPIC return EndpointType.GENERIC +async def stream_response( + response: httpx.Response, + logging_obj: LiteLLMLoggingObj, + endpoint_type: EndpointType, + start_time: datetime, + url: str, +) -> AsyncIterable[bytes]: + async for chunk in chunk_processor( + response.aiter_bytes(), + litellm_logging_obj=logging_obj, + endpoint_type=endpoint_type, + start_time=start_time, + passthrough_success_handler_obj=pass_through_endpoint_logging, + url_route=str(url), + ): + yield chunk + + async def pass_through_request( # noqa: PLR0915 request: Request, target: str, @@ -445,19 +465,14 @@ async def pass_through_request( # noqa: PLR0915 status_code=e.response.status_code, detail=await e.response.aread() ) - async def stream_response() -> AsyncIterable[bytes]: - async for chunk in chunk_processor( - response.aiter_bytes(), - litellm_logging_obj=logging_obj, + return StreamingResponse( + stream_response( + response=response, + logging_obj=logging_obj, endpoint_type=endpoint_type, start_time=start_time, - passthrough_success_handler_obj=pass_through_endpoint_logging, - url_route=str(url), - ): - yield chunk - - return StreamingResponse( - stream_response(), + url=str(url), + ), headers=get_response_headers(response.headers), status_code=response.status_code, ) @@ -478,10 +493,9 @@ async def pass_through_request( # noqa: PLR0915 json=_parsed_body, ) - if ( - response.headers.get("content-type") is not None - and response.headers["content-type"] == "text/event-stream" - ): + verbose_proxy_logger.debug("response.headers= %s", response.headers) + + if _is_streaming_response(response) is True: try: response.raise_for_status() except httpx.HTTPStatusError as e: @@ -489,19 +503,14 @@ async def pass_through_request( # noqa: PLR0915 status_code=e.response.status_code, detail=await e.response.aread() ) - async def stream_response() -> AsyncIterable[bytes]: - async for chunk in chunk_processor( - response.aiter_bytes(), - litellm_logging_obj=logging_obj, + return StreamingResponse( + stream_response( + response=response, + logging_obj=logging_obj, endpoint_type=endpoint_type, start_time=start_time, - passthrough_success_handler_obj=pass_through_endpoint_logging, - url_route=str(url), - ): - yield chunk - - return StreamingResponse( - stream_response(), + url=str(url), + ), headers=get_response_headers(response.headers), status_code=response.status_code, ) @@ -519,10 +528,12 @@ async def pass_through_request( # noqa: PLR0915 content = await response.aread() ## LOG SUCCESS - passthrough_logging_payload["response_body"] = get_response_body(response) + response_body: Optional[dict] = get_response_body(response) + passthrough_logging_payload["response_body"] = response_body end_time = datetime.now() await pass_through_endpoint_logging.pass_through_async_success_handler( httpx_response=response, + response_body=response_body, url_route=str(url), result="", start_time=start_time, @@ -619,6 +630,13 @@ def create_pass_through_route( return endpoint_func +def _is_streaming_response(response: httpx.Response) -> bool: + _content_type = response.headers.get("content-type") + if _content_type is not None and "text/event-stream" in _content_type: + return True + return False + + async def initialize_pass_through_endpoints(pass_through_endpoints: list): verbose_proxy_logger.debug("initializing pass through endpoints") diff --git a/litellm/proxy/pass_through_endpoints/success_handler.py b/litellm/proxy/pass_through_endpoints/success_handler.py index 0a7ae541d..05ba53fa0 100644 --- a/litellm/proxy/pass_through_endpoints/success_handler.py +++ b/litellm/proxy/pass_through_endpoints/success_handler.py @@ -2,12 +2,17 @@ import json import re import threading from datetime import datetime -from typing import Union +from typing import Optional, Union import httpx import litellm +from litellm._logging import verbose_proxy_logger from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj +from litellm.litellm_core_utils.litellm_logging import ( + get_standard_logging_object_payload, +) +from litellm.llms.anthropic.chat.transformation import AnthropicConfig from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import ( VertexLLM, ) @@ -23,9 +28,13 @@ class PassThroughEndpointLogging: "predict", ] + # Anthropic + self.TRACKED_ANTHROPIC_ROUTES = ["/messages"] + async def pass_through_async_success_handler( self, httpx_response: httpx.Response, + response_body: Optional[dict], logging_obj: LiteLLMLoggingObj, url_route: str, result: str, @@ -45,6 +54,18 @@ class PassThroughEndpointLogging: cache_hit=cache_hit, **kwargs, ) + elif self.is_anthropic_route(url_route): + await self.anthropic_passthrough_handler( + httpx_response=httpx_response, + response_body=response_body or {}, + logging_obj=logging_obj, + url_route=url_route, + result=result, + start_time=start_time, + end_time=end_time, + cache_hit=cache_hit, + **kwargs, + ) else: standard_logging_response_object = StandardPassThroughResponseObject( response=httpx_response.text @@ -76,6 +97,12 @@ class PassThroughEndpointLogging: return True return False + def is_anthropic_route(self, url_route: str): + for route in self.TRACKED_ANTHROPIC_ROUTES: + if route in url_route: + return True + return False + def extract_model_from_url(self, url: str) -> str: pattern = r"/models/([^:]+)" match = re.search(pattern, url) @@ -83,6 +110,72 @@ class PassThroughEndpointLogging: return match.group(1) return "unknown" + async def anthropic_passthrough_handler( + self, + httpx_response: httpx.Response, + response_body: dict, + logging_obj: LiteLLMLoggingObj, + url_route: str, + result: str, + start_time: datetime, + end_time: datetime, + cache_hit: bool, + **kwargs, + ): + """ + Transforms Anthropic response to OpenAI response, generates a standard logging object so downstream logging can be handled + """ + model = response_body.get("model", "") + litellm_model_response: litellm.ModelResponse = ( + AnthropicConfig._process_response( + response=httpx_response, + model_response=litellm.ModelResponse(), + model=model, + stream=False, + messages=[], + logging_obj=logging_obj, + optional_params={}, + api_key="", + data={}, + print_verbose=litellm.print_verbose, + encoding=None, + json_mode=False, + ) + ) + + response_cost = litellm.completion_cost( + completion_response=litellm_model_response, + model=model, + ) + kwargs["response_cost"] = response_cost + kwargs["model"] = model + + # Make standard logging object for Vertex AI + standard_logging_object = get_standard_logging_object_payload( + kwargs=kwargs, + init_response_obj=litellm_model_response, + start_time=start_time, + end_time=end_time, + logging_obj=logging_obj, + status="success", + ) + + # pretty print standard logging object + verbose_proxy_logger.debug( + "standard_logging_object= %s", json.dumps(standard_logging_object, indent=4) + ) + kwargs["standard_logging_object"] = standard_logging_object + + await logging_obj.async_success_handler( + result=litellm_model_response, + start_time=start_time, + end_time=end_time, + cache_hit=cache_hit, + **kwargs, + ) + + pass + async def vertex_passthrough_handler( self, httpx_response: httpx.Response, diff --git a/litellm/proxy/pass_through_endpoints/types.py b/litellm/proxy/pass_through_endpoints/types.py index b3aa4418d..59047a630 100644 --- a/litellm/proxy/pass_through_endpoints/types.py +++ b/litellm/proxy/pass_through_endpoints/types.py @@ -4,6 +4,7 @@ from typing import Optional, TypedDict class EndpointType(str, Enum): VERTEX_AI = "vertex-ai" + ANTHROPIC = "anthropic" GENERIC = "generic"