fix(proxy_server.py): fix key caching logic

This commit is contained in:
Krrish Dholakia 2024-03-13 19:10:24 -07:00
parent acc672a78f
commit 1b807fa3f5
5 changed files with 214 additions and 75 deletions

View file

@ -651,7 +651,7 @@ async def user_api_key_auth(
)
)
if valid_token.spend > valid_token.max_budget:
if valid_token.spend >= valid_token.max_budget:
raise Exception(
f"ExceededTokenBudget: Current spend for token: {valid_token.spend}; Max Budget for Token: {valid_token.max_budget}"
)
@ -678,14 +678,17 @@ async def user_api_key_auth(
]
},
)
if len(model_spend) > 0:
if (
len(model_spend) > 0
and max_budget_per_model.get(current_model, None) is not None
):
if (
model_spend[0]["model"] == model
model_spend[0]["model"] == current_model
and model_spend[0]["_sum"]["spend"]
>= max_budget_per_model["model"]
>= max_budget_per_model[current_model]
):
current_model_spend = model_spend[0]["_sum"]["spend"]
current_model_budget = max_budget_per_model["model"]
current_model_budget = max_budget_per_model[current_model]
raise Exception(
f"ExceededModelBudget: Current spend for model: {current_model_spend}; Max Budget for Model: {current_model_budget}"
)
@ -742,15 +745,7 @@ async def user_api_key_auth(
This makes the user row data accessible to pre-api call hooks.
"""
if prisma_client is not None:
asyncio.create_task(
_cache_user_row(
user_id=valid_token.user_id,
cache=user_api_key_cache,
db=prisma_client,
)
)
elif custom_db_client is not None:
if custom_db_client is not None:
asyncio.create_task(
_cache_user_row(
user_id=valid_token.user_id,
@ -1023,7 +1018,9 @@ async def _PROXY_track_cost_callback(
end_time=end_time,
)
await update_cache(token=user_api_key, response_cost=response_cost)
await update_cache(
token=user_api_key, user_id=user_id, response_cost=response_cost
)
else:
raise Exception("User API key missing from custom callback.")
else:
@ -1072,15 +1069,54 @@ async def update_database(
- Update that user's row
- Update litellm-proxy-budget row (global proxy spend)
"""
user_ids = [user_id, litellm_proxy_budget_name]
## if an end-user is passed in, do an upsert - we can't guarantee they already exist in db
end_user_id = None
if isinstance(token, str) and token.startswith("sk-"):
hashed_token = hash_token(token=token)
else:
hashed_token = token
existing_token_obj = await user_api_key_cache.async_get_cache(
key=hashed_token
)
existing_user_obj = await user_api_key_cache.async_get_cache(key=user_id)
if existing_token_obj.user_id != user_id: # an end-user id was passed in
end_user_id = user_id
user_ids = [existing_token_obj.user_id, litellm_proxy_budget_name]
data_list = []
try:
if prisma_client is not None: # update
user_ids = [user_id, litellm_proxy_budget_name]
await prisma_client.db.litellm_usertable.update(
## do a group update for the user-id of the key + global proxy budget
await prisma_client.db.litellm_usertable.update_many(
where={"user_id": {"in": user_ids}},
data={"spend": {"increment": response_cost}},
)
if end_user_id is not None:
if existing_user_obj is None:
# if user does not exist in LiteLLM_UserTable, create a new user
existing_spend = 0
max_user_budget = None
if litellm.max_user_budget is not None:
max_user_budget = litellm.max_user_budget
existing_user_obj = LiteLLM_UserTable(
user_id=end_user_id,
spend=0,
max_budget=max_user_budget,
user_email=None,
)
else:
existing_user_obj.spend = (
existing_user_obj.spend + response_cost
)
await prisma_client.db.litellm_usertable.upsert(
where={"user_id": end_user_id},
data={
"create": {**existing_user_obj.json(exclude_none=True)},
"update": {"spend": {"increment": response_cost}},
},
)
elif custom_db_client is not None:
for id in user_ids:
if id is None:
@ -1261,13 +1297,11 @@ async def update_database(
)
raise e
tasks = []
tasks.append(_update_user_db())
tasks.append(_update_key_db())
tasks.append(_update_team_db())
tasks.append(_insert_spend_log_to_db())
asyncio.create_task(_update_user_db())
asyncio.create_task(_update_key_db())
asyncio.create_task(_update_team_db())
asyncio.create_task(_insert_spend_log_to_db())
await asyncio.gather(*tasks)
verbose_proxy_logger.info("Successfully updated spend in all 3 tables")
except Exception as e:
verbose_proxy_logger.debug(
@ -1277,6 +1311,7 @@ async def update_database(
async def update_cache(
token,
user_id,
response_cost,
):
"""
@ -1284,63 +1319,131 @@ async def update_cache(
Put any alerting logic in here.
"""
### UPDATE KEY SPEND ###
# Fetch the existing cost for the given token
existing_spend_obj = await user_api_key_cache.async_get_cache(key=token)
verbose_proxy_logger.debug(f"_update_key_db: existing spend: {existing_spend_obj}")
if existing_spend_obj is None:
existing_spend = 0
else:
existing_spend = existing_spend_obj.spend
# Calculate the new cost by adding the existing cost and response_cost
new_spend = existing_spend + response_cost
## CHECK IF USER PROJECTED SPEND > SOFT LIMIT
soft_budget_cooldown = existing_spend_obj.soft_budget_cooldown
if (
existing_spend_obj.soft_budget_cooldown == False
and existing_spend_obj.litellm_budget_table is not None
and (
_is_projected_spend_over_limit(
### UPDATE KEY SPEND ###
async def _update_key_cache():
# Fetch the existing cost for the given token
if isinstance(token, str) and token.startswith("sk-"):
hashed_token = hash_token(token=token)
else:
hashed_token = token
existing_spend_obj = await user_api_key_cache.async_get_cache(key=hashed_token)
verbose_proxy_logger.debug(
f"_update_key_db: existing spend: {existing_spend_obj}"
)
if existing_spend_obj is None:
existing_spend = 0
else:
existing_spend = existing_spend_obj.spend
# Calculate the new cost by adding the existing cost and response_cost
new_spend = existing_spend + response_cost
## CHECK IF USER PROJECTED SPEND > SOFT LIMIT
soft_budget_cooldown = existing_spend_obj.soft_budget_cooldown
if (
existing_spend_obj.soft_budget_cooldown == False
and existing_spend_obj.litellm_budget_table is not None
and (
_is_projected_spend_over_limit(
current_spend=new_spend,
soft_budget_limit=existing_spend_obj.litellm_budget_table.soft_budget,
)
== True
)
):
key_alias = existing_spend_obj.key_alias
projected_spend, projected_exceeded_date = _get_projected_spend_over_limit(
current_spend=new_spend,
soft_budget_limit=existing_spend_obj.litellm_budget_table.soft_budget,
)
== True
)
):
key_alias = existing_spend_obj.key_alias
projected_spend, projected_exceeded_date = _get_projected_spend_over_limit(
current_spend=new_spend,
soft_budget_limit=existing_spend_obj.litellm_budget_table.soft_budget,
)
soft_limit = existing_spend_obj.litellm_budget_table.soft_budget
user_info = {
"key_alias": key_alias,
"projected_spend": projected_spend,
"projected_exceeded_date": projected_exceeded_date,
}
# alert user
asyncio.create_task(
proxy_logging_obj.budget_alerts(
type="projected_limit_exceeded",
user_info=user_info,
user_max_budget=soft_limit,
user_current_spend=new_spend,
soft_limit = existing_spend_obj.litellm_budget_table.soft_budget
user_info = {
"key_alias": key_alias,
"projected_spend": projected_spend,
"projected_exceeded_date": projected_exceeded_date,
}
# alert user
asyncio.create_task(
proxy_logging_obj.budget_alerts(
type="projected_limit_exceeded",
user_info=user_info,
user_max_budget=soft_limit,
user_current_spend=new_spend,
)
)
)
# set cooldown on alert
soft_budget_cooldown = True
# set cooldown on alert
soft_budget_cooldown = True
if existing_spend_obj is None:
existing_team_spend = 0
else:
existing_team_spend = existing_spend_obj.team_spend
# Calculate the new cost by adding the existing cost and response_cost
existing_spend_obj.team_spend = existing_team_spend + response_cost
if (
existing_spend_obj is not None
and getattr(existing_spend_obj, "team_spend", None) is not None
):
existing_team_spend = existing_spend_obj.team_spend
# Calculate the new cost by adding the existing cost and response_cost
existing_spend_obj.team_spend = existing_team_spend + response_cost
# Update the cost column for the given token
existing_spend_obj.spend = new_spend
user_api_key_cache.set_cache(key=token, value=existing_spend_obj)
# Update the cost column for the given token
existing_spend_obj.spend = new_spend
user_api_key_cache.set_cache(key=hashed_token, value=existing_spend_obj)
async def _update_user_cache():
## UPDATE CACHE FOR USER ID + GLOBAL PROXY
end_user_id = None
if isinstance(token, str) and token.startswith("sk-"):
hashed_token = hash_token(token=token)
else:
hashed_token = token
existing_token_obj = await user_api_key_cache.async_get_cache(key=hashed_token)
existing_user_obj = await user_api_key_cache.async_get_cache(key=user_id)
if existing_token_obj.user_id != user_id: # an end-user id was passed in
end_user_id = user_id
user_ids = [existing_token_obj.user_id, litellm_proxy_budget_name, end_user_id]
try:
for _id in user_ids:
# Fetch the existing cost for the given user
existing_spend_obj = await user_api_key_cache.async_get_cache(key=_id)
if existing_spend_obj is None:
# if user does not exist in LiteLLM_UserTable, create a new user
existing_spend = 0
max_user_budget = None
if litellm.max_user_budget is not None:
max_user_budget = litellm.max_user_budget
existing_spend_obj = LiteLLM_UserTable(
user_id=_id,
spend=0,
max_budget=max_user_budget,
user_email=None,
)
verbose_proxy_logger.debug(
f"_update_user_db: existing spend: {existing_spend_obj}"
)
if existing_spend_obj is None:
existing_spend = 0
else:
if isinstance(existing_spend_obj, dict):
existing_spend = existing_spend_obj["spend"]
else:
existing_spend = existing_spend_obj.spend
# Calculate the new cost by adding the existing cost and response_cost
new_spend = existing_spend + response_cost
# Update the cost column for the given user
if isinstance(existing_spend_obj, dict):
existing_spend_obj["spend"] = new_spend
user_api_key_cache.set_cache(key=_id, value=existing_spend_obj)
else:
existing_spend_obj.spend = new_spend
user_api_key_cache.set_cache(
key=_id, value=existing_spend_obj.json()
)
except Exception as e:
verbose_proxy_logger.debug(
f"An error occurred updating user cache: {str(e)}\n\n{traceback.format_exc()}"
)
asyncio.create_task(_update_key_cache())
asyncio.create_task(_update_user_cache())
def run_ollama_serve():