mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 11:43:54 +00:00
fix(proxy_server.py): fix key caching logic
This commit is contained in:
parent
8f6af575e7
commit
3232feb123
5 changed files with 214 additions and 75 deletions
|
@ -742,6 +742,39 @@ class DualCache(BaseCache):
|
|||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
|
||||
async def async_get_cache(self, key, local_only: bool = False, **kwargs):
|
||||
# Try to fetch from in-memory cache first
|
||||
try:
|
||||
print_verbose(
|
||||
f"async get cache: cache key: {key}; local_only: {local_only}"
|
||||
)
|
||||
result = None
|
||||
if self.in_memory_cache is not None:
|
||||
in_memory_result = await self.in_memory_cache.async_get_cache(
|
||||
key, **kwargs
|
||||
)
|
||||
|
||||
print_verbose(f"in_memory_result: {in_memory_result}")
|
||||
if in_memory_result is not None:
|
||||
result = in_memory_result
|
||||
|
||||
if result is None and self.redis_cache is not None and local_only == False:
|
||||
# If not found in in-memory cache, try fetching from Redis
|
||||
redis_result = await self.redis_cache.async_get_cache(key, **kwargs)
|
||||
|
||||
if redis_result is not None:
|
||||
# Update in-memory cache with the value from Redis
|
||||
await self.in_memory_cache.async_set_cache(
|
||||
key, redis_result, **kwargs
|
||||
)
|
||||
|
||||
result = redis_result
|
||||
|
||||
print_verbose(f"get cache: cache result: {result}")
|
||||
return result
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
|
||||
def flush_cache(self):
|
||||
if self.in_memory_cache is not None:
|
||||
self.in_memory_cache.flush_cache()
|
||||
|
|
|
@ -535,6 +535,8 @@ class LiteLLM_VerificationToken(LiteLLMBase):
|
|||
permissions: Dict = {}
|
||||
model_spend: Dict = {}
|
||||
model_max_budget: Dict = {}
|
||||
soft_budget_cooldown: bool = False
|
||||
litellm_budget_table: Optional[dict] = None
|
||||
|
||||
# hidden params used for parallel request limiting, not required to create a token
|
||||
user_id_rate_limits: Optional[dict] = None
|
||||
|
|
|
@ -651,7 +651,7 @@ async def user_api_key_auth(
|
|||
)
|
||||
)
|
||||
|
||||
if valid_token.spend > valid_token.max_budget:
|
||||
if valid_token.spend >= valid_token.max_budget:
|
||||
raise Exception(
|
||||
f"ExceededTokenBudget: Current spend for token: {valid_token.spend}; Max Budget for Token: {valid_token.max_budget}"
|
||||
)
|
||||
|
@ -678,14 +678,17 @@ async def user_api_key_auth(
|
|||
]
|
||||
},
|
||||
)
|
||||
if len(model_spend) > 0:
|
||||
if (
|
||||
model_spend[0]["model"] == model
|
||||
len(model_spend) > 0
|
||||
and max_budget_per_model.get(current_model, None) is not None
|
||||
):
|
||||
if (
|
||||
model_spend[0]["model"] == current_model
|
||||
and model_spend[0]["_sum"]["spend"]
|
||||
>= max_budget_per_model["model"]
|
||||
>= max_budget_per_model[current_model]
|
||||
):
|
||||
current_model_spend = model_spend[0]["_sum"]["spend"]
|
||||
current_model_budget = max_budget_per_model["model"]
|
||||
current_model_budget = max_budget_per_model[current_model]
|
||||
raise Exception(
|
||||
f"ExceededModelBudget: Current spend for model: {current_model_spend}; Max Budget for Model: {current_model_budget}"
|
||||
)
|
||||
|
@ -742,15 +745,7 @@ async def user_api_key_auth(
|
|||
|
||||
This makes the user row data accessible to pre-api call hooks.
|
||||
"""
|
||||
if prisma_client is not None:
|
||||
asyncio.create_task(
|
||||
_cache_user_row(
|
||||
user_id=valid_token.user_id,
|
||||
cache=user_api_key_cache,
|
||||
db=prisma_client,
|
||||
)
|
||||
)
|
||||
elif custom_db_client is not None:
|
||||
if custom_db_client is not None:
|
||||
asyncio.create_task(
|
||||
_cache_user_row(
|
||||
user_id=valid_token.user_id,
|
||||
|
@ -1023,7 +1018,9 @@ async def _PROXY_track_cost_callback(
|
|||
end_time=end_time,
|
||||
)
|
||||
|
||||
await update_cache(token=user_api_key, response_cost=response_cost)
|
||||
await update_cache(
|
||||
token=user_api_key, user_id=user_id, response_cost=response_cost
|
||||
)
|
||||
else:
|
||||
raise Exception("User API key missing from custom callback.")
|
||||
else:
|
||||
|
@ -1072,15 +1069,54 @@ async def update_database(
|
|||
- Update that user's row
|
||||
- Update litellm-proxy-budget row (global proxy spend)
|
||||
"""
|
||||
user_ids = [user_id, litellm_proxy_budget_name]
|
||||
## if an end-user is passed in, do an upsert - we can't guarantee they already exist in db
|
||||
end_user_id = None
|
||||
if isinstance(token, str) and token.startswith("sk-"):
|
||||
hashed_token = hash_token(token=token)
|
||||
else:
|
||||
hashed_token = token
|
||||
existing_token_obj = await user_api_key_cache.async_get_cache(
|
||||
key=hashed_token
|
||||
)
|
||||
existing_user_obj = await user_api_key_cache.async_get_cache(key=user_id)
|
||||
if existing_token_obj.user_id != user_id: # an end-user id was passed in
|
||||
end_user_id = user_id
|
||||
user_ids = [existing_token_obj.user_id, litellm_proxy_budget_name]
|
||||
data_list = []
|
||||
try:
|
||||
if prisma_client is not None: # update
|
||||
user_ids = [user_id, litellm_proxy_budget_name]
|
||||
await prisma_client.db.litellm_usertable.update(
|
||||
## do a group update for the user-id of the key + global proxy budget
|
||||
await prisma_client.db.litellm_usertable.update_many(
|
||||
where={"user_id": {"in": user_ids}},
|
||||
data={"spend": {"increment": response_cost}},
|
||||
)
|
||||
if end_user_id is not None:
|
||||
if existing_user_obj is None:
|
||||
# if user does not exist in LiteLLM_UserTable, create a new user
|
||||
existing_spend = 0
|
||||
max_user_budget = None
|
||||
if litellm.max_user_budget is not None:
|
||||
max_user_budget = litellm.max_user_budget
|
||||
existing_user_obj = LiteLLM_UserTable(
|
||||
user_id=end_user_id,
|
||||
spend=0,
|
||||
max_budget=max_user_budget,
|
||||
user_email=None,
|
||||
)
|
||||
else:
|
||||
existing_user_obj.spend = (
|
||||
existing_user_obj.spend + response_cost
|
||||
)
|
||||
|
||||
await prisma_client.db.litellm_usertable.upsert(
|
||||
where={"user_id": end_user_id},
|
||||
data={
|
||||
"create": {**existing_user_obj.json(exclude_none=True)},
|
||||
"update": {"spend": {"increment": response_cost}},
|
||||
},
|
||||
)
|
||||
|
||||
elif custom_db_client is not None:
|
||||
for id in user_ids:
|
||||
if id is None:
|
||||
|
@ -1261,13 +1297,11 @@ async def update_database(
|
|||
)
|
||||
raise e
|
||||
|
||||
tasks = []
|
||||
tasks.append(_update_user_db())
|
||||
tasks.append(_update_key_db())
|
||||
tasks.append(_update_team_db())
|
||||
tasks.append(_insert_spend_log_to_db())
|
||||
asyncio.create_task(_update_user_db())
|
||||
asyncio.create_task(_update_key_db())
|
||||
asyncio.create_task(_update_team_db())
|
||||
asyncio.create_task(_insert_spend_log_to_db())
|
||||
|
||||
await asyncio.gather(*tasks)
|
||||
verbose_proxy_logger.info("Successfully updated spend in all 3 tables")
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.debug(
|
||||
|
@ -1277,6 +1311,7 @@ async def update_database(
|
|||
|
||||
async def update_cache(
|
||||
token,
|
||||
user_id,
|
||||
response_cost,
|
||||
):
|
||||
"""
|
||||
|
@ -1284,10 +1319,18 @@ async def update_cache(
|
|||
|
||||
Put any alerting logic in here.
|
||||
"""
|
||||
|
||||
### UPDATE KEY SPEND ###
|
||||
async def _update_key_cache():
|
||||
# Fetch the existing cost for the given token
|
||||
existing_spend_obj = await user_api_key_cache.async_get_cache(key=token)
|
||||
verbose_proxy_logger.debug(f"_update_key_db: existing spend: {existing_spend_obj}")
|
||||
if isinstance(token, str) and token.startswith("sk-"):
|
||||
hashed_token = hash_token(token=token)
|
||||
else:
|
||||
hashed_token = token
|
||||
existing_spend_obj = await user_api_key_cache.async_get_cache(key=hashed_token)
|
||||
verbose_proxy_logger.debug(
|
||||
f"_update_key_db: existing spend: {existing_spend_obj}"
|
||||
)
|
||||
if existing_spend_obj is None:
|
||||
existing_spend = 0
|
||||
else:
|
||||
|
@ -1331,16 +1374,76 @@ async def update_cache(
|
|||
# set cooldown on alert
|
||||
soft_budget_cooldown = True
|
||||
|
||||
if existing_spend_obj is None:
|
||||
existing_team_spend = 0
|
||||
else:
|
||||
if (
|
||||
existing_spend_obj is not None
|
||||
and getattr(existing_spend_obj, "team_spend", None) is not None
|
||||
):
|
||||
existing_team_spend = existing_spend_obj.team_spend
|
||||
# Calculate the new cost by adding the existing cost and response_cost
|
||||
existing_spend_obj.team_spend = existing_team_spend + response_cost
|
||||
|
||||
# Update the cost column for the given token
|
||||
existing_spend_obj.spend = new_spend
|
||||
user_api_key_cache.set_cache(key=token, value=existing_spend_obj)
|
||||
user_api_key_cache.set_cache(key=hashed_token, value=existing_spend_obj)
|
||||
|
||||
async def _update_user_cache():
|
||||
## UPDATE CACHE FOR USER ID + GLOBAL PROXY
|
||||
end_user_id = None
|
||||
if isinstance(token, str) and token.startswith("sk-"):
|
||||
hashed_token = hash_token(token=token)
|
||||
else:
|
||||
hashed_token = token
|
||||
existing_token_obj = await user_api_key_cache.async_get_cache(key=hashed_token)
|
||||
existing_user_obj = await user_api_key_cache.async_get_cache(key=user_id)
|
||||
if existing_token_obj.user_id != user_id: # an end-user id was passed in
|
||||
end_user_id = user_id
|
||||
user_ids = [existing_token_obj.user_id, litellm_proxy_budget_name, end_user_id]
|
||||
|
||||
try:
|
||||
for _id in user_ids:
|
||||
# Fetch the existing cost for the given user
|
||||
existing_spend_obj = await user_api_key_cache.async_get_cache(key=_id)
|
||||
if existing_spend_obj is None:
|
||||
# if user does not exist in LiteLLM_UserTable, create a new user
|
||||
existing_spend = 0
|
||||
max_user_budget = None
|
||||
if litellm.max_user_budget is not None:
|
||||
max_user_budget = litellm.max_user_budget
|
||||
existing_spend_obj = LiteLLM_UserTable(
|
||||
user_id=_id,
|
||||
spend=0,
|
||||
max_budget=max_user_budget,
|
||||
user_email=None,
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
f"_update_user_db: existing spend: {existing_spend_obj}"
|
||||
)
|
||||
if existing_spend_obj is None:
|
||||
existing_spend = 0
|
||||
else:
|
||||
if isinstance(existing_spend_obj, dict):
|
||||
existing_spend = existing_spend_obj["spend"]
|
||||
else:
|
||||
existing_spend = existing_spend_obj.spend
|
||||
# Calculate the new cost by adding the existing cost and response_cost
|
||||
new_spend = existing_spend + response_cost
|
||||
|
||||
# Update the cost column for the given user
|
||||
if isinstance(existing_spend_obj, dict):
|
||||
existing_spend_obj["spend"] = new_spend
|
||||
user_api_key_cache.set_cache(key=_id, value=existing_spend_obj)
|
||||
else:
|
||||
existing_spend_obj.spend = new_spend
|
||||
user_api_key_cache.set_cache(
|
||||
key=_id, value=existing_spend_obj.json()
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.debug(
|
||||
f"An error occurred updating user cache: {str(e)}\n\n{traceback.format_exc()}"
|
||||
)
|
||||
|
||||
asyncio.create_task(_update_key_cache())
|
||||
asyncio.create_task(_update_user_cache())
|
||||
|
||||
|
||||
def run_ollama_serve():
|
||||
|
|
|
@ -1596,7 +1596,6 @@ async def _cache_user_row(
|
|||
Check if a user_id exists in cache,
|
||||
if not retrieve it.
|
||||
"""
|
||||
print_verbose(f"Prisma: _cache_user_row, user_id: {user_id}")
|
||||
cache_key = f"{user_id}_user_api_key_user_id"
|
||||
response = cache.get_cache(key=cache_key)
|
||||
if response is None: # Cache miss
|
||||
|
|
|
@ -318,7 +318,7 @@ def test_call_with_user_over_budget(prisma_client):
|
|||
|
||||
|
||||
def test_call_with_end_user_over_budget(prisma_client):
|
||||
# Test if a user passed to /chat/completions is tracked & fails whe they cross their budget
|
||||
# Test if a user passed to /chat/completions is tracked & fails when they cross their budget
|
||||
# we only check this when litellm.max_user_budget is set
|
||||
import random
|
||||
|
||||
|
@ -339,6 +339,8 @@ def test_call_with_end_user_over_budget(prisma_client):
|
|||
request = Request(scope={"type": "http"})
|
||||
request._url = URL(url="/chat/completions")
|
||||
|
||||
result = await user_api_key_auth(request=request, api_key=bearer_token)
|
||||
|
||||
async def return_body():
|
||||
return_string = f'{{"model": "gemini-pro-vision", "user": "{user}"}}'
|
||||
# return string as bytes
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue