[Feat-Proxy] Slack Alerting - allow using os.environ/ vars for alert to webhook url (#5726)

* allow using os.environ for slack urls

* use env vars for webhook urls

* fix types for get_secret

* fix linting

* fix linting

* fix linting

* linting fixes

* linting fix

* docs alerting slack

* fix get data
This commit is contained in:
Ishaan Jaff 2024-09-16 18:03:37 -07:00 committed by GitHub
parent 8103e2b2da
commit b6ae2204a8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
23 changed files with 286 additions and 84 deletions

View file

@ -1,4 +1,6 @@
import Image from '@theme/IdealImage';
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# 🚨 Alerting / Webhooks
@ -149,6 +151,10 @@ spend_reports -> go to slack channel #llm-spend-reports
Set `alert_to_webhook_url` on your config.yaml
<Tabs>
<TabItem label="1 channel per alert" value="1">
```yaml
model_list:
- model_name: gpt-4
@ -177,6 +183,44 @@ general_settings:
litellm_settings:
success_callback: ["langfuse"]
```
</TabItem>
<TabItem label="multiple channels per alert" value="2">
Provide multiple slack channels for a given alert type
```yaml
model_list:
- model_name: gpt-4
litellm_params:
model: openai/fake
api_key: fake-key
api_base: https://exampleopenaiendpoint-production.up.railway.app/
general_settings:
master_key: sk-1234
alerting: ["slack"]
alerting_threshold: 0.0001 # (Seconds) set an artifically low threshold for testing alerting
alert_to_webhook_url: {
"llm_exceptions": ["os.environ/SLACK_WEBHOOK_URL", "os.environ/SLACK_WEBHOOK_URL_2"],
"llm_too_slow": ["https://webhook.site/7843a980-a494-4967-80fb-d502dbc16886", "https://webhook.site/28cfb179-f4fb-4408-8129-729ff55cf213"],
"llm_requests_hanging": ["os.environ/SLACK_WEBHOOK_URL_5", "os.environ/SLACK_WEBHOOK_URL_6"],
"budget_alerts": ["os.environ/SLACK_WEBHOOK_URL_7", "os.environ/SLACK_WEBHOOK_URL_8"],
"db_exceptions": ["os.environ/SLACK_WEBHOOK_URL_9", "os.environ/SLACK_WEBHOOK_URL_10"],
"daily_reports": ["os.environ/SLACK_WEBHOOK_URL_11", "os.environ/SLACK_WEBHOOK_URL_12"],
"spend_reports": ["os.environ/SLACK_WEBHOOK_URL_13", "os.environ/SLACK_WEBHOOK_URL_14"],
"cooldown_deployment": ["os.environ/SLACK_WEBHOOK_URL_15", "os.environ/SLACK_WEBHOOK_URL_16"],
"new_model_added": ["os.environ/SLACK_WEBHOOK_URL_17", "os.environ/SLACK_WEBHOOK_URL_18"],
"outage_alerts": ["os.environ/SLACK_WEBHOOK_URL_19", "os.environ/SLACK_WEBHOOK_URL_20"],
}
litellm_settings:
success_callback: ["langfuse"]
```
</TabItem>
</Tabs>
Test it - send a valid llm request - expect to see a `llm_too_slow` alert in it's own slack channel
@ -193,36 +237,6 @@ curl -i http://localhost:4000/v1/chat/completions \
```
### Provide multiple slack channels for a given alert type
Just add it like this - `alert_type: [<hook_url_channel_1>, <hook_url_channel_2>]`.
1. Setup config.yaml
```yaml
general_settings:
master_key: sk-1234
alerting: ["slack"]
alert_to_webhook_url: {
"spend_reports": ["https://webhook.site/7843a980-a494-4967-80fb-d502dbc16886", "https://webhook.site/28cfb179-f4fb-4408-8129-729ff55cf213"]
}
```
2. Start proxy
```bash
litellm --config /path/to/config.yaml
```
3. Test it!
```bash
curl -X GET 'http://0.0.0.0:4000/health/services?service=slack' \
-H 'Authorization: Bearer sk-1234'
```
In case of error, check server logs for the error message!
### Using MS Teams Webhooks
MS Teams provides a slack compatible webhook url that you can use for alerting

View file

@ -57,7 +57,7 @@ class MyCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/observabilit
self,
data: dict,
user_api_key_dict: UserAPIKeyAuth,
call_type: Literal["completion", "embeddings", "image_generation"],
call_type: Literal["completion", "embeddings", "image_generation", "moderation", "audio_transcription"],
):
pass

View file

@ -84,7 +84,7 @@ class myCustomGuardrail(CustomGuardrail):
self,
data: dict,
user_api_key_dict: UserAPIKeyAuth,
call_type: Literal["completion", "embeddings", "image_generation"],
call_type: Literal["completion", "embeddings", "image_generation", "moderation", "audio_transcription"],
):
"""
Runs in parallel to LLM API call

View file

@ -174,7 +174,13 @@ class AporiaGuardrail(CustomGuardrail):
self,
data: dict,
user_api_key_dict: UserAPIKeyAuth,
call_type: Literal["completion", "embeddings", "image_generation"],
call_type: Literal[
"completion",
"embeddings",
"image_generation",
"moderation",
"audio_transcription",
],
):
from litellm.proxy.common_utils.callback_utils import (
add_guardrail_to_applied_guardrails_header,

View file

@ -97,7 +97,13 @@ class _ENTERPRISE_GoogleTextModeration(CustomLogger):
self,
data: dict,
user_api_key_dict: UserAPIKeyAuth,
call_type: Literal["completion", "embeddings", "image_generation"],
call_type: Literal[
"completion",
"embeddings",
"image_generation",
"moderation",
"audio_transcription",
],
):
"""
- Calls Google's Text Moderation API

View file

@ -100,7 +100,13 @@ class _ENTERPRISE_LlamaGuard(CustomLogger):
self,
data: dict,
user_api_key_dict: UserAPIKeyAuth,
call_type: Literal["completion", "embeddings", "image_generation"],
call_type: Literal[
"completion",
"embeddings",
"image_generation",
"moderation",
"audio_transcription",
],
):
"""
- Calls the Llama Guard Endpoint

View file

@ -127,7 +127,13 @@ class _ENTERPRISE_LLMGuard(CustomLogger):
self,
data: dict,
user_api_key_dict: UserAPIKeyAuth,
call_type: Literal["completion", "embeddings", "image_generation"],
call_type: Literal[
"completion",
"embeddings",
"image_generation",
"moderation",
"audio_transcription",
],
):
"""
- Calls the LLM Guard Endpoint

View file

@ -43,7 +43,13 @@ class _ENTERPRISE_OpenAI_Moderation(CustomLogger):
self,
data: dict,
user_api_key_dict: UserAPIKeyAuth,
call_type: Literal["completion", "embeddings", "image_generation"],
call_type: Literal[
"completion",
"embeddings",
"image_generation",
"moderation",
"audio_transcription",
],
):
if "messages" in data and isinstance(data["messages"], list):
text = ""

View file

@ -41,6 +41,7 @@ from litellm.types.router import LiteLLM_Params
from ..email_templates.templates import *
from .batching_handler import send_to_webhook, squash_payloads
from .types import *
from .utils import process_slack_alerting_variables
class SlackAlerting(CustomBatchLogger):
@ -70,7 +71,7 @@ class SlackAlerting(CustomBatchLogger):
"outage_alerts",
],
alert_to_webhook_url: Optional[
Dict
Dict[AlertType, Union[List[str], str]]
] = None, # if user wants to separate alerts to diff channels
alerting_args={},
default_webhook_url: Optional[str] = None,
@ -85,7 +86,9 @@ class SlackAlerting(CustomBatchLogger):
self.async_http_handler = get_async_httpx_client(
llm_provider=httpxSpecialProvider.LoggingCallback
)
self.alert_to_webhook_url = alert_to_webhook_url
self.alert_to_webhook_url = process_slack_alerting_variables(
alert_to_webhook_url=alert_to_webhook_url
)
self.is_running = False
self.alerting_args = SlackAlertingArgs(**alerting_args)
self.default_webhook_url = default_webhook_url
@ -97,7 +100,7 @@ class SlackAlerting(CustomBatchLogger):
alerting: Optional[List] = None,
alerting_threshold: Optional[float] = None,
alert_types: Optional[List] = None,
alert_to_webhook_url: Optional[Dict] = None,
alert_to_webhook_url: Optional[Dict[AlertType, Union[List[str], str]]] = None,
alerting_args: Optional[Dict] = None,
llm_router: Optional[litellm.Router] = None,
):
@ -113,9 +116,17 @@ class SlackAlerting(CustomBatchLogger):
if alert_to_webhook_url is not None:
# update the dict
if self.alert_to_webhook_url is None:
self.alert_to_webhook_url = alert_to_webhook_url
self.alert_to_webhook_url = process_slack_alerting_variables(
alert_to_webhook_url=alert_to_webhook_url
)
else:
self.alert_to_webhook_url.update(alert_to_webhook_url)
_new_values = (
process_slack_alerting_variables(
alert_to_webhook_url=alert_to_webhook_url
)
or {}
)
self.alert_to_webhook_url.update(_new_values)
if llm_router is not None:
self.llm_router = llm_router

View file

@ -0,0 +1,51 @@
"""
Utils used for slack alerting
"""
from typing import Dict, List, Optional, Union
import litellm
from litellm.proxy._types import AlertType
from litellm.secret_managers.main import get_secret
def process_slack_alerting_variables(
alert_to_webhook_url: Optional[Dict[AlertType, Union[List[str], str]]]
) -> Optional[Dict[AlertType, Union[List[str], str]]]:
"""
process alert_to_webhook_url
- check if any urls are set as os.environ/SLACK_WEBHOOK_URL_1 read env var and set the correct value
"""
if alert_to_webhook_url is None:
return None
for alert_type, webhook_urls in alert_to_webhook_url.items():
if isinstance(webhook_urls, list):
_webhook_values: List[str] = []
for webhook_url in webhook_urls:
if "os.environ/" in webhook_url:
_env_value = get_secret(secret_name=webhook_url)
if not isinstance(_env_value, str):
raise ValueError(
f"Invalid webhook url value for: {webhook_url}. Got type={type(_env_value)}"
)
_webhook_values.append(_env_value)
else:
_webhook_values.append(webhook_url)
alert_to_webhook_url[alert_type] = _webhook_values
else:
_webhook_value_str: str = webhook_urls
if "os.environ/" in webhook_urls:
_env_value = get_secret(secret_name=webhook_urls)
if not isinstance(_env_value, str):
raise ValueError(
f"Invalid webhook url value for: {webhook_urls}. Got type={type(_env_value)}"
)
_webhook_value_str = _env_value
else:
_webhook_value_str = webhook_urls
alert_to_webhook_url[alert_type] = _webhook_value_str
return alert_to_webhook_url

View file

@ -17,6 +17,11 @@ class CustomGuardrail(CustomLogger):
self.event_hook: Optional[GuardrailEventHooks] = event_hook
super().__init__(**kwargs)
# older v1 implementation - not used, just kept for backward compatibility
self.moderation_check: Literal["pre_call", "in_parallel"] = kwargs.get(
"moderation_check", "pre_call"
)
def should_run_guardrail(self, data, event_type: GuardrailEventHooks) -> bool:
metadata = data.get("metadata") or {}
requested_guardrails = metadata.get("guardrails") or []

View file

@ -151,7 +151,13 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
self,
data: dict,
user_api_key_dict: UserAPIKeyAuth,
call_type: Literal["completion", "embeddings", "image_generation"],
call_type: Literal[
"completion",
"embeddings",
"image_generation",
"moderation",
"audio_transcription",
],
):
pass

View file

@ -52,7 +52,13 @@ class MyCustomHandler(
self,
data: dict,
user_api_key_dict: UserAPIKeyAuth,
call_type: Literal["completion", "embeddings", "image_generation"],
call_type: Literal[
"completion",
"embeddings",
"image_generation",
"moderation",
"audio_transcription",
],
):
pass

View file

@ -61,7 +61,13 @@ class myCustomGuardrail(CustomGuardrail):
self,
data: dict,
user_api_key_dict: UserAPIKeyAuth,
call_type: Literal["completion", "embeddings", "image_generation"],
call_type: Literal[
"completion",
"embeddings",
"image_generation",
"moderation",
"audio_transcription",
],
):
"""
Runs in parallel to LLM API call

View file

@ -61,7 +61,13 @@ class myCustomGuardrail(CustomGuardrail):
self,
data: dict,
user_api_key_dict: UserAPIKeyAuth,
call_type: Literal["completion", "embeddings", "image_generation"],
call_type: Literal[
"completion",
"embeddings",
"image_generation",
"moderation",
"audio_transcription",
],
):
"""
Runs in parallel to LLM API call

View file

@ -177,7 +177,13 @@ class AporiaGuardrail(CustomGuardrail):
self,
data: dict,
user_api_key_dict: UserAPIKeyAuth,
call_type: Literal["completion", "embeddings", "image_generation"],
call_type: Literal[
"completion",
"embeddings",
"image_generation",
"moderation",
"audio_transcription",
],
):
from litellm.proxy.common_utils.callback_utils import (
add_guardrail_to_applied_guardrails_header,

View file

@ -218,7 +218,7 @@ class BedrockGuardrail(CustomGuardrail, BaseAWSLLM):
response = await self.async_handler.post(
url=prepared_request.url,
json=request_data, # type: ignore
headers=prepared_request.headers,
headers=dict(prepared_request.headers),
)
verbose_proxy_logger.debug("Bedrock AI response: %s", response.text)
if response.status_code == 200:
@ -243,7 +243,13 @@ class BedrockGuardrail(CustomGuardrail, BaseAWSLLM):
self,
data: dict,
user_api_key_dict: UserAPIKeyAuth,
call_type: Literal["completion", "embeddings", "image_generation"],
call_type: Literal[
"completion",
"embeddings",
"image_generation",
"moderation",
"audio_transcription",
],
):
from litellm.proxy.common_utils.callback_utils import (
add_guardrail_to_applied_guardrails_header,

View file

@ -61,7 +61,13 @@ class myCustomGuardrail(CustomGuardrail):
self,
data: dict,
user_api_key_dict: UserAPIKeyAuth,
call_type: Literal["completion", "embeddings", "image_generation"],
call_type: Literal[
"completion",
"embeddings",
"image_generation",
"moderation",
"audio_transcription",
],
):
"""
Runs in parallel to LLM API call

View file

@ -143,6 +143,7 @@ class lakeraAI_Moderation(CustomGuardrail):
):
return
text = ""
_json_data: str = ""
if "messages" in data and isinstance(data["messages"], list):
prompt_injection_obj: Optional[GuardrailItem] = (
litellm.guardrail_name_config_map.get("prompt_injection")
@ -320,7 +321,13 @@ class lakeraAI_Moderation(CustomGuardrail):
self,
data: dict,
user_api_key_dict: UserAPIKeyAuth,
call_type: Literal["completion", "embeddings", "image_generation"],
call_type: Literal[
"completion",
"embeddings",
"image_generation",
"moderation",
"audio_transcription",
],
):
if self.event_hook is None:
if self.moderation_check == "pre_call":

View file

@ -174,6 +174,7 @@ async def health_services_endpoint(
not in proxy_logging_obj.slack_alerting_instance.alert_types
):
continue
test_message = "default test message"
if alert_type == "llm_exceptions":
test_message = f"LLM Exception test alert"
@ -189,6 +190,8 @@ async def health_services_endpoint(
test_message = f"Outage Alert Exception test alert"
elif alert_type == "daily_reports":
test_message = f"Daily Reports test alert"
else:
test_message = f"Budget Alert test alert"
await proxy_logging_obj.alerting_handler(
message=test_message, level="Low", alert_type=alert_type
@ -354,7 +357,7 @@ async def health_endpoint(
db_health_cache = {"status": "unknown", "last_updated": datetime.now()}
def _db_health_readiness_check():
async def _db_health_readiness_check():
from litellm.proxy.proxy_server import prisma_client
global db_health_cache
@ -365,7 +368,12 @@ def _db_health_readiness_check():
time_diff = datetime.now() - db_health_cache["last_updated"]
if db_health_cache["status"] != "unknown" and time_diff < timedelta(minutes=2):
return db_health_cache
prisma_client.health_check()
if prisma_client is None:
db_health_cache = {"status": "disconnected", "last_updated": datetime.now()}
return db_health_cache
await prisma_client.health_check()
db_health_cache = {"status": "connected", "last_updated": datetime.now()}
return db_health_cache
@ -478,7 +486,7 @@ async def health_readiness():
# check DB
if prisma_client is not None: # if db passed in, check if it's connected
db_health_status = _db_health_readiness_check()
db_health_status = await _db_health_readiness_check()
return {
"status": "healthy",
"db": "connected",

View file

@ -7,18 +7,22 @@
## Reject a call if it contains a prompt injection attack.
from typing import Optional, Literal
import litellm
from litellm.caching import DualCache
from litellm.proxy._types import UserAPIKeyAuth, LiteLLMPromptInjectionParams
from litellm.integrations.custom_logger import CustomLogger
from litellm._logging import verbose_proxy_logger
from litellm.utils import get_formatted_prompt
from litellm.llms.prompt_templates.factory import prompt_injection_detection_default_pt
from fastapi import HTTPException
import json, traceback, re
import json
import re
import traceback
from difflib import SequenceMatcher
from typing import List
from typing import List, Literal, Optional
from fastapi import HTTPException
from typing_extensions import overload
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.caching import DualCache
from litellm.integrations.custom_logger import CustomLogger
from litellm.llms.prompt_templates.factory import prompt_injection_detection_default_pt
from litellm.proxy._types import LiteLLMPromptInjectionParams, UserAPIKeyAuth
from litellm.utils import get_formatted_prompt
class _OPTIONAL_PromptInjectionDetection(CustomLogger):
@ -201,7 +205,7 @@ class _OPTIONAL_PromptInjectionDetection(CustomLogger):
and self.prompt_injection_params is not None
and self.prompt_injection_params.reject_as_response
):
return e.detail["error"]
return e.detail.get("error")
raise e
except Exception as e:
verbose_proxy_logger.error(
@ -211,18 +215,24 @@ class _OPTIONAL_PromptInjectionDetection(CustomLogger):
)
verbose_proxy_logger.debug(traceback.format_exc())
async def async_moderation_hook(
async def async_moderation_hook( # type: ignore
self,
data: dict,
user_api_key_dict: UserAPIKeyAuth,
call_type: Literal["completion", "embeddings", "image_generation"],
):
call_type: Literal[
"completion",
"embeddings",
"image_generation",
"moderation",
"audio_transcription",
],
) -> Optional[bool]:
self.print_verbose(
f"IN ASYNC MODERATION HOOK - self.prompt_injection_params = {self.prompt_injection_params}"
)
if self.prompt_injection_params is None:
return
return None
formatted_prompt = get_formatted_prompt(data=data, call_type=call_type) # type: ignore
is_prompt_attack = False

View file

@ -20,9 +20,20 @@ model_list:
api_key: fake-key
api_base: https://exampleopenaiendpoint-production.up.railway.app
general_settings:
general_settings:
master_key: sk-1234
alerting: ["slack"]
alerting_threshold: 0.0001 # (Seconds) set an artifically low threshold for testing alerting
alert_to_webhook_url: {
"llm_too_slow": [
"os.environ/SLACK_WEBHOOK_URL",
"os.environ/SLACK_WEBHOOK_URL_2",
],
}
key_management_system: "azure_key_vault"
litellm_settings:
success_callback: ["prometheus"]

View file

@ -93,6 +93,7 @@ def safe_deep_copy(data):
return data
# Step 1: Remove the litellm_parent_otel_span
litellm_parent_otel_span = None
if isinstance(data, dict):
# remove litellm_parent_otel_span since this is not picklable
if "metadata" in data and "litellm_parent_otel_span" in data["metadata"]:
@ -519,13 +520,7 @@ class ProxyLogging:
self,
message: str,
level: Literal["Low", "Medium", "High"],
alert_type: Literal[
"llm_exceptions",
"llm_too_slow",
"llm_requests_hanging",
"budget_alerts",
"db_exceptions",
],
alert_type: AlertType,
request_data: Optional[dict] = None,
):
"""
@ -1302,6 +1297,7 @@ class PrismaClient:
table_name is not None and table_name == "key"
):
# check if plain text or hash
hashed_token = None
if token is not None:
if isinstance(token, str):
hashed_token = token
@ -1712,7 +1708,7 @@ class PrismaClient:
updated_table_row = self.db.litellm_config.upsert(
where={"param_name": k},
data={
"create": {"param_name": k, "param_value": updated_data},
"create": {"param_name": k, "param_value": updated_data}, # type: ignore
"update": {"param_value": updated_data},
},
)
@ -2265,11 +2261,15 @@ class DBClient:
"""
For closing connection on server shutdown
"""
return await self.db.disconnect()
if self.db is not None:
return await self.db.disconnect() # type: ignore
return asyncio.sleep(0) # Return a dummy coroutine if self.db is None
### CUSTOM FILE ###
def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any:
module_name = value
instance_name = None
try:
print_verbose(f"value: {value}")
# Split the path by dots to separate module from instance
@ -2363,6 +2363,15 @@ async def send_email(receiver_email, subject, html):
"sending email from %s to %s", sender_email, receiver_email
)
if smtp_host is None:
raise ValueError("Trying to use SMTP, but SMTP_HOST is not set")
if smtp_username is None:
raise ValueError("Trying to use SMTP, but SMTP_USERNAME is not set")
if smtp_password is None:
raise ValueError("Trying to use SMTP, but SMTP_PASSWORD is not set")
# Attach the body to the email
email_message.attach(MIMEText(html, "html"))
@ -2555,6 +2564,7 @@ async def update_spend(
spend_logs: list,
"""
n_retry_times = 3
i = None
### UPDATE USER TABLE ###
if len(prisma_client.user_list_transactons.keys()) > 0:
for i in range(n_retry_times + 1):
@ -2930,6 +2940,8 @@ async def update_spend(
)
break
except httpx.ReadTimeout:
if i is None:
i = 0
if i >= n_retry_times: # If we've reached the maximum number of retries
raise # Re-raise the last exception
# Optionally, sleep for a bit before retrying
@ -3044,10 +3056,11 @@ def get_error_message_str(e: Exception) -> str:
elif isinstance(e.detail, dict):
error_message = json.dumps(e.detail)
elif hasattr(e, "message"):
if isinstance(e.message, "str"):
error_message = e.message
elif isinstance(e.message, dict):
error_message = json.dumps(e.message)
_error = getattr(e, "message", None)
if isinstance(_error, str):
error_message = _error
elif isinstance(_error, dict):
error_message = json.dumps(_error)
else:
error_message = str(e)
else: