mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
feat(pass_through_endpoints.py): support streaming requests
This commit is contained in:
parent
bc0023a409
commit
fd44cf8d26
2 changed files with 70 additions and 12 deletions
|
@ -3,7 +3,7 @@ import asyncio
|
|||
import json
|
||||
import traceback
|
||||
from base64 import b64encode
|
||||
from typing import List, Optional
|
||||
from typing import AsyncIterable, List, Optional
|
||||
|
||||
import httpx
|
||||
from fastapi import (
|
||||
|
@ -267,6 +267,17 @@ def forward_headers_from_request(
|
|||
return headers
|
||||
|
||||
|
||||
def get_response_headers(headers: httpx.Headers) -> dict:
|
||||
excluded_headers = {"transfer-encoding", "content-encoding"}
|
||||
return_headers = {
|
||||
key: value
|
||||
for key, value in headers.items()
|
||||
if key.lower() not in excluded_headers
|
||||
}
|
||||
|
||||
return return_headers
|
||||
|
||||
|
||||
async def pass_through_request(
|
||||
request: Request,
|
||||
target: str,
|
||||
|
@ -274,6 +285,7 @@ async def pass_through_request(
|
|||
user_api_key_dict: UserAPIKeyAuth,
|
||||
forward_headers: Optional[bool] = False,
|
||||
query_params: Optional[dict] = None,
|
||||
stream: Optional[bool] = None,
|
||||
):
|
||||
try:
|
||||
import time
|
||||
|
@ -292,7 +304,7 @@ async def pass_through_request(
|
|||
body_str = request_body.decode()
|
||||
try:
|
||||
_parsed_body = ast.literal_eval(body_str)
|
||||
except:
|
||||
except Exception:
|
||||
_parsed_body = json.loads(body_str)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
|
@ -364,6 +376,27 @@ async def pass_through_request(
|
|||
},
|
||||
)
|
||||
|
||||
if stream:
|
||||
req = async_client.build_request(
|
||||
"POST",
|
||||
url,
|
||||
json=_parsed_body,
|
||||
params=requested_query_params,
|
||||
headers=headers,
|
||||
)
|
||||
response = await async_client.send(req, stream=stream)
|
||||
|
||||
# Create an async generator to yield the response content
|
||||
async def stream_response() -> AsyncIterable[bytes]:
|
||||
async for chunk in response.aiter_bytes():
|
||||
yield chunk
|
||||
|
||||
return StreamingResponse(
|
||||
stream_response(),
|
||||
headers=get_response_headers(response.headers),
|
||||
status_code=response.status_code,
|
||||
)
|
||||
|
||||
response = await async_client.request(
|
||||
method=request.method,
|
||||
url=url,
|
||||
|
@ -372,6 +405,22 @@ async def pass_through_request(
|
|||
json=_parsed_body,
|
||||
)
|
||||
|
||||
if (
|
||||
response.headers.get("content-type") is not None
|
||||
and response.headers["content-type"] == "text/event-stream"
|
||||
):
|
||||
# streaming response
|
||||
# Create an async generator to yield the response content
|
||||
async def stream_response() -> AsyncIterable[bytes]:
|
||||
async for chunk in response.aiter_bytes():
|
||||
yield chunk
|
||||
|
||||
return StreamingResponse(
|
||||
stream_response(),
|
||||
headers=get_response_headers(response.headers),
|
||||
status_code=response.status_code,
|
||||
)
|
||||
|
||||
if response.status_code >= 300:
|
||||
raise HTTPException(status_code=response.status_code, detail=response.text)
|
||||
|
||||
|
@ -387,17 +436,10 @@ async def pass_through_request(
|
|||
cache_hit=False,
|
||||
)
|
||||
|
||||
excluded_headers = {"transfer-encoding", "content-encoding"}
|
||||
headers = {
|
||||
key: value
|
||||
for key, value in response.headers.items()
|
||||
if key.lower() not in excluded_headers
|
||||
}
|
||||
|
||||
return Response(
|
||||
content=content,
|
||||
status_code=response.status_code,
|
||||
headers=headers,
|
||||
headers=get_response_headers(response.headers),
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
|
@ -462,6 +504,9 @@ def create_pass_through_route(
|
|||
fastapi_response: Response,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
query_params: Optional[dict] = None,
|
||||
stream: Optional[
|
||||
bool
|
||||
] = None, # if pass-through endpoint is a streaming request
|
||||
):
|
||||
return await pass_through_request( # type: ignore
|
||||
request=request,
|
||||
|
@ -470,6 +515,7 @@ def create_pass_through_route(
|
|||
user_api_key_dict=user_api_key_dict,
|
||||
forward_headers=_forward_headers,
|
||||
query_params=query_params,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
return endpoint_func
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue