diff --git a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py index 5ee6cc6d6b..511ce1ee9f 100644 --- a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py +++ b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py @@ -3,7 +3,7 @@ import asyncio import json import traceback from base64 import b64encode -from typing import List, Optional +from typing import AsyncIterable, List, Optional import httpx from fastapi import ( @@ -267,6 +267,17 @@ def forward_headers_from_request( return headers +def get_response_headers(headers: httpx.Headers) -> dict: + excluded_headers = {"transfer-encoding", "content-encoding"} + return_headers = { + key: value + for key, value in headers.items() + if key.lower() not in excluded_headers + } + + return return_headers + + async def pass_through_request( request: Request, target: str, @@ -274,6 +285,7 @@ async def pass_through_request( user_api_key_dict: UserAPIKeyAuth, forward_headers: Optional[bool] = False, query_params: Optional[dict] = None, + stream: Optional[bool] = None, ): try: import time @@ -292,7 +304,7 @@ async def pass_through_request( body_str = request_body.decode() try: _parsed_body = ast.literal_eval(body_str) - except: + except Exception: _parsed_body = json.loads(body_str) verbose_proxy_logger.debug( @@ -364,6 +376,27 @@ async def pass_through_request( }, ) + if stream: + req = async_client.build_request( + "POST", + url, + json=_parsed_body, + params=requested_query_params, + headers=headers, + ) + response = await async_client.send(req, stream=stream) + + # Create an async generator to yield the response content + async def stream_response() -> AsyncIterable[bytes]: + async for chunk in response.aiter_bytes(): + yield chunk + + return StreamingResponse( + stream_response(), + headers=get_response_headers(response.headers), + status_code=response.status_code, + ) + response = await async_client.request( method=request.method, url=url, @@ -372,6 +405,22 @@ async def pass_through_request( json=_parsed_body, ) + if ( + response.headers.get("content-type") is not None + and response.headers["content-type"] == "text/event-stream" + ): + # streaming response + # Create an async generator to yield the response content + async def stream_response() -> AsyncIterable[bytes]: + async for chunk in response.aiter_bytes(): + yield chunk + + return StreamingResponse( + stream_response(), + headers=get_response_headers(response.headers), + status_code=response.status_code, + ) + if response.status_code >= 300: raise HTTPException(status_code=response.status_code, detail=response.text) @@ -387,17 +436,10 @@ async def pass_through_request( cache_hit=False, ) - excluded_headers = {"transfer-encoding", "content-encoding"} - headers = { - key: value - for key, value in response.headers.items() - if key.lower() not in excluded_headers - } - return Response( content=content, status_code=response.status_code, - headers=headers, + headers=get_response_headers(response.headers), ) except Exception as e: verbose_proxy_logger.exception( @@ -462,6 +504,9 @@ def create_pass_through_route( fastapi_response: Response, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), query_params: Optional[dict] = None, + stream: Optional[ + bool + ] = None, # if pass-through endpoint is a streaming request ): return await pass_through_request( # type: ignore request=request, @@ -470,6 +515,7 @@ def create_pass_through_route( user_api_key_dict=user_api_key_dict, forward_headers=_forward_headers, query_params=query_params, + stream=stream, ) return endpoint_func diff --git a/litellm/proxy/vertex_ai_endpoints/google_ai_studio_endpoints.py b/litellm/proxy/vertex_ai_endpoints/google_ai_studio_endpoints.py index 3b6105b447..9eb0f9bd1f 100644 --- a/litellm/proxy/vertex_ai_endpoints/google_ai_studio_endpoints.py +++ b/litellm/proxy/vertex_ai_endpoints/google_ai_studio_endpoints.py @@ -75,10 +75,22 @@ async def gemini_proxy_route( merged_params = dict(request.query_params) merged_params.update({"key": gemini_api_key}) + ## check for streaming + is_streaming_request = False + if "stream" in str(updated_url): + is_streaming_request = True + + ## CREATE PASS-THROUGH endpoint_func = create_pass_through_route( endpoint=endpoint, target=str(updated_url), ) # dynamically construct pass-through endpoint based on incoming path - return await endpoint_func( - request, fastapi_response, user_api_key_dict, query_params=merged_params + received_value = await endpoint_func( + request, + fastapi_response, + user_api_key_dict, + query_params=merged_params, + stream=is_streaming_request, ) + + return received_value