mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
(Feat) soft budget alerts on keys (#7623)
* class WebhookEvent(CallInfo): Add * handle soft budget alerts * handle soft budget * fix budget alerts * fix CallInfo * fix _get_user_info_str * test_soft_budget_alerts * test_soft_budget_alert
This commit is contained in:
parent
4e69711411
commit
081826a5d6
7 changed files with 201 additions and 33 deletions
|
@ -570,6 +570,7 @@ class SlackAlerting(CustomBatchLogger):
|
|||
self,
|
||||
type: Literal[
|
||||
"token_budget",
|
||||
"soft_budget",
|
||||
"user_budget",
|
||||
"team_budget",
|
||||
"proxy_budget",
|
||||
|
@ -590,12 +591,14 @@ class SlackAlerting(CustomBatchLogger):
|
|||
return
|
||||
_id: Optional[str] = "default_id" # used for caching
|
||||
user_info_json = user_info.model_dump(exclude_none=True)
|
||||
user_info_str = ""
|
||||
for k, v in user_info_json.items():
|
||||
user_info_str = "\n{}: {}\n".format(k, v)
|
||||
|
||||
user_info_str = self._get_user_info_str(user_info)
|
||||
event: Optional[
|
||||
Literal["budget_crossed", "threshold_crossed", "projected_limit_exceeded"]
|
||||
Literal[
|
||||
"budget_crossed",
|
||||
"threshold_crossed",
|
||||
"projected_limit_exceeded",
|
||||
"soft_budget_crossed",
|
||||
]
|
||||
] = None
|
||||
event_group: Optional[
|
||||
Literal["internal_user", "team", "key", "proxy", "customer"]
|
||||
|
@ -605,6 +608,9 @@ class SlackAlerting(CustomBatchLogger):
|
|||
if type == "proxy_budget":
|
||||
event_group = "proxy"
|
||||
event_message += "Proxy Budget: "
|
||||
elif type == "soft_budget":
|
||||
event_group = "proxy"
|
||||
event_message += "Soft Budget Crossed: "
|
||||
elif type == "user_budget":
|
||||
event_group = "internal_user"
|
||||
event_message += "User Budget: "
|
||||
|
@ -624,27 +630,31 @@ class SlackAlerting(CustomBatchLogger):
|
|||
_id = user_info.token
|
||||
|
||||
# percent of max_budget left to spend
|
||||
if user_info.max_budget is None:
|
||||
if user_info.max_budget is None and user_info.soft_budget is None:
|
||||
return
|
||||
|
||||
percent_left: float = 0
|
||||
if user_info.max_budget is not None:
|
||||
if user_info.max_budget > 0:
|
||||
percent_left = (
|
||||
user_info.max_budget - user_info.spend
|
||||
) / user_info.max_budget
|
||||
else:
|
||||
percent_left = 0
|
||||
|
||||
# check if crossed budget
|
||||
if user_info.max_budget is not None:
|
||||
if user_info.spend >= user_info.max_budget:
|
||||
event = "budget_crossed"
|
||||
event_message += f"Budget Crossed\n Total Budget:`{user_info.max_budget}`"
|
||||
event_message += (
|
||||
f"Budget Crossed\n Total Budget:`{user_info.max_budget}`"
|
||||
)
|
||||
elif percent_left <= 0.05:
|
||||
event = "threshold_crossed"
|
||||
event_message += "5% Threshold Crossed "
|
||||
elif percent_left <= 0.15:
|
||||
event = "threshold_crossed"
|
||||
event_message += "15% Threshold Crossed"
|
||||
|
||||
elif user_info.soft_budget is not None:
|
||||
if user_info.spend >= user_info.soft_budget:
|
||||
event = "soft_budget_crossed"
|
||||
if event is not None and event_group is not None:
|
||||
_cache_key = "budget_alerts:{}:{}".format(event, _id)
|
||||
result = await _cache.async_get_cache(key=_cache_key)
|
||||
|
@ -671,6 +681,18 @@ class SlackAlerting(CustomBatchLogger):
|
|||
return
|
||||
return
|
||||
|
||||
def _get_user_info_str(self, user_info: CallInfo) -> str:
|
||||
"""
|
||||
Create a standard message for a budget alert
|
||||
"""
|
||||
_all_fields_as_dict = user_info.model_dump(exclude_none=True)
|
||||
_all_fields_as_dict.pop("token")
|
||||
msg = ""
|
||||
for k, v in _all_fields_as_dict.items():
|
||||
msg += f"*{k}:* `{v}`\n"
|
||||
|
||||
return msg
|
||||
|
||||
async def customer_spend_alert(
|
||||
self,
|
||||
token: Optional[str],
|
||||
|
|
|
@ -1637,6 +1637,7 @@ class CallInfo(LiteLLMPydanticObjectBase):
|
|||
|
||||
spend: float
|
||||
max_budget: Optional[float] = None
|
||||
soft_budget: Optional[float] = None
|
||||
token: Optional[str] = Field(default=None, description="Hashed value of that key")
|
||||
customer_id: Optional[str] = None
|
||||
user_id: Optional[str] = None
|
||||
|
@ -1651,6 +1652,7 @@ class CallInfo(LiteLLMPydanticObjectBase):
|
|||
class WebhookEvent(CallInfo):
|
||||
event: Literal[
|
||||
"budget_crossed",
|
||||
"soft_budget_crossed",
|
||||
"threshold_crossed",
|
||||
"projected_limit_exceeded",
|
||||
"key_created",
|
||||
|
|
|
@ -1014,6 +1014,30 @@ async def user_api_key_auth( # noqa: PLR0915
|
|||
current_cost=valid_token.spend,
|
||||
max_budget=valid_token.max_budget,
|
||||
)
|
||||
if valid_token.soft_budget and valid_token.spend >= valid_token.soft_budget:
|
||||
verbose_proxy_logger.debug(
|
||||
"Crossed Soft Budget for token %s, spend %s, soft_budget %s",
|
||||
valid_token.token,
|
||||
valid_token.spend,
|
||||
valid_token.soft_budget,
|
||||
)
|
||||
call_info = CallInfo(
|
||||
token=valid_token.token,
|
||||
spend=valid_token.spend,
|
||||
max_budget=valid_token.max_budget,
|
||||
soft_budget=valid_token.soft_budget,
|
||||
user_id=valid_token.user_id,
|
||||
team_id=valid_token.team_id,
|
||||
team_alias=valid_token.team_alias,
|
||||
user_email=None,
|
||||
key_alias=valid_token.key_alias,
|
||||
)
|
||||
asyncio.create_task(
|
||||
proxy_logging_obj.budget_alerts(
|
||||
type="soft_budget",
|
||||
user_info=call_info,
|
||||
)
|
||||
)
|
||||
|
||||
# Check 5. Token Model Spend is under Model budget
|
||||
max_budget_per_model = valid_token.model_max_budget
|
||||
|
|
|
@ -1,13 +1,13 @@
|
|||
model_list:
|
||||
- model_name: "fake-openai-endpoint"
|
||||
litellm_params:
|
||||
model: aiohttp_openai/any
|
||||
api_base: https://example-openai-endpoint.onrender.com/chat/completions
|
||||
api_key: "ishaan"
|
||||
model: openai/gpt-4o
|
||||
|
||||
general_settings:
|
||||
key_management_system: "hashicorp_vault" # 👈 KEY CHANGE
|
||||
key_management_settings:
|
||||
store_virtual_keys: true # OPTIONAL. Defaults to False, when True will store virtual keys in secret manager
|
||||
prefix_for_stored_virtual_keys: "litellm/" # OPTIONAL. If set, this prefix will be used for stored virtual keys in the secret manager
|
||||
access_mode: "write_only" # Literal["read_only", "write_only", "read_and_write"]
|
||||
alerting: ["slack"]
|
||||
alert_to_webhook_url: {
|
||||
"budget_alerts": ["https://hooks.slack.com/services/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"]
|
||||
}
|
||||
alerting_args:
|
||||
budget_alert_ttl: 10
|
||||
log_to_console: true
|
|
@ -600,6 +600,7 @@ class ProxyLogging:
|
|||
type: Literal[
|
||||
"token_budget",
|
||||
"user_budget",
|
||||
"soft_budget",
|
||||
"team_budget",
|
||||
"proxy_budget",
|
||||
"projected_limit_exceeded",
|
||||
|
@ -1534,7 +1535,8 @@ class PrismaClient:
|
|||
b.max_budget AS litellm_budget_table_max_budget,
|
||||
b.tpm_limit AS litellm_budget_table_tpm_limit,
|
||||
b.rpm_limit AS litellm_budget_table_rpm_limit,
|
||||
b.model_max_budget as litellm_budget_table_model_max_budget
|
||||
b.model_max_budget as litellm_budget_table_model_max_budget,
|
||||
b.soft_budget as litellm_budget_table_soft_budget
|
||||
FROM "LiteLLM_VerificationToken" AS v
|
||||
LEFT JOIN "LiteLLM_TeamTable" AS t ON v.team_id = t.team_id
|
||||
LEFT JOIN "LiteLLM_TeamMembership" AS tm ON v.team_id = tm.team_id AND tm.user_id = v.user_id
|
||||
|
|
|
@ -981,3 +981,43 @@ async def test_spend_report_cache(report_type):
|
|||
else:
|
||||
await slack_alerting.send_monthly_spend_report()
|
||||
mock_send_alert.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_soft_budget_alerts():
|
||||
"""
|
||||
Test if soft budget alerts (warnings when approaching budget limit) work correctly
|
||||
- Test alert is sent when spend reaches 80% of budget
|
||||
"""
|
||||
slack_alerting = SlackAlerting(alerting=["webhook"])
|
||||
|
||||
with patch.object(slack_alerting, "send_alert", new=AsyncMock()) as mock_send_alert:
|
||||
# Test 80% threshold
|
||||
user_info = CallInfo(
|
||||
token="test_token",
|
||||
spend=80, # $80 spent
|
||||
soft_budget=80,
|
||||
user_id="test@test.com",
|
||||
user_email="test@test.com",
|
||||
key_alias="test-key",
|
||||
)
|
||||
|
||||
await slack_alerting.budget_alerts(
|
||||
type="soft_budget",
|
||||
user_info=user_info,
|
||||
)
|
||||
mock_send_alert.assert_called_once()
|
||||
|
||||
# Verify alert message contains correct percentage
|
||||
alert_message = mock_send_alert.call_args[1]["message"]
|
||||
print(alert_message)
|
||||
|
||||
expected_message = (
|
||||
"Soft Budget Crossed: \n\n"
|
||||
"*spend:* `80.0`\n"
|
||||
"*soft_budget:* `80.0`\n"
|
||||
"*user_id:* `test@test.com`\n"
|
||||
"*user_email:* `test@test.com`\n"
|
||||
"*key_alias:* `test-key`\n"
|
||||
)
|
||||
assert alert_message == expected_message
|
||||
|
|
|
@ -551,3 +551,81 @@ async def test_auth_with_form_data_and_model():
|
|||
# Test user_api_key_auth with form data request
|
||||
response = await user_api_key_auth(request=request, api_key="Bearer " + user_key)
|
||||
assert response.models == ["gpt-4"], "Model from virtual key should be preserved"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_soft_budget_alert():
|
||||
"""
|
||||
Test that when a token's spend exceeds soft_budget, it triggers a budget alert but allows the request
|
||||
"""
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
from fastapi import Request
|
||||
from starlette.datastructures import URL
|
||||
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.proxy.proxy_server import hash_token, user_api_key_cache
|
||||
|
||||
# Setup
|
||||
user_key = "sk-12345"
|
||||
soft_budget = 10
|
||||
current_spend = 15 # Spend exceeds soft budget
|
||||
|
||||
# Create a valid token with soft budget
|
||||
valid_token = UserAPIKeyAuth(
|
||||
token=hash_token(user_key),
|
||||
soft_budget=soft_budget,
|
||||
spend=current_spend,
|
||||
last_refreshed_at=time.time(),
|
||||
)
|
||||
|
||||
# Store in cache
|
||||
user_api_key_cache.set_cache(key=hash_token(user_key), value=valid_token)
|
||||
|
||||
# Mock proxy server settings
|
||||
setattr(litellm.proxy.proxy_server, "user_api_key_cache", user_api_key_cache)
|
||||
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
||||
setattr(litellm.proxy.proxy_server, "prisma_client", AsyncMock())
|
||||
|
||||
# Create request
|
||||
request = Request(scope={"type": "http"})
|
||||
request._url = URL(url="/chat/completions")
|
||||
|
||||
# Track if budget_alerts was called
|
||||
alert_called = False
|
||||
original_budget_alerts = litellm.proxy.proxy_server.proxy_logging_obj.budget_alerts
|
||||
|
||||
async def mock_budget_alerts(*args, **kwargs):
|
||||
nonlocal alert_called
|
||||
if kwargs.get("type") == "soft_budget":
|
||||
alert_called = True
|
||||
return await original_budget_alerts(*args, **kwargs)
|
||||
|
||||
# Patch the budget_alerts method
|
||||
setattr(
|
||||
litellm.proxy.proxy_server.proxy_logging_obj,
|
||||
"budget_alerts",
|
||||
mock_budget_alerts,
|
||||
)
|
||||
|
||||
try:
|
||||
# Call user_api_key_auth
|
||||
response = await user_api_key_auth(
|
||||
request=request, api_key="Bearer " + user_key
|
||||
)
|
||||
|
||||
# Assert the request was allowed (no exception raised)
|
||||
assert response is not None
|
||||
# Assert the alert was triggered
|
||||
await asyncio.sleep(3)
|
||||
assert alert_called == True, "Soft budget alert should have been triggered"
|
||||
|
||||
finally:
|
||||
# Restore original budget_alerts
|
||||
setattr(
|
||||
litellm.proxy.proxy_server.proxy_logging_obj,
|
||||
"budget_alerts",
|
||||
original_budget_alerts,
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue