From cd98d256b5569b56477864f2704a63b805ca5248 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 3 Jan 2024 11:18:21 +0530 Subject: [PATCH] fix(proxy_server.py): add alerting for responses taking too long https://github.com/BerriAI/litellm/issues/1298 --- litellm/proxy/proxy_server.py | 55 ++++++++++++++++++++++++++++++++++- litellm/proxy/utils.py | 50 +++++++++++++++++++++++++------ 2 files changed, 95 insertions(+), 10 deletions(-) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 98de2cd1a0..d3862973ef 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -648,7 +648,10 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str): use_azure_key_vault = general_settings.get("use_azure_key_vault", False) load_from_azure_key_vault(use_azure_key_vault=use_azure_key_vault) ### ALERTING ### - proxy_logging_obj.update_values(alerting=general_settings.get("alerting", None)) + proxy_logging_obj.update_values( + alerting=general_settings.get("alerting", None), + alerting_threshold=general_settings.get("alerting_threshold", 600), + ) ### CONNECT TO DATABASE ### database_url = general_settings.get("database_url", None) if database_url and database_url.startswith("os.environ/"): @@ -927,6 +930,7 @@ def data_generator(response): async def async_data_generator(response, user_api_key_dict): print_verbose("inside generator") try: + start_time = time.time() async for chunk in response: print_verbose(f"returned chunk: {chunk}") try: @@ -934,6 +938,14 @@ async def async_data_generator(response, user_api_key_dict): except Exception as e: yield f"data: {str(e)}\n\n" + ### ALERTING ### + end_time = time.time() + asyncio.create_task( + proxy_logging_obj.response_taking_too_long( + start_time=start_time, end_time=end_time, type="slow_response" + ) + ) + # Streaming is done, yield the [DONE] chunk done_message = "[DONE]" yield f"data: {done_message}\n\n" @@ -1103,6 +1115,8 @@ async def completion( user_api_key_dict=user_api_key_dict, data=data, call_type="completion" ) + start_time = time.time() + ### ROUTE THE REQUESTs ### router_model_names = ( [m["model_name"] for m in llm_model_list] @@ -1150,6 +1164,14 @@ async def completion( headers=custom_headers, ) + ### ALERTING ### + end_time = time.time() + asyncio.create_task( + proxy_logging_obj.response_taking_too_long( + start_time=start_time, end_time=end_time, type="slow_response" + ) + ) + fastapi_response.headers["x-litellm-model-id"] = model_id return response except Exception as e: @@ -1254,6 +1276,8 @@ async def chat_completion( user_api_key_dict=user_api_key_dict, data=data, call_type="completion" ) + start_time = time.time() + ### ROUTE THE REQUEST ### router_model_names = ( [m["model_name"] for m in llm_model_list] @@ -1289,6 +1313,7 @@ async def chat_completion( model_id = response._hidden_params.get("model_id", None) or "" else: model_id = "" + if ( "stream" in data and data["stream"] == True ): # use generate_responses to stream responses @@ -1302,6 +1327,14 @@ async def chat_completion( headers=custom_headers, ) + ### ALERTING ### + end_time = time.time() + asyncio.create_task( + proxy_logging_obj.response_taking_too_long( + start_time=start_time, end_time=end_time, type="slow_response" + ) + ) + fastapi_response.headers["x-litellm-model-id"] = model_id return response except Exception as e: @@ -1428,6 +1461,8 @@ async def embeddings( user_api_key_dict=user_api_key_dict, data=data, call_type="embeddings" ) + start_time = time.time() + ## ROUTE TO CORRECT ENDPOINT ## # skip router if user passed their key if "api_key" in data: @@ -1456,6 +1491,14 @@ async def embeddings( else: response = await litellm.aembedding(**data) + ### ALERTING ### + end_time = time.time() + asyncio.create_task( + proxy_logging_obj.response_taking_too_long( + start_time=start_time, end_time=end_time, type="slow_response" + ) + ) + return response except Exception as e: await proxy_logging_obj.post_call_failure_hook( @@ -1535,6 +1578,8 @@ async def image_generation( user_api_key_dict=user_api_key_dict, data=data, call_type="embeddings" ) + start_time = time.time() + ## ROUTE TO CORRECT ENDPOINT ## # skip router if user passed their key if "api_key" in data: @@ -1560,6 +1605,14 @@ async def image_generation( else: response = await litellm.aimage_generation(**data) + ### ALERTING ### + end_time = time.time() + asyncio.create_task( + proxy_logging_obj.response_taking_too_long( + start_time=start_time, end_time=end_time, type="slow_response" + ) + ) + return response except Exception as e: await proxy_logging_obj.post_call_failure_hook( diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 0a52b129e3..fcc58920d5 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -34,10 +34,15 @@ class ProxyLogging: self.max_parallel_request_limiter = MaxParallelRequestsHandler() self.max_budget_limiter = MaxBudgetLimiter() self.alerting: Optional[List] = None + self.alerting_threshold: float = 300 # default to 5 min. threshold pass - def update_values(self, alerting: Optional[List]): + def update_values( + self, alerting: Optional[List], alerting_threshold: Optional[float] + ): self.alerting = alerting + if alerting_threshold is not None: + self.alerting_threshold = alerting_threshold def _init_litellm_callbacks(self): print_verbose(f"INITIALIZING LITELLM CALLBACKS!") @@ -105,18 +110,45 @@ class ProxyLogging: except Exception as e: raise e - async def success_handler(self, *args, **kwargs): + async def success_handler( + self, + user_api_key_dict: UserAPIKeyAuth, + response: Any, + call_type: Literal["completion", "embeddings"], + start_time, + end_time, + ): """ - Log successful db read/writes + Log successful API calls / db read/writes """ + pass - async def response_taking_too_long(self): - # Simulate a long-running operation that could take more than 5 minutes - await asyncio.sleep( - 300 - ) # Set it to 5 minutes - i'd imagine this might be different for streaming, non-streaming, non-completion (embedding + img) requests - await self.alerting_handler(message="Requests are hanging", level="Medium") + async def response_taking_too_long( + self, + start_time: Optional[float] = None, + end_time: Optional[float] = None, + type: Literal["hanging_request", "slow_response"] = "hanging_request", + ): + if type == "hanging_request": + # Simulate a long-running operation that could take more than 5 minutes + await asyncio.sleep( + self.alerting_threshold + ) # Set it to 5 minutes - i'd imagine this might be different for streaming, non-streaming, non-completion (embedding + img) requests + + await self.alerting_handler( + message=f"Requests are hanging - {self.alerting_threshold}s+ request time", + level="Medium", + ) + + elif ( + type == "slow_response" and start_time is not None and end_time is not None + ): + if end_time - start_time > self.alerting_threshold: + await self.alerting_handler( + message=f"Responses are slow - {round(end_time-start_time,2)}s response time", + level="Low", + ) async def alerting_handler( self, message: str, level: Literal["Low", "Medium", "High"]