(Feat) soft budget alerts on keys (#7623)

* class WebhookEvent(CallInfo):
Add

* handle soft budget alerts

* handle soft budget

* fix budget alerts

* fix CallInfo

* fix _get_user_info_str

* test_soft_budget_alerts

* test_soft_budget_alert
This commit is contained in:
Ishaan Jaff 2025-01-07 21:36:34 -08:00 committed by GitHub
parent 4e69711411
commit 081826a5d6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 201 additions and 33 deletions

View file

@ -570,6 +570,7 @@ class SlackAlerting(CustomBatchLogger):
self,
type: Literal[
"token_budget",
"soft_budget",
"user_budget",
"team_budget",
"proxy_budget",
@ -590,12 +591,14 @@ class SlackAlerting(CustomBatchLogger):
return
_id: Optional[str] = "default_id" # used for caching
user_info_json = user_info.model_dump(exclude_none=True)
user_info_str = ""
for k, v in user_info_json.items():
user_info_str = "\n{}: {}\n".format(k, v)
user_info_str = self._get_user_info_str(user_info)
event: Optional[
Literal["budget_crossed", "threshold_crossed", "projected_limit_exceeded"]
Literal[
"budget_crossed",
"threshold_crossed",
"projected_limit_exceeded",
"soft_budget_crossed",
]
] = None
event_group: Optional[
Literal["internal_user", "team", "key", "proxy", "customer"]
@ -605,6 +608,9 @@ class SlackAlerting(CustomBatchLogger):
if type == "proxy_budget":
event_group = "proxy"
event_message += "Proxy Budget: "
elif type == "soft_budget":
event_group = "proxy"
event_message += "Soft Budget Crossed: "
elif type == "user_budget":
event_group = "internal_user"
event_message += "User Budget: "
@ -624,27 +630,31 @@ class SlackAlerting(CustomBatchLogger):
_id = user_info.token
# percent of max_budget left to spend
if user_info.max_budget is None:
if user_info.max_budget is None and user_info.soft_budget is None:
return
percent_left: float = 0
if user_info.max_budget is not None:
if user_info.max_budget > 0:
percent_left = (
user_info.max_budget - user_info.spend
) / user_info.max_budget
else:
percent_left = 0
# check if crossed budget
if user_info.max_budget is not None:
if user_info.spend >= user_info.max_budget:
event = "budget_crossed"
event_message += f"Budget Crossed\n Total Budget:`{user_info.max_budget}`"
event_message += (
f"Budget Crossed\n Total Budget:`{user_info.max_budget}`"
)
elif percent_left <= 0.05:
event = "threshold_crossed"
event_message += "5% Threshold Crossed "
elif percent_left <= 0.15:
event = "threshold_crossed"
event_message += "15% Threshold Crossed"
elif user_info.soft_budget is not None:
if user_info.spend >= user_info.soft_budget:
event = "soft_budget_crossed"
if event is not None and event_group is not None:
_cache_key = "budget_alerts:{}:{}".format(event, _id)
result = await _cache.async_get_cache(key=_cache_key)
@ -671,6 +681,18 @@ class SlackAlerting(CustomBatchLogger):
return
return
def _get_user_info_str(self, user_info: CallInfo) -> str:
"""
Create a standard message for a budget alert
"""
_all_fields_as_dict = user_info.model_dump(exclude_none=True)
_all_fields_as_dict.pop("token")
msg = ""
for k, v in _all_fields_as_dict.items():
msg += f"*{k}:* `{v}`\n"
return msg
async def customer_spend_alert(
self,
token: Optional[str],

View file

@ -1637,6 +1637,7 @@ class CallInfo(LiteLLMPydanticObjectBase):
spend: float
max_budget: Optional[float] = None
soft_budget: Optional[float] = None
token: Optional[str] = Field(default=None, description="Hashed value of that key")
customer_id: Optional[str] = None
user_id: Optional[str] = None
@ -1651,6 +1652,7 @@ class CallInfo(LiteLLMPydanticObjectBase):
class WebhookEvent(CallInfo):
event: Literal[
"budget_crossed",
"soft_budget_crossed",
"threshold_crossed",
"projected_limit_exceeded",
"key_created",

View file

@ -1014,6 +1014,30 @@ async def user_api_key_auth( # noqa: PLR0915
current_cost=valid_token.spend,
max_budget=valid_token.max_budget,
)
if valid_token.soft_budget and valid_token.spend >= valid_token.soft_budget:
verbose_proxy_logger.debug(
"Crossed Soft Budget for token %s, spend %s, soft_budget %s",
valid_token.token,
valid_token.spend,
valid_token.soft_budget,
)
call_info = CallInfo(
token=valid_token.token,
spend=valid_token.spend,
max_budget=valid_token.max_budget,
soft_budget=valid_token.soft_budget,
user_id=valid_token.user_id,
team_id=valid_token.team_id,
team_alias=valid_token.team_alias,
user_email=None,
key_alias=valid_token.key_alias,
)
asyncio.create_task(
proxy_logging_obj.budget_alerts(
type="soft_budget",
user_info=call_info,
)
)
# Check 5. Token Model Spend is under Model budget
max_budget_per_model = valid_token.model_max_budget

View file

@ -1,13 +1,13 @@
model_list:
- model_name: "fake-openai-endpoint"
litellm_params:
model: aiohttp_openai/any
api_base: https://example-openai-endpoint.onrender.com/chat/completions
api_key: "ishaan"
model: openai/gpt-4o
general_settings:
key_management_system: "hashicorp_vault" # 👈 KEY CHANGE
key_management_settings:
store_virtual_keys: true # OPTIONAL. Defaults to False, when True will store virtual keys in secret manager
prefix_for_stored_virtual_keys: "litellm/" # OPTIONAL. If set, this prefix will be used for stored virtual keys in the secret manager
access_mode: "write_only" # Literal["read_only", "write_only", "read_and_write"]
alerting: ["slack"]
alert_to_webhook_url: {
"budget_alerts": ["https://hooks.slack.com/services/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"]
}
alerting_args:
budget_alert_ttl: 10
log_to_console: true

View file

@ -600,6 +600,7 @@ class ProxyLogging:
type: Literal[
"token_budget",
"user_budget",
"soft_budget",
"team_budget",
"proxy_budget",
"projected_limit_exceeded",
@ -1534,7 +1535,8 @@ class PrismaClient:
b.max_budget AS litellm_budget_table_max_budget,
b.tpm_limit AS litellm_budget_table_tpm_limit,
b.rpm_limit AS litellm_budget_table_rpm_limit,
b.model_max_budget as litellm_budget_table_model_max_budget
b.model_max_budget as litellm_budget_table_model_max_budget,
b.soft_budget as litellm_budget_table_soft_budget
FROM "LiteLLM_VerificationToken" AS v
LEFT JOIN "LiteLLM_TeamTable" AS t ON v.team_id = t.team_id
LEFT JOIN "LiteLLM_TeamMembership" AS tm ON v.team_id = tm.team_id AND tm.user_id = v.user_id

View file

@ -981,3 +981,43 @@ async def test_spend_report_cache(report_type):
else:
await slack_alerting.send_monthly_spend_report()
mock_send_alert.assert_not_called()
@pytest.mark.asyncio
async def test_soft_budget_alerts():
"""
Test if soft budget alerts (warnings when approaching budget limit) work correctly
- Test alert is sent when spend reaches 80% of budget
"""
slack_alerting = SlackAlerting(alerting=["webhook"])
with patch.object(slack_alerting, "send_alert", new=AsyncMock()) as mock_send_alert:
# Test 80% threshold
user_info = CallInfo(
token="test_token",
spend=80, # $80 spent
soft_budget=80,
user_id="test@test.com",
user_email="test@test.com",
key_alias="test-key",
)
await slack_alerting.budget_alerts(
type="soft_budget",
user_info=user_info,
)
mock_send_alert.assert_called_once()
# Verify alert message contains correct percentage
alert_message = mock_send_alert.call_args[1]["message"]
print(alert_message)
expected_message = (
"Soft Budget Crossed: \n\n"
"*spend:* `80.0`\n"
"*soft_budget:* `80.0`\n"
"*user_id:* `test@test.com`\n"
"*user_email:* `test@test.com`\n"
"*key_alias:* `test-key`\n"
)
assert alert_message == expected_message

View file

@ -551,3 +551,81 @@ async def test_auth_with_form_data_and_model():
# Test user_api_key_auth with form data request
response = await user_api_key_auth(request=request, api_key="Bearer " + user_key)
assert response.models == ["gpt-4"], "Model from virtual key should be preserved"
@pytest.mark.asyncio
async def test_soft_budget_alert():
"""
Test that when a token's spend exceeds soft_budget, it triggers a budget alert but allows the request
"""
import asyncio
import time
from fastapi import Request
from starlette.datastructures import URL
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.proxy_server import hash_token, user_api_key_cache
# Setup
user_key = "sk-12345"
soft_budget = 10
current_spend = 15 # Spend exceeds soft budget
# Create a valid token with soft budget
valid_token = UserAPIKeyAuth(
token=hash_token(user_key),
soft_budget=soft_budget,
spend=current_spend,
last_refreshed_at=time.time(),
)
# Store in cache
user_api_key_cache.set_cache(key=hash_token(user_key), value=valid_token)
# Mock proxy server settings
setattr(litellm.proxy.proxy_server, "user_api_key_cache", user_api_key_cache)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
setattr(litellm.proxy.proxy_server, "prisma_client", AsyncMock())
# Create request
request = Request(scope={"type": "http"})
request._url = URL(url="/chat/completions")
# Track if budget_alerts was called
alert_called = False
original_budget_alerts = litellm.proxy.proxy_server.proxy_logging_obj.budget_alerts
async def mock_budget_alerts(*args, **kwargs):
nonlocal alert_called
if kwargs.get("type") == "soft_budget":
alert_called = True
return await original_budget_alerts(*args, **kwargs)
# Patch the budget_alerts method
setattr(
litellm.proxy.proxy_server.proxy_logging_obj,
"budget_alerts",
mock_budget_alerts,
)
try:
# Call user_api_key_auth
response = await user_api_key_auth(
request=request, api_key="Bearer " + user_key
)
# Assert the request was allowed (no exception raised)
assert response is not None
# Assert the alert was triggered
await asyncio.sleep(3)
assert alert_called == True, "Soft budget alert should have been triggered"
finally:
# Restore original budget_alerts
setattr(
litellm.proxy.proxy_server.proxy_logging_obj,
"budget_alerts",
original_budget_alerts,
)