Merge pull request #3640 from BerriAI/litellm_fix_client_side_disconnecting_reqs

[Feat] Proxy - cancel tasks when fast api request is cancelled
This commit is contained in:
Ishaan Jaff 2024-05-14 20:14:42 -07:00 committed by GitHub
commit aaea02dee8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -351,6 +351,32 @@ def _get_pydantic_json_dict(pydantic_obj: BaseModel) -> dict:
return pydantic_obj.dict() return pydantic_obj.dict()
async def check_request_disconnection(request: Request, llm_api_call_task):
"""
Asynchronously checks if the request is disconnected at regular intervals.
If the request is disconnected
- cancel the litellm.router task
- raises an HTTPException with status code 499 and detail "Client disconnected the request".
Parameters:
- request: Request: The request object to check for disconnection.
Returns:
- None
"""
while True:
await asyncio.sleep(1)
if await request.is_disconnected():
# cancel the LLM API Call task if any passed - this is passed from individual providers
# Example OpenAI, Azure, VertexAI etc
llm_api_call_task.cancel()
raise HTTPException(
status_code=499,
detail="Client disconnected the request",
)
async def user_api_key_auth( async def user_api_key_auth(
request: Request, api_key: str = fastapi.Security(api_key_header) request: Request, api_key: str = fastapi.Security(api_key_header)
) -> UserAPIKeyAuth: ) -> UserAPIKeyAuth:
@ -3768,9 +3794,15 @@ async def chat_completion(
) )
# wait for call to end # wait for call to end
responses = await asyncio.gather( llm_responses = asyncio.gather(
*tasks *tasks
) # run the moderation check in parallel to the actual llm api call ) # run the moderation check in parallel to the actual llm api call
check_request_disconnected = asyncio.create_task(
check_request_disconnection(request, llm_responses)
)
responses = await llm_responses
response = responses[1] response = responses[1]
hidden_params = getattr(response, "_hidden_params", {}) or {} hidden_params = getattr(response, "_hidden_params", {}) or {}
@ -3845,6 +3877,8 @@ async def chat_completion(
param=getattr(e, "param", "None"), param=getattr(e, "param", "None"),
code=getattr(e, "status_code", 500), code=getattr(e, "status_code", 500),
) )
finally:
check_request_disconnected.cancel()
@router.post( @router.post(
@ -3933,31 +3967,31 @@ async def completion(
router_model_names = llm_router.model_names if llm_router is not None else [] router_model_names = llm_router.model_names if llm_router is not None else []
# skip router if user passed their key # skip router if user passed their key
if "api_key" in data: if "api_key" in data:
response = await litellm.atext_completion(**data) llm_response = asyncio.create_task(litellm.atext_completion(**data))
elif ( elif (
llm_router is not None and data["model"] in router_model_names llm_router is not None and data["model"] in router_model_names
): # model in router model list ): # model in router model list
response = await llm_router.atext_completion(**data) llm_response = asyncio.create_task(llm_router.atext_completion(**data))
elif ( elif (
llm_router is not None llm_router is not None
and llm_router.model_group_alias is not None and llm_router.model_group_alias is not None
and data["model"] in llm_router.model_group_alias and data["model"] in llm_router.model_group_alias
): # model set in model_group_alias ): # model set in model_group_alias
response = await llm_router.atext_completion(**data) llm_response = asyncio.create_task(llm_router.atext_completion(**data))
elif ( elif (
llm_router is not None and data["model"] in llm_router.deployment_names llm_router is not None and data["model"] in llm_router.deployment_names
): # model in router deployments, calling a specific deployment on the router ): # model in router deployments, calling a specific deployment on the router
response = await llm_router.atext_completion( llm_response = asyncio.create_task(
**data, specific_deployment=True llm_router.atext_completion(**data, specific_deployment=True)
) )
elif ( elif (
llm_router is not None llm_router is not None
and data["model"] not in router_model_names and data["model"] not in router_model_names
and llm_router.default_deployment is not None and llm_router.default_deployment is not None
): # model in router deployments, calling a specific deployment on the router ): # model in router deployments, calling a specific deployment on the router
response = await llm_router.atext_completion(**data) llm_response = asyncio.create_task(llm_router.atext_completion(**data))
elif user_model is not None: # `litellm --model <your-model-name>` elif user_model is not None: # `litellm --model <your-model-name>`
response = await litellm.atext_completion(**data) llm_response = asyncio.create_task(litellm.atext_completion(**data))
else: else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
@ -3966,6 +4000,12 @@ async def completion(
+ data.get("model", "") + data.get("model", "")
}, },
) )
check_request_disconnected = asyncio.create_task(
check_request_disconnection(request, llm_response)
)
# Await the llm_response task
response = await llm_response
hidden_params = getattr(response, "_hidden_params", {}) or {} hidden_params = getattr(response, "_hidden_params", {}) or {}
model_id = hidden_params.get("model_id", None) or "" model_id = hidden_params.get("model_id", None) or ""
@ -4016,6 +4056,8 @@ async def completion(
param=getattr(e, "param", "None"), param=getattr(e, "param", "None"),
code=getattr(e, "status_code", 500), code=getattr(e, "status_code", 500),
) )
finally:
check_request_disconnected.cancel()
@router.post( @router.post(