forked from phoenix/litellm-mirror
Merge pull request #2506 from BerriAI/litellm_update_db_perf_improvements
fix(proxy_server.py): move to using UPDATE + SET for track_cost_callback
This commit is contained in:
commit
d8eff53ebe
5 changed files with 372 additions and 200 deletions
|
@ -742,6 +742,39 @@ class DualCache(BaseCache):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
||||||
|
async def async_get_cache(self, key, local_only: bool = False, **kwargs):
|
||||||
|
# Try to fetch from in-memory cache first
|
||||||
|
try:
|
||||||
|
print_verbose(
|
||||||
|
f"async get cache: cache key: {key}; local_only: {local_only}"
|
||||||
|
)
|
||||||
|
result = None
|
||||||
|
if self.in_memory_cache is not None:
|
||||||
|
in_memory_result = await self.in_memory_cache.async_get_cache(
|
||||||
|
key, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
print_verbose(f"in_memory_result: {in_memory_result}")
|
||||||
|
if in_memory_result is not None:
|
||||||
|
result = in_memory_result
|
||||||
|
|
||||||
|
if result is None and self.redis_cache is not None and local_only == False:
|
||||||
|
# If not found in in-memory cache, try fetching from Redis
|
||||||
|
redis_result = await self.redis_cache.async_get_cache(key, **kwargs)
|
||||||
|
|
||||||
|
if redis_result is not None:
|
||||||
|
# Update in-memory cache with the value from Redis
|
||||||
|
await self.in_memory_cache.async_set_cache(
|
||||||
|
key, redis_result, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
result = redis_result
|
||||||
|
|
||||||
|
print_verbose(f"get cache: cache result: {result}")
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
def flush_cache(self):
|
def flush_cache(self):
|
||||||
if self.in_memory_cache is not None:
|
if self.in_memory_cache is not None:
|
||||||
self.in_memory_cache.flush_cache()
|
self.in_memory_cache.flush_cache()
|
||||||
|
|
|
@ -535,6 +535,8 @@ class LiteLLM_VerificationToken(LiteLLMBase):
|
||||||
permissions: Dict = {}
|
permissions: Dict = {}
|
||||||
model_spend: Dict = {}
|
model_spend: Dict = {}
|
||||||
model_max_budget: Dict = {}
|
model_max_budget: Dict = {}
|
||||||
|
soft_budget_cooldown: bool = False
|
||||||
|
litellm_budget_table: Optional[dict] = None
|
||||||
|
|
||||||
# hidden params used for parallel request limiting, not required to create a token
|
# hidden params used for parallel request limiting, not required to create a token
|
||||||
user_id_rate_limits: Optional[dict] = None
|
user_id_rate_limits: Optional[dict] = None
|
||||||
|
|
|
@ -395,6 +395,7 @@ async def user_api_key_auth(
|
||||||
user_api_key_cache.set_cache(
|
user_api_key_cache.set_cache(
|
||||||
key=hash_token(master_key), value=_user_api_key_obj
|
key=hash_token(master_key), value=_user_api_key_obj
|
||||||
)
|
)
|
||||||
|
|
||||||
return _user_api_key_obj
|
return _user_api_key_obj
|
||||||
|
|
||||||
if isinstance(
|
if isinstance(
|
||||||
|
@ -658,30 +659,47 @@ async def user_api_key_auth(
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if valid_token.spend > valid_token.max_budget:
|
if valid_token.spend >= valid_token.max_budget:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"ExceededTokenBudget: Current spend for token: {valid_token.spend}; Max Budget for Token: {valid_token.max_budget}"
|
f"ExceededTokenBudget: Current spend for token: {valid_token.spend}; Max Budget for Token: {valid_token.max_budget}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check 5. Token Model Spend is under Model budget
|
# Check 5. Token Model Spend is under Model budget
|
||||||
max_budget_per_model = valid_token.model_max_budget
|
max_budget_per_model = valid_token.model_max_budget
|
||||||
spend_per_model = valid_token.model_spend
|
|
||||||
|
|
||||||
if max_budget_per_model is not None and spend_per_model is not None:
|
|
||||||
current_model = request_data.get("model")
|
|
||||||
if current_model is not None:
|
|
||||||
current_model_spend = spend_per_model.get(current_model, None)
|
|
||||||
current_model_budget = max_budget_per_model.get(current_model, None)
|
|
||||||
|
|
||||||
if (
|
if (
|
||||||
current_model_spend is not None
|
max_budget_per_model is not None
|
||||||
and current_model_budget is not None
|
and isinstance(max_budget_per_model, dict)
|
||||||
|
and len(max_budget_per_model) > 0
|
||||||
):
|
):
|
||||||
if current_model_spend > current_model_budget:
|
current_model = request_data.get("model")
|
||||||
|
## GET THE SPEND FOR THIS MODEL
|
||||||
|
twenty_eight_days_ago = datetime.now() - timedelta(days=28)
|
||||||
|
model_spend = await prisma_client.db.litellm_spendlogs.group_by(
|
||||||
|
by=["model"],
|
||||||
|
sum={"spend": True},
|
||||||
|
where={
|
||||||
|
"AND": [
|
||||||
|
{"api_key": valid_token.token},
|
||||||
|
{"startTime": {"gt": twenty_eight_days_ago}},
|
||||||
|
{"model": current_model},
|
||||||
|
]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
len(model_spend) > 0
|
||||||
|
and max_budget_per_model.get(current_model, None) is not None
|
||||||
|
):
|
||||||
|
if (
|
||||||
|
model_spend[0]["model"] == current_model
|
||||||
|
and model_spend[0]["_sum"]["spend"]
|
||||||
|
>= max_budget_per_model[current_model]
|
||||||
|
):
|
||||||
|
current_model_spend = model_spend[0]["_sum"]["spend"]
|
||||||
|
current_model_budget = max_budget_per_model[current_model]
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"ExceededModelBudget: Current spend for model: {current_model_spend}; Max Budget for Model: {current_model_budget}"
|
f"ExceededModelBudget: Current spend for model: {current_model_spend}; Max Budget for Model: {current_model_budget}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check 6. Token spend is under Team budget
|
# Check 6. Token spend is under Team budget
|
||||||
if (
|
if (
|
||||||
valid_token.spend is not None
|
valid_token.spend is not None
|
||||||
|
@ -735,15 +753,7 @@ async def user_api_key_auth(
|
||||||
|
|
||||||
This makes the user row data accessible to pre-api call hooks.
|
This makes the user row data accessible to pre-api call hooks.
|
||||||
"""
|
"""
|
||||||
if prisma_client is not None:
|
if custom_db_client is not None:
|
||||||
asyncio.create_task(
|
|
||||||
_cache_user_row(
|
|
||||||
user_id=valid_token.user_id,
|
|
||||||
cache=user_api_key_cache,
|
|
||||||
db=prisma_client,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
elif custom_db_client is not None:
|
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
_cache_user_row(
|
_cache_user_row(
|
||||||
user_id=valid_token.user_id,
|
user_id=valid_token.user_id,
|
||||||
|
@ -1015,6 +1025,10 @@ async def _PROXY_track_cost_callback(
|
||||||
start_time=start_time,
|
start_time=start_time,
|
||||||
end_time=end_time,
|
end_time=end_time,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
await update_cache(
|
||||||
|
token=user_api_key, user_id=user_id, response_cost=response_cost
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise Exception("User API key missing from custom callback.")
|
raise Exception("User API key missing from custom callback.")
|
||||||
else:
|
else:
|
||||||
|
@ -1057,26 +1071,67 @@ async def update_database(
|
||||||
f"Enters prisma db call, response_cost: {response_cost}, token: {token}; user_id: {user_id}; team_id: {team_id}"
|
f"Enters prisma db call, response_cost: {response_cost}, token: {token}; user_id: {user_id}; team_id: {team_id}"
|
||||||
)
|
)
|
||||||
|
|
||||||
### [TODO] STEP 1: GET KEY + USER SPEND ### (key, user)
|
|
||||||
|
|
||||||
### [TODO] STEP 2: UPDATE SPEND ### (key, user, spend logs)
|
|
||||||
|
|
||||||
### UPDATE USER SPEND ###
|
### UPDATE USER SPEND ###
|
||||||
async def _update_user_db():
|
async def _update_user_db():
|
||||||
"""
|
"""
|
||||||
- Update that user's row
|
- Update that user's row
|
||||||
- Update litellm-proxy-budget row (global proxy spend)
|
- Update litellm-proxy-budget row (global proxy spend)
|
||||||
"""
|
"""
|
||||||
user_ids = [user_id, litellm_proxy_budget_name]
|
## if an end-user is passed in, do an upsert - we can't guarantee they already exist in db
|
||||||
|
end_user_id = None
|
||||||
|
if isinstance(token, str) and token.startswith("sk-"):
|
||||||
|
hashed_token = hash_token(token=token)
|
||||||
|
else:
|
||||||
|
hashed_token = token
|
||||||
|
existing_token_obj = await user_api_key_cache.async_get_cache(
|
||||||
|
key=hashed_token
|
||||||
|
)
|
||||||
|
existing_user_obj = await user_api_key_cache.async_get_cache(key=user_id)
|
||||||
|
if existing_token_obj.user_id != user_id: # an end-user id was passed in
|
||||||
|
end_user_id = user_id
|
||||||
|
user_ids = [existing_token_obj.user_id, litellm_proxy_budget_name]
|
||||||
data_list = []
|
data_list = []
|
||||||
try:
|
try:
|
||||||
|
if prisma_client is not None: # update
|
||||||
|
user_ids = [user_id, litellm_proxy_budget_name]
|
||||||
|
## do a group update for the user-id of the key + global proxy budget
|
||||||
|
await prisma_client.db.litellm_usertable.update_many(
|
||||||
|
where={"user_id": {"in": user_ids}},
|
||||||
|
data={"spend": {"increment": response_cost}},
|
||||||
|
)
|
||||||
|
if end_user_id is not None:
|
||||||
|
if existing_user_obj is None:
|
||||||
|
# if user does not exist in LiteLLM_UserTable, create a new user
|
||||||
|
existing_spend = 0
|
||||||
|
max_user_budget = None
|
||||||
|
if litellm.max_user_budget is not None:
|
||||||
|
max_user_budget = litellm.max_user_budget
|
||||||
|
existing_user_obj = LiteLLM_UserTable(
|
||||||
|
user_id=end_user_id,
|
||||||
|
spend=0,
|
||||||
|
max_budget=max_user_budget,
|
||||||
|
user_email=None,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
existing_user_obj.spend = (
|
||||||
|
existing_user_obj.spend + response_cost
|
||||||
|
)
|
||||||
|
|
||||||
|
await prisma_client.db.litellm_usertable.upsert(
|
||||||
|
where={"user_id": end_user_id},
|
||||||
|
data={
|
||||||
|
"create": {**existing_user_obj.json(exclude_none=True)},
|
||||||
|
"update": {"spend": {"increment": response_cost}},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
elif custom_db_client is not None:
|
||||||
for id in user_ids:
|
for id in user_ids:
|
||||||
if id is None:
|
if id is None:
|
||||||
continue
|
continue
|
||||||
if prisma_client is not None:
|
if (
|
||||||
existing_spend_obj = await prisma_client.get_data(user_id=id)
|
custom_db_client is not None
|
||||||
elif (
|
and id != litellm_proxy_budget_name
|
||||||
custom_db_client is not None and id != litellm_proxy_budget_name
|
|
||||||
):
|
):
|
||||||
existing_spend_obj = await custom_db_client.get_data(
|
existing_spend_obj = await custom_db_client.get_data(
|
||||||
key=id, table_name="user"
|
key=id, table_name="user"
|
||||||
|
@ -1127,13 +1182,8 @@ async def update_database(
|
||||||
if custom_db_client is not None and user_id is not None:
|
if custom_db_client is not None and user_id is not None:
|
||||||
new_spend = data_list[0].spend
|
new_spend = data_list[0].spend
|
||||||
await custom_db_client.update_data(
|
await custom_db_client.update_data(
|
||||||
key=user_id, value={"spend": new_spend}, table_name="user"
|
key=user_id,
|
||||||
)
|
value={"spend": new_spend},
|
||||||
# Update the cost column for the given user id
|
|
||||||
if prisma_client is not None:
|
|
||||||
await prisma_client.update_data(
|
|
||||||
data_list=data_list,
|
|
||||||
query_type="update_many",
|
|
||||||
table_name="user",
|
table_name="user",
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -1148,82 +1198,10 @@ async def update_database(
|
||||||
f"adding spend to key db. Response cost: {response_cost}. Token: {token}."
|
f"adding spend to key db. Response cost: {response_cost}. Token: {token}."
|
||||||
)
|
)
|
||||||
if prisma_client is not None:
|
if prisma_client is not None:
|
||||||
# Fetch the existing cost for the given token
|
await prisma_client.db.litellm_verificationtoken.update(
|
||||||
existing_spend_obj = await prisma_client.get_data(token=token)
|
where={"token": token},
|
||||||
verbose_proxy_logger.debug(
|
data={"spend": {"increment": response_cost}},
|
||||||
f"_update_key_db: existing spend: {existing_spend_obj}"
|
|
||||||
)
|
)
|
||||||
if existing_spend_obj is None:
|
|
||||||
existing_spend = 0
|
|
||||||
else:
|
|
||||||
existing_spend = existing_spend_obj.spend
|
|
||||||
# Calculate the new cost by adding the existing cost and response_cost
|
|
||||||
new_spend = existing_spend + response_cost
|
|
||||||
|
|
||||||
## CHECK IF USER PROJECTED SPEND > SOFT LIMIT
|
|
||||||
soft_budget_cooldown = existing_spend_obj.soft_budget_cooldown
|
|
||||||
if (
|
|
||||||
existing_spend_obj.soft_budget_cooldown == False
|
|
||||||
and existing_spend_obj.litellm_budget_table is not None
|
|
||||||
and (
|
|
||||||
_is_projected_spend_over_limit(
|
|
||||||
current_spend=new_spend,
|
|
||||||
soft_budget_limit=existing_spend_obj.litellm_budget_table.soft_budget,
|
|
||||||
)
|
|
||||||
== True
|
|
||||||
)
|
|
||||||
):
|
|
||||||
key_alias = existing_spend_obj.key_alias
|
|
||||||
projected_spend, projected_exceeded_date = (
|
|
||||||
_get_projected_spend_over_limit(
|
|
||||||
current_spend=new_spend,
|
|
||||||
soft_budget_limit=existing_spend_obj.litellm_budget_table.soft_budget,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
soft_limit = existing_spend_obj.litellm_budget_table.soft_budget
|
|
||||||
user_info = {
|
|
||||||
"key_alias": key_alias,
|
|
||||||
"projected_spend": projected_spend,
|
|
||||||
"projected_exceeded_date": projected_exceeded_date,
|
|
||||||
}
|
|
||||||
# alert user
|
|
||||||
asyncio.create_task(
|
|
||||||
proxy_logging_obj.budget_alerts(
|
|
||||||
type="projected_limit_exceeded",
|
|
||||||
user_info=user_info,
|
|
||||||
user_max_budget=soft_limit,
|
|
||||||
user_current_spend=new_spend,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
# set cooldown on alert
|
|
||||||
soft_budget_cooldown = True
|
|
||||||
# track cost per model, for the given key
|
|
||||||
spend_per_model = existing_spend_obj.model_spend or {}
|
|
||||||
current_model = kwargs.get("model")
|
|
||||||
if current_model is not None and spend_per_model is not None:
|
|
||||||
if spend_per_model.get(current_model) is None:
|
|
||||||
spend_per_model[current_model] = response_cost
|
|
||||||
else:
|
|
||||||
spend_per_model[current_model] += response_cost
|
|
||||||
|
|
||||||
verbose_proxy_logger.debug(
|
|
||||||
f"new cost: {new_spend}, new spend per model: {spend_per_model}"
|
|
||||||
)
|
|
||||||
# Update the cost column for the given token
|
|
||||||
await prisma_client.update_data(
|
|
||||||
token=token,
|
|
||||||
data={
|
|
||||||
"spend": new_spend,
|
|
||||||
"model_spend": spend_per_model,
|
|
||||||
"soft_budget_cooldown": soft_budget_cooldown,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
valid_token = user_api_key_cache.get_cache(key=token)
|
|
||||||
if valid_token is not None:
|
|
||||||
valid_token.spend = new_spend
|
|
||||||
valid_token.model_spend = spend_per_model
|
|
||||||
user_api_key_cache.set_cache(key=token, value=valid_token)
|
|
||||||
elif custom_db_client is not None:
|
elif custom_db_client is not None:
|
||||||
# Fetch the existing cost for the given token
|
# Fetch the existing cost for the given token
|
||||||
existing_spend_obj = await custom_db_client.get_data(
|
existing_spend_obj = await custom_db_client.get_data(
|
||||||
|
@ -1254,6 +1232,7 @@ async def update_database(
|
||||||
verbose_proxy_logger.info(
|
verbose_proxy_logger.info(
|
||||||
f"Update Key DB Call failed to execute - {str(e)}"
|
f"Update Key DB Call failed to execute - {str(e)}"
|
||||||
)
|
)
|
||||||
|
raise e
|
||||||
|
|
||||||
### UPDATE SPEND LOGS ###
|
### UPDATE SPEND LOGS ###
|
||||||
async def _insert_spend_log_to_db():
|
async def _insert_spend_log_to_db():
|
||||||
|
@ -1277,6 +1256,7 @@ async def update_database(
|
||||||
verbose_proxy_logger.info(
|
verbose_proxy_logger.info(
|
||||||
f"Update Spend Logs DB failed to execute - {str(e)}"
|
f"Update Spend Logs DB failed to execute - {str(e)}"
|
||||||
)
|
)
|
||||||
|
raise e
|
||||||
|
|
||||||
### UPDATE KEY SPEND ###
|
### UPDATE KEY SPEND ###
|
||||||
async def _update_team_db():
|
async def _update_team_db():
|
||||||
|
@ -1290,41 +1270,10 @@ async def update_database(
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
if prisma_client is not None:
|
if prisma_client is not None:
|
||||||
# Fetch the existing cost for the given token
|
await prisma_client.db.litellm_teamtable.update(
|
||||||
existing_spend_obj = await prisma_client.get_data(
|
where={"team_id": team_id},
|
||||||
team_id=team_id, table_name="team"
|
data={"spend": {"increment": response_cost}},
|
||||||
)
|
)
|
||||||
verbose_proxy_logger.debug(
|
|
||||||
f"_update_team_db: existing spend: {existing_spend_obj}"
|
|
||||||
)
|
|
||||||
if existing_spend_obj is None:
|
|
||||||
# the team does not exist in the db - return
|
|
||||||
verbose_proxy_logger.debug(
|
|
||||||
"team_id does not exist in db, not tracking spend for team"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
else:
|
|
||||||
existing_spend = existing_spend_obj.spend
|
|
||||||
# Calculate the new cost by adding the existing cost and response_cost
|
|
||||||
new_spend = existing_spend + response_cost
|
|
||||||
spend_per_model = getattr(existing_spend_obj, "model_spend", {})
|
|
||||||
# track cost per model, for the given team
|
|
||||||
spend_per_model = existing_spend_obj.model_spend or {}
|
|
||||||
current_model = kwargs.get("model")
|
|
||||||
if current_model is not None and spend_per_model is not None:
|
|
||||||
if spend_per_model.get(current_model) is None:
|
|
||||||
spend_per_model[current_model] = response_cost
|
|
||||||
else:
|
|
||||||
spend_per_model[current_model] += response_cost
|
|
||||||
|
|
||||||
verbose_proxy_logger.debug(f"new cost: {new_spend}")
|
|
||||||
# Update the cost column for the given token
|
|
||||||
await prisma_client.update_data(
|
|
||||||
team_id=team_id,
|
|
||||||
data={"spend": new_spend, "model_spend": spend_per_model},
|
|
||||||
table_name="team",
|
|
||||||
)
|
|
||||||
|
|
||||||
elif custom_db_client is not None:
|
elif custom_db_client is not None:
|
||||||
# Fetch the existing cost for the given token
|
# Fetch the existing cost for the given token
|
||||||
existing_spend_obj = await custom_db_client.get_data(
|
existing_spend_obj = await custom_db_client.get_data(
|
||||||
|
@ -1354,17 +1303,155 @@ async def update_database(
|
||||||
verbose_proxy_logger.info(
|
verbose_proxy_logger.info(
|
||||||
f"Update Team DB failed to execute - {str(e)}"
|
f"Update Team DB failed to execute - {str(e)}"
|
||||||
)
|
)
|
||||||
|
raise e
|
||||||
|
|
||||||
asyncio.create_task(_update_user_db())
|
asyncio.create_task(_update_user_db())
|
||||||
asyncio.create_task(_update_key_db())
|
asyncio.create_task(_update_key_db())
|
||||||
asyncio.create_task(_update_team_db())
|
asyncio.create_task(_update_team_db())
|
||||||
asyncio.create_task(_insert_spend_log_to_db())
|
asyncio.create_task(_insert_spend_log_to_db())
|
||||||
|
|
||||||
verbose_proxy_logger.info("Successfully updated spend in all 3 tables")
|
verbose_proxy_logger.info("Successfully updated spend in all 3 tables")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
verbose_proxy_logger.debug(
|
verbose_proxy_logger.debug(
|
||||||
f"Error updating Prisma database: {traceback.format_exc()}"
|
f"Error updating Prisma database: {traceback.format_exc()}"
|
||||||
)
|
)
|
||||||
pass
|
|
||||||
|
|
||||||
|
async def update_cache(
|
||||||
|
token,
|
||||||
|
user_id,
|
||||||
|
response_cost,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Use this to update the cache with new user spend.
|
||||||
|
|
||||||
|
Put any alerting logic in here.
|
||||||
|
"""
|
||||||
|
|
||||||
|
### UPDATE KEY SPEND ###
|
||||||
|
async def _update_key_cache():
|
||||||
|
# Fetch the existing cost for the given token
|
||||||
|
if isinstance(token, str) and token.startswith("sk-"):
|
||||||
|
hashed_token = hash_token(token=token)
|
||||||
|
else:
|
||||||
|
hashed_token = token
|
||||||
|
existing_spend_obj = await user_api_key_cache.async_get_cache(key=hashed_token)
|
||||||
|
verbose_proxy_logger.debug(
|
||||||
|
f"_update_key_db: existing spend: {existing_spend_obj}"
|
||||||
|
)
|
||||||
|
if existing_spend_obj is None:
|
||||||
|
existing_spend = 0
|
||||||
|
else:
|
||||||
|
existing_spend = existing_spend_obj.spend
|
||||||
|
# Calculate the new cost by adding the existing cost and response_cost
|
||||||
|
new_spend = existing_spend + response_cost
|
||||||
|
|
||||||
|
## CHECK IF USER PROJECTED SPEND > SOFT LIMIT
|
||||||
|
soft_budget_cooldown = existing_spend_obj.soft_budget_cooldown
|
||||||
|
if (
|
||||||
|
existing_spend_obj.soft_budget_cooldown == False
|
||||||
|
and existing_spend_obj.litellm_budget_table is not None
|
||||||
|
and (
|
||||||
|
_is_projected_spend_over_limit(
|
||||||
|
current_spend=new_spend,
|
||||||
|
soft_budget_limit=existing_spend_obj.litellm_budget_table.soft_budget,
|
||||||
|
)
|
||||||
|
== True
|
||||||
|
)
|
||||||
|
):
|
||||||
|
key_alias = existing_spend_obj.key_alias
|
||||||
|
projected_spend, projected_exceeded_date = _get_projected_spend_over_limit(
|
||||||
|
current_spend=new_spend,
|
||||||
|
soft_budget_limit=existing_spend_obj.litellm_budget_table.soft_budget,
|
||||||
|
)
|
||||||
|
soft_limit = existing_spend_obj.litellm_budget_table.soft_budget
|
||||||
|
user_info = {
|
||||||
|
"key_alias": key_alias,
|
||||||
|
"projected_spend": projected_spend,
|
||||||
|
"projected_exceeded_date": projected_exceeded_date,
|
||||||
|
}
|
||||||
|
# alert user
|
||||||
|
asyncio.create_task(
|
||||||
|
proxy_logging_obj.budget_alerts(
|
||||||
|
type="projected_limit_exceeded",
|
||||||
|
user_info=user_info,
|
||||||
|
user_max_budget=soft_limit,
|
||||||
|
user_current_spend=new_spend,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# set cooldown on alert
|
||||||
|
soft_budget_cooldown = True
|
||||||
|
|
||||||
|
if (
|
||||||
|
existing_spend_obj is not None
|
||||||
|
and getattr(existing_spend_obj, "team_spend", None) is not None
|
||||||
|
):
|
||||||
|
existing_team_spend = existing_spend_obj.team_spend
|
||||||
|
# Calculate the new cost by adding the existing cost and response_cost
|
||||||
|
existing_spend_obj.team_spend = existing_team_spend + response_cost
|
||||||
|
|
||||||
|
# Update the cost column for the given token
|
||||||
|
existing_spend_obj.spend = new_spend
|
||||||
|
user_api_key_cache.set_cache(key=hashed_token, value=existing_spend_obj)
|
||||||
|
|
||||||
|
async def _update_user_cache():
|
||||||
|
## UPDATE CACHE FOR USER ID + GLOBAL PROXY
|
||||||
|
end_user_id = None
|
||||||
|
if isinstance(token, str) and token.startswith("sk-"):
|
||||||
|
hashed_token = hash_token(token=token)
|
||||||
|
else:
|
||||||
|
hashed_token = token
|
||||||
|
existing_token_obj = await user_api_key_cache.async_get_cache(key=hashed_token)
|
||||||
|
existing_user_obj = await user_api_key_cache.async_get_cache(key=user_id)
|
||||||
|
if existing_token_obj.user_id != user_id: # an end-user id was passed in
|
||||||
|
end_user_id = user_id
|
||||||
|
user_ids = [existing_token_obj.user_id, litellm_proxy_budget_name, end_user_id]
|
||||||
|
|
||||||
|
try:
|
||||||
|
for _id in user_ids:
|
||||||
|
# Fetch the existing cost for the given user
|
||||||
|
existing_spend_obj = await user_api_key_cache.async_get_cache(key=_id)
|
||||||
|
if existing_spend_obj is None:
|
||||||
|
# if user does not exist in LiteLLM_UserTable, create a new user
|
||||||
|
existing_spend = 0
|
||||||
|
max_user_budget = None
|
||||||
|
if litellm.max_user_budget is not None:
|
||||||
|
max_user_budget = litellm.max_user_budget
|
||||||
|
existing_spend_obj = LiteLLM_UserTable(
|
||||||
|
user_id=_id,
|
||||||
|
spend=0,
|
||||||
|
max_budget=max_user_budget,
|
||||||
|
user_email=None,
|
||||||
|
)
|
||||||
|
verbose_proxy_logger.debug(
|
||||||
|
f"_update_user_db: existing spend: {existing_spend_obj}"
|
||||||
|
)
|
||||||
|
if existing_spend_obj is None:
|
||||||
|
existing_spend = 0
|
||||||
|
else:
|
||||||
|
if isinstance(existing_spend_obj, dict):
|
||||||
|
existing_spend = existing_spend_obj["spend"]
|
||||||
|
else:
|
||||||
|
existing_spend = existing_spend_obj.spend
|
||||||
|
# Calculate the new cost by adding the existing cost and response_cost
|
||||||
|
new_spend = existing_spend + response_cost
|
||||||
|
|
||||||
|
# Update the cost column for the given user
|
||||||
|
if isinstance(existing_spend_obj, dict):
|
||||||
|
existing_spend_obj["spend"] = new_spend
|
||||||
|
user_api_key_cache.set_cache(key=_id, value=existing_spend_obj)
|
||||||
|
else:
|
||||||
|
existing_spend_obj.spend = new_spend
|
||||||
|
user_api_key_cache.set_cache(
|
||||||
|
key=_id, value=existing_spend_obj.json()
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
verbose_proxy_logger.debug(
|
||||||
|
f"An error occurred updating user cache: {str(e)}\n\n{traceback.format_exc()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
asyncio.create_task(_update_key_cache())
|
||||||
|
asyncio.create_task(_update_user_cache())
|
||||||
|
|
||||||
|
|
||||||
def run_ollama_serve():
|
def run_ollama_serve():
|
||||||
|
@ -7242,6 +7329,55 @@ async def get_routes():
|
||||||
return {"routes": routes}
|
return {"routes": routes}
|
||||||
|
|
||||||
|
|
||||||
|
## TEST ENDPOINT
|
||||||
|
# @router.post("/update_database", dependencies=[Depends(user_api_key_auth)])
|
||||||
|
# async def update_database_endpoint(
|
||||||
|
# user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||||
|
# ):
|
||||||
|
# """
|
||||||
|
# Test endpoint. DO NOT MERGE IN PROD.
|
||||||
|
|
||||||
|
# Used for isolating and testing our prisma db update logic in high-traffic.
|
||||||
|
# """
|
||||||
|
# try:
|
||||||
|
# request_id = f"chatcmpl-e41836bb-bb8b-4df2-8e70-8f3e160155ac{time.time()}"
|
||||||
|
# resp = litellm.ModelResponse(
|
||||||
|
# id=request_id,
|
||||||
|
# choices=[
|
||||||
|
# litellm.Choices(
|
||||||
|
# finish_reason=None,
|
||||||
|
# index=0,
|
||||||
|
# message=litellm.Message(
|
||||||
|
# content=" Sure! Here is a short poem about the sky:\n\nA canvas of blue, a",
|
||||||
|
# role="assistant",
|
||||||
|
# ),
|
||||||
|
# )
|
||||||
|
# ],
|
||||||
|
# model="gpt-35-turbo", # azure always has model written like this
|
||||||
|
# usage=litellm.Usage(
|
||||||
|
# prompt_tokens=210, completion_tokens=200, total_tokens=410
|
||||||
|
# ),
|
||||||
|
# )
|
||||||
|
# await _PROXY_track_cost_callback(
|
||||||
|
# kwargs={
|
||||||
|
# "model": "chatgpt-v-2",
|
||||||
|
# "stream": False,
|
||||||
|
# "litellm_params": {
|
||||||
|
# "metadata": {
|
||||||
|
# "user_api_key": user_api_key_dict.token,
|
||||||
|
# "user_api_key_user_id": user_api_key_dict.user_id,
|
||||||
|
# }
|
||||||
|
# },
|
||||||
|
# "response_cost": 0.00002,
|
||||||
|
# },
|
||||||
|
# completion_response=resp,
|
||||||
|
# start_time=datetime.now(),
|
||||||
|
# end_time=datetime.now(),
|
||||||
|
# )
|
||||||
|
# except Exception as e:
|
||||||
|
# raise e
|
||||||
|
|
||||||
|
|
||||||
def _has_user_setup_sso():
|
def _has_user_setup_sso():
|
||||||
"""
|
"""
|
||||||
Check if the user has set up single sign-on (SSO) by verifying the presence of Microsoft client ID, Google client ID, and UI username environment variables.
|
Check if the user has set up single sign-on (SSO) by verifying the presence of Microsoft client ID, Google client ID, and UI username environment variables.
|
||||||
|
|
|
@ -1596,7 +1596,6 @@ async def _cache_user_row(
|
||||||
Check if a user_id exists in cache,
|
Check if a user_id exists in cache,
|
||||||
if not retrieve it.
|
if not retrieve it.
|
||||||
"""
|
"""
|
||||||
print_verbose(f"Prisma: _cache_user_row, user_id: {user_id}")
|
|
||||||
cache_key = f"{user_id}_user_api_key_user_id"
|
cache_key = f"{user_id}_user_api_key_user_id"
|
||||||
response = cache.get_cache(key=cache_key)
|
response = cache.get_cache(key=cache_key)
|
||||||
if response is None: # Cache miss
|
if response is None: # Cache miss
|
||||||
|
|
|
@ -318,7 +318,7 @@ def test_call_with_user_over_budget(prisma_client):
|
||||||
|
|
||||||
|
|
||||||
def test_call_with_end_user_over_budget(prisma_client):
|
def test_call_with_end_user_over_budget(prisma_client):
|
||||||
# Test if a user passed to /chat/completions is tracked & fails whe they cross their budget
|
# Test if a user passed to /chat/completions is tracked & fails when they cross their budget
|
||||||
# we only check this when litellm.max_user_budget is set
|
# we only check this when litellm.max_user_budget is set
|
||||||
import random
|
import random
|
||||||
|
|
||||||
|
@ -339,6 +339,8 @@ def test_call_with_end_user_over_budget(prisma_client):
|
||||||
request = Request(scope={"type": "http"})
|
request = Request(scope={"type": "http"})
|
||||||
request._url = URL(url="/chat/completions")
|
request._url = URL(url="/chat/completions")
|
||||||
|
|
||||||
|
result = await user_api_key_auth(request=request, api_key=bearer_token)
|
||||||
|
|
||||||
async def return_body():
|
async def return_body():
|
||||||
return_string = f'{{"model": "gemini-pro-vision", "user": "{user}"}}'
|
return_string = f'{{"model": "gemini-pro-vision", "user": "{user}"}}'
|
||||||
# return string as bytes
|
# return string as bytes
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue