From 8979b74d49b5daefc37bcafc2032561d22017074 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Mon, 22 Jan 2024 12:13:19 -0800 Subject: [PATCH] (feat) working budgets per key --- litellm/proxy/proxy_server.py | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 9a2196519..dc965f5c9 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -306,6 +306,7 @@ async def user_api_key_auth( # 1. If token can call model # 2. If user_id for this token is in budget # 3. If token is expired + # 4. If token spend is under Budget for the token # Check 1. If token can call model litellm.model_alias_map = valid_token.aliases @@ -406,6 +407,13 @@ async def user_api_key_auth( detail=f"Authentication Error - Expired Key. Key Expiry time {expiry_time} and current time {current_time}", ) + # Check 4. Token Spend is under budget + if valid_token.spend is not None and valid_token.max_budget is not None: + if valid_token.spend > valid_token.max_budget: + raise Exception( + f"ExceededTokenBudget: Current spend for token: {valid_token.spend}; Max Budget for Token: {valid_token.max_budget}" + ) + # Token passed all checks # Add token to cache user_api_key_cache.set_cache(key=api_key, value=valid_token, ttl=60) @@ -668,7 +676,9 @@ async def update_database( if prisma_client is not None: # Fetch the existing cost for the given token existing_spend_obj = await prisma_client.get_data(token=token) - verbose_proxy_logger.debug(f"existing spend: {existing_spend_obj}") + verbose_proxy_logger.debug( + f"_update_key_db: existing spend: {existing_spend_obj}" + ) if existing_spend_obj is None: existing_spend = 0 else: @@ -679,12 +689,18 @@ async def update_database( verbose_proxy_logger.debug(f"new cost: {new_spend}") # Update the cost column for the given token await prisma_client.update_data(token=token, data={"spend": new_spend}) + + valid_token = user_api_key_cache.get_cache(key=token) + valid_token.spend = new_spend + user_api_key_cache.set_cache(key=token, value=valid_token) elif custom_db_client is not None: # Fetch the existing cost for the given token existing_spend_obj = await custom_db_client.get_data( key=token, table_name="key" ) - verbose_proxy_logger.debug(f"existing spend: {existing_spend_obj}") + verbose_proxy_logger.debug( + f"_update_key_db existing spend: {existing_spend_obj}" + ) if existing_spend_obj is None: existing_spend = 0 else: @@ -698,6 +714,10 @@ async def update_database( key=token, value={"spend": new_spend}, table_name="key" ) + valid_token = user_api_key_cache.get_cache(key=token) + valid_token.spend = new_spend + user_api_key_cache.set_cache(key=token, value=valid_token) + async def _insert_spend_log_to_db(): # Helper to generate payload to log verbose_proxy_logger.debug("inserting spend log to db")