mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
(Feat) soft budget alerts on keys (#7623)
* class WebhookEvent(CallInfo): Add * handle soft budget alerts * handle soft budget * fix budget alerts * fix CallInfo * fix _get_user_info_str * test_soft_budget_alerts * test_soft_budget_alert
This commit is contained in:
parent
4e69711411
commit
081826a5d6
7 changed files with 201 additions and 33 deletions
|
@ -570,6 +570,7 @@ class SlackAlerting(CustomBatchLogger):
|
||||||
self,
|
self,
|
||||||
type: Literal[
|
type: Literal[
|
||||||
"token_budget",
|
"token_budget",
|
||||||
|
"soft_budget",
|
||||||
"user_budget",
|
"user_budget",
|
||||||
"team_budget",
|
"team_budget",
|
||||||
"proxy_budget",
|
"proxy_budget",
|
||||||
|
@ -590,12 +591,14 @@ class SlackAlerting(CustomBatchLogger):
|
||||||
return
|
return
|
||||||
_id: Optional[str] = "default_id" # used for caching
|
_id: Optional[str] = "default_id" # used for caching
|
||||||
user_info_json = user_info.model_dump(exclude_none=True)
|
user_info_json = user_info.model_dump(exclude_none=True)
|
||||||
user_info_str = ""
|
user_info_str = self._get_user_info_str(user_info)
|
||||||
for k, v in user_info_json.items():
|
|
||||||
user_info_str = "\n{}: {}\n".format(k, v)
|
|
||||||
|
|
||||||
event: Optional[
|
event: Optional[
|
||||||
Literal["budget_crossed", "threshold_crossed", "projected_limit_exceeded"]
|
Literal[
|
||||||
|
"budget_crossed",
|
||||||
|
"threshold_crossed",
|
||||||
|
"projected_limit_exceeded",
|
||||||
|
"soft_budget_crossed",
|
||||||
|
]
|
||||||
] = None
|
] = None
|
||||||
event_group: Optional[
|
event_group: Optional[
|
||||||
Literal["internal_user", "team", "key", "proxy", "customer"]
|
Literal["internal_user", "team", "key", "proxy", "customer"]
|
||||||
|
@ -605,6 +608,9 @@ class SlackAlerting(CustomBatchLogger):
|
||||||
if type == "proxy_budget":
|
if type == "proxy_budget":
|
||||||
event_group = "proxy"
|
event_group = "proxy"
|
||||||
event_message += "Proxy Budget: "
|
event_message += "Proxy Budget: "
|
||||||
|
elif type == "soft_budget":
|
||||||
|
event_group = "proxy"
|
||||||
|
event_message += "Soft Budget Crossed: "
|
||||||
elif type == "user_budget":
|
elif type == "user_budget":
|
||||||
event_group = "internal_user"
|
event_group = "internal_user"
|
||||||
event_message += "User Budget: "
|
event_message += "User Budget: "
|
||||||
|
@ -624,27 +630,31 @@ class SlackAlerting(CustomBatchLogger):
|
||||||
_id = user_info.token
|
_id = user_info.token
|
||||||
|
|
||||||
# percent of max_budget left to spend
|
# percent of max_budget left to spend
|
||||||
if user_info.max_budget is None:
|
if user_info.max_budget is None and user_info.soft_budget is None:
|
||||||
return
|
return
|
||||||
|
percent_left: float = 0
|
||||||
|
if user_info.max_budget is not None:
|
||||||
if user_info.max_budget > 0:
|
if user_info.max_budget > 0:
|
||||||
percent_left = (
|
percent_left = (
|
||||||
user_info.max_budget - user_info.spend
|
user_info.max_budget - user_info.spend
|
||||||
) / user_info.max_budget
|
) / user_info.max_budget
|
||||||
else:
|
|
||||||
percent_left = 0
|
|
||||||
|
|
||||||
# check if crossed budget
|
# check if crossed budget
|
||||||
|
if user_info.max_budget is not None:
|
||||||
if user_info.spend >= user_info.max_budget:
|
if user_info.spend >= user_info.max_budget:
|
||||||
event = "budget_crossed"
|
event = "budget_crossed"
|
||||||
event_message += f"Budget Crossed\n Total Budget:`{user_info.max_budget}`"
|
event_message += (
|
||||||
|
f"Budget Crossed\n Total Budget:`{user_info.max_budget}`"
|
||||||
|
)
|
||||||
elif percent_left <= 0.05:
|
elif percent_left <= 0.05:
|
||||||
event = "threshold_crossed"
|
event = "threshold_crossed"
|
||||||
event_message += "5% Threshold Crossed "
|
event_message += "5% Threshold Crossed "
|
||||||
elif percent_left <= 0.15:
|
elif percent_left <= 0.15:
|
||||||
event = "threshold_crossed"
|
event = "threshold_crossed"
|
||||||
event_message += "15% Threshold Crossed"
|
event_message += "15% Threshold Crossed"
|
||||||
|
elif user_info.soft_budget is not None:
|
||||||
|
if user_info.spend >= user_info.soft_budget:
|
||||||
|
event = "soft_budget_crossed"
|
||||||
if event is not None and event_group is not None:
|
if event is not None and event_group is not None:
|
||||||
_cache_key = "budget_alerts:{}:{}".format(event, _id)
|
_cache_key = "budget_alerts:{}:{}".format(event, _id)
|
||||||
result = await _cache.async_get_cache(key=_cache_key)
|
result = await _cache.async_get_cache(key=_cache_key)
|
||||||
|
@ -671,6 +681,18 @@ class SlackAlerting(CustomBatchLogger):
|
||||||
return
|
return
|
||||||
return
|
return
|
||||||
|
|
||||||
|
def _get_user_info_str(self, user_info: CallInfo) -> str:
|
||||||
|
"""
|
||||||
|
Create a standard message for a budget alert
|
||||||
|
"""
|
||||||
|
_all_fields_as_dict = user_info.model_dump(exclude_none=True)
|
||||||
|
_all_fields_as_dict.pop("token")
|
||||||
|
msg = ""
|
||||||
|
for k, v in _all_fields_as_dict.items():
|
||||||
|
msg += f"*{k}:* `{v}`\n"
|
||||||
|
|
||||||
|
return msg
|
||||||
|
|
||||||
async def customer_spend_alert(
|
async def customer_spend_alert(
|
||||||
self,
|
self,
|
||||||
token: Optional[str],
|
token: Optional[str],
|
||||||
|
|
|
@ -1637,6 +1637,7 @@ class CallInfo(LiteLLMPydanticObjectBase):
|
||||||
|
|
||||||
spend: float
|
spend: float
|
||||||
max_budget: Optional[float] = None
|
max_budget: Optional[float] = None
|
||||||
|
soft_budget: Optional[float] = None
|
||||||
token: Optional[str] = Field(default=None, description="Hashed value of that key")
|
token: Optional[str] = Field(default=None, description="Hashed value of that key")
|
||||||
customer_id: Optional[str] = None
|
customer_id: Optional[str] = None
|
||||||
user_id: Optional[str] = None
|
user_id: Optional[str] = None
|
||||||
|
@ -1651,6 +1652,7 @@ class CallInfo(LiteLLMPydanticObjectBase):
|
||||||
class WebhookEvent(CallInfo):
|
class WebhookEvent(CallInfo):
|
||||||
event: Literal[
|
event: Literal[
|
||||||
"budget_crossed",
|
"budget_crossed",
|
||||||
|
"soft_budget_crossed",
|
||||||
"threshold_crossed",
|
"threshold_crossed",
|
||||||
"projected_limit_exceeded",
|
"projected_limit_exceeded",
|
||||||
"key_created",
|
"key_created",
|
||||||
|
|
|
@ -1014,6 +1014,30 @@ async def user_api_key_auth( # noqa: PLR0915
|
||||||
current_cost=valid_token.spend,
|
current_cost=valid_token.spend,
|
||||||
max_budget=valid_token.max_budget,
|
max_budget=valid_token.max_budget,
|
||||||
)
|
)
|
||||||
|
if valid_token.soft_budget and valid_token.spend >= valid_token.soft_budget:
|
||||||
|
verbose_proxy_logger.debug(
|
||||||
|
"Crossed Soft Budget for token %s, spend %s, soft_budget %s",
|
||||||
|
valid_token.token,
|
||||||
|
valid_token.spend,
|
||||||
|
valid_token.soft_budget,
|
||||||
|
)
|
||||||
|
call_info = CallInfo(
|
||||||
|
token=valid_token.token,
|
||||||
|
spend=valid_token.spend,
|
||||||
|
max_budget=valid_token.max_budget,
|
||||||
|
soft_budget=valid_token.soft_budget,
|
||||||
|
user_id=valid_token.user_id,
|
||||||
|
team_id=valid_token.team_id,
|
||||||
|
team_alias=valid_token.team_alias,
|
||||||
|
user_email=None,
|
||||||
|
key_alias=valid_token.key_alias,
|
||||||
|
)
|
||||||
|
asyncio.create_task(
|
||||||
|
proxy_logging_obj.budget_alerts(
|
||||||
|
type="soft_budget",
|
||||||
|
user_info=call_info,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# Check 5. Token Model Spend is under Model budget
|
# Check 5. Token Model Spend is under Model budget
|
||||||
max_budget_per_model = valid_token.model_max_budget
|
max_budget_per_model = valid_token.model_max_budget
|
||||||
|
|
|
@ -1,13 +1,13 @@
|
||||||
model_list:
|
model_list:
|
||||||
- model_name: "fake-openai-endpoint"
|
- model_name: "fake-openai-endpoint"
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: aiohttp_openai/any
|
model: openai/gpt-4o
|
||||||
api_base: https://example-openai-endpoint.onrender.com/chat/completions
|
|
||||||
api_key: "ishaan"
|
|
||||||
|
|
||||||
general_settings:
|
general_settings:
|
||||||
key_management_system: "hashicorp_vault" # 👈 KEY CHANGE
|
alerting: ["slack"]
|
||||||
key_management_settings:
|
alert_to_webhook_url: {
|
||||||
store_virtual_keys: true # OPTIONAL. Defaults to False, when True will store virtual keys in secret manager
|
"budget_alerts": ["https://hooks.slack.com/services/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"]
|
||||||
prefix_for_stored_virtual_keys: "litellm/" # OPTIONAL. If set, this prefix will be used for stored virtual keys in the secret manager
|
}
|
||||||
access_mode: "write_only" # Literal["read_only", "write_only", "read_and_write"]
|
alerting_args:
|
||||||
|
budget_alert_ttl: 10
|
||||||
|
log_to_console: true
|
|
@ -600,6 +600,7 @@ class ProxyLogging:
|
||||||
type: Literal[
|
type: Literal[
|
||||||
"token_budget",
|
"token_budget",
|
||||||
"user_budget",
|
"user_budget",
|
||||||
|
"soft_budget",
|
||||||
"team_budget",
|
"team_budget",
|
||||||
"proxy_budget",
|
"proxy_budget",
|
||||||
"projected_limit_exceeded",
|
"projected_limit_exceeded",
|
||||||
|
@ -1534,7 +1535,8 @@ class PrismaClient:
|
||||||
b.max_budget AS litellm_budget_table_max_budget,
|
b.max_budget AS litellm_budget_table_max_budget,
|
||||||
b.tpm_limit AS litellm_budget_table_tpm_limit,
|
b.tpm_limit AS litellm_budget_table_tpm_limit,
|
||||||
b.rpm_limit AS litellm_budget_table_rpm_limit,
|
b.rpm_limit AS litellm_budget_table_rpm_limit,
|
||||||
b.model_max_budget as litellm_budget_table_model_max_budget
|
b.model_max_budget as litellm_budget_table_model_max_budget,
|
||||||
|
b.soft_budget as litellm_budget_table_soft_budget
|
||||||
FROM "LiteLLM_VerificationToken" AS v
|
FROM "LiteLLM_VerificationToken" AS v
|
||||||
LEFT JOIN "LiteLLM_TeamTable" AS t ON v.team_id = t.team_id
|
LEFT JOIN "LiteLLM_TeamTable" AS t ON v.team_id = t.team_id
|
||||||
LEFT JOIN "LiteLLM_TeamMembership" AS tm ON v.team_id = tm.team_id AND tm.user_id = v.user_id
|
LEFT JOIN "LiteLLM_TeamMembership" AS tm ON v.team_id = tm.team_id AND tm.user_id = v.user_id
|
||||||
|
|
|
@ -981,3 +981,43 @@ async def test_spend_report_cache(report_type):
|
||||||
else:
|
else:
|
||||||
await slack_alerting.send_monthly_spend_report()
|
await slack_alerting.send_monthly_spend_report()
|
||||||
mock_send_alert.assert_not_called()
|
mock_send_alert.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_soft_budget_alerts():
|
||||||
|
"""
|
||||||
|
Test if soft budget alerts (warnings when approaching budget limit) work correctly
|
||||||
|
- Test alert is sent when spend reaches 80% of budget
|
||||||
|
"""
|
||||||
|
slack_alerting = SlackAlerting(alerting=["webhook"])
|
||||||
|
|
||||||
|
with patch.object(slack_alerting, "send_alert", new=AsyncMock()) as mock_send_alert:
|
||||||
|
# Test 80% threshold
|
||||||
|
user_info = CallInfo(
|
||||||
|
token="test_token",
|
||||||
|
spend=80, # $80 spent
|
||||||
|
soft_budget=80,
|
||||||
|
user_id="test@test.com",
|
||||||
|
user_email="test@test.com",
|
||||||
|
key_alias="test-key",
|
||||||
|
)
|
||||||
|
|
||||||
|
await slack_alerting.budget_alerts(
|
||||||
|
type="soft_budget",
|
||||||
|
user_info=user_info,
|
||||||
|
)
|
||||||
|
mock_send_alert.assert_called_once()
|
||||||
|
|
||||||
|
# Verify alert message contains correct percentage
|
||||||
|
alert_message = mock_send_alert.call_args[1]["message"]
|
||||||
|
print(alert_message)
|
||||||
|
|
||||||
|
expected_message = (
|
||||||
|
"Soft Budget Crossed: \n\n"
|
||||||
|
"*spend:* `80.0`\n"
|
||||||
|
"*soft_budget:* `80.0`\n"
|
||||||
|
"*user_id:* `test@test.com`\n"
|
||||||
|
"*user_email:* `test@test.com`\n"
|
||||||
|
"*key_alias:* `test-key`\n"
|
||||||
|
)
|
||||||
|
assert alert_message == expected_message
|
||||||
|
|
|
@ -551,3 +551,81 @@ async def test_auth_with_form_data_and_model():
|
||||||
# Test user_api_key_auth with form data request
|
# Test user_api_key_auth with form data request
|
||||||
response = await user_api_key_auth(request=request, api_key="Bearer " + user_key)
|
response = await user_api_key_auth(request=request, api_key="Bearer " + user_key)
|
||||||
assert response.models == ["gpt-4"], "Model from virtual key should be preserved"
|
assert response.models == ["gpt-4"], "Model from virtual key should be preserved"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_soft_budget_alert():
|
||||||
|
"""
|
||||||
|
Test that when a token's spend exceeds soft_budget, it triggers a budget alert but allows the request
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
|
||||||
|
from fastapi import Request
|
||||||
|
from starlette.datastructures import URL
|
||||||
|
|
||||||
|
from litellm.proxy._types import UserAPIKeyAuth
|
||||||
|
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||||
|
from litellm.proxy.proxy_server import hash_token, user_api_key_cache
|
||||||
|
|
||||||
|
# Setup
|
||||||
|
user_key = "sk-12345"
|
||||||
|
soft_budget = 10
|
||||||
|
current_spend = 15 # Spend exceeds soft budget
|
||||||
|
|
||||||
|
# Create a valid token with soft budget
|
||||||
|
valid_token = UserAPIKeyAuth(
|
||||||
|
token=hash_token(user_key),
|
||||||
|
soft_budget=soft_budget,
|
||||||
|
spend=current_spend,
|
||||||
|
last_refreshed_at=time.time(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store in cache
|
||||||
|
user_api_key_cache.set_cache(key=hash_token(user_key), value=valid_token)
|
||||||
|
|
||||||
|
# Mock proxy server settings
|
||||||
|
setattr(litellm.proxy.proxy_server, "user_api_key_cache", user_api_key_cache)
|
||||||
|
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
||||||
|
setattr(litellm.proxy.proxy_server, "prisma_client", AsyncMock())
|
||||||
|
|
||||||
|
# Create request
|
||||||
|
request = Request(scope={"type": "http"})
|
||||||
|
request._url = URL(url="/chat/completions")
|
||||||
|
|
||||||
|
# Track if budget_alerts was called
|
||||||
|
alert_called = False
|
||||||
|
original_budget_alerts = litellm.proxy.proxy_server.proxy_logging_obj.budget_alerts
|
||||||
|
|
||||||
|
async def mock_budget_alerts(*args, **kwargs):
|
||||||
|
nonlocal alert_called
|
||||||
|
if kwargs.get("type") == "soft_budget":
|
||||||
|
alert_called = True
|
||||||
|
return await original_budget_alerts(*args, **kwargs)
|
||||||
|
|
||||||
|
# Patch the budget_alerts method
|
||||||
|
setattr(
|
||||||
|
litellm.proxy.proxy_server.proxy_logging_obj,
|
||||||
|
"budget_alerts",
|
||||||
|
mock_budget_alerts,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Call user_api_key_auth
|
||||||
|
response = await user_api_key_auth(
|
||||||
|
request=request, api_key="Bearer " + user_key
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert the request was allowed (no exception raised)
|
||||||
|
assert response is not None
|
||||||
|
# Assert the alert was triggered
|
||||||
|
await asyncio.sleep(3)
|
||||||
|
assert alert_called == True, "Soft budget alert should have been triggered"
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Restore original budget_alerts
|
||||||
|
setattr(
|
||||||
|
litellm.proxy.proxy_server.proxy_logging_obj,
|
||||||
|
"budget_alerts",
|
||||||
|
original_budget_alerts,
|
||||||
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue