forked from phoenix/litellm-mirror
feat(proxy_server.py): add slack alerting to proxy server
add alerting for calls hanging, failing and db read/writes failing https://github.com/BerriAI/litellm/issues/1298
This commit is contained in:
parent
10b71b0ff1
commit
940569703e
2 changed files with 79 additions and 5 deletions
|
@ -640,12 +640,14 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
||||||
load_google_kms(use_google_kms=True)
|
load_google_kms(use_google_kms=True)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Invalid Key Management System selected")
|
raise ValueError("Invalid Key Management System selected")
|
||||||
### [DEPRECATED] LOAD FROM GOOGLE KMS ###
|
### [DEPRECATED] LOAD FROM GOOGLE KMS ### old way of loading from google kms
|
||||||
use_google_kms = general_settings.get("use_google_kms", False)
|
use_google_kms = general_settings.get("use_google_kms", False)
|
||||||
load_google_kms(use_google_kms=use_google_kms)
|
load_google_kms(use_google_kms=use_google_kms)
|
||||||
### [DEPRECATED] LOAD FROM AZURE KEY VAULT ###
|
### [DEPRECATED] LOAD FROM AZURE KEY VAULT ### old way of loading from azure secret manager
|
||||||
use_azure_key_vault = general_settings.get("use_azure_key_vault", False)
|
use_azure_key_vault = general_settings.get("use_azure_key_vault", False)
|
||||||
load_from_azure_key_vault(use_azure_key_vault=use_azure_key_vault)
|
load_from_azure_key_vault(use_azure_key_vault=use_azure_key_vault)
|
||||||
|
### ALERTING ###
|
||||||
|
proxy_logging_obj.update_values(alerting=general_settings.get("alerting", None))
|
||||||
### CONNECT TO DATABASE ###
|
### CONNECT TO DATABASE ###
|
||||||
database_url = general_settings.get("database_url", None)
|
database_url = general_settings.get("database_url", None)
|
||||||
if database_url and database_url.startswith("os.environ/"):
|
if database_url and database_url.startswith("os.environ/"):
|
||||||
|
@ -655,8 +657,6 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
||||||
prisma_setup(database_url=database_url)
|
prisma_setup(database_url=database_url)
|
||||||
## COST TRACKING ##
|
## COST TRACKING ##
|
||||||
cost_tracking()
|
cost_tracking()
|
||||||
### START REDIS QUEUE ###
|
|
||||||
use_queue = general_settings.get("use_queue", False)
|
|
||||||
### MASTER KEY ###
|
### MASTER KEY ###
|
||||||
master_key = general_settings.get("master_key", None)
|
master_key = general_settings.get("master_key", None)
|
||||||
if master_key and master_key.startswith("os.environ/"):
|
if master_key and master_key.startswith("os.environ/"):
|
||||||
|
@ -1423,6 +1423,7 @@ async def embeddings(
|
||||||
data = await proxy_logging_obj.pre_call_hook(
|
data = await proxy_logging_obj.pre_call_hook(
|
||||||
user_api_key_dict=user_api_key_dict, data=data, call_type="embeddings"
|
user_api_key_dict=user_api_key_dict, data=data, call_type="embeddings"
|
||||||
)
|
)
|
||||||
|
|
||||||
## ROUTE TO CORRECT ENDPOINT ##
|
## ROUTE TO CORRECT ENDPOINT ##
|
||||||
# skip router if user passed their key
|
# skip router if user passed their key
|
||||||
if "api_key" in data:
|
if "api_key" in data:
|
||||||
|
@ -1529,6 +1530,7 @@ async def image_generation(
|
||||||
data = await proxy_logging_obj.pre_call_hook(
|
data = await proxy_logging_obj.pre_call_hook(
|
||||||
user_api_key_dict=user_api_key_dict, data=data, call_type="embeddings"
|
user_api_key_dict=user_api_key_dict, data=data, call_type="embeddings"
|
||||||
)
|
)
|
||||||
|
|
||||||
## ROUTE TO CORRECT ENDPOINT ##
|
## ROUTE TO CORRECT ENDPOINT ##
|
||||||
# skip router if user passed their key
|
# skip router if user passed their key
|
||||||
if "api_key" in data:
|
if "api_key" in data:
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from typing import Optional, List, Any, Literal
|
from typing import Optional, List, Any, Literal
|
||||||
import os, subprocess, hashlib, importlib, asyncio, copy, json
|
import os, subprocess, hashlib, importlib, asyncio, copy, json, aiohttp
|
||||||
import litellm, backoff
|
import litellm, backoff
|
||||||
from litellm.proxy._types import UserAPIKeyAuth
|
from litellm.proxy._types import UserAPIKeyAuth
|
||||||
from litellm.caching import DualCache
|
from litellm.caching import DualCache
|
||||||
|
@ -32,6 +32,9 @@ class ProxyLogging:
|
||||||
self.max_budget_limiter = MaxBudgetLimiter()
|
self.max_budget_limiter = MaxBudgetLimiter()
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def update_values(self, alerting: Optional[List]):
|
||||||
|
self.alerting = alerting
|
||||||
|
|
||||||
def _init_litellm_callbacks(self):
|
def _init_litellm_callbacks(self):
|
||||||
print_verbose(f"INITIALIZING LITELLM CALLBACKS!")
|
print_verbose(f"INITIALIZING LITELLM CALLBACKS!")
|
||||||
litellm.callbacks.append(self.max_parallel_request_limiter)
|
litellm.callbacks.append(self.max_parallel_request_limiter)
|
||||||
|
@ -74,7 +77,11 @@ class ProxyLogging:
|
||||||
Covers:
|
Covers:
|
||||||
1. /chat/completions
|
1. /chat/completions
|
||||||
2. /embeddings
|
2. /embeddings
|
||||||
|
3. /image/generation
|
||||||
"""
|
"""
|
||||||
|
### ALERTING ###
|
||||||
|
asyncio.create_task(self.response_taking_too_long())
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for callback in litellm.callbacks:
|
for callback in litellm.callbacks:
|
||||||
if isinstance(callback, CustomLogger) and "async_pre_call_hook" in vars(
|
if isinstance(callback, CustomLogger) and "async_pre_call_hook" in vars(
|
||||||
|
@ -100,12 +107,69 @@ class ProxyLogging:
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
async def response_taking_too_long(self):
|
||||||
|
# Simulate a long-running operation that could take more than 5 minutes
|
||||||
|
await asyncio.sleep(
|
||||||
|
300
|
||||||
|
) # Set it to 5 minutes - i'd imagine this might be different for streaming, non-streaming, non-completion (embedding + img) requests
|
||||||
|
await self.alerting_handler(message="Requests are hanging", level="Medium")
|
||||||
|
|
||||||
|
async def alerting_handler(
|
||||||
|
self, message: str, level: Literal["Low", "Medium", "High"]
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Alerting based on thresholds: - https://github.com/BerriAI/litellm/issues/1298
|
||||||
|
|
||||||
|
- Responses taking too long
|
||||||
|
- Requests are hanging
|
||||||
|
- Calls are failing
|
||||||
|
- DB Read/Writes are failing
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
level: str - Low|Medium|High - if calls might fail (Medium) or are failing (High); Currently, no alerts would be 'Low'.
|
||||||
|
message: str - what is the alert about
|
||||||
|
"""
|
||||||
|
formatted_message = f"Level: {level}\n\nMessage: {message}"
|
||||||
|
if self.alerting is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
for client in self.alerting:
|
||||||
|
if client == "slack":
|
||||||
|
slack_webhook_url = os.getenv("SLACK_WEBHOOK_URL", None)
|
||||||
|
if slack_webhook_url is None:
|
||||||
|
raise Exception("Missing SLACK_WEBHOOK_URL from environment")
|
||||||
|
payload = {"text": formatted_message}
|
||||||
|
headers = {"Content-type": "application/json"}
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
slack_webhook_url, json=payload, headers=headers
|
||||||
|
) as response:
|
||||||
|
if response.status == 200:
|
||||||
|
pass
|
||||||
|
elif client == "sentry":
|
||||||
|
if litellm.utils.sentry_sdk_instance is not None:
|
||||||
|
litellm.utils.sentry_sdk_instance.capture_message(formatted_message)
|
||||||
|
else:
|
||||||
|
raise Exception("Missing SENTRY_DSN from environment")
|
||||||
|
|
||||||
async def failure_handler(self, original_exception):
|
async def failure_handler(self, original_exception):
|
||||||
"""
|
"""
|
||||||
Log failed db read/writes
|
Log failed db read/writes
|
||||||
|
|
||||||
Currently only logs exceptions to sentry
|
Currently only logs exceptions to sentry
|
||||||
"""
|
"""
|
||||||
|
### ALERTING ###
|
||||||
|
if isinstance(original_exception, HTTPException):
|
||||||
|
error_message = original_exception.detail
|
||||||
|
else:
|
||||||
|
error_message = str(original_exception)
|
||||||
|
asyncio.create_task(
|
||||||
|
self.alerting_handler(
|
||||||
|
message=f"DB read/write call failed: {error_message}",
|
||||||
|
level="High",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
if litellm.utils.capture_exception:
|
if litellm.utils.capture_exception:
|
||||||
litellm.utils.capture_exception(error=original_exception)
|
litellm.utils.capture_exception(error=original_exception)
|
||||||
|
|
||||||
|
@ -118,8 +182,16 @@ class ProxyLogging:
|
||||||
Covers:
|
Covers:
|
||||||
1. /chat/completions
|
1. /chat/completions
|
||||||
2. /embeddings
|
2. /embeddings
|
||||||
|
3. /image/generation
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
### ALERTING ###
|
||||||
|
asyncio.create_task(
|
||||||
|
self.alerting_handler(
|
||||||
|
message=f"LLM API call failed: {str(original_exception)}", level="High"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
for callback in litellm.callbacks:
|
for callback in litellm.callbacks:
|
||||||
try:
|
try:
|
||||||
if isinstance(callback, CustomLogger):
|
if isinstance(callback, CustomLogger):
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue