forked from phoenix/litellm-mirror
Merge pull request #3640 from BerriAI/litellm_fix_client_side_disconnecting_reqs
[Feat] Proxy - cancel tasks when fast api request is cancelled
This commit is contained in:
commit
aaea02dee8
1 changed files with 50 additions and 8 deletions
|
@ -351,6 +351,32 @@ def _get_pydantic_json_dict(pydantic_obj: BaseModel) -> dict:
|
|||
return pydantic_obj.dict()
|
||||
|
||||
|
||||
async def check_request_disconnection(request: Request, llm_api_call_task):
|
||||
"""
|
||||
Asynchronously checks if the request is disconnected at regular intervals.
|
||||
If the request is disconnected
|
||||
- cancel the litellm.router task
|
||||
- raises an HTTPException with status code 499 and detail "Client disconnected the request".
|
||||
|
||||
Parameters:
|
||||
- request: Request: The request object to check for disconnection.
|
||||
Returns:
|
||||
- None
|
||||
"""
|
||||
while True:
|
||||
await asyncio.sleep(1)
|
||||
if await request.is_disconnected():
|
||||
|
||||
# cancel the LLM API Call task if any passed - this is passed from individual providers
|
||||
# Example OpenAI, Azure, VertexAI etc
|
||||
llm_api_call_task.cancel()
|
||||
|
||||
raise HTTPException(
|
||||
status_code=499,
|
||||
detail="Client disconnected the request",
|
||||
)
|
||||
|
||||
|
||||
async def user_api_key_auth(
|
||||
request: Request, api_key: str = fastapi.Security(api_key_header)
|
||||
) -> UserAPIKeyAuth:
|
||||
|
@ -3768,9 +3794,15 @@ async def chat_completion(
|
|||
)
|
||||
|
||||
# wait for call to end
|
||||
responses = await asyncio.gather(
|
||||
llm_responses = asyncio.gather(
|
||||
*tasks
|
||||
) # run the moderation check in parallel to the actual llm api call
|
||||
|
||||
check_request_disconnected = asyncio.create_task(
|
||||
check_request_disconnection(request, llm_responses)
|
||||
)
|
||||
responses = await llm_responses
|
||||
|
||||
response = responses[1]
|
||||
|
||||
hidden_params = getattr(response, "_hidden_params", {}) or {}
|
||||
|
@ -3845,6 +3877,8 @@ async def chat_completion(
|
|||
param=getattr(e, "param", "None"),
|
||||
code=getattr(e, "status_code", 500),
|
||||
)
|
||||
finally:
|
||||
check_request_disconnected.cancel()
|
||||
|
||||
|
||||
@router.post(
|
||||
|
@ -3933,31 +3967,31 @@ async def completion(
|
|||
router_model_names = llm_router.model_names if llm_router is not None else []
|
||||
# skip router if user passed their key
|
||||
if "api_key" in data:
|
||||
response = await litellm.atext_completion(**data)
|
||||
llm_response = asyncio.create_task(litellm.atext_completion(**data))
|
||||
elif (
|
||||
llm_router is not None and data["model"] in router_model_names
|
||||
): # model in router model list
|
||||
response = await llm_router.atext_completion(**data)
|
||||
llm_response = asyncio.create_task(llm_router.atext_completion(**data))
|
||||
elif (
|
||||
llm_router is not None
|
||||
and llm_router.model_group_alias is not None
|
||||
and data["model"] in llm_router.model_group_alias
|
||||
): # model set in model_group_alias
|
||||
response = await llm_router.atext_completion(**data)
|
||||
llm_response = asyncio.create_task(llm_router.atext_completion(**data))
|
||||
elif (
|
||||
llm_router is not None and data["model"] in llm_router.deployment_names
|
||||
): # model in router deployments, calling a specific deployment on the router
|
||||
response = await llm_router.atext_completion(
|
||||
**data, specific_deployment=True
|
||||
llm_response = asyncio.create_task(
|
||||
llm_router.atext_completion(**data, specific_deployment=True)
|
||||
)
|
||||
elif (
|
||||
llm_router is not None
|
||||
and data["model"] not in router_model_names
|
||||
and llm_router.default_deployment is not None
|
||||
): # model in router deployments, calling a specific deployment on the router
|
||||
response = await llm_router.atext_completion(**data)
|
||||
llm_response = asyncio.create_task(llm_router.atext_completion(**data))
|
||||
elif user_model is not None: # `litellm --model <your-model-name>`
|
||||
response = await litellm.atext_completion(**data)
|
||||
llm_response = asyncio.create_task(litellm.atext_completion(**data))
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
|
@ -3966,6 +4000,12 @@ async def completion(
|
|||
+ data.get("model", "")
|
||||
},
|
||||
)
|
||||
check_request_disconnected = asyncio.create_task(
|
||||
check_request_disconnection(request, llm_response)
|
||||
)
|
||||
|
||||
# Await the llm_response task
|
||||
response = await llm_response
|
||||
|
||||
hidden_params = getattr(response, "_hidden_params", {}) or {}
|
||||
model_id = hidden_params.get("model_id", None) or ""
|
||||
|
@ -4016,6 +4056,8 @@ async def completion(
|
|||
param=getattr(e, "param", "None"),
|
||||
code=getattr(e, "status_code", 500),
|
||||
)
|
||||
finally:
|
||||
check_request_disconnected.cancel()
|
||||
|
||||
|
||||
@router.post(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue