mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
feat(pass_through_endpoints.py): support streaming requests
This commit is contained in:
parent
bc0023a409
commit
fd44cf8d26
2 changed files with 70 additions and 12 deletions
|
@ -3,7 +3,7 @@ import asyncio
|
||||||
import json
|
import json
|
||||||
import traceback
|
import traceback
|
||||||
from base64 import b64encode
|
from base64 import b64encode
|
||||||
from typing import List, Optional
|
from typing import AsyncIterable, List, Optional
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from fastapi import (
|
from fastapi import (
|
||||||
|
@ -267,6 +267,17 @@ def forward_headers_from_request(
|
||||||
return headers
|
return headers
|
||||||
|
|
||||||
|
|
||||||
|
def get_response_headers(headers: httpx.Headers) -> dict:
|
||||||
|
excluded_headers = {"transfer-encoding", "content-encoding"}
|
||||||
|
return_headers = {
|
||||||
|
key: value
|
||||||
|
for key, value in headers.items()
|
||||||
|
if key.lower() not in excluded_headers
|
||||||
|
}
|
||||||
|
|
||||||
|
return return_headers
|
||||||
|
|
||||||
|
|
||||||
async def pass_through_request(
|
async def pass_through_request(
|
||||||
request: Request,
|
request: Request,
|
||||||
target: str,
|
target: str,
|
||||||
|
@ -274,6 +285,7 @@ async def pass_through_request(
|
||||||
user_api_key_dict: UserAPIKeyAuth,
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
forward_headers: Optional[bool] = False,
|
forward_headers: Optional[bool] = False,
|
||||||
query_params: Optional[dict] = None,
|
query_params: Optional[dict] = None,
|
||||||
|
stream: Optional[bool] = None,
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
import time
|
import time
|
||||||
|
@ -292,7 +304,7 @@ async def pass_through_request(
|
||||||
body_str = request_body.decode()
|
body_str = request_body.decode()
|
||||||
try:
|
try:
|
||||||
_parsed_body = ast.literal_eval(body_str)
|
_parsed_body = ast.literal_eval(body_str)
|
||||||
except:
|
except Exception:
|
||||||
_parsed_body = json.loads(body_str)
|
_parsed_body = json.loads(body_str)
|
||||||
|
|
||||||
verbose_proxy_logger.debug(
|
verbose_proxy_logger.debug(
|
||||||
|
@ -364,6 +376,27 @@ async def pass_through_request(
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
req = async_client.build_request(
|
||||||
|
"POST",
|
||||||
|
url,
|
||||||
|
json=_parsed_body,
|
||||||
|
params=requested_query_params,
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
response = await async_client.send(req, stream=stream)
|
||||||
|
|
||||||
|
# Create an async generator to yield the response content
|
||||||
|
async def stream_response() -> AsyncIterable[bytes]:
|
||||||
|
async for chunk in response.aiter_bytes():
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
stream_response(),
|
||||||
|
headers=get_response_headers(response.headers),
|
||||||
|
status_code=response.status_code,
|
||||||
|
)
|
||||||
|
|
||||||
response = await async_client.request(
|
response = await async_client.request(
|
||||||
method=request.method,
|
method=request.method,
|
||||||
url=url,
|
url=url,
|
||||||
|
@ -372,6 +405,22 @@ async def pass_through_request(
|
||||||
json=_parsed_body,
|
json=_parsed_body,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
response.headers.get("content-type") is not None
|
||||||
|
and response.headers["content-type"] == "text/event-stream"
|
||||||
|
):
|
||||||
|
# streaming response
|
||||||
|
# Create an async generator to yield the response content
|
||||||
|
async def stream_response() -> AsyncIterable[bytes]:
|
||||||
|
async for chunk in response.aiter_bytes():
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
stream_response(),
|
||||||
|
headers=get_response_headers(response.headers),
|
||||||
|
status_code=response.status_code,
|
||||||
|
)
|
||||||
|
|
||||||
if response.status_code >= 300:
|
if response.status_code >= 300:
|
||||||
raise HTTPException(status_code=response.status_code, detail=response.text)
|
raise HTTPException(status_code=response.status_code, detail=response.text)
|
||||||
|
|
||||||
|
@ -387,17 +436,10 @@ async def pass_through_request(
|
||||||
cache_hit=False,
|
cache_hit=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
excluded_headers = {"transfer-encoding", "content-encoding"}
|
|
||||||
headers = {
|
|
||||||
key: value
|
|
||||||
for key, value in response.headers.items()
|
|
||||||
if key.lower() not in excluded_headers
|
|
||||||
}
|
|
||||||
|
|
||||||
return Response(
|
return Response(
|
||||||
content=content,
|
content=content,
|
||||||
status_code=response.status_code,
|
status_code=response.status_code,
|
||||||
headers=headers,
|
headers=get_response_headers(response.headers),
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
verbose_proxy_logger.exception(
|
verbose_proxy_logger.exception(
|
||||||
|
@ -462,6 +504,9 @@ def create_pass_through_route(
|
||||||
fastapi_response: Response,
|
fastapi_response: Response,
|
||||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||||
query_params: Optional[dict] = None,
|
query_params: Optional[dict] = None,
|
||||||
|
stream: Optional[
|
||||||
|
bool
|
||||||
|
] = None, # if pass-through endpoint is a streaming request
|
||||||
):
|
):
|
||||||
return await pass_through_request( # type: ignore
|
return await pass_through_request( # type: ignore
|
||||||
request=request,
|
request=request,
|
||||||
|
@ -470,6 +515,7 @@ def create_pass_through_route(
|
||||||
user_api_key_dict=user_api_key_dict,
|
user_api_key_dict=user_api_key_dict,
|
||||||
forward_headers=_forward_headers,
|
forward_headers=_forward_headers,
|
||||||
query_params=query_params,
|
query_params=query_params,
|
||||||
|
stream=stream,
|
||||||
)
|
)
|
||||||
|
|
||||||
return endpoint_func
|
return endpoint_func
|
||||||
|
|
|
@ -75,10 +75,22 @@ async def gemini_proxy_route(
|
||||||
merged_params = dict(request.query_params)
|
merged_params = dict(request.query_params)
|
||||||
merged_params.update({"key": gemini_api_key})
|
merged_params.update({"key": gemini_api_key})
|
||||||
|
|
||||||
|
## check for streaming
|
||||||
|
is_streaming_request = False
|
||||||
|
if "stream" in str(updated_url):
|
||||||
|
is_streaming_request = True
|
||||||
|
|
||||||
|
## CREATE PASS-THROUGH
|
||||||
endpoint_func = create_pass_through_route(
|
endpoint_func = create_pass_through_route(
|
||||||
endpoint=endpoint,
|
endpoint=endpoint,
|
||||||
target=str(updated_url),
|
target=str(updated_url),
|
||||||
) # dynamically construct pass-through endpoint based on incoming path
|
) # dynamically construct pass-through endpoint based on incoming path
|
||||||
return await endpoint_func(
|
received_value = await endpoint_func(
|
||||||
request, fastapi_response, user_api_key_dict, query_params=merged_params
|
request,
|
||||||
|
fastapi_response,
|
||||||
|
user_api_key_dict,
|
||||||
|
query_params=merged_params,
|
||||||
|
stream=is_streaming_request,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return received_value
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue