use 1 helper to return stream_response on passthrough

This commit is contained in:
Ishaan Jaff 2024-11-20 15:49:33 -08:00
parent acf350a2fb
commit 97ecedf997
2 changed files with 33 additions and 22 deletions

View file

@ -303,9 +303,29 @@ def get_response_headers(headers: httpx.Headers) -> dict:
def get_endpoint_type(url: str) -> EndpointType:
if ("generateContent") in url or ("streamGenerateContent") in url:
return EndpointType.VERTEX_AI
elif ("api.anthropic.com") in url:
return EndpointType.ANTHROPIC
return EndpointType.GENERIC
async def stream_response(
response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
endpoint_type: EndpointType,
start_time: datetime,
url: str,
) -> AsyncIterable[bytes]:
async for chunk in chunk_processor(
response.aiter_bytes(),
litellm_logging_obj=logging_obj,
endpoint_type=endpoint_type,
start_time=start_time,
passthrough_success_handler_obj=pass_through_endpoint_logging,
url_route=str(url),
):
yield chunk
async def pass_through_request( # noqa: PLR0915
request: Request,
target: str,
@ -445,19 +465,14 @@ async def pass_through_request( # noqa: PLR0915
status_code=e.response.status_code, detail=await e.response.aread()
)
async def stream_response() -> AsyncIterable[bytes]:
async for chunk in chunk_processor(
response.aiter_bytes(),
litellm_logging_obj=logging_obj,
return StreamingResponse(
stream_response(
response=response,
logging_obj=logging_obj,
endpoint_type=endpoint_type,
start_time=start_time,
passthrough_success_handler_obj=pass_through_endpoint_logging,
url_route=str(url),
):
yield chunk
return StreamingResponse(
stream_response(),
url=str(url),
),
headers=get_response_headers(response.headers),
status_code=response.status_code,
)
@ -488,19 +503,14 @@ async def pass_through_request( # noqa: PLR0915
status_code=e.response.status_code, detail=await e.response.aread()
)
async def stream_response() -> AsyncIterable[bytes]:
async for chunk in chunk_processor(
response.aiter_bytes(),
litellm_logging_obj=logging_obj,
return StreamingResponse(
stream_response(
response=response,
logging_obj=logging_obj,
endpoint_type=endpoint_type,
start_time=start_time,
passthrough_success_handler_obj=pass_through_endpoint_logging,
url_route=str(url),
):
yield chunk
return StreamingResponse(
stream_response(),
url=str(url),
),
headers=get_response_headers(response.headers),
status_code=response.status_code,
)

View file

@ -4,6 +4,7 @@ from typing import Optional, TypedDict
class EndpointType(str, Enum):
VERTEX_AI = "vertex-ai"
ANTHROPIC = "anthropic"
GENERIC = "generic"