(e2e testing + minor refactor) - Virtual Key Max budget check (#7888)

* use helper _virtual_key_max_budget_check

* e2e testing for budget exceeded errors

* e2e budget testing

* test_chat_completion_budget_update

* test_chat_completion_high_budget
This commit is contained in:
Ishaan Jaff 2025-01-21 06:47:26 -08:00 committed by GitHub
parent 64e1df1f14
commit 0295f494b6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 267 additions and 43 deletions

View file

@ -9,6 +9,7 @@ Run checks for:
3. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget 3. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget
""" """
import asyncio
import time import time
import traceback import traceback
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional
@ -21,6 +22,7 @@ from litellm.caching.caching import DualCache
from litellm.caching.dual_cache import LimitedSizeOrderedDict from litellm.caching.dual_cache import LimitedSizeOrderedDict
from litellm.proxy._types import ( from litellm.proxy._types import (
DB_CONNECTION_ERROR_TYPES, DB_CONNECTION_ERROR_TYPES,
CallInfo,
LiteLLM_EndUserTable, LiteLLM_EndUserTable,
LiteLLM_JWTAuth, LiteLLM_JWTAuth,
LiteLLM_OrganizationTable, LiteLLM_OrganizationTable,
@ -931,3 +933,51 @@ async def is_valid_fallback_model(
) )
return True return True
async def _virtual_key_max_budget_check(
valid_token: UserAPIKeyAuth,
proxy_logging_obj: ProxyLogging,
user_obj: Optional[LiteLLM_UserTable] = None,
):
"""
Raises:
BudgetExceededError if the token is over it's max budget.
Triggers a budget alert if the token is over it's max budget.
"""
if valid_token.spend is not None and valid_token.max_budget is not None:
####################################
# collect information for alerting #
####################################
user_email = None
# Check if the token has any user id information
if user_obj is not None:
user_email = user_obj.user_email
call_info = CallInfo(
token=valid_token.token,
spend=valid_token.spend,
max_budget=valid_token.max_budget,
user_id=valid_token.user_id,
team_id=valid_token.team_id,
user_email=user_email,
key_alias=valid_token.key_alias,
)
asyncio.create_task(
proxy_logging_obj.budget_alerts(
type="token_budget",
user_info=call_info,
)
)
####################################
# collect information for alerting #
####################################
if valid_token.spend >= valid_token.max_budget:
raise litellm.BudgetExceededError(
current_cost=valid_token.spend,
max_budget=valid_token.max_budget,
)

View file

@ -24,6 +24,7 @@ from litellm.proxy._types import *
from litellm.proxy.auth.auth_checks import ( from litellm.proxy.auth.auth_checks import (
_cache_key_object, _cache_key_object,
_handle_failed_db_connection_for_get_key_object, _handle_failed_db_connection_for_get_key_object,
_virtual_key_max_budget_check,
allowed_routes_check, allowed_routes_check,
can_key_call_model, can_key_call_model,
common_checks, common_checks,
@ -1092,41 +1093,10 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
) )
# Check 4. Token Spend is under budget # Check 4. Token Spend is under budget
if valid_token.spend is not None and valid_token.max_budget is not None: await _virtual_key_max_budget_check(
valid_token=valid_token,
#################################### proxy_logging_obj=proxy_logging_obj,
# collect information for alerting # user_obj=user_obj,
####################################
user_email = None
# Check if the token has any user id information
if user_obj is not None:
user_email = user_obj.user_email
call_info = CallInfo(
token=valid_token.token,
spend=valid_token.spend,
max_budget=valid_token.max_budget,
user_id=valid_token.user_id,
team_id=valid_token.team_id,
user_email=user_email,
key_alias=valid_token.key_alias,
)
asyncio.create_task(
proxy_logging_obj.budget_alerts(
type="token_budget",
user_info=call_info,
)
)
####################################
# collect information for alerting #
####################################
if valid_token.spend >= valid_token.max_budget:
raise litellm.BudgetExceededError(
current_cost=valid_token.spend,
max_budget=valid_token.max_budget,
) )
if valid_token.soft_budget and valid_token.spend >= valid_token.soft_budget: if valid_token.soft_budget and valid_token.spend >= valid_token.soft_budget:

View file

@ -3,14 +3,10 @@ model_list:
litellm_params: litellm_params:
model: openai/* model: openai/*
api_key: os.environ/OPENAI_API_KEY api_key: os.environ/OPENAI_API_KEY
model_info: - model_name: text-embedding-ada-002
health_check_model: openai/gpt-4o-mini
- model_name: anthropic/*
litellm_params: litellm_params:
model: anthropic/* model: openai/text-embedding-ada-002
api_key: os.environ/ANTHROPIC_API_KEY api_key: os.environ/OPENAI_API_KEY
model_info:
health_check_model: anthropic/claude-3-5-sonnet-20240620
- model_name: fake-openai-endpoint - model_name: fake-openai-endpoint
litellm_params: litellm_params:
model: openai/fake model: openai/fake

View file

@ -0,0 +1,208 @@
import pytest
import asyncio
import aiohttp
import json
async def make_calls_until_budget_exceeded(session, key: str, call_function, **kwargs):
"""Helper function to make API calls until budget is exceeded. Verify that the budget is exceeded error is returned."""
MAX_CALLS = 50
call_count = 0
try:
while call_count < MAX_CALLS:
await call_function(session=session, key=key, **kwargs)
call_count += 1
pytest.fail(f"Budget was not exceeded after {MAX_CALLS} calls")
except Exception as e:
print("vars: ", vars(e))
print("e.body: ", e.body)
error_dict = e.body
print("error_dict: ", error_dict)
# Check error structure and values that should be consistent
assert (
error_dict["code"] == "400"
), f"Expected error code 400, got: {error_dict['code']}"
assert (
error_dict["type"] == "budget_exceeded"
), f"Expected error type budget_exceeded, got: {error_dict['type']}"
# Check message contains required parts without checking specific values
message = error_dict["message"]
assert (
"Budget has been exceeded!" in message
), f"Expected message to start with 'Budget has been exceeded!', got: {message}"
assert (
"Current cost:" in message
), f"Expected message to contain 'Current cost:', got: {message}"
assert (
"Max budget:" in message
), f"Expected message to contain 'Max budget:', got: {message}"
return call_count
async def generate_key(
session,
max_budget=None,
):
url = "http://0.0.0.0:4000/key/generate"
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
data = {
"max_budget": max_budget,
}
async with session.post(url, headers=headers, json=data) as response:
return await response.json()
async def chat_completion(session, key: str, model: str):
"""Make a chat completion request using OpenAI SDK"""
from openai import AsyncOpenAI
import uuid
client = AsyncOpenAI(
api_key=key, base_url="http://0.0.0.0:4000/v1" # Point to our local proxy
)
response = await client.chat.completions.create(
model=model,
messages=[{"role": "user", "content": f"Say hello! {uuid.uuid4()}" * 100}],
)
return response
async def update_key_budget(session, key: str, max_budget: float):
"""Helper function to update a key's max budget"""
url = "http://0.0.0.0:4000/key/update"
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
data = {
"key": key,
"max_budget": max_budget,
}
async with session.post(url, headers=headers, json=data) as response:
return await response.json()
@pytest.mark.asyncio
async def test_chat_completion_low_budget():
"""
Test budget enforcement for chat completions:
1. Create key with $0.01 budget
2. Make chat completion calls until budget exceeded
3. Verify budget exceeded error
"""
async with aiohttp.ClientSession() as session:
# Create key with $0.01 budget
key_gen = await generate_key(session=session, max_budget=0.0000000005)
print("response from key generation: ", key_gen)
key = key_gen["key"]
# Make calls until budget exceeded
calls_made = await make_calls_until_budget_exceeded(
session=session,
key=key,
call_function=chat_completion,
model="fake-openai-endpoint",
)
assert (
calls_made > 0
), "Should make at least one successful call before budget exceeded"
@pytest.mark.asyncio
async def test_chat_completion_zero_budget():
"""
Test budget enforcement for chat completions:
1. Create key with $0.01 budget
2. Make chat completion calls until budget exceeded
3. Verify budget exceeded error
"""
async with aiohttp.ClientSession() as session:
# Create key with $0.01 budget
key_gen = await generate_key(session=session, max_budget=0.000000000)
print("response from key generation: ", key_gen)
key = key_gen["key"]
# Make calls until budget exceeded
calls_made = await make_calls_until_budget_exceeded(
session=session,
key=key,
call_function=chat_completion,
model="fake-openai-endpoint",
)
assert calls_made == 0, "Should make no calls before budget exceeded"
@pytest.mark.asyncio
async def test_chat_completion_high_budget():
"""
Test budget enforcement for chat completions:
1. Create key with $0.01 budget
2. Make chat completion calls until budget exceeded
3. Verify budget exceeded error
"""
async with aiohttp.ClientSession() as session:
# Create key with $0.01 budget
key_gen = await generate_key(session=session, max_budget=0.001)
print("response from key generation: ", key_gen)
key = key_gen["key"]
# Make calls until budget exceeded
calls_made = await make_calls_until_budget_exceeded(
session=session,
key=key,
call_function=chat_completion,
model="fake-openai-endpoint",
)
assert (
calls_made > 0
), "Should make at least one successful call before budget exceeded"
@pytest.mark.asyncio
async def test_chat_completion_budget_update():
"""
Test that requests continue working after updating a key's budget:
1. Create key with low budget
2. Make calls until budget exceeded
3. Update key with higher budget
4. Verify calls work again
"""
async with aiohttp.ClientSession() as session:
# Create key with very low budget
key_gen = await generate_key(session=session, max_budget=0.0000000005)
key = key_gen["key"]
# Make calls until budget exceeded
calls_made = await make_calls_until_budget_exceeded(
session=session,
key=key,
call_function=chat_completion,
model="fake-openai-endpoint",
)
assert (
calls_made > 0
), "Should make at least one successful call before budget exceeded"
# Update key with higher budget
await update_key_budget(session, key, max_budget=0.001)
# Verify calls work again
for _ in range(3):
try:
response = await chat_completion(
session=session, key=key, model="fake-openai-endpoint"
)
print("response: ", response)
assert (
response is not None
), "Should get valid response after budget update"
except Exception as e:
pytest.fail(
f"Request should succeed after budget update but got error: {e}"
)