diff --git a/litellm/proxy/auth/auth_checks.py b/litellm/proxy/auth/auth_checks.py index ff47bd0418..eca700c908 100644 --- a/litellm/proxy/auth/auth_checks.py +++ b/litellm/proxy/auth/auth_checks.py @@ -9,6 +9,7 @@ Run checks for: 3. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget """ +import asyncio import time import traceback from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional @@ -21,6 +22,7 @@ from litellm.caching.caching import DualCache from litellm.caching.dual_cache import LimitedSizeOrderedDict from litellm.proxy._types import ( DB_CONNECTION_ERROR_TYPES, + CallInfo, LiteLLM_EndUserTable, LiteLLM_JWTAuth, LiteLLM_OrganizationTable, @@ -931,3 +933,51 @@ async def is_valid_fallback_model( ) return True + + +async def _virtual_key_max_budget_check( + valid_token: UserAPIKeyAuth, + proxy_logging_obj: ProxyLogging, + user_obj: Optional[LiteLLM_UserTable] = None, +): + """ + Raises: + BudgetExceededError if the token is over it's max budget. + Triggers a budget alert if the token is over it's max budget. + + """ + if valid_token.spend is not None and valid_token.max_budget is not None: + #################################### + # collect information for alerting # + #################################### + + user_email = None + # Check if the token has any user id information + if user_obj is not None: + user_email = user_obj.user_email + + call_info = CallInfo( + token=valid_token.token, + spend=valid_token.spend, + max_budget=valid_token.max_budget, + user_id=valid_token.user_id, + team_id=valid_token.team_id, + user_email=user_email, + key_alias=valid_token.key_alias, + ) + asyncio.create_task( + proxy_logging_obj.budget_alerts( + type="token_budget", + user_info=call_info, + ) + ) + + #################################### + # collect information for alerting # + #################################### + + if valid_token.spend >= valid_token.max_budget: + raise litellm.BudgetExceededError( + current_cost=valid_token.spend, + max_budget=valid_token.max_budget, + ) diff --git a/litellm/proxy/auth/user_api_key_auth.py b/litellm/proxy/auth/user_api_key_auth.py index 292e9d2bea..1a4b70e6eb 100644 --- a/litellm/proxy/auth/user_api_key_auth.py +++ b/litellm/proxy/auth/user_api_key_auth.py @@ -24,6 +24,7 @@ from litellm.proxy._types import * from litellm.proxy.auth.auth_checks import ( _cache_key_object, _handle_failed_db_connection_for_get_key_object, + _virtual_key_max_budget_check, allowed_routes_check, can_key_call_model, common_checks, @@ -1092,42 +1093,11 @@ async def _user_api_key_auth_builder( # noqa: PLR0915 ) # Check 4. Token Spend is under budget - if valid_token.spend is not None and valid_token.max_budget is not None: - - #################################### - # collect information for alerting # - #################################### - - user_email = None - # Check if the token has any user id information - if user_obj is not None: - user_email = user_obj.user_email - - call_info = CallInfo( - token=valid_token.token, - spend=valid_token.spend, - max_budget=valid_token.max_budget, - user_id=valid_token.user_id, - team_id=valid_token.team_id, - user_email=user_email, - key_alias=valid_token.key_alias, - ) - asyncio.create_task( - proxy_logging_obj.budget_alerts( - type="token_budget", - user_info=call_info, - ) - ) - - #################################### - # collect information for alerting # - #################################### - - if valid_token.spend >= valid_token.max_budget: - raise litellm.BudgetExceededError( - current_cost=valid_token.spend, - max_budget=valid_token.max_budget, - ) + await _virtual_key_max_budget_check( + valid_token=valid_token, + proxy_logging_obj=proxy_logging_obj, + user_obj=user_obj, + ) if valid_token.soft_budget and valid_token.spend >= valid_token.soft_budget: verbose_proxy_logger.debug( diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml index 49e8dd6934..09df4ec481 100644 --- a/litellm/proxy/proxy_config.yaml +++ b/litellm/proxy/proxy_config.yaml @@ -3,14 +3,10 @@ model_list: litellm_params: model: openai/* api_key: os.environ/OPENAI_API_KEY - model_info: - health_check_model: openai/gpt-4o-mini - - model_name: anthropic/* + - model_name: text-embedding-ada-002 litellm_params: - model: anthropic/* - api_key: os.environ/ANTHROPIC_API_KEY - model_info: - health_check_model: anthropic/claude-3-5-sonnet-20240620 + model: openai/text-embedding-ada-002 + api_key: os.environ/OPENAI_API_KEY - model_name: fake-openai-endpoint litellm_params: model: openai/fake diff --git a/tests/otel_tests/test_e2e_budgeting.py b/tests/otel_tests/test_e2e_budgeting.py new file mode 100644 index 0000000000..2d852a4e1c --- /dev/null +++ b/tests/otel_tests/test_e2e_budgeting.py @@ -0,0 +1,208 @@ +import pytest +import asyncio +import aiohttp +import json + + +async def make_calls_until_budget_exceeded(session, key: str, call_function, **kwargs): + """Helper function to make API calls until budget is exceeded. Verify that the budget is exceeded error is returned.""" + MAX_CALLS = 50 + call_count = 0 + try: + while call_count < MAX_CALLS: + await call_function(session=session, key=key, **kwargs) + call_count += 1 + pytest.fail(f"Budget was not exceeded after {MAX_CALLS} calls") + except Exception as e: + print("vars: ", vars(e)) + print("e.body: ", e.body) + + error_dict = e.body + print("error_dict: ", error_dict) + + # Check error structure and values that should be consistent + assert ( + error_dict["code"] == "400" + ), f"Expected error code 400, got: {error_dict['code']}" + assert ( + error_dict["type"] == "budget_exceeded" + ), f"Expected error type budget_exceeded, got: {error_dict['type']}" + + # Check message contains required parts without checking specific values + message = error_dict["message"] + assert ( + "Budget has been exceeded!" in message + ), f"Expected message to start with 'Budget has been exceeded!', got: {message}" + assert ( + "Current cost:" in message + ), f"Expected message to contain 'Current cost:', got: {message}" + assert ( + "Max budget:" in message + ), f"Expected message to contain 'Max budget:', got: {message}" + + return call_count + + +async def generate_key( + session, + max_budget=None, +): + url = "http://0.0.0.0:4000/key/generate" + headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"} + data = { + "max_budget": max_budget, + } + async with session.post(url, headers=headers, json=data) as response: + return await response.json() + + +async def chat_completion(session, key: str, model: str): + """Make a chat completion request using OpenAI SDK""" + from openai import AsyncOpenAI + import uuid + + client = AsyncOpenAI( + api_key=key, base_url="http://0.0.0.0:4000/v1" # Point to our local proxy + ) + + response = await client.chat.completions.create( + model=model, + messages=[{"role": "user", "content": f"Say hello! {uuid.uuid4()}" * 100}], + ) + return response + + +async def update_key_budget(session, key: str, max_budget: float): + """Helper function to update a key's max budget""" + url = "http://0.0.0.0:4000/key/update" + headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"} + data = { + "key": key, + "max_budget": max_budget, + } + async with session.post(url, headers=headers, json=data) as response: + return await response.json() + + +@pytest.mark.asyncio +async def test_chat_completion_low_budget(): + """ + Test budget enforcement for chat completions: + 1. Create key with $0.01 budget + 2. Make chat completion calls until budget exceeded + 3. Verify budget exceeded error + """ + async with aiohttp.ClientSession() as session: + # Create key with $0.01 budget + key_gen = await generate_key(session=session, max_budget=0.0000000005) + print("response from key generation: ", key_gen) + key = key_gen["key"] + + # Make calls until budget exceeded + calls_made = await make_calls_until_budget_exceeded( + session=session, + key=key, + call_function=chat_completion, + model="fake-openai-endpoint", + ) + + assert ( + calls_made > 0 + ), "Should make at least one successful call before budget exceeded" + + +@pytest.mark.asyncio +async def test_chat_completion_zero_budget(): + """ + Test budget enforcement for chat completions: + 1. Create key with $0.01 budget + 2. Make chat completion calls until budget exceeded + 3. Verify budget exceeded error + """ + async with aiohttp.ClientSession() as session: + # Create key with $0.01 budget + key_gen = await generate_key(session=session, max_budget=0.000000000) + print("response from key generation: ", key_gen) + key = key_gen["key"] + + # Make calls until budget exceeded + calls_made = await make_calls_until_budget_exceeded( + session=session, + key=key, + call_function=chat_completion, + model="fake-openai-endpoint", + ) + + assert calls_made == 0, "Should make no calls before budget exceeded" + + +@pytest.mark.asyncio +async def test_chat_completion_high_budget(): + """ + Test budget enforcement for chat completions: + 1. Create key with $0.01 budget + 2. Make chat completion calls until budget exceeded + 3. Verify budget exceeded error + """ + async with aiohttp.ClientSession() as session: + # Create key with $0.01 budget + key_gen = await generate_key(session=session, max_budget=0.001) + print("response from key generation: ", key_gen) + key = key_gen["key"] + + # Make calls until budget exceeded + calls_made = await make_calls_until_budget_exceeded( + session=session, + key=key, + call_function=chat_completion, + model="fake-openai-endpoint", + ) + + assert ( + calls_made > 0 + ), "Should make at least one successful call before budget exceeded" + + +@pytest.mark.asyncio +async def test_chat_completion_budget_update(): + """ + Test that requests continue working after updating a key's budget: + 1. Create key with low budget + 2. Make calls until budget exceeded + 3. Update key with higher budget + 4. Verify calls work again + """ + async with aiohttp.ClientSession() as session: + # Create key with very low budget + key_gen = await generate_key(session=session, max_budget=0.0000000005) + key = key_gen["key"] + + # Make calls until budget exceeded + calls_made = await make_calls_until_budget_exceeded( + session=session, + key=key, + call_function=chat_completion, + model="fake-openai-endpoint", + ) + + assert ( + calls_made > 0 + ), "Should make at least one successful call before budget exceeded" + + # Update key with higher budget + await update_key_budget(session, key, max_budget=0.001) + + # Verify calls work again + for _ in range(3): + try: + response = await chat_completion( + session=session, key=key, model="fake-openai-endpoint" + ) + print("response: ", response) + assert ( + response is not None + ), "Should get valid response after budget update" + except Exception as e: + pytest.fail( + f"Request should succeed after budget update but got error: {e}" + )