feat(pass_through_endpoints.py): support streaming requests

This commit is contained in:
Krrish Dholakia 2024-08-17 12:46:57 -07:00
parent 29bedae79f
commit b2ffa564d1
2 changed files with 70 additions and 12 deletions

View file

@ -3,7 +3,7 @@ import asyncio
import json
import traceback
from base64 import b64encode
from typing import List, Optional
from typing import AsyncIterable, List, Optional
import httpx
from fastapi import (
@ -267,6 +267,17 @@ def forward_headers_from_request(
return headers
def get_response_headers(headers: httpx.Headers) -> dict:
excluded_headers = {"transfer-encoding", "content-encoding"}
return_headers = {
key: value
for key, value in headers.items()
if key.lower() not in excluded_headers
}
return return_headers
async def pass_through_request(
request: Request,
target: str,
@ -274,6 +285,7 @@ async def pass_through_request(
user_api_key_dict: UserAPIKeyAuth,
forward_headers: Optional[bool] = False,
query_params: Optional[dict] = None,
stream: Optional[bool] = None,
):
try:
import time
@ -292,7 +304,7 @@ async def pass_through_request(
body_str = request_body.decode()
try:
_parsed_body = ast.literal_eval(body_str)
except:
except Exception:
_parsed_body = json.loads(body_str)
verbose_proxy_logger.debug(
@ -364,6 +376,27 @@ async def pass_through_request(
},
)
if stream:
req = async_client.build_request(
"POST",
url,
json=_parsed_body,
params=requested_query_params,
headers=headers,
)
response = await async_client.send(req, stream=stream)
# Create an async generator to yield the response content
async def stream_response() -> AsyncIterable[bytes]:
async for chunk in response.aiter_bytes():
yield chunk
return StreamingResponse(
stream_response(),
headers=get_response_headers(response.headers),
status_code=response.status_code,
)
response = await async_client.request(
method=request.method,
url=url,
@ -372,6 +405,22 @@ async def pass_through_request(
json=_parsed_body,
)
if (
response.headers.get("content-type") is not None
and response.headers["content-type"] == "text/event-stream"
):
# streaming response
# Create an async generator to yield the response content
async def stream_response() -> AsyncIterable[bytes]:
async for chunk in response.aiter_bytes():
yield chunk
return StreamingResponse(
stream_response(),
headers=get_response_headers(response.headers),
status_code=response.status_code,
)
if response.status_code >= 300:
raise HTTPException(status_code=response.status_code, detail=response.text)
@ -387,17 +436,10 @@ async def pass_through_request(
cache_hit=False,
)
excluded_headers = {"transfer-encoding", "content-encoding"}
headers = {
key: value
for key, value in response.headers.items()
if key.lower() not in excluded_headers
}
return Response(
content=content,
status_code=response.status_code,
headers=headers,
headers=get_response_headers(response.headers),
)
except Exception as e:
verbose_proxy_logger.exception(
@ -462,6 +504,9 @@ def create_pass_through_route(
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
query_params: Optional[dict] = None,
stream: Optional[
bool
] = None, # if pass-through endpoint is a streaming request
):
return await pass_through_request( # type: ignore
request=request,
@ -470,6 +515,7 @@ def create_pass_through_route(
user_api_key_dict=user_api_key_dict,
forward_headers=_forward_headers,
query_params=query_params,
stream=stream,
)
return endpoint_func