Merge pull request #3640 from BerriAI/litellm_fix_client_side_disconnecting_reqs

[Feat] Proxy - cancel tasks when fast api request is cancelled
This commit is contained in:
Ishaan Jaff 2024-05-14 20:14:42 -07:00 committed by GitHub
commit aaea02dee8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -351,6 +351,32 @@ def _get_pydantic_json_dict(pydantic_obj: BaseModel) -> dict:
return pydantic_obj.dict()
async def check_request_disconnection(request: Request, llm_api_call_task):
"""
Asynchronously checks if the request is disconnected at regular intervals.
If the request is disconnected
- cancel the litellm.router task
- raises an HTTPException with status code 499 and detail "Client disconnected the request".
Parameters:
- request: Request: The request object to check for disconnection.
Returns:
- None
"""
while True:
await asyncio.sleep(1)
if await request.is_disconnected():
# cancel the LLM API Call task if any passed - this is passed from individual providers
# Example OpenAI, Azure, VertexAI etc
llm_api_call_task.cancel()
raise HTTPException(
status_code=499,
detail="Client disconnected the request",
)
async def user_api_key_auth(
request: Request, api_key: str = fastapi.Security(api_key_header)
) -> UserAPIKeyAuth:
@ -3768,9 +3794,15 @@ async def chat_completion(
)
# wait for call to end
responses = await asyncio.gather(
llm_responses = asyncio.gather(
*tasks
) # run the moderation check in parallel to the actual llm api call
check_request_disconnected = asyncio.create_task(
check_request_disconnection(request, llm_responses)
)
responses = await llm_responses
response = responses[1]
hidden_params = getattr(response, "_hidden_params", {}) or {}
@ -3845,6 +3877,8 @@ async def chat_completion(
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", 500),
)
finally:
check_request_disconnected.cancel()
@router.post(
@ -3933,31 +3967,31 @@ async def completion(
router_model_names = llm_router.model_names if llm_router is not None else []
# skip router if user passed their key
if "api_key" in data:
response = await litellm.atext_completion(**data)
llm_response = asyncio.create_task(litellm.atext_completion(**data))
elif (
llm_router is not None and data["model"] in router_model_names
): # model in router model list
response = await llm_router.atext_completion(**data)
llm_response = asyncio.create_task(llm_router.atext_completion(**data))
elif (
llm_router is not None
and llm_router.model_group_alias is not None
and data["model"] in llm_router.model_group_alias
): # model set in model_group_alias
response = await llm_router.atext_completion(**data)
llm_response = asyncio.create_task(llm_router.atext_completion(**data))
elif (
llm_router is not None and data["model"] in llm_router.deployment_names
): # model in router deployments, calling a specific deployment on the router
response = await llm_router.atext_completion(
**data, specific_deployment=True
llm_response = asyncio.create_task(
llm_router.atext_completion(**data, specific_deployment=True)
)
elif (
llm_router is not None
and data["model"] not in router_model_names
and llm_router.default_deployment is not None
): # model in router deployments, calling a specific deployment on the router
response = await llm_router.atext_completion(**data)
llm_response = asyncio.create_task(llm_router.atext_completion(**data))
elif user_model is not None: # `litellm --model <your-model-name>`
response = await litellm.atext_completion(**data)
llm_response = asyncio.create_task(litellm.atext_completion(**data))
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
@ -3966,6 +4000,12 @@ async def completion(
+ data.get("model", "")
},
)
check_request_disconnected = asyncio.create_task(
check_request_disconnection(request, llm_response)
)
# Await the llm_response task
response = await llm_response
hidden_params = getattr(response, "_hidden_params", {}) or {}
model_id = hidden_params.get("model_id", None) or ""
@ -4016,6 +4056,8 @@ async def completion(
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", 500),
)
finally:
check_request_disconnected.cancel()
@router.post(