fix(slack_alerting.py): support region based outage alerting

This commit is contained in:
Krrish Dholakia 2024-05-24 16:59:16 -07:00
parent 18f8287e29
commit f8350b9461
5 changed files with 309 additions and 61 deletions

View file

@ -1,6 +1,6 @@
#### What this does #### #### What this does ####
# Class for sending Slack Alerts # # Class for sending Slack Alerts #
import dotenv, os import dotenv, os, traceback
from litellm.proxy._types import UserAPIKeyAuth, CallInfo from litellm.proxy._types import UserAPIKeyAuth, CallInfo
from litellm._logging import verbose_logger, verbose_proxy_logger from litellm._logging import verbose_logger, verbose_proxy_logger
import litellm, threading import litellm, threading
@ -15,6 +15,35 @@ from enum import Enum
from datetime import datetime as dt, timedelta, timezone from datetime import datetime as dt, timedelta, timezone
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
import random import random
from typing import TypedDict
from openai import APIError
import litellm.types
import litellm.types.router
class OutageModel(TypedDict):
provider: str
region_name: str
alerts: List[str]
deployment_ids: List[str]
minor_alert_sent: bool
major_alert_sent: bool
last_updated_at: float
AlertType = Literal[
"llm_exceptions",
"llm_too_slow",
"llm_requests_hanging",
"budget_alerts",
"db_exceptions",
"daily_reports",
"spend_reports",
"cooldown_deployment",
"new_model_added",
"outage_alerts",
]
class LiteLLMBase(BaseModel): class LiteLLMBase(BaseModel):
@ -37,6 +66,10 @@ class SlackAlertingArgs(LiteLLMBase):
) )
report_check_interval: int = 5 * 60 # 5 minutes report_check_interval: int = 5 * 60 # 5 minutes
budget_alert_ttl: int = 24 * 60 * 60 # 24 hours budget_alert_ttl: int = 24 * 60 * 60 # 24 hours
outage_alert_ttl: int = 1 * 60 * 60 # 1 hour
minor_outage_alert_threshold: int = 3
major_outage_alert_threshold: int = 10
max_outage_alert_list_size: int = 10 # prevent memory leak
class WebhookEvent(CallInfo): class WebhookEvent(CallInfo):
@ -86,19 +119,7 @@ class SlackAlerting(CustomLogger):
internal_usage_cache: Optional[DualCache] = None, internal_usage_cache: Optional[DualCache] = None,
alerting_threshold: float = 300, # threshold for slow / hanging llm responses (in seconds) alerting_threshold: float = 300, # threshold for slow / hanging llm responses (in seconds)
alerting: Optional[List] = [], alerting: Optional[List] = [],
alert_types: List[ alert_types: List[AlertType] = [
Literal[
"llm_exceptions",
"llm_too_slow",
"llm_requests_hanging",
"budget_alerts",
"db_exceptions",
"daily_reports",
"spend_reports",
"cooldown_deployment",
"new_model_added",
]
] = [
"llm_exceptions", "llm_exceptions",
"llm_too_slow", "llm_too_slow",
"llm_requests_hanging", "llm_requests_hanging",
@ -108,6 +129,7 @@ class SlackAlerting(CustomLogger):
"spend_reports", "spend_reports",
"cooldown_deployment", "cooldown_deployment",
"new_model_added", "new_model_added",
"outage_alerts",
], ],
alert_to_webhook_url: Optional[ alert_to_webhook_url: Optional[
Dict Dict
@ -696,6 +718,99 @@ class SlackAlerting(CustomLogger):
return return
return return
async def outage_alerts(
self,
provider: str,
region_name: str,
exception: APIError,
deployment_id: str,
) -> None:
"""
Send slack alert if provider region (e.g. azure east-us-1) is having an outage (408 or >500 errors).
key = (provider + region)
value = {
- provider
- region
- threshold
- alerts []
}
ttl = 1hr
max_alerts_size = 10
"""
_id = provider + region_name
outage_value: Optional[OutageModel] = await self.internal_usage_cache.async_get_cache(key=_id) # type: ignore
if (
getattr(exception, "status_code", None) is not None
and exception.status_code != 408 # type: ignore
and exception.status_code < 500 # type: ignore
):
return
if outage_value is None:
outage_value = OutageModel(
provider=provider,
region_name=region_name,
alerts=[exception.message],
deployment_ids=[deployment_id],
minor_alert_sent=False,
major_alert_sent=False,
last_updated_at=time.time(),
)
## add to cache ##
await self.internal_usage_cache.async_set_cache(
key=_id, value=outage_value, ttl=self.alerting_args.outage_alert_ttl
)
return
outage_value["alerts"].append(exception.message)
outage_value["deployment_ids"].append(deployment_id)
outage_value["last_updated_at"] = time.time()
## MINOR OUTAGE ALERT SENT ##
if (
outage_value["minor_alert_sent"] == False
and len(outage_value["alerts"])
> self.alerting_args.minor_outage_alert_threshold
):
msg = "{} {} is having a **Minor Service Outage**.\n\n**Errors**\n{}\n\nLast Check:{}".format(
provider,
region_name,
outage_value["alerts"],
outage_value["last_updated_at"],
)
# send minor alert
_result_val = self.send_alert(
message=msg, level="Medium", alert_type="outage_alerts"
)
if _result_val is not None:
await _result_val
# set to true
outage_value["minor_alert_sent"] = True
elif (
outage_value["major_alert_sent"] == False
and len(outage_value["alerts"])
> self.alerting_args.major_outage_alert_threshold
):
msg = "{} {} is having a **Major Service Outage**.\n\n**Errors**\n{}\n\nLast Check:{}".format(
provider,
region_name,
outage_value["alerts"],
outage_value["last_updated_at"],
)
# send minor alert
await self.send_alert(message=msg, level="High", alert_type="outage_alerts")
# set to true
outage_value["major_alert_sent"] = True
## update cache ##
await self.internal_usage_cache.async_set_cache(key=_id, value=outage_value)
async def model_added_alert( async def model_added_alert(
self, model_name: str, litellm_model_name: str, passed_model_info: Any self, model_name: str, litellm_model_name: str, passed_model_info: Any
): ):
@ -745,10 +860,12 @@ Model Info:
``` ```
""" """
await self.send_alert( alert_val = self.send_alert(
message=message, level="Low", alert_type="new_model_added" message=message, level="Low", alert_type="new_model_added"
) )
pass
if alert_val is not None and asyncio.iscoroutine(alert_val):
await alert_val
async def model_removed_alert(self, model_name: str): async def model_removed_alert(self, model_name: str):
pass pass
@ -795,6 +912,7 @@ Model Info:
"spend_reports", "spend_reports",
"new_model_added", "new_model_added",
"cooldown_deployment", "cooldown_deployment",
"outage_alerts",
], ],
user_info: Optional[WebhookEvent] = None, user_info: Optional[WebhookEvent] = None,
**kwargs, **kwargs,
@ -910,10 +1028,12 @@ Model Info:
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
"""Log failure + deployment latency""" """Log failure + deployment latency"""
_litellm_params = kwargs.get("litellm_params", {})
_model_info = _litellm_params.get("model_info", {}) or {}
model_id = _model_info.get("id", "")
try:
if "daily_reports" in self.alert_types: if "daily_reports" in self.alert_types:
model_id = ( try:
kwargs.get("litellm_params", {}).get("model_info", {}).get("id", "")
)
await self.async_update_daily_reports( await self.async_update_daily_reports(
DeploymentMetrics( DeploymentMetrics(
id=model_id, id=model_id,
@ -922,6 +1042,41 @@ Model Info:
updated_at=litellm.utils.get_utc_datetime(), updated_at=litellm.utils.get_utc_datetime(),
) )
) )
except Exception as e:
verbose_logger.debug(f"Exception raises -{str(e)}")
if "outage_alerts" in self.alert_types and isinstance(
kwargs.get("exception", ""), APIError
):
_litellm_params = litellm.types.router.LiteLLM_Params(
model=kwargs.get("model", ""),
**kwargs.get("litellm_params", {}),
**kwargs.get("optional_params", {}),
)
_region_name = litellm.utils._get_model_region(
custom_llm_provider=kwargs.get("custom_llm_provider", ""),
litellm_params=_litellm_params,
)
# if region name not known, default to api base #
if _region_name is None:
_region_name = litellm.get_api_base(
model=kwargs.get("model", ""),
optional_params={
**kwargs.get("litellm_params", {}),
**kwargs.get("optional_params", {}),
},
)
if _region_name is None:
_region_name = ""
await self.outage_alerts(
provider=kwargs.get("custom_llm_provider", "") or "",
region_name=_region_name,
exception=kwargs["exception"],
deployment_id=model_id,
)
except Exception as e:
pass
async def _run_scheduler_helper(self, llm_router) -> bool: async def _run_scheduler_helper(self, llm_router) -> bool:
""" """

View file

@ -420,6 +420,8 @@ def mock_completion(
api_key="mock-key", api_key="mock-key",
) )
if isinstance(mock_response, Exception): if isinstance(mock_response, Exception):
if isinstance(mock_response, openai.APIError):
raise mock_response
raise litellm.APIError( raise litellm.APIError(
status_code=500, # type: ignore status_code=500, # type: ignore
message=str(mock_response), message=str(mock_response),
@ -463,7 +465,9 @@ def mock_completion(
return model_response return model_response
except: except Exception as e:
if isinstance(e, openai.APIError):
raise e
traceback.print_exc() traceback.print_exc()
raise Exception("Mock completion response failed") raise Exception("Mock completion response failed")

View file

@ -42,7 +42,7 @@ import smtplib, re
from email.mime.text import MIMEText from email.mime.text import MIMEText
from email.mime.multipart import MIMEMultipart from email.mime.multipart import MIMEMultipart
from datetime import datetime, timedelta from datetime import datetime, timedelta
from litellm.integrations.slack_alerting import SlackAlerting from litellm.integrations.slack_alerting import SlackAlerting, AlertType
from typing_extensions import overload from typing_extensions import overload
@ -78,19 +78,7 @@ class ProxyLogging:
self.cache_control_check = _PROXY_CacheControlCheck() self.cache_control_check = _PROXY_CacheControlCheck()
self.alerting: Optional[List] = None self.alerting: Optional[List] = None
self.alerting_threshold: float = 300 # default to 5 min. threshold self.alerting_threshold: float = 300 # default to 5 min. threshold
self.alert_types: List[ self.alert_types: List[AlertType] = [
Literal[
"llm_exceptions",
"llm_too_slow",
"llm_requests_hanging",
"budget_alerts",
"db_exceptions",
"daily_reports",
"spend_reports",
"cooldown_deployment",
"new_model_added",
]
] = [
"llm_exceptions", "llm_exceptions",
"llm_too_slow", "llm_too_slow",
"llm_requests_hanging", "llm_requests_hanging",
@ -100,6 +88,7 @@ class ProxyLogging:
"spend_reports", "spend_reports",
"cooldown_deployment", "cooldown_deployment",
"new_model_added", "new_model_added",
"outage_alerts",
] ]
self.slack_alerting_instance = SlackAlerting( self.slack_alerting_instance = SlackAlerting(
alerting_threshold=self.alerting_threshold, alerting_threshold=self.alerting_threshold,
@ -113,21 +102,7 @@ class ProxyLogging:
alerting: Optional[List], alerting: Optional[List],
alerting_threshold: Optional[float], alerting_threshold: Optional[float],
redis_cache: Optional[RedisCache], redis_cache: Optional[RedisCache],
alert_types: Optional[ alert_types: Optional[List[AlertType]] = None,
List[
Literal[
"llm_exceptions",
"llm_too_slow",
"llm_requests_hanging",
"budget_alerts",
"db_exceptions",
"daily_reports",
"spend_reports",
"cooldown_deployment",
"new_model_added",
]
]
] = None,
alerting_args: Optional[dict] = None, alerting_args: Optional[dict] = None,
): ):
self.alerting = alerting self.alerting = alerting

View file

@ -1,10 +1,11 @@
# What is this? # What is this?
## Tests slack alerting on proxy logging object ## Tests slack alerting on proxy logging object
import sys, json, uuid, random import sys, json, uuid, random, httpx
import os import os
import io, asyncio import io, asyncio
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Optional
# import logging # import logging
# logging.basicConfig(level=logging.DEBUG) # logging.basicConfig(level=logging.DEBUG)
@ -23,6 +24,7 @@ from unittest.mock import AsyncMock
import pytest import pytest
from litellm.router import AlertingConfig, Router from litellm.router import AlertingConfig, Router
from litellm.proxy._types import CallInfo from litellm.proxy._types import CallInfo
from openai import APIError
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -495,3 +497,96 @@ async def test_webhook_alerting(alerting_type):
user_info=user_info, user_info=user_info,
) )
mock_send_alert.assert_awaited_once() mock_send_alert.assert_awaited_once()
@pytest.mark.parametrize(
"model, api_base, llm_provider, vertex_project, vertex_location",
[
("gpt-3.5-turbo", None, "openai", None, None),
(
"azure/gpt-3.5-turbo",
"https://openai-gpt-4-test-v-1.openai.azure.com",
"azure",
None,
None,
),
("gemini-pro", None, "vertex_ai", "hardy-device-38811", "us-central1"),
],
)
@pytest.mark.parametrize("error_code", [500, 408, 400])
@pytest.mark.asyncio
async def test_outage_alerting_called(
model, api_base, llm_provider, vertex_project, vertex_location, error_code
):
"""
If call fails, outage alert is called
If multiple calls fail, outage alert is sent
"""
slack_alerting = SlackAlerting(alerting=["webhook"])
litellm.callbacks = [slack_alerting]
error_to_raise: Optional[APIError] = None
if error_code == 400:
print("RAISING 400 ERROR CODE")
error_to_raise = litellm.BadRequestError(
message="this is a bad request",
model=model,
llm_provider=llm_provider,
)
elif error_code == 408:
print("RAISING 408 ERROR CODE")
error_to_raise = litellm.Timeout(
message="A timeout occurred", model=model, llm_provider=llm_provider
)
elif error_code == 500:
print("RAISING 500 ERROR CODE")
error_to_raise = litellm.ServiceUnavailableError(
message="API is unavailable",
model=model,
llm_provider=llm_provider,
response=httpx.Response(
status_code=503,
request=httpx.Request(
method="completion",
url="https://github.com/BerriAI/litellm",
),
),
)
with patch.object(
slack_alerting, "outage_alerts", new=AsyncMock()
) as mock_send_alert:
try:
await litellm.acompletion(
model=model,
messages=[{"role": "user", "content": "Hey!"}],
api_base=api_base,
vertex_location=vertex_location,
vertex_project=vertex_project,
mock_response=error_to_raise,
)
except Exception as e:
pass
mock_send_alert.assert_called_once()
with patch.object(slack_alerting, "send_alert", new=AsyncMock()) as mock_send_alert:
for _ in range(3):
try:
await litellm.acompletion(
model=model,
messages=[{"role": "user", "content": "Hey!"}],
api_base=api_base,
vertex_location=vertex_location,
vertex_project=vertex_project,
mock_response=error_to_raise,
)
except Exception as e:
pass
if error_code == 500 or error_code == 408:
mock_send_alert.assert_called_once()
else:
mock_send_alert.assert_not_called()

View file

@ -8632,7 +8632,16 @@ def exception_type(
) )
elif hasattr(original_exception, "status_code"): elif hasattr(original_exception, "status_code"):
exception_mapping_worked = True exception_mapping_worked = True
if original_exception.status_code == 401: if original_exception.status_code == 400:
exception_mapping_worked = True
raise BadRequestError(
message=f"{exception_provider} - {message}",
llm_provider=custom_llm_provider,
model=model,
response=original_exception.response,
litellm_debug_info=extra_information,
)
elif original_exception.status_code == 401:
exception_mapping_worked = True exception_mapping_worked = True
raise AuthenticationError( raise AuthenticationError(
message=f"{exception_provider} - {message}", message=f"{exception_provider} - {message}",
@ -9145,6 +9154,7 @@ def exception_type(
), ),
), ),
) )
if hasattr(original_exception, "status_code"): if hasattr(original_exception, "status_code"):
if original_exception.status_code == 400: if original_exception.status_code == 400:
exception_mapping_worked = True exception_mapping_worked = True
@ -9825,7 +9835,16 @@ def exception_type(
) )
elif hasattr(original_exception, "status_code"): elif hasattr(original_exception, "status_code"):
exception_mapping_worked = True exception_mapping_worked = True
if original_exception.status_code == 401: if original_exception.status_code == 400:
exception_mapping_worked = True
raise BadRequestError(
message=f"AzureException - {original_exception.message}",
llm_provider="azure",
model=model,
litellm_debug_info=extra_information,
response=original_exception.response,
)
elif original_exception.status_code == 401:
exception_mapping_worked = True exception_mapping_worked = True
raise AuthenticationError( raise AuthenticationError(
message=f"AzureException - {original_exception.message}", message=f"AzureException - {original_exception.message}",
@ -9842,7 +9861,7 @@ def exception_type(
litellm_debug_info=extra_information, litellm_debug_info=extra_information,
llm_provider="azure", llm_provider="azure",
) )
if original_exception.status_code == 422: elif original_exception.status_code == 422:
exception_mapping_worked = True exception_mapping_worked = True
raise BadRequestError( raise BadRequestError(
message=f"AzureException - {original_exception.message}", message=f"AzureException - {original_exception.message}",