From 97ecedf997f03fbd6187959679ade4daeac760a9 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Wed, 20 Nov 2024 15:49:33 -0800 Subject: [PATCH] use 1 helper to return stream_response on passthrough --- .../pass_through_endpoints.py | 54 +++++++++++-------- litellm/proxy/pass_through_endpoints/types.py | 1 + 2 files changed, 33 insertions(+), 22 deletions(-) diff --git a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py index 8be241458..6c9a93849 100644 --- a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py +++ b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py @@ -303,9 +303,29 @@ def get_response_headers(headers: httpx.Headers) -> dict: def get_endpoint_type(url: str) -> EndpointType: if ("generateContent") in url or ("streamGenerateContent") in url: return EndpointType.VERTEX_AI + elif ("api.anthropic.com") in url: + return EndpointType.ANTHROPIC return EndpointType.GENERIC +async def stream_response( + response: httpx.Response, + logging_obj: LiteLLMLoggingObj, + endpoint_type: EndpointType, + start_time: datetime, + url: str, +) -> AsyncIterable[bytes]: + async for chunk in chunk_processor( + response.aiter_bytes(), + litellm_logging_obj=logging_obj, + endpoint_type=endpoint_type, + start_time=start_time, + passthrough_success_handler_obj=pass_through_endpoint_logging, + url_route=str(url), + ): + yield chunk + + async def pass_through_request( # noqa: PLR0915 request: Request, target: str, @@ -445,19 +465,14 @@ async def pass_through_request( # noqa: PLR0915 status_code=e.response.status_code, detail=await e.response.aread() ) - async def stream_response() -> AsyncIterable[bytes]: - async for chunk in chunk_processor( - response.aiter_bytes(), - litellm_logging_obj=logging_obj, + return StreamingResponse( + stream_response( + response=response, + logging_obj=logging_obj, endpoint_type=endpoint_type, start_time=start_time, - passthrough_success_handler_obj=pass_through_endpoint_logging, - url_route=str(url), - ): - yield chunk - - return StreamingResponse( - stream_response(), + url=str(url), + ), headers=get_response_headers(response.headers), status_code=response.status_code, ) @@ -488,19 +503,14 @@ async def pass_through_request( # noqa: PLR0915 status_code=e.response.status_code, detail=await e.response.aread() ) - async def stream_response() -> AsyncIterable[bytes]: - async for chunk in chunk_processor( - response.aiter_bytes(), - litellm_logging_obj=logging_obj, + return StreamingResponse( + stream_response( + response=response, + logging_obj=logging_obj, endpoint_type=endpoint_type, start_time=start_time, - passthrough_success_handler_obj=pass_through_endpoint_logging, - url_route=str(url), - ): - yield chunk - - return StreamingResponse( - stream_response(), + url=str(url), + ), headers=get_response_headers(response.headers), status_code=response.status_code, ) diff --git a/litellm/proxy/pass_through_endpoints/types.py b/litellm/proxy/pass_through_endpoints/types.py index b3aa4418d..59047a630 100644 --- a/litellm/proxy/pass_through_endpoints/types.py +++ b/litellm/proxy/pass_through_endpoints/types.py @@ -4,6 +4,7 @@ from typing import Optional, TypedDict class EndpointType(str, Enum): VERTEX_AI = "vertex-ai" + ANTHROPIC = "anthropic" GENERIC = "generic"