diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 720033419..ccf8e32e1 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -351,6 +351,32 @@ def _get_pydantic_json_dict(pydantic_obj: BaseModel) -> dict: return pydantic_obj.dict() +async def check_request_disconnection(request: Request, llm_api_call_task): + """ + Asynchronously checks if the request is disconnected at regular intervals. + If the request is disconnected + - cancel the litellm.router task + - raises an HTTPException with status code 499 and detail "Client disconnected the request". + + Parameters: + - request: Request: The request object to check for disconnection. + Returns: + - None + """ + while True: + await asyncio.sleep(1) + if await request.is_disconnected(): + + # cancel the LLM API Call task if any passed - this is passed from individual providers + # Example OpenAI, Azure, VertexAI etc + llm_api_call_task.cancel() + + raise HTTPException( + status_code=499, + detail="Client disconnected the request", + ) + + async def user_api_key_auth( request: Request, api_key: str = fastapi.Security(api_key_header) ) -> UserAPIKeyAuth: @@ -3768,9 +3794,15 @@ async def chat_completion( ) # wait for call to end - responses = await asyncio.gather( + llm_responses = asyncio.gather( *tasks ) # run the moderation check in parallel to the actual llm api call + + check_request_disconnected = asyncio.create_task( + check_request_disconnection(request, llm_responses) + ) + responses = await llm_responses + response = responses[1] hidden_params = getattr(response, "_hidden_params", {}) or {} @@ -3845,6 +3877,8 @@ async def chat_completion( param=getattr(e, "param", "None"), code=getattr(e, "status_code", 500), ) + finally: + check_request_disconnected.cancel() @router.post( @@ -3933,31 +3967,31 @@ async def completion( router_model_names = llm_router.model_names if llm_router is not None else [] # skip router if user passed their key if "api_key" in data: - response = await litellm.atext_completion(**data) + llm_response = asyncio.create_task(litellm.atext_completion(**data)) elif ( llm_router is not None and data["model"] in router_model_names ): # model in router model list - response = await llm_router.atext_completion(**data) + llm_response = asyncio.create_task(llm_router.atext_completion(**data)) elif ( llm_router is not None and llm_router.model_group_alias is not None and data["model"] in llm_router.model_group_alias ): # model set in model_group_alias - response = await llm_router.atext_completion(**data) + llm_response = asyncio.create_task(llm_router.atext_completion(**data)) elif ( llm_router is not None and data["model"] in llm_router.deployment_names ): # model in router deployments, calling a specific deployment on the router - response = await llm_router.atext_completion( - **data, specific_deployment=True + llm_response = asyncio.create_task( + llm_router.atext_completion(**data, specific_deployment=True) ) elif ( llm_router is not None and data["model"] not in router_model_names and llm_router.default_deployment is not None ): # model in router deployments, calling a specific deployment on the router - response = await llm_router.atext_completion(**data) + llm_response = asyncio.create_task(llm_router.atext_completion(**data)) elif user_model is not None: # `litellm --model ` - response = await litellm.atext_completion(**data) + llm_response = asyncio.create_task(litellm.atext_completion(**data)) else: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -3966,6 +4000,12 @@ async def completion( + data.get("model", "") }, ) + check_request_disconnected = asyncio.create_task( + check_request_disconnection(request, llm_response) + ) + + # Await the llm_response task + response = await llm_response hidden_params = getattr(response, "_hidden_params", {}) or {} model_id = hidden_params.get("model_id", None) or "" @@ -4016,6 +4056,8 @@ async def completion( param=getattr(e, "param", "None"), code=getattr(e, "status_code", 500), ) + finally: + check_request_disconnected.cancel() @router.post(