feat(proxy_server.py): schedule slack daily report if enabled

if user enabled daily_reports, send them a slack report every 12 hours
This commit is contained in:
Krrish Dholakia 2024-05-06 18:25:48 -07:00
parent 718f423d7d
commit 6b9b4f05ba
4 changed files with 114 additions and 35 deletions

View file

@ -17,6 +17,7 @@ from pydantic import BaseModel
from enum import Enum from enum import Enum
from datetime import datetime as dt, timedelta from datetime import datetime as dt, timedelta
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
import random
class LiteLLMBase(BaseModel): class LiteLLMBase(BaseModel):
@ -32,8 +33,9 @@ class LiteLLMBase(BaseModel):
return self.dict() return self.dict()
class SlackArgs(LiteLLMBase): class SlackAlertingArgs(LiteLLMBase):
daily_report_frequency: int = 12 * 60 * 60 # 12 hours daily_report_frequency: int = 12 * 60 * 60 # 12 hours
report_check_interval: int = 5 * 60 # 5 minutes
class DeploymentMetrics(LiteLLMBase): class DeploymentMetrics(LiteLLMBase):
@ -63,6 +65,7 @@ class SlackAlertingCacheKeys(Enum):
failed_requests_key = "failed_requests_daily_metrics" failed_requests_key = "failed_requests_daily_metrics"
latency_key = "latency_daily_metrics" latency_key = "latency_daily_metrics"
report_sent_key = "daily_metrics_report_sent"
class SlackAlerting(CustomLogger): class SlackAlerting(CustomLogger):
@ -94,6 +97,7 @@ class SlackAlerting(CustomLogger):
alert_to_webhook_url: Optional[ alert_to_webhook_url: Optional[
Dict Dict
] = None, # if user wants to separate alerts to diff channels ] = None, # if user wants to separate alerts to diff channels
alerting_args={},
): ):
self.alerting_threshold = alerting_threshold self.alerting_threshold = alerting_threshold
self.alerting = alerting self.alerting = alerting
@ -102,6 +106,7 @@ class SlackAlerting(CustomLogger):
self.async_http_handler = AsyncHTTPHandler() self.async_http_handler = AsyncHTTPHandler()
self.alert_to_webhook_url = alert_to_webhook_url self.alert_to_webhook_url = alert_to_webhook_url
self.is_running = False self.is_running = False
self.alerting_args = SlackAlertingArgs(**alerting_args)
def update_values( def update_values(
self, self,
@ -109,6 +114,7 @@ class SlackAlerting(CustomLogger):
alerting_threshold: Optional[float] = None, alerting_threshold: Optional[float] = None,
alert_types: Optional[List] = None, alert_types: Optional[List] = None,
alert_to_webhook_url: Optional[Dict] = None, alert_to_webhook_url: Optional[Dict] = None,
alerting_args: Optional[Dict] = None,
): ):
if alerting is not None: if alerting is not None:
self.alerting = alerting self.alerting = alerting
@ -116,7 +122,8 @@ class SlackAlerting(CustomLogger):
self.alerting_threshold = alerting_threshold self.alerting_threshold = alerting_threshold
if alert_types is not None: if alert_types is not None:
self.alert_types = alert_types self.alert_types = alert_types
if alerting_args is not None:
self.alerting_args = SlackAlertingArgs(**alerting_args)
if alert_to_webhook_url is not None: if alert_to_webhook_url is not None:
# update the dict # update the dict
if self.alert_to_webhook_url is None: if self.alert_to_webhook_url is None:
@ -356,7 +363,7 @@ class SlackAlerting(CustomLogger):
# find top 5 slowest # find top 5 slowest
# Replace None values with a placeholder value (-1 in this case) # Replace None values with a placeholder value (-1 in this case)
placeholder_value = -1 placeholder_value = 0
replaced_slowest_values = [ replaced_slowest_values = [
value if value is not None else placeholder_value value if value is not None else placeholder_value
for value in latency_values for value in latency_values
@ -406,8 +413,8 @@ class SlackAlerting(CustomLogger):
_deployment["litellm_params"] if _deployment is not None else {} _deployment["litellm_params"] if _deployment is not None else {}
), ),
) )
value = replaced_slowest_values[top_5_slowest[i]] value = round(replaced_slowest_values[top_5_slowest[i]], 3)
message += f"\t{i+1}. Deployment: `{deployment_name}`, Latency: `{value}`, API Base: `{api_base}`\n\n" message += f"\t{i+1}. Deployment: `{deployment_name}`, Latency per output token: `{value}s/token`, API Base: `{api_base}`\n\n"
# cache cleanup -> reset values to 0 # cache cleanup -> reset values to 0
latency_cache_keys = [(key, 0) for key in latency_keys] latency_cache_keys = [(key, 0) for key in latency_keys]
@ -698,33 +705,82 @@ class SlackAlerting(CustomLogger):
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
"""Log deployment latency""" """Log deployment latency"""
model_id = kwargs.get("litellm_params", {}).get("model_info", {}).get("id", "") if "daily_reports" in self.alert_types:
response_ms: timedelta = end_time - start_time model_id = (
kwargs.get("litellm_params", {}).get("model_info", {}).get("id", "")
final_value = response_ms )
total_tokens = 0 response_s: timedelta = end_time - start_time
if isinstance(response_obj, litellm.ModelResponse): final_value = response_s
completion_tokens = response_obj.usage.completion_tokens total_tokens = 0
final_value = float(response_ms.total_seconds() / completion_tokens)
if isinstance(response_obj, litellm.ModelResponse):
await self.async_update_daily_reports( completion_tokens = response_obj.usage.completion_tokens
DeploymentMetrics( final_value = float(response_s.total_seconds() / completion_tokens)
id=model_id,
failed_request=False, await self.async_update_daily_reports(
latency_per_output_token=final_value, DeploymentMetrics(
updated_at=litellm.utils.get_utc_datetime(), id=model_id,
failed_request=False,
latency_per_output_token=final_value,
updated_at=litellm.utils.get_utc_datetime(),
)
) )
)
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
"""Log failure + deployment latency""" """Log failure + deployment latency"""
model_id = kwargs.get("litellm_params", {}).get("model_info", {}).get("id", "") if "daily_reports" in self.alert_types:
await self.async_update_daily_reports( model_id = (
DeploymentMetrics( kwargs.get("litellm_params", {}).get("model_info", {}).get("id", "")
id=model_id,
failed_request=True,
latency_per_output_token=None,
updated_at=litellm.utils.get_utc_datetime(),
) )
) await self.async_update_daily_reports(
DeploymentMetrics(
id=model_id,
failed_request=True,
latency_per_output_token=None,
updated_at=litellm.utils.get_utc_datetime(),
)
)
async def _run_scheduled_daily_report(self, llm_router: Optional[litellm.Router]):
"""
If 'daily_reports' enabled
Ping redis cache every 5 minutes to check if we should send the report
If yes -> call send_daily_report()
"""
if llm_router is None or self.alert_types is None:
return
if "daily_reports" in self.alert_types:
while True:
report_sent = await self.internal_usage_cache.async_get_cache(
key=SlackAlertingCacheKeys.report_sent_key.value
) # None | datetime
if report_sent is None:
await self.internal_usage_cache.async_set_cache(
key=SlackAlertingCacheKeys.report_sent_key.value,
value=litellm.utils.get_utc_datetime(),
)
else:
# check if current time - interval >= time last sent
current_time = litellm.utils.get_utc_datetime()
delta = current_time - timedelta(
seconds=self.alerting_args.daily_report_frequency
)
if delta >= report_sent:
# Sneak in the reporting logic here
await self.send_daily_reports(router=llm_router)
# Also, don't forget to update the report_sent time after sending the report!
await self.internal_usage_cache.async_set_cache(
key=SlackAlertingCacheKeys.report_sent_key.value,
value=litellm.utils.get_utc_datetime(),
)
interval = random.randint(
self.alerting_args.report_check_interval - 3,
self.alerting_args.report_check_interval + 3,
) # shuffle to prevent collisions
await asyncio.sleep(interval)
return

View file

@ -4,6 +4,16 @@ model_list:
api_key: my-fake-key api_key: my-fake-key
model: openai/my-fake-model model: openai/my-fake-model
model_name: fake-openai-endpoint model_name: fake-openai-endpoint
- litellm_params:
api_base: https://openai-function-calling-workers.tasslexyz.workers.dev/
api_key: my-fake-key-2
model: openai/my-fake-model-2
model_name: fake-openai-endpoint
- litellm_params:
api_base: https://openai-function-calling-workers.tasslexyz.workers.dev/
api_key: my-fake-key-3
model: openai/my-fake-model-3
model_name: fake-openai-endpoint
router_settings: router_settings:
num_retries: 0 num_retries: 0
enable_pre_call_checks: true enable_pre_call_checks: true
@ -19,4 +29,7 @@ litellm_settings:
general_settings: general_settings:
alerting: ["slack"] alerting: ["slack"]
alert_types: ["llm_exceptions", "daily_reports"] alert_types: ["llm_exceptions", "daily_reports"]
alerting_args:
daily_report_frequency: 60 # every minute
report_check_interval: 5 # every 5s

View file

@ -1900,9 +1900,6 @@ async def _run_background_health_check():
await asyncio.sleep(health_check_interval) await asyncio.sleep(health_check_interval)
semaphore = asyncio.Semaphore(1)
class ProxyConfig: class ProxyConfig:
""" """
Abstraction class on top of config loading/updating logic. Gives us one place to control all config updating logic. Abstraction class on top of config loading/updating logic. Gives us one place to control all config updating logic.
@ -2377,6 +2374,7 @@ class ProxyConfig:
alerting=general_settings.get("alerting", None), alerting=general_settings.get("alerting", None),
alerting_threshold=general_settings.get("alerting_threshold", 600), alerting_threshold=general_settings.get("alerting_threshold", 600),
alert_types=general_settings.get("alert_types", None), alert_types=general_settings.get("alert_types", None),
alerting_args=general_settings.get("alerting_args", None),
redis_cache=redis_usage_cache, redis_cache=redis_usage_cache,
) )
### CONNECT TO DATABASE ### ### CONNECT TO DATABASE ###
@ -2501,7 +2499,7 @@ class ProxyConfig:
for k, v in router_settings.items(): for k, v in router_settings.items():
if k in available_args: if k in available_args:
router_params[k] = v router_params[k] = v
router = litellm.Router(**router_params, semaphore=semaphore) # type:ignore router = litellm.Router(**router_params) # type:ignore
return router, model_list, general_settings return router, model_list, general_settings
def get_model_info_with_id(self, model) -> RouterModelInfo: def get_model_info_with_id(self, model) -> RouterModelInfo:
@ -3273,6 +3271,13 @@ async def startup_event():
proxy_logging_obj._init_litellm_callbacks() # INITIALIZE LITELLM CALLBACKS ON SERVER STARTUP <- do this to catch any logging errors on startup, not when calls are being made proxy_logging_obj._init_litellm_callbacks() # INITIALIZE LITELLM CALLBACKS ON SERVER STARTUP <- do this to catch any logging errors on startup, not when calls are being made
if "daily_reports" in proxy_logging_obj.slack_alerting_instance.alert_types:
asyncio.create_task(
proxy_logging_obj.slack_alerting_instance._run_scheduled_daily_report(
llm_router=llm_router
)
) # RUN DAILY REPORT (if scheduled)
## JWT AUTH ## ## JWT AUTH ##
if general_settings.get("litellm_jwtauth", None) is not None: if general_settings.get("litellm_jwtauth", None) is not None:
for k, v in general_settings["litellm_jwtauth"].items(): for k, v in general_settings["litellm_jwtauth"].items():

View file

@ -107,6 +107,7 @@ class ProxyLogging:
] ]
] ]
] = None, ] = None,
alerting_args: Optional[dict] = None,
): ):
self.alerting = alerting self.alerting = alerting
if alerting_threshold is not None: if alerting_threshold is not None:
@ -118,8 +119,12 @@ class ProxyLogging:
alerting=self.alerting, alerting=self.alerting,
alerting_threshold=self.alerting_threshold, alerting_threshold=self.alerting_threshold,
alert_types=self.alert_types, alert_types=self.alert_types,
alerting_args=alerting_args,
) )
if "daily_reports" in self.alert_types:
litellm.callbacks.append(self.slack_alerting_instance) # type: ignore
if redis_cache is not None: if redis_cache is not None:
self.internal_usage_cache.redis_cache = redis_cache self.internal_usage_cache.redis_cache = redis_cache