diff --git a/docs/my-website/docs/proxy/users.md b/docs/my-website/docs/proxy/users.md index 04f6e8c945..1db749e83e 100644 --- a/docs/my-website/docs/proxy/users.md +++ b/docs/my-website/docs/proxy/users.md @@ -10,7 +10,7 @@ Requirements: ## Set Budgets -You can set budgets at 3 levels: +You can set budgets at 5 levels: - For the proxy - For an internal user - For a customer (end-user) @@ -392,32 +392,88 @@ curl --location 'http://0.0.0.0:4000/key/generate' \ -Apply model specific budgets on a key. - -**Expected Behaviour** -- `model_spend` gets auto-populated in `LiteLLM_VerificationToken` Table -- After the key crosses the budget set for the `model` in `model_max_budget`, calls fail - -By default the `model_max_budget` is set to `{}` and is not checked for keys - -:::info - -- LiteLLM will track the cost/budgets for the `model` passed to LLM endpoints (`/chat/completions`, `/embeddings`) - - -::: +Apply model specific budgets on a key. Example: +- Budget for `gpt-4o` is $0.0000001, for time period `1d` for `key = "sk-12345"` +- Budget for `gpt-4o-mini` is $10, for time period `30d` for `key = "sk-12345"` #### **Add model specific budgets to keys** +The spec for `model_max_budget` is **[`Dict[str, GenericBudgetInfo]`](#genericbudgetinfo)** + ```bash curl 'http://0.0.0.0:4000/key/generate' \ --header 'Authorization: Bearer ' \ --header 'Content-Type: application/json' \ --data-raw '{ - model_max_budget={"gpt4": 0.5, "gpt-5": 0.01} + "model_max_budget": {"gpt-4o": {"budget_limit": "0.0000001", "time_period": "1d"}} }' ``` + +#### Make a test request + +We expect the first request to succeed, and the second request to fail since we cross the budget for `gpt-4o` on the Virtual Key + +**[Langchain, OpenAI SDK Usage Examples](../proxy/user_keys#request-format)** + + + + +```shell +curl --location 'http://0.0.0.0:4000/chat/completions' \ +--header 'Content-Type: application/json' \ +--header 'Authorization: Bearer ' \ +--data ' { + "model": "gpt-4o", + "messages": [ + { + "role": "user", + "content": "testing request" + } + ] + } +' +``` + + + + +Expect this to fail since since we cross the budget `model=gpt-4o` on the Virtual Key + +```shell +curl --location 'http://0.0.0.0:4000/chat/completions' \ +--header 'Content-Type: application/json' \ +--header 'Authorization: Bearer ' \ +--data ' { + "model": "gpt-4o", + "messages": [ + { + "role": "user", + "content": "testing request" + } + ] + } +' +``` + +Expected response on failure + +```json +{ + "error": { + "message": "LiteLLM Virtual Key: 9769f3f6768a199f76cc29xxxx, key_alias: None, exceeded budget for model=gpt-4o", + "type": "budget_exceeded", + "param": null, + "code": "400" + } +} +``` + + + + + + @@ -783,3 +839,32 @@ curl --location 'http://0.0.0.0:4000/key/generate' \ --header 'Content-Type: application/json' \ --data '{"models": ["azure-models"], "user_id": "krrish@berri.ai"}' ``` + + +## API Specification + +### `GenericBudgetInfo` + +A Pydantic model that defines budget information with a time period and limit. + +```python +class GenericBudgetInfo(BaseModel): + budget_limit: float # The maximum budget amount in USD + time_period: str # Duration string like "1d", "30d", etc. +``` + +#### Fields: +- `budget_limit` (float): The maximum budget amount in USD +- `time_period` (str): Duration string specifying the time period for the budget. Supported formats: + - Seconds: "30s" + - Minutes: "30m" + - Hours: "30h" + - Days: "30d" + +#### Example: +```json +{ + "budget_limit": "0.0001", + "time_period": "1d" +} +``` \ No newline at end of file diff --git a/litellm/proxy/auth/user_api_key_auth.py b/litellm/proxy/auth/user_api_key_auth.py index e9c56d7950..acd7b9ebba 100644 --- a/litellm/proxy/auth/user_api_key_auth.py +++ b/litellm/proxy/auth/user_api_key_auth.py @@ -262,6 +262,7 @@ async def user_api_key_auth( # noqa: PLR0915 llm_model_list, llm_router, master_key, + model_max_budget_limiter, open_telemetry_logger, prisma_client, proxy_logging_obj, @@ -1053,37 +1054,10 @@ async def user_api_key_auth( # noqa: PLR0915 and valid_token.token is not None ): ## GET THE SPEND FOR THIS MODEL - twenty_eight_days_ago = datetime.now() - timedelta(days=28) - model_spend = await prisma_client.db.litellm_spendlogs.group_by( - by=["model"], - sum={"spend": True}, - where={ - "AND": [ - {"api_key": valid_token.token}, - {"startTime": {"gt": twenty_eight_days_ago}}, - {"model": current_model}, - ] - }, # type: ignore + await model_max_budget_limiter.is_key_within_model_budget( + user_api_key_dict=valid_token, + model=current_model, ) - if ( - len(model_spend) > 0 - and max_budget_per_model.get(current_model, None) is not None - ): - if ( - "model" in model_spend[0] - and model_spend[0].get("model") == current_model - and "_sum" in model_spend[0] - and "spend" in model_spend[0]["_sum"] - and model_spend[0]["_sum"]["spend"] - >= max_budget_per_model[current_model] - ): - current_model_spend = model_spend[0]["_sum"]["spend"] - current_model_budget = max_budget_per_model[current_model] - raise litellm.BudgetExceededError( - current_cost=current_model_spend, - max_budget=current_model_budget, - ) - # Check 6. Team spend is under Team budget if ( hasattr(valid_token, "team_spend") diff --git a/litellm/proxy/hooks/model_max_budget_limiter.py b/litellm/proxy/hooks/model_max_budget_limiter.py new file mode 100644 index 0000000000..8ce6da8d19 --- /dev/null +++ b/litellm/proxy/hooks/model_max_budget_limiter.py @@ -0,0 +1,196 @@ +import json +import traceback +from typing import List, Optional + +from fastapi import HTTPException + +import litellm +from litellm import verbose_logger +from litellm._logging import verbose_proxy_logger +from litellm.caching.caching import DualCache +from litellm.integrations.custom_logger import CustomLogger, Span +from litellm.proxy._types import UserAPIKeyAuth +from litellm.router_strategy.budget_limiter import RouterBudgetLimiting +from litellm.types.llms.openai import AllMessageValues +from litellm.types.utils import ( + GenericBudgetConfigType, + GenericBudgetInfo, + StandardLoggingPayload, +) + +VIRTUAL_KEY_SPEND_CACHE_KEY_PREFIX = "virtual_key_spend" + + +class _PROXY_VirtualKeyModelMaxBudgetLimiter(RouterBudgetLimiting): + """ + Handles budgets for model + virtual key + + Example: key=sk-1234567890, model=gpt-4o, max_budget=100, time_period=1d + """ + + def __init__(self, dual_cache: DualCache): + self.dual_cache = dual_cache + self.redis_increment_operation_queue = [] + + async def is_key_within_model_budget( + self, + user_api_key_dict: UserAPIKeyAuth, + model: str, + ) -> bool: + """ + Check if the user_api_key_dict is within the model budget + + Raises: + BudgetExceededError: If the user_api_key_dict has exceeded the model budget + """ + _model_max_budget = user_api_key_dict.model_max_budget + internal_model_max_budget: GenericBudgetConfigType = {} + + # case each element in _model_max_budget to GenericBudgetInfo + for _model, _budget_info in _model_max_budget.items(): + internal_model_max_budget[_model] = GenericBudgetInfo( + time_period=_budget_info.get("time_period"), + budget_limit=float(_budget_info.get("budget_limit")), + ) + + verbose_proxy_logger.debug( + "internal_model_max_budget %s", + json.dumps(internal_model_max_budget, indent=4, default=str), + ) + + # check if current model is in internal_model_max_budget + _current_model_budget_info = self._get_request_model_budget_config( + model=model, internal_model_max_budget=internal_model_max_budget + ) + if _current_model_budget_info is None: + verbose_proxy_logger.debug( + f"Model {model} not found in internal_model_max_budget" + ) + return True + + # check if current model is within budget + if _current_model_budget_info.budget_limit > 0: + _current_spend = await self._get_virtual_key_spend_for_model( + user_api_key_hash=user_api_key_dict.token, + model=model, + key_budget_config=_current_model_budget_info, + ) + if ( + _current_spend is not None + and _current_spend > _current_model_budget_info.budget_limit + ): + raise litellm.BudgetExceededError( + message=f"LiteLLM Virtual Key: {user_api_key_dict.token}, key_alias: {user_api_key_dict.key_alias}, exceeded budget for model={model}", + current_cost=_current_spend, + max_budget=_current_model_budget_info.budget_limit, + ) + + return True + + async def _get_virtual_key_spend_for_model( + self, + user_api_key_hash: Optional[str], + model: str, + key_budget_config: GenericBudgetInfo, + ) -> Optional[float]: + """ + Get the current spend for a virtual key for a model + + Lookup model in this order: + 1. model: directly look up `model` + 2. If 1, does not exist, check if passed as {custom_llm_provider}/model + """ + + # 1. model: directly look up `model` + virtual_key_model_spend_cache_key = f"{VIRTUAL_KEY_SPEND_CACHE_KEY_PREFIX}:{user_api_key_hash}:{model}:{key_budget_config.time_period}" + _current_spend = await self.dual_cache.async_get_cache( + key=virtual_key_model_spend_cache_key, + ) + + if _current_spend is None: + # 2. If 1, does not exist, check if passed as {custom_llm_provider}/model + # if "/" in model, remove first part before "/" - eg. openai/o1-preview -> o1-preview + virtual_key_model_spend_cache_key = f"{VIRTUAL_KEY_SPEND_CACHE_KEY_PREFIX}:{user_api_key_hash}:{self._get_model_without_custom_llm_provider(model)}:{key_budget_config.time_period}" + _current_spend = await self.dual_cache.async_get_cache( + key=virtual_key_model_spend_cache_key, + ) + return _current_spend + + def _get_request_model_budget_config( + self, model: str, internal_model_max_budget: GenericBudgetConfigType + ) -> Optional[GenericBudgetInfo]: + """ + Get the budget config for the request model + + 1. Check if `model` is in `internal_model_max_budget` + 2. If not, check if `model` without custom llm provider is in `internal_model_max_budget` + """ + return internal_model_max_budget.get( + model, None + ) or internal_model_max_budget.get( + self._get_model_without_custom_llm_provider(model), None + ) + + def _get_model_without_custom_llm_provider(self, model: str) -> str: + if "/" in model: + return model.split("/")[-1] + return model + + async def async_filter_deployments( + self, + model: str, + healthy_deployments: List, + messages: Optional[List[AllMessageValues]], + request_kwargs: Optional[dict] = None, + parent_otel_span: Optional[Span] = None, # type: ignore + ) -> List[dict]: + return healthy_deployments + + async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): + """ + Track spend for virtual key + model in DualCache + + Example: key=sk-1234567890, model=gpt-4o, max_budget=100, time_period=1d + """ + verbose_proxy_logger.debug("in RouterBudgetLimiting.async_log_success_event") + standard_logging_payload: Optional[StandardLoggingPayload] = kwargs.get( + "standard_logging_object", None + ) + if standard_logging_payload is None: + raise ValueError("standard_logging_payload is required") + + _litellm_params = kwargs.get("litellm_params", {}) + _metadata = _litellm_params.get("metadata", {}) + user_api_key_model_max_budget: Optional[dict] = _metadata.get( + "user_api_key_model_max_budget", None + ) + if ( + user_api_key_model_max_budget is None + or len(user_api_key_model_max_budget) == 0 + ): + verbose_proxy_logger.debug( + "Not running _PROXY_VirtualKeyModelMaxBudgetLimiter.async_log_success_event because user_api_key_model_max_budget is None or empty. `user_api_key_model_max_budget`=%s", + user_api_key_model_max_budget, + ) + return + response_cost: float = standard_logging_payload.get("response_cost", 0) + model = standard_logging_payload.get("model") + + virtual_key = standard_logging_payload.get("metadata").get("user_api_key_hash") + model = standard_logging_payload.get("model") + if virtual_key is not None: + budget_config = GenericBudgetInfo(time_period="1d", budget_limit=0.1) + virtual_spend_key = f"{VIRTUAL_KEY_SPEND_CACHE_KEY_PREFIX}:{virtual_key}:{model}:{budget_config.time_period}" + virtual_start_time_key = f"virtual_key_budget_start_time:{virtual_key}" + await self._increment_spend_for_key( + budget_config=budget_config, + spend_key=virtual_spend_key, + start_time_key=virtual_start_time_key, + response_cost=response_cost, + ) + verbose_proxy_logger.debug( + "current state of in memory cache %s", + json.dumps( + self.dual_cache.in_memory_cache.cache_dict, indent=4, default=str + ), + ) diff --git a/litellm/proxy/litellm_pre_call_utils.py b/litellm/proxy/litellm_pre_call_utils.py index f690f517c7..325aff881d 100644 --- a/litellm/proxy/litellm_pre_call_utils.py +++ b/litellm/proxy/litellm_pre_call_utils.py @@ -499,6 +499,9 @@ async def add_litellm_data_to_request( # noqa: PLR0915 data[_metadata_variable_name][ "user_api_key_max_budget" ] = user_api_key_dict.max_budget + data[_metadata_variable_name][ + "user_api_key_model_max_budget" + ] = user_api_key_dict.model_max_budget data[_metadata_variable_name]["user_api_key_metadata"] = user_api_key_dict.metadata _headers = dict(request.headers) diff --git a/litellm/proxy/management_endpoints/key_management_endpoints.py b/litellm/proxy/management_endpoints/key_management_endpoints.py index ee1b9bd8b3..93613c4bc2 100644 --- a/litellm/proxy/management_endpoints/key_management_endpoints.py +++ b/litellm/proxy/management_endpoints/key_management_endpoints.py @@ -40,7 +40,11 @@ from litellm.proxy.utils import ( handle_exception_on_proxy, ) from litellm.secret_managers.main import get_secret -from litellm.types.utils import PersonalUIKeyGenerationConfig, TeamUIKeyGenerationConfig +from litellm.types.utils import ( + GenericBudgetInfo, + PersonalUIKeyGenerationConfig, + TeamUIKeyGenerationConfig, +) def _is_team_key(data: GenerateKeyRequest): @@ -246,7 +250,7 @@ async def generate_key_fn( # noqa: PLR0915 - metadata: Optional[dict] - Metadata for key, store information for key. Example metadata = {"team": "core-infra", "app": "app2", "email": "ishaan@berri.ai" } - guardrails: Optional[List[str]] - List of active guardrails for the key - permissions: Optional[dict] - key-specific permissions. Currently just used for turning off pii masking (if connected). Example - {"pii": false} - - model_max_budget: Optional[dict] - key-specific model budget in USD. Example - {"text-davinci-002": 0.5, "gpt-3.5-turbo": 0.5}. IF null or {} then no model specific budget. + - model_max_budget: Optional[Dict[str, GenericBudgetInfo]] - Model-specific budgets {"gpt-4": {"budget_limit": 0.0005, "time_period": "30d"}}}. IF null or {} then no model specific budget. - model_rpm_limit: Optional[dict] - key-specific model rpm limit. Example - {"text-davinci-002": 1000, "gpt-3.5-turbo": 1000}. IF null or {} then no model specific rpm limit. - model_tpm_limit: Optional[dict] - key-specific model tpm limit. Example - {"text-davinci-002": 1000, "gpt-3.5-turbo": 1000}. IF null or {} then no model specific tpm limit. - allowed_cache_controls: Optional[list] - List of allowed cache control values. Example - ["no-cache", "no-store"]. See all values - https://docs.litellm.ai/docs/proxy/caching#turn-on--off-caching-per-request @@ -515,6 +519,10 @@ def prepare_key_update_data( _metadata = existing_key_row.metadata or {} + # validate model_max_budget + if "model_max_budget" in non_default_values: + validate_model_max_budget(non_default_values["model_max_budget"]) + non_default_values = prepare_metadata_fields( data=data, non_default_values=non_default_values, existing_metadata=_metadata ) @@ -548,7 +556,7 @@ async def update_key_fn( - enforced_params: Optional[List[str]] - List of enforced params for the key (Enterprise only). [Docs](https://docs.litellm.ai/docs/proxy/enterprise#enforce-required-params-for-llm-requests) - spend: Optional[float] - Amount spent by key - max_budget: Optional[float] - Max budget for key - - model_max_budget: Optional[dict] - Model-specific budgets {"gpt-4": 0.5, "claude-v1": 1.0} + - model_max_budget: Optional[Dict[str, GenericBudgetInfo]] - Model-specific budgets {"gpt-4": {"budget_limit": 0.0005, "time_period": "30d"}} - budget_duration: Optional[str] - Budget reset period ("30d", "1h", etc.) - soft_budget: Optional[float] - Soft budget limit (warning vs. hard stop). Will trigger a slack alert when this soft budget is reached. - max_parallel_requests: Optional[int] - Rate limit for parallel requests @@ -1035,6 +1043,7 @@ async def generate_key_helper_fn( # noqa: PLR0915 metadata["guardrails"] = guardrails metadata_json = json.dumps(metadata) + validate_model_max_budget(model_max_budget) model_max_budget_json = json.dumps(model_max_budget) user_role = user_role tpm_limit = tpm_limit @@ -1266,7 +1275,7 @@ async def regenerate_key_fn( - tags: Optional[List[str]] - Tags for organizing keys (Enterprise only) - spend: Optional[float] - Amount spent by key - max_budget: Optional[float] - Max budget for key - - model_max_budget: Optional[dict] - Model-specific budgets {"gpt-4": 0.5, "claude-v1": 1.0} + - model_max_budget: Optional[Dict[str, GenericBudgetInfo]] - Model-specific budgets {"gpt-4": {"budget_limit": 0.0005, "time_period": "30d"}} - budget_duration: Optional[str] - Budget reset period ("30d", "1h", etc.) - soft_budget: Optional[float] - Soft budget limit (warning vs. hard stop). Will trigger a slack alert when this soft budget is reached. - max_parallel_requests: Optional[int] - Rate limit for parallel requests @@ -1293,8 +1302,7 @@ async def regenerate_key_fn( --data-raw '{ "max_budget": 100, "metadata": {"team": "core-infra"}, - "models": ["gpt-4", "gpt-3.5-turbo"], - "model_max_budget": {"gpt-4": 50, "gpt-3.5-turbo": 50} + "models": ["gpt-4", "gpt-3.5-turbo"] }' ``` @@ -1949,3 +1957,29 @@ async def _enforce_unique_key_alias( param="key_alias", code=status.HTTP_400_BAD_REQUEST, ) + + +def validate_model_max_budget(model_max_budget: Optional[Dict]) -> None: + """ + Validate the model_max_budget is GenericBudgetConfigType + + Raises: + Exception: If model_max_budget is not a valid GenericBudgetConfigType + """ + try: + if model_max_budget is None: + return + if len(model_max_budget) == 0: + return + if model_max_budget is not None: + for _model, _budget_info in model_max_budget.items(): + assert isinstance(_model, str) + + # /CRUD endpoints can pass budget_limit as a string, so we need to convert it to a float + if "budget_limit" in _budget_info: + _budget_info["budget_limit"] = float(_budget_info["budget_limit"]) + GenericBudgetInfo(**_budget_info) + except Exception as e: + raise ValueError( + f"Invalid model_max_budget: {str(e)}. Example of valid model_max_budget: https://docs.litellm.ai/docs/proxy/users" + ) diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml index 5a6aef7d2a..063101d598 100644 --- a/litellm/proxy/proxy_config.yaml +++ b/litellm/proxy/proxy_config.yaml @@ -1,42 +1,10 @@ model_list: - - model_name: fake-openai-endpoint - litellm_params: - model: openai/fake - api_key: fake-key - api_base: https://exampleopenaiendpoint-production.up.railway.app/ - tags: ["teamA"] - model_info: - id: "team-a-model" - - model_name: fake-openai-endpoint - litellm_params: - model: openai/fake - api_key: fake-key - api_base: https://exampleopenaiendpoint-production.up.railway.app/ - tags: ["teamB"] - model_info: - id: "team-b-model" - - model_name: rerank-english-v3.0 - litellm_params: - model: cohere/rerank-english-v3.0 - api_key: os.environ/COHERE_API_KEY - - model_name: fake-azure-endpoint - litellm_params: - model: openai/429 - api_key: fake-key - api_base: https://exampleopenaiendpoint-production.up.railway.app - - model_name: llava-hf - litellm_params: - model: openai/llava-hf/llava-v1.6-vicuna-7b-hf - api_base: http://localhost:8000 - api_key: fake-key - model_info: - supports_vision: True + - model_name: openai/o1-preview + litellm_params: + model: openai/o1-preview + api_key: os.environ/OPENAI_API_KEY + - model_name: openai/* + litellm_params: + model: openai/* + api_key: os.environ/OPENAI_API_KEY - - -litellm_settings: - cache: true - callbacks: ["otel", "prometheus"] - -router_settings: - enable_tag_filtering: True # 👈 Key Change \ No newline at end of file diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 2052082c4b..18ba012d72 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -173,6 +173,9 @@ from litellm.proxy.guardrails.init_guardrails import ( ) from litellm.proxy.health_check import perform_health_check from litellm.proxy.health_endpoints._health_endpoints import router as health_router +from litellm.proxy.hooks.model_max_budget_limiter import ( + _PROXY_VirtualKeyModelMaxBudgetLimiter, +) from litellm.proxy.hooks.prompt_injection_detection import ( _OPTIONAL_PromptInjectionDetection, ) @@ -495,6 +498,10 @@ prisma_client: Optional[PrismaClient] = None user_api_key_cache = DualCache( default_in_memory_ttl=UserAPIKeyCacheTTLEnum.in_memory_cache_ttl.value ) +model_max_budget_limiter = _PROXY_VirtualKeyModelMaxBudgetLimiter( + dual_cache=user_api_key_cache +) +litellm.callbacks.append(model_max_budget_limiter) redis_usage_cache: Optional[RedisCache] = ( None # redis cache used for tracking spend, tpm/rpm limits ) diff --git a/litellm/router.py b/litellm/router.py index 9e1fb7d9f4..6832ffae94 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -631,7 +631,7 @@ class Router: _callback = PromptCachingDeploymentCheck(cache=self.cache) elif pre_call_check == "router_budget_limiting": _callback = RouterBudgetLimiting( - router_cache=self.cache, + dual_cache=self.cache, provider_budget_config=self.provider_budget_config, model_list=self.model_list, ) @@ -5292,14 +5292,6 @@ class Router: healthy_deployments=healthy_deployments, ) - # if self.router_budget_logger: - # healthy_deployments = ( - # await self.router_budget_logger.async_filter_deployments( - # healthy_deployments=healthy_deployments, - # request_kwargs=request_kwargs, - # ) - # ) - if len(healthy_deployments) == 0: exception = await async_raise_no_deployment_exception( litellm_router_instance=self, diff --git a/litellm/router_strategy/budget_limiter.py b/litellm/router_strategy/budget_limiter.py index 8e4d675750..0452a174b5 100644 --- a/litellm/router_strategy/budget_limiter.py +++ b/litellm/router_strategy/budget_limiter.py @@ -49,13 +49,13 @@ DEFAULT_REDIS_SYNC_INTERVAL = 1 class RouterBudgetLimiting(CustomLogger): def __init__( self, - router_cache: DualCache, + dual_cache: DualCache, provider_budget_config: Optional[dict], model_list: Optional[ Union[List[DeploymentTypedDict], List[Dict[str, Any]]] ] = None, ): - self.router_cache = router_cache + self.dual_cache = dual_cache self.redis_increment_operation_queue: List[RedisPipelineIncrementOperation] = [] asyncio.create_task(self.periodic_sync_in_memory_spend_with_redis()) self.provider_budget_config: Optional[GenericBudgetConfigType] = ( @@ -108,7 +108,7 @@ class RouterBudgetLimiting(CustomLogger): # Single cache read for all spend values if len(cache_keys) > 0: - _current_spends = await self.router_cache.async_batch_get_cache( + _current_spends = await self.dual_cache.async_batch_get_cache( keys=cache_keys, parent_otel_span=parent_otel_span, ) @@ -286,9 +286,9 @@ class RouterBudgetLimiting(CustomLogger): If it does, return the value. If it does not, set the key to `current_time` and return the value. """ - budget_start = await self.router_cache.async_get_cache(start_time_key) + budget_start = await self.dual_cache.async_get_cache(start_time_key) if budget_start is None: - await self.router_cache.async_set_cache( + await self.dual_cache.async_set_cache( key=start_time_key, value=current_time, ttl=ttl_seconds ) return current_time @@ -314,10 +314,10 @@ class RouterBudgetLimiting(CustomLogger): - stores key: `provider_budget_start_time:{provider}`, value: current_time. This stores the start time of the new budget window """ - await self.router_cache.async_set_cache( + await self.dual_cache.async_set_cache( key=spend_key, value=response_cost, ttl=ttl_seconds ) - await self.router_cache.async_set_cache( + await self.dual_cache.async_set_cache( key=start_time_key, value=current_time, ttl=ttl_seconds ) return current_time @@ -333,7 +333,7 @@ class RouterBudgetLimiting(CustomLogger): - Increments the spend in memory cache (so spend instantly updated in memory) - Queues the increment operation to Redis Pipeline (using batched pipeline to optimize performance. Using Redis for multi instance environment of LiteLLM) """ - await self.router_cache.in_memory_cache.async_increment( + await self.dual_cache.in_memory_cache.async_increment( key=spend_key, value=response_cost, ttl=ttl, @@ -481,7 +481,7 @@ class RouterBudgetLimiting(CustomLogger): Only runs if Redis is initialized """ try: - if not self.router_cache.redis_cache: + if not self.dual_cache.redis_cache: return # Redis is not initialized verbose_router_logger.debug( @@ -490,7 +490,7 @@ class RouterBudgetLimiting(CustomLogger): ) if len(self.redis_increment_operation_queue) > 0: asyncio.create_task( - self.router_cache.redis_cache.async_increment_pipeline( + self.dual_cache.redis_cache.async_increment_pipeline( increment_list=self.redis_increment_operation_queue, ) ) @@ -517,7 +517,7 @@ class RouterBudgetLimiting(CustomLogger): try: # No need to sync if Redis cache is not initialized - if self.router_cache.redis_cache is None: + if self.dual_cache.redis_cache is None: return # 1. Push all provider spend increments to Redis @@ -547,7 +547,7 @@ class RouterBudgetLimiting(CustomLogger): cache_keys.append(f"tag_spend:{tag}:{config.time_period}") # Batch fetch current spend values from Redis - redis_values = await self.router_cache.redis_cache.async_batch_get_cache( + redis_values = await self.dual_cache.redis_cache.async_batch_get_cache( key_list=cache_keys ) @@ -555,7 +555,7 @@ class RouterBudgetLimiting(CustomLogger): if isinstance(redis_values, dict): # Check if redis_values is a dictionary for key, value in redis_values.items(): if value is not None: - await self.router_cache.in_memory_cache.async_set_cache( + await self.dual_cache.in_memory_cache.async_set_cache( key=key, value=float(value) ) verbose_router_logger.debug( @@ -639,14 +639,12 @@ class RouterBudgetLimiting(CustomLogger): spend_key = f"provider_spend:{provider}:{budget_config.time_period}" - if self.router_cache.redis_cache: + if self.dual_cache.redis_cache: # use Redis as source of truth since that has spend across all instances - current_spend = await self.router_cache.redis_cache.async_get_cache( - spend_key - ) + current_spend = await self.dual_cache.redis_cache.async_get_cache(spend_key) else: # use in-memory cache if Redis is not initialized - current_spend = await self.router_cache.async_get_cache(spend_key) + current_spend = await self.dual_cache.async_get_cache(spend_key) return float(current_spend) if current_spend is not None else 0.0 async def _get_current_provider_budget_reset_at( @@ -657,10 +655,10 @@ class RouterBudgetLimiting(CustomLogger): return None spend_key = f"provider_spend:{provider}:{budget_config.time_period}" - if self.router_cache.redis_cache: - ttl_seconds = await self.router_cache.redis_cache.async_get_ttl(spend_key) + if self.dual_cache.redis_cache: + ttl_seconds = await self.dual_cache.redis_cache.async_get_ttl(spend_key) else: - ttl_seconds = await self.router_cache.async_get_ttl(spend_key) + ttl_seconds = await self.dual_cache.async_get_ttl(spend_key) if ttl_seconds is None: return None @@ -679,16 +677,16 @@ class RouterBudgetLimiting(CustomLogger): spend_key = f"provider_spend:{provider}:{budget_config.time_period}" start_time_key = f"provider_budget_start_time:{provider}" ttl_seconds = duration_in_seconds(budget_config.time_period) - budget_start = await self.router_cache.async_get_cache(start_time_key) + budget_start = await self.dual_cache.async_get_cache(start_time_key) if budget_start is None: budget_start = datetime.now(timezone.utc).timestamp() - await self.router_cache.async_set_cache( + await self.dual_cache.async_set_cache( key=start_time_key, value=budget_start, ttl=ttl_seconds ) - _spend_key = await self.router_cache.async_get_cache(spend_key) + _spend_key = await self.dual_cache.async_get_cache(spend_key) if _spend_key is None: - await self.router_cache.async_set_cache( + await self.dual_cache.async_set_cache( key=spend_key, value=0.0, ttl=ttl_seconds ) diff --git a/litellm/types/router.py b/litellm/types/router.py index 974c7085fc..e5d6511359 100644 --- a/litellm/types/router.py +++ b/litellm/types/router.py @@ -11,6 +11,8 @@ import httpx from pydantic import BaseModel, ConfigDict, Field from typing_extensions import Required, TypedDict +from litellm.types.utils import GenericBudgetConfigType, GenericBudgetInfo + from ..exceptions import RateLimitError from .completion import CompletionRequest from .embedding import EmbeddingRequest @@ -647,14 +649,6 @@ class RoutingStrategy(enum.Enum): PROVIDER_BUDGET_LIMITING = "provider-budget-routing" -class GenericBudgetInfo(BaseModel): - time_period: str # e.g., '1d', '30d' - budget_limit: float - - -GenericBudgetConfigType = Dict[str, GenericBudgetInfo] - - class RouterCacheEnum(enum.Enum): TPM = "global_router:{id}:{model}:tpm:{current_minute}" RPM = "global_router:{id}:{model}:rpm:{current_minute}" diff --git a/litellm/types/utils.py b/litellm/types/utils.py index ca28b15b71..9e299e62f6 100644 --- a/litellm/types/utils.py +++ b/litellm/types/utils.py @@ -1665,6 +1665,14 @@ class StandardKeyGenerationConfig(TypedDict, total=False): personal_key_generation: PersonalUIKeyGenerationConfig +class GenericBudgetInfo(BaseModel): + time_period: str # e.g., '1d', '30d' + budget_limit: float + + +GenericBudgetConfigType = Dict[str, GenericBudgetInfo] + + class BudgetConfig(BaseModel): max_budget: float budget_duration: str diff --git a/tests/local_testing/test_router_budget_limiter.py b/tests/local_testing/test_router_budget_limiter.py index 305db6ccf7..8ca1f4e767 100644 --- a/tests/local_testing/test_router_budget_limiter.py +++ b/tests/local_testing/test_router_budget_limiter.py @@ -183,7 +183,7 @@ async def test_get_llm_provider_for_deployment(): """ cleanup_redis() provider_budget = RouterBudgetLimiting( - router_cache=DualCache(), provider_budget_config={} + dual_cache=DualCache(), provider_budget_config={} ) # Test OpenAI deployment @@ -220,7 +220,7 @@ async def test_get_budget_config_for_provider(): } provider_budget = RouterBudgetLimiting( - router_cache=DualCache(), provider_budget_config=config + dual_cache=DualCache(), provider_budget_config=config ) # Test existing providers @@ -252,7 +252,7 @@ async def test_prometheus_metric_tracking(): # Setup provider budget limiting provider_budget = RouterBudgetLimiting( - router_cache=DualCache(), + dual_cache=DualCache(), provider_budget_config={ "openai": GenericBudgetInfo(time_period="1d", budget_limit=100) }, @@ -316,7 +316,7 @@ async def test_handle_new_budget_window(): """ cleanup_redis() provider_budget = RouterBudgetLimiting( - router_cache=DualCache(), provider_budget_config={} + dual_cache=DualCache(), provider_budget_config={} ) spend_key = "provider_spend:openai:7d" @@ -337,12 +337,12 @@ async def test_handle_new_budget_window(): assert new_start_time == current_time # Verify the spend was set correctly - spend = await provider_budget.router_cache.async_get_cache(spend_key) + spend = await provider_budget.dual_cache.async_get_cache(spend_key) print("spend in cache for key", spend_key, "is", spend) assert float(spend) == response_cost # Verify start time was set correctly - start_time = await provider_budget.router_cache.async_get_cache(start_time_key) + start_time = await provider_budget.dual_cache.async_get_cache(start_time_key) print("start time in cache for key", start_time_key, "is", start_time) assert float(start_time) == current_time @@ -357,7 +357,7 @@ async def test_get_or_set_budget_start_time(): """ cleanup_redis() provider_budget = RouterBudgetLimiting( - router_cache=DualCache(), provider_budget_config={} + dual_cache=DualCache(), provider_budget_config={} ) start_time_key = "test_start_time" @@ -398,7 +398,7 @@ async def test_increment_spend_in_current_window(): """ cleanup_redis() provider_budget = RouterBudgetLimiting( - router_cache=DualCache(), provider_budget_config={} + dual_cache=DualCache(), provider_budget_config={} ) spend_key = "provider_spend:openai:1d" @@ -406,9 +406,7 @@ async def test_increment_spend_in_current_window(): ttl = 86400 # 1 day # Set initial spend - await provider_budget.router_cache.async_set_cache( - key=spend_key, value=1.0, ttl=ttl - ) + await provider_budget.dual_cache.async_set_cache(key=spend_key, value=1.0, ttl=ttl) # Test incrementing spend await provider_budget._increment_spend_in_current_window( @@ -418,7 +416,7 @@ async def test_increment_spend_in_current_window(): ) # Verify the spend was incremented correctly in memory - spend = await provider_budget.router_cache.async_get_cache(spend_key) + spend = await provider_budget.dual_cache.async_get_cache(spend_key) assert float(spend) == 1.5 # Verify the increment operation was queued for Redis @@ -449,7 +447,7 @@ async def test_sync_in_memory_spend_with_redis(): } provider_budget = RouterBudgetLimiting( - router_cache=DualCache( + dual_cache=DualCache( redis_cache=RedisCache( host=os.getenv("REDIS_HOST"), port=int(os.getenv("REDIS_PORT")), @@ -463,10 +461,10 @@ async def test_sync_in_memory_spend_with_redis(): spend_key_openai = "provider_spend:openai:1d" spend_key_anthropic = "provider_spend:anthropic:1d" - await provider_budget.router_cache.redis_cache.async_set_cache( + await provider_budget.dual_cache.redis_cache.async_set_cache( key=spend_key_openai, value=50.0 ) - await provider_budget.router_cache.redis_cache.async_set_cache( + await provider_budget.dual_cache.redis_cache.async_set_cache( key=spend_key_anthropic, value=75.0 ) @@ -474,13 +472,11 @@ async def test_sync_in_memory_spend_with_redis(): await provider_budget._sync_in_memory_spend_with_redis() # Verify in-memory cache was updated - openai_spend = await provider_budget.router_cache.in_memory_cache.async_get_cache( + openai_spend = await provider_budget.dual_cache.in_memory_cache.async_get_cache( spend_key_openai ) - anthropic_spend = ( - await provider_budget.router_cache.in_memory_cache.async_get_cache( - spend_key_anthropic - ) + anthropic_spend = await provider_budget.dual_cache.in_memory_cache.async_get_cache( + spend_key_anthropic ) assert float(openai_spend) == 50.0 @@ -499,7 +495,7 @@ async def test_get_current_provider_spend(): """ cleanup_redis() provider_budget = RouterBudgetLimiting( - router_cache=DualCache(), + dual_cache=DualCache(), provider_budget_config={ "openai": GenericBudgetInfo(time_period="1d", budget_limit=100), }, @@ -515,7 +511,7 @@ async def test_get_current_provider_spend(): # Test provider with budget config and spend spend_key = "provider_spend:openai:1d" - await provider_budget.router_cache.async_set_cache(key=spend_key, value=50.5) + await provider_budget.dual_cache.async_set_cache(key=spend_key, value=50.5) spend = await provider_budget._get_current_provider_spend("openai") assert spend == 50.5 @@ -534,7 +530,7 @@ async def test_get_current_provider_budget_reset_at(): """ cleanup_redis() provider_budget = RouterBudgetLimiting( - router_cache=DualCache( + dual_cache=DualCache( redis_cache=RedisCache( host=os.getenv("REDIS_HOST"), port=int(os.getenv("REDIS_PORT")), diff --git a/tests/proxy_unit_tests/test_key_generate_prisma.py b/tests/proxy_unit_tests/test_key_generate_prisma.py index 73a9c4bd58..c745f9dd96 100644 --- a/tests/proxy_unit_tests/test_key_generate_prisma.py +++ b/tests/proxy_unit_tests/test_key_generate_prisma.py @@ -1684,124 +1684,109 @@ def test_call_with_key_over_budget_no_cache(prisma_client): print(vars(e)) -def test_call_with_key_over_model_budget(prisma_client): +@pytest.mark.asyncio() +@pytest.mark.parametrize( + "request_model,should_pass", + [ + ("openai/gpt-4o-mini", False), + ("gpt-4o-mini", False), + ("gpt-4o", True), + ], +) +async def test_call_with_key_over_model_budget( + prisma_client, request_model, should_pass +): # 12. Make a call with a key over budget, expect to fail setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") + await litellm.proxy.proxy_server.prisma_client.connect() + verbose_proxy_logger.setLevel(logging.DEBUG) + + # init model max budget limiter + from litellm.proxy.hooks.model_max_budget_limiter import ( + _PROXY_VirtualKeyModelMaxBudgetLimiter, + ) + + model_budget_limiter = _PROXY_VirtualKeyModelMaxBudgetLimiter( + dual_cache=DualCache() + ) + litellm.callbacks.append(model_budget_limiter) + try: - async def test(): - await litellm.proxy.proxy_server.prisma_client.connect() + # set budget for chatgpt-v-2 to 0.000001, expect the next request to fail + model_max_budget = { + "gpt-4o-mini": { + "budget_limit": "0.000001", + "time_period": "1d", + }, + "gpt-4o": { + "budget_limit": "200", + "time_period": "30d", + }, + } - # set budget for chatgpt-v-2 to 0.000001, expect the next request to fail - request = GenerateKeyRequest( - max_budget=1000, - model_max_budget={ - "chatgpt-v-2": 0.000001, - }, - metadata={"user_api_key": 0.0001}, + request = GenerateKeyRequest( + max_budget=100000, # the key itself has a very high budget + model_max_budget=model_max_budget, + ) + key = await generate_key_fn(request) + print(key) + + generated_key = key.key + user_id = key.user_id + bearer_token = "Bearer " + generated_key + + request = Request(scope={"type": "http"}) + request._url = URL(url="/chat/completions") + + async def return_body(): + request_str = f'{{"model": "{request_model}"}}' # Added extra curly braces to escape JSON + return request_str.encode() + + request.body = return_body + + # use generated key to auth in + result = await user_api_key_auth(request=request, api_key=bearer_token) + print("result from user auth with new key", result) + + # update spend using track_cost callback, make 2nd request, it should fail + await litellm.acompletion( + model=request_model, + messages=[{"role": "user", "content": "Hello, how are you?"}], + metadata={ + "user_api_key": hash_token(generated_key), + "user_api_key_model_max_budget": model_max_budget, + }, + ) + + await asyncio.sleep(2) + + # use generated key to auth in + result = await user_api_key_auth(request=request, api_key=bearer_token) + if should_pass is True: + print( + f"Passed request for model={request_model}, model_max_budget={model_max_budget}" ) - key = await generate_key_fn(request) - print(key) - - generated_key = key.key - user_id = key.user_id - bearer_token = "Bearer " + generated_key - - request = Request(scope={"type": "http"}) - request._url = URL(url="/chat/completions") - - async def return_body(): - return b'{"model": "chatgpt-v-2"}' - - request.body = return_body - - # use generated key to auth in - result = await user_api_key_auth(request=request, api_key=bearer_token) - print("result from user auth with new key", result) - - # update spend using track_cost callback, make 2nd request, it should fail - from litellm import Choices, Message, ModelResponse, Usage - from litellm.caching.caching import Cache - from litellm.proxy.proxy_server import ( - _PROXY_track_cost_callback as track_cost_callback, - ) - - litellm.cache = Cache() - import time - import uuid - - request_id = f"chatcmpl-{uuid.uuid4()}" - - resp = ModelResponse( - id=request_id, - choices=[ - Choices( - finish_reason=None, - index=0, - message=Message( - content=" Sure! Here is a short poem about the sky:\n\nA canvas of blue, a", - role="assistant", - ), - ) - ], - model="gpt-35-turbo", # azure always has model written like this - usage=Usage(prompt_tokens=210, completion_tokens=200, total_tokens=410), - ) - await track_cost_callback( - kwargs={ - "model": "chatgpt-v-2", - "stream": False, - "litellm_params": { - "metadata": { - "user_api_key": hash_token(generated_key), - "user_api_key_user_id": user_id, - } - }, - "response_cost": 0.00002, - }, - completion_response=resp, - start_time=datetime.now(), - end_time=datetime.now(), - ) - await update_spend( - prisma_client=prisma_client, - db_writer_client=None, - proxy_logging_obj=proxy_logging_obj, - ) - # test spend_log was written and we can read it - spend_logs = await view_spend_logs( - request_id=request_id, - user_api_key_dict=UserAPIKeyAuth(api_key=generated_key), - ) - - print("read spend logs", spend_logs) - assert len(spend_logs) == 1 - - spend_log = spend_logs[0] - - assert spend_log.request_id == request_id - assert spend_log.spend == float("2e-05") - assert spend_log.model == "chatgpt-v-2" - assert ( - spend_log.cache_key - == "c891d64397a472e6deb31b87a5ac4d3ed5b2dcc069bc87e2afe91e6d64e95a1e" - ) - - # use generated key to auth in - result = await user_api_key_auth(request=request, api_key=bearer_token) - print("result from user auth with new key", result) - pytest.fail("This should have failed!. They key crossed it's budget") - - asyncio.run(test()) + return + print("result from user auth with new key", result) + pytest.fail("This should have failed!. They key crossed it's budget") except Exception as e: # print(f"Error - {str(e)}") + print( + f"Failed request for model={request_model}, model_max_budget={model_max_budget}" + ) + assert ( + should_pass is False + ), f"This should have failed!. They key crossed it's budget for model={request_model}. {e}" traceback.print_exc() error_detail = e.message - assert "Budget has been exceeded!" in error_detail + assert f"exceeded budget for model={request_model}" in error_detail assert isinstance(e, ProxyException) assert e.type == ProxyErrorTypes.budget_exceeded print(vars(e)) + finally: + litellm.callbacks.remove(model_budget_limiter) @pytest.mark.asyncio() diff --git a/tests/proxy_unit_tests/test_unit_test_max_model_budget_limiter.py b/tests/proxy_unit_tests/test_unit_test_max_model_budget_limiter.py new file mode 100644 index 0000000000..52a8b03909 --- /dev/null +++ b/tests/proxy_unit_tests/test_unit_test_max_model_budget_limiter.py @@ -0,0 +1,127 @@ +import json +import os +import sys +from datetime import datetime +from unittest.mock import AsyncMock + +from pydantic.main import Model + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system-path +from datetime import datetime as dt_object +import time +import pytest +import litellm + +import json +from litellm.types.utils import GenericBudgetInfo +import os +import sys +from datetime import datetime +from unittest.mock import AsyncMock, patch +import pytest +from litellm.caching.caching import DualCache +from litellm.proxy.hooks.model_max_budget_limiter import ( + _PROXY_VirtualKeyModelMaxBudgetLimiter, +) +from litellm.proxy._types import UserAPIKeyAuth +import litellm + + +# Test class setup +@pytest.fixture +def budget_limiter(): + dual_cache = DualCache() + return _PROXY_VirtualKeyModelMaxBudgetLimiter(dual_cache=dual_cache) + + +# Test _get_model_without_custom_llm_provider +def test_get_model_without_custom_llm_provider(budget_limiter): + # Test with custom provider + assert ( + budget_limiter._get_model_without_custom_llm_provider("openai/gpt-4") == "gpt-4" + ) + + # Test without custom provider + assert budget_limiter._get_model_without_custom_llm_provider("gpt-4") == "gpt-4" + + +# Test _get_request_model_budget_config +def test_get_request_model_budget_config(budget_limiter): + internal_budget = { + "gpt-4": GenericBudgetInfo(budget_limit=100.0, time_period="1d"), + "claude-3": GenericBudgetInfo(budget_limit=50.0, time_period="1d"), + } + + # Test direct model match + config = budget_limiter._get_request_model_budget_config( + model="gpt-4", internal_model_max_budget=internal_budget + ) + assert config.budget_limit == 100.0 + + # Test model with provider + config = budget_limiter._get_request_model_budget_config( + model="openai/gpt-4", internal_model_max_budget=internal_budget + ) + assert config.budget_limit == 100.0 + + # Test non-existent model + config = budget_limiter._get_request_model_budget_config( + model="non-existent", internal_model_max_budget=internal_budget + ) + assert config is None + + +# Test is_key_within_model_budget +@pytest.mark.asyncio +async def test_is_key_within_model_budget(budget_limiter): + # Mock user API key dict + user_api_key = UserAPIKeyAuth( + token="test-key", + key_alias="test-alias", + model_max_budget={"gpt-4": {"budget_limit": 100.0, "time_period": "1d"}}, + ) + + # Test when model is within budget + with patch.object( + budget_limiter, "_get_virtual_key_spend_for_model", return_value=50.0 + ): + assert ( + await budget_limiter.is_key_within_model_budget(user_api_key, "gpt-4") + is True + ) + + # Test when model exceeds budget + with patch.object( + budget_limiter, "_get_virtual_key_spend_for_model", return_value=150.0 + ): + with pytest.raises(litellm.BudgetExceededError): + await budget_limiter.is_key_within_model_budget(user_api_key, "gpt-4") + + # Test model not in budget config + assert ( + await budget_limiter.is_key_within_model_budget(user_api_key, "non-existent") + is True + ) + + +# Test _get_virtual_key_spend_for_model +@pytest.mark.asyncio +async def test_get_virtual_key_spend_for_model(budget_limiter): + budget_config = GenericBudgetInfo(budget_limit=100.0, time_period="1d") + + # Mock cache get + with patch.object(budget_limiter.dual_cache, "async_get_cache", return_value=50.0): + spend = await budget_limiter._get_virtual_key_spend_for_model( + user_api_key_hash="test-key", model="gpt-4", key_budget_config=budget_config + ) + assert spend == 50.0 + + # Test with provider prefix + spend = await budget_limiter._get_virtual_key_spend_for_model( + user_api_key_hash="test-key", + model="openai/gpt-4", + key_budget_config=budget_config, + ) + assert spend == 50.0