feat(pass_through_endpoints.py): support streaming requests

This commit is contained in:
Krrish Dholakia 2024-08-17 12:46:57 -07:00
parent bc0023a409
commit fd44cf8d26
2 changed files with 70 additions and 12 deletions

View file

@ -3,7 +3,7 @@ import asyncio
import json import json
import traceback import traceback
from base64 import b64encode from base64 import b64encode
from typing import List, Optional from typing import AsyncIterable, List, Optional
import httpx import httpx
from fastapi import ( from fastapi import (
@ -267,6 +267,17 @@ def forward_headers_from_request(
return headers return headers
def get_response_headers(headers: httpx.Headers) -> dict:
excluded_headers = {"transfer-encoding", "content-encoding"}
return_headers = {
key: value
for key, value in headers.items()
if key.lower() not in excluded_headers
}
return return_headers
async def pass_through_request( async def pass_through_request(
request: Request, request: Request,
target: str, target: str,
@ -274,6 +285,7 @@ async def pass_through_request(
user_api_key_dict: UserAPIKeyAuth, user_api_key_dict: UserAPIKeyAuth,
forward_headers: Optional[bool] = False, forward_headers: Optional[bool] = False,
query_params: Optional[dict] = None, query_params: Optional[dict] = None,
stream: Optional[bool] = None,
): ):
try: try:
import time import time
@ -292,7 +304,7 @@ async def pass_through_request(
body_str = request_body.decode() body_str = request_body.decode()
try: try:
_parsed_body = ast.literal_eval(body_str) _parsed_body = ast.literal_eval(body_str)
except: except Exception:
_parsed_body = json.loads(body_str) _parsed_body = json.loads(body_str)
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
@ -364,6 +376,27 @@ async def pass_through_request(
}, },
) )
if stream:
req = async_client.build_request(
"POST",
url,
json=_parsed_body,
params=requested_query_params,
headers=headers,
)
response = await async_client.send(req, stream=stream)
# Create an async generator to yield the response content
async def stream_response() -> AsyncIterable[bytes]:
async for chunk in response.aiter_bytes():
yield chunk
return StreamingResponse(
stream_response(),
headers=get_response_headers(response.headers),
status_code=response.status_code,
)
response = await async_client.request( response = await async_client.request(
method=request.method, method=request.method,
url=url, url=url,
@ -372,6 +405,22 @@ async def pass_through_request(
json=_parsed_body, json=_parsed_body,
) )
if (
response.headers.get("content-type") is not None
and response.headers["content-type"] == "text/event-stream"
):
# streaming response
# Create an async generator to yield the response content
async def stream_response() -> AsyncIterable[bytes]:
async for chunk in response.aiter_bytes():
yield chunk
return StreamingResponse(
stream_response(),
headers=get_response_headers(response.headers),
status_code=response.status_code,
)
if response.status_code >= 300: if response.status_code >= 300:
raise HTTPException(status_code=response.status_code, detail=response.text) raise HTTPException(status_code=response.status_code, detail=response.text)
@ -387,17 +436,10 @@ async def pass_through_request(
cache_hit=False, cache_hit=False,
) )
excluded_headers = {"transfer-encoding", "content-encoding"}
headers = {
key: value
for key, value in response.headers.items()
if key.lower() not in excluded_headers
}
return Response( return Response(
content=content, content=content,
status_code=response.status_code, status_code=response.status_code,
headers=headers, headers=get_response_headers(response.headers),
) )
except Exception as e: except Exception as e:
verbose_proxy_logger.exception( verbose_proxy_logger.exception(
@ -462,6 +504,9 @@ def create_pass_through_route(
fastapi_response: Response, fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
query_params: Optional[dict] = None, query_params: Optional[dict] = None,
stream: Optional[
bool
] = None, # if pass-through endpoint is a streaming request
): ):
return await pass_through_request( # type: ignore return await pass_through_request( # type: ignore
request=request, request=request,
@ -470,6 +515,7 @@ def create_pass_through_route(
user_api_key_dict=user_api_key_dict, user_api_key_dict=user_api_key_dict,
forward_headers=_forward_headers, forward_headers=_forward_headers,
query_params=query_params, query_params=query_params,
stream=stream,
) )
return endpoint_func return endpoint_func

View file

@ -75,10 +75,22 @@ async def gemini_proxy_route(
merged_params = dict(request.query_params) merged_params = dict(request.query_params)
merged_params.update({"key": gemini_api_key}) merged_params.update({"key": gemini_api_key})
## check for streaming
is_streaming_request = False
if "stream" in str(updated_url):
is_streaming_request = True
## CREATE PASS-THROUGH
endpoint_func = create_pass_through_route( endpoint_func = create_pass_through_route(
endpoint=endpoint, endpoint=endpoint,
target=str(updated_url), target=str(updated_url),
) # dynamically construct pass-through endpoint based on incoming path ) # dynamically construct pass-through endpoint based on incoming path
return await endpoint_func( received_value = await endpoint_func(
request, fastapi_response, user_api_key_dict, query_params=merged_params request,
fastapi_response,
user_api_key_dict,
query_params=merged_params,
stream=is_streaming_request,
) )
return received_value