feat(proxy_server.py): add slack alerting to proxy server

add alerting for calls hanging, failing and db read/writes failing

https://github.com/BerriAI/litellm/issues/1298
This commit is contained in:
Krrish Dholakia 2024-01-02 17:44:32 +05:30
parent 10b71b0ff1
commit 940569703e
2 changed files with 79 additions and 5 deletions

View file

@ -640,12 +640,14 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
load_google_kms(use_google_kms=True) load_google_kms(use_google_kms=True)
else: else:
raise ValueError("Invalid Key Management System selected") raise ValueError("Invalid Key Management System selected")
### [DEPRECATED] LOAD FROM GOOGLE KMS ### ### [DEPRECATED] LOAD FROM GOOGLE KMS ### old way of loading from google kms
use_google_kms = general_settings.get("use_google_kms", False) use_google_kms = general_settings.get("use_google_kms", False)
load_google_kms(use_google_kms=use_google_kms) load_google_kms(use_google_kms=use_google_kms)
### [DEPRECATED] LOAD FROM AZURE KEY VAULT ### ### [DEPRECATED] LOAD FROM AZURE KEY VAULT ### old way of loading from azure secret manager
use_azure_key_vault = general_settings.get("use_azure_key_vault", False) use_azure_key_vault = general_settings.get("use_azure_key_vault", False)
load_from_azure_key_vault(use_azure_key_vault=use_azure_key_vault) load_from_azure_key_vault(use_azure_key_vault=use_azure_key_vault)
### ALERTING ###
proxy_logging_obj.update_values(alerting=general_settings.get("alerting", None))
### CONNECT TO DATABASE ### ### CONNECT TO DATABASE ###
database_url = general_settings.get("database_url", None) database_url = general_settings.get("database_url", None)
if database_url and database_url.startswith("os.environ/"): if database_url and database_url.startswith("os.environ/"):
@ -655,8 +657,6 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
prisma_setup(database_url=database_url) prisma_setup(database_url=database_url)
## COST TRACKING ## ## COST TRACKING ##
cost_tracking() cost_tracking()
### START REDIS QUEUE ###
use_queue = general_settings.get("use_queue", False)
### MASTER KEY ### ### MASTER KEY ###
master_key = general_settings.get("master_key", None) master_key = general_settings.get("master_key", None)
if master_key and master_key.startswith("os.environ/"): if master_key and master_key.startswith("os.environ/"):
@ -1423,6 +1423,7 @@ async def embeddings(
data = await proxy_logging_obj.pre_call_hook( data = await proxy_logging_obj.pre_call_hook(
user_api_key_dict=user_api_key_dict, data=data, call_type="embeddings" user_api_key_dict=user_api_key_dict, data=data, call_type="embeddings"
) )
## ROUTE TO CORRECT ENDPOINT ## ## ROUTE TO CORRECT ENDPOINT ##
# skip router if user passed their key # skip router if user passed their key
if "api_key" in data: if "api_key" in data:
@ -1529,6 +1530,7 @@ async def image_generation(
data = await proxy_logging_obj.pre_call_hook( data = await proxy_logging_obj.pre_call_hook(
user_api_key_dict=user_api_key_dict, data=data, call_type="embeddings" user_api_key_dict=user_api_key_dict, data=data, call_type="embeddings"
) )
## ROUTE TO CORRECT ENDPOINT ## ## ROUTE TO CORRECT ENDPOINT ##
# skip router if user passed their key # skip router if user passed their key
if "api_key" in data: if "api_key" in data:

View file

@ -1,5 +1,5 @@
from typing import Optional, List, Any, Literal from typing import Optional, List, Any, Literal
import os, subprocess, hashlib, importlib, asyncio, copy, json import os, subprocess, hashlib, importlib, asyncio, copy, json, aiohttp
import litellm, backoff import litellm, backoff
from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy._types import UserAPIKeyAuth
from litellm.caching import DualCache from litellm.caching import DualCache
@ -32,6 +32,9 @@ class ProxyLogging:
self.max_budget_limiter = MaxBudgetLimiter() self.max_budget_limiter = MaxBudgetLimiter()
pass pass
def update_values(self, alerting: Optional[List]):
self.alerting = alerting
def _init_litellm_callbacks(self): def _init_litellm_callbacks(self):
print_verbose(f"INITIALIZING LITELLM CALLBACKS!") print_verbose(f"INITIALIZING LITELLM CALLBACKS!")
litellm.callbacks.append(self.max_parallel_request_limiter) litellm.callbacks.append(self.max_parallel_request_limiter)
@ -74,7 +77,11 @@ class ProxyLogging:
Covers: Covers:
1. /chat/completions 1. /chat/completions
2. /embeddings 2. /embeddings
3. /image/generation
""" """
### ALERTING ###
asyncio.create_task(self.response_taking_too_long())
try: try:
for callback in litellm.callbacks: for callback in litellm.callbacks:
if isinstance(callback, CustomLogger) and "async_pre_call_hook" in vars( if isinstance(callback, CustomLogger) and "async_pre_call_hook" in vars(
@ -100,12 +107,69 @@ class ProxyLogging:
""" """
pass pass
async def response_taking_too_long(self):
# Simulate a long-running operation that could take more than 5 minutes
await asyncio.sleep(
300
) # Set it to 5 minutes - i'd imagine this might be different for streaming, non-streaming, non-completion (embedding + img) requests
await self.alerting_handler(message="Requests are hanging", level="Medium")
async def alerting_handler(
self, message: str, level: Literal["Low", "Medium", "High"]
):
"""
Alerting based on thresholds: - https://github.com/BerriAI/litellm/issues/1298
- Responses taking too long
- Requests are hanging
- Calls are failing
- DB Read/Writes are failing
Parameters:
level: str - Low|Medium|High - if calls might fail (Medium) or are failing (High); Currently, no alerts would be 'Low'.
message: str - what is the alert about
"""
formatted_message = f"Level: {level}\n\nMessage: {message}"
if self.alerting is None:
return
for client in self.alerting:
if client == "slack":
slack_webhook_url = os.getenv("SLACK_WEBHOOK_URL", None)
if slack_webhook_url is None:
raise Exception("Missing SLACK_WEBHOOK_URL from environment")
payload = {"text": formatted_message}
headers = {"Content-type": "application/json"}
async with aiohttp.ClientSession() as session:
async with session.post(
slack_webhook_url, json=payload, headers=headers
) as response:
if response.status == 200:
pass
elif client == "sentry":
if litellm.utils.sentry_sdk_instance is not None:
litellm.utils.sentry_sdk_instance.capture_message(formatted_message)
else:
raise Exception("Missing SENTRY_DSN from environment")
async def failure_handler(self, original_exception): async def failure_handler(self, original_exception):
""" """
Log failed db read/writes Log failed db read/writes
Currently only logs exceptions to sentry Currently only logs exceptions to sentry
""" """
### ALERTING ###
if isinstance(original_exception, HTTPException):
error_message = original_exception.detail
else:
error_message = str(original_exception)
asyncio.create_task(
self.alerting_handler(
message=f"DB read/write call failed: {error_message}",
level="High",
)
)
if litellm.utils.capture_exception: if litellm.utils.capture_exception:
litellm.utils.capture_exception(error=original_exception) litellm.utils.capture_exception(error=original_exception)
@ -118,8 +182,16 @@ class ProxyLogging:
Covers: Covers:
1. /chat/completions 1. /chat/completions
2. /embeddings 2. /embeddings
3. /image/generation
""" """
### ALERTING ###
asyncio.create_task(
self.alerting_handler(
message=f"LLM API call failed: {str(original_exception)}", level="High"
)
)
for callback in litellm.callbacks: for callback in litellm.callbacks:
try: try:
if isinstance(callback, CustomLogger): if isinstance(callback, CustomLogger):