From 940569703e4f7c5ddfacc20b423ce4225997317f Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Tue, 2 Jan 2024 17:44:32 +0530 Subject: [PATCH] feat(proxy_server.py): add slack alerting to proxy server add alerting for calls hanging, failing and db read/writes failing https://github.com/BerriAI/litellm/issues/1298 --- litellm/proxy/proxy_server.py | 10 +++-- litellm/proxy/utils.py | 74 ++++++++++++++++++++++++++++++++++- 2 files changed, 79 insertions(+), 5 deletions(-) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index e5ad6595b..5543b13d3 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -640,12 +640,14 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str): load_google_kms(use_google_kms=True) else: raise ValueError("Invalid Key Management System selected") - ### [DEPRECATED] LOAD FROM GOOGLE KMS ### + ### [DEPRECATED] LOAD FROM GOOGLE KMS ### old way of loading from google kms use_google_kms = general_settings.get("use_google_kms", False) load_google_kms(use_google_kms=use_google_kms) - ### [DEPRECATED] LOAD FROM AZURE KEY VAULT ### + ### [DEPRECATED] LOAD FROM AZURE KEY VAULT ### old way of loading from azure secret manager use_azure_key_vault = general_settings.get("use_azure_key_vault", False) load_from_azure_key_vault(use_azure_key_vault=use_azure_key_vault) + ### ALERTING ### + proxy_logging_obj.update_values(alerting=general_settings.get("alerting", None)) ### CONNECT TO DATABASE ### database_url = general_settings.get("database_url", None) if database_url and database_url.startswith("os.environ/"): @@ -655,8 +657,6 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str): prisma_setup(database_url=database_url) ## COST TRACKING ## cost_tracking() - ### START REDIS QUEUE ### - use_queue = general_settings.get("use_queue", False) ### MASTER KEY ### master_key = general_settings.get("master_key", None) if master_key and master_key.startswith("os.environ/"): @@ -1423,6 +1423,7 @@ async def embeddings( data = await proxy_logging_obj.pre_call_hook( user_api_key_dict=user_api_key_dict, data=data, call_type="embeddings" ) + ## ROUTE TO CORRECT ENDPOINT ## # skip router if user passed their key if "api_key" in data: @@ -1529,6 +1530,7 @@ async def image_generation( data = await proxy_logging_obj.pre_call_hook( user_api_key_dict=user_api_key_dict, data=data, call_type="embeddings" ) + ## ROUTE TO CORRECT ENDPOINT ## # skip router if user passed their key if "api_key" in data: diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index ea73891c4..d8d921a24 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -1,5 +1,5 @@ from typing import Optional, List, Any, Literal -import os, subprocess, hashlib, importlib, asyncio, copy, json +import os, subprocess, hashlib, importlib, asyncio, copy, json, aiohttp import litellm, backoff from litellm.proxy._types import UserAPIKeyAuth from litellm.caching import DualCache @@ -32,6 +32,9 @@ class ProxyLogging: self.max_budget_limiter = MaxBudgetLimiter() pass + def update_values(self, alerting: Optional[List]): + self.alerting = alerting + def _init_litellm_callbacks(self): print_verbose(f"INITIALIZING LITELLM CALLBACKS!") litellm.callbacks.append(self.max_parallel_request_limiter) @@ -74,7 +77,11 @@ class ProxyLogging: Covers: 1. /chat/completions 2. /embeddings + 3. /image/generation """ + ### ALERTING ### + asyncio.create_task(self.response_taking_too_long()) + try: for callback in litellm.callbacks: if isinstance(callback, CustomLogger) and "async_pre_call_hook" in vars( @@ -100,12 +107,69 @@ class ProxyLogging: """ pass + async def response_taking_too_long(self): + # Simulate a long-running operation that could take more than 5 minutes + await asyncio.sleep( + 300 + ) # Set it to 5 minutes - i'd imagine this might be different for streaming, non-streaming, non-completion (embedding + img) requests + await self.alerting_handler(message="Requests are hanging", level="Medium") + + async def alerting_handler( + self, message: str, level: Literal["Low", "Medium", "High"] + ): + """ + Alerting based on thresholds: - https://github.com/BerriAI/litellm/issues/1298 + + - Responses taking too long + - Requests are hanging + - Calls are failing + - DB Read/Writes are failing + + Parameters: + level: str - Low|Medium|High - if calls might fail (Medium) or are failing (High); Currently, no alerts would be 'Low'. + message: str - what is the alert about + """ + formatted_message = f"Level: {level}\n\nMessage: {message}" + if self.alerting is None: + return + + for client in self.alerting: + if client == "slack": + slack_webhook_url = os.getenv("SLACK_WEBHOOK_URL", None) + if slack_webhook_url is None: + raise Exception("Missing SLACK_WEBHOOK_URL from environment") + payload = {"text": formatted_message} + headers = {"Content-type": "application/json"} + async with aiohttp.ClientSession() as session: + async with session.post( + slack_webhook_url, json=payload, headers=headers + ) as response: + if response.status == 200: + pass + elif client == "sentry": + if litellm.utils.sentry_sdk_instance is not None: + litellm.utils.sentry_sdk_instance.capture_message(formatted_message) + else: + raise Exception("Missing SENTRY_DSN from environment") + async def failure_handler(self, original_exception): """ Log failed db read/writes Currently only logs exceptions to sentry """ + ### ALERTING ### + if isinstance(original_exception, HTTPException): + error_message = original_exception.detail + else: + error_message = str(original_exception) + asyncio.create_task( + self.alerting_handler( + message=f"DB read/write call failed: {error_message}", + level="High", + ) + ) + if litellm.utils.capture_exception: litellm.utils.capture_exception(error=original_exception) @@ -118,8 +182,16 @@ class ProxyLogging: Covers: 1. /chat/completions 2. /embeddings + 3. /image/generation """ + ### ALERTING ### + asyncio.create_task( + self.alerting_handler( + message=f"LLM API call failed: {str(original_exception)}", level="High" + ) + ) + for callback in litellm.callbacks: try: if isinstance(callback, CustomLogger):