feat(proxy_server.py): add slack alerting to proxy server

add alerting for calls hanging, failing and db read/writes failing

https://github.com/BerriAI/litellm/issues/1298
This commit is contained in:
Krrish Dholakia 2024-01-02 17:44:32 +05:30
parent 10b71b0ff1
commit 940569703e
2 changed files with 79 additions and 5 deletions

View file

@ -640,12 +640,14 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
load_google_kms(use_google_kms=True)
else:
raise ValueError("Invalid Key Management System selected")
### [DEPRECATED] LOAD FROM GOOGLE KMS ###
### [DEPRECATED] LOAD FROM GOOGLE KMS ### old way of loading from google kms
use_google_kms = general_settings.get("use_google_kms", False)
load_google_kms(use_google_kms=use_google_kms)
### [DEPRECATED] LOAD FROM AZURE KEY VAULT ###
### [DEPRECATED] LOAD FROM AZURE KEY VAULT ### old way of loading from azure secret manager
use_azure_key_vault = general_settings.get("use_azure_key_vault", False)
load_from_azure_key_vault(use_azure_key_vault=use_azure_key_vault)
### ALERTING ###
proxy_logging_obj.update_values(alerting=general_settings.get("alerting", None))
### CONNECT TO DATABASE ###
database_url = general_settings.get("database_url", None)
if database_url and database_url.startswith("os.environ/"):
@ -655,8 +657,6 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
prisma_setup(database_url=database_url)
## COST TRACKING ##
cost_tracking()
### START REDIS QUEUE ###
use_queue = general_settings.get("use_queue", False)
### MASTER KEY ###
master_key = general_settings.get("master_key", None)
if master_key and master_key.startswith("os.environ/"):
@ -1423,6 +1423,7 @@ async def embeddings(
data = await proxy_logging_obj.pre_call_hook(
user_api_key_dict=user_api_key_dict, data=data, call_type="embeddings"
)
## ROUTE TO CORRECT ENDPOINT ##
# skip router if user passed their key
if "api_key" in data:
@ -1529,6 +1530,7 @@ async def image_generation(
data = await proxy_logging_obj.pre_call_hook(
user_api_key_dict=user_api_key_dict, data=data, call_type="embeddings"
)
## ROUTE TO CORRECT ENDPOINT ##
# skip router if user passed their key
if "api_key" in data:

View file

@ -1,5 +1,5 @@
from typing import Optional, List, Any, Literal
import os, subprocess, hashlib, importlib, asyncio, copy, json
import os, subprocess, hashlib, importlib, asyncio, copy, json, aiohttp
import litellm, backoff
from litellm.proxy._types import UserAPIKeyAuth
from litellm.caching import DualCache
@ -32,6 +32,9 @@ class ProxyLogging:
self.max_budget_limiter = MaxBudgetLimiter()
pass
def update_values(self, alerting: Optional[List]):
self.alerting = alerting
def _init_litellm_callbacks(self):
print_verbose(f"INITIALIZING LITELLM CALLBACKS!")
litellm.callbacks.append(self.max_parallel_request_limiter)
@ -74,7 +77,11 @@ class ProxyLogging:
Covers:
1. /chat/completions
2. /embeddings
3. /image/generation
"""
### ALERTING ###
asyncio.create_task(self.response_taking_too_long())
try:
for callback in litellm.callbacks:
if isinstance(callback, CustomLogger) and "async_pre_call_hook" in vars(
@ -100,12 +107,69 @@ class ProxyLogging:
"""
pass
async def response_taking_too_long(self):
# Simulate a long-running operation that could take more than 5 minutes
await asyncio.sleep(
300
) # Set it to 5 minutes - i'd imagine this might be different for streaming, non-streaming, non-completion (embedding + img) requests
await self.alerting_handler(message="Requests are hanging", level="Medium")
async def alerting_handler(
self, message: str, level: Literal["Low", "Medium", "High"]
):
"""
Alerting based on thresholds: - https://github.com/BerriAI/litellm/issues/1298
- Responses taking too long
- Requests are hanging
- Calls are failing
- DB Read/Writes are failing
Parameters:
level: str - Low|Medium|High - if calls might fail (Medium) or are failing (High); Currently, no alerts would be 'Low'.
message: str - what is the alert about
"""
formatted_message = f"Level: {level}\n\nMessage: {message}"
if self.alerting is None:
return
for client in self.alerting:
if client == "slack":
slack_webhook_url = os.getenv("SLACK_WEBHOOK_URL", None)
if slack_webhook_url is None:
raise Exception("Missing SLACK_WEBHOOK_URL from environment")
payload = {"text": formatted_message}
headers = {"Content-type": "application/json"}
async with aiohttp.ClientSession() as session:
async with session.post(
slack_webhook_url, json=payload, headers=headers
) as response:
if response.status == 200:
pass
elif client == "sentry":
if litellm.utils.sentry_sdk_instance is not None:
litellm.utils.sentry_sdk_instance.capture_message(formatted_message)
else:
raise Exception("Missing SENTRY_DSN from environment")
async def failure_handler(self, original_exception):
"""
Log failed db read/writes
Currently only logs exceptions to sentry
"""
### ALERTING ###
if isinstance(original_exception, HTTPException):
error_message = original_exception.detail
else:
error_message = str(original_exception)
asyncio.create_task(
self.alerting_handler(
message=f"DB read/write call failed: {error_message}",
level="High",
)
)
if litellm.utils.capture_exception:
litellm.utils.capture_exception(error=original_exception)
@ -118,8 +182,16 @@ class ProxyLogging:
Covers:
1. /chat/completions
2. /embeddings
3. /image/generation
"""
### ALERTING ###
asyncio.create_task(
self.alerting_handler(
message=f"LLM API call failed: {str(original_exception)}", level="High"
)
)
for callback in litellm.callbacks:
try:
if isinstance(callback, CustomLogger):