fix(router.py): support retry and fallbacks for atext_completion

This commit is contained in:
Krrish Dholakia 2023-12-30 11:19:13 +05:30
parent 7ecd7b3e8d
commit 38f55249e1
6 changed files with 290 additions and 69 deletions

View file

@ -86,6 +86,7 @@ from fastapi import (
Depends,
BackgroundTasks,
Header,
Response,
)
from fastapi.routing import APIRouter
from fastapi.security import OAuth2PasswordBearer
@ -1068,6 +1069,7 @@ def model_list():
)
async def completion(
request: Request,
fastapi_response: Response,
model: Optional[str] = None,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
background_tasks: BackgroundTasks = BackgroundTasks(),
@ -1143,17 +1145,23 @@ async def completion(
else: # router is not set
response = await litellm.atext_completion(**data)
model_id = response._hidden_params.get("model_id", None) or ""
print(f"final response: {response}")
if (
"stream" in data and data["stream"] == True
): # use generate_responses to stream responses
custom_headers = {"x-litellm-model-id": model_id}
return StreamingResponse(
async_data_generator(
user_api_key_dict=user_api_key_dict, response=response
user_api_key_dict=user_api_key_dict,
response=response,
),
media_type="text/event-stream",
headers=custom_headers,
)
fastapi_response.headers["x-litellm-model-id"] = model_id
return response
except Exception as e:
print(f"EXCEPTION RAISED IN PROXY MAIN.PY")
@ -1187,6 +1195,7 @@ async def completion(
) # azure compatible endpoint
async def chat_completion(
request: Request,
fastapi_response: Response,
model: Optional[str] = None,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
background_tasks: BackgroundTasks = BackgroundTasks(),
@ -1282,19 +1291,24 @@ async def chat_completion(
else: # router is not set
response = await litellm.acompletion(**data)
print(f"final response: {response}")
model_id = response._hidden_params.get("model_id", None) or ""
if (
"stream" in data and data["stream"] == True
): # use generate_responses to stream responses
custom_headers = {"x-litellm-model-id": model_id}
return StreamingResponse(
async_data_generator(
user_api_key_dict=user_api_key_dict, response=response
user_api_key_dict=user_api_key_dict,
response=response,
),
media_type="text/event-stream",
headers=custom_headers,
)
fastapi_response.headers["x-litellm-model-id"] = model_id
return response
except Exception as e:
traceback.print_exc()
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, original_exception=e
)