forked from phoenix/litellm-mirror
fix(proxy_server.py): move to using UPDATE + SET for track_cost_callback
This commit is contained in:
parent
d82be720d2
commit
cf090acb25
1 changed files with 197 additions and 178 deletions
|
@ -387,6 +387,7 @@ async def user_api_key_auth(
|
|||
user_api_key_cache.set_cache(
|
||||
key=hash_token(master_key), value=_user_api_key_obj
|
||||
)
|
||||
|
||||
return _user_api_key_obj
|
||||
|
||||
if isinstance(
|
||||
|
@ -1007,6 +1008,8 @@ async def _PROXY_track_cost_callback(
|
|||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
|
||||
await update_cache(token=user_api_key, response_cost=response_cost)
|
||||
else:
|
||||
raise Exception("User API key missing from custom callback.")
|
||||
else:
|
||||
|
@ -1049,10 +1052,6 @@ async def update_database(
|
|||
f"Enters prisma db call, response_cost: {response_cost}, token: {token}; user_id: {user_id}; team_id: {team_id}"
|
||||
)
|
||||
|
||||
### [TODO] STEP 1: GET KEY + USER SPEND ### (key, user)
|
||||
|
||||
### [TODO] STEP 2: UPDATE SPEND ### (key, user, spend logs)
|
||||
|
||||
### UPDATE USER SPEND ###
|
||||
async def _update_user_db():
|
||||
"""
|
||||
|
@ -1062,72 +1061,73 @@ async def update_database(
|
|||
user_ids = [user_id, litellm_proxy_budget_name]
|
||||
data_list = []
|
||||
try:
|
||||
for id in user_ids:
|
||||
if id is None:
|
||||
continue
|
||||
if prisma_client is not None:
|
||||
existing_spend_obj = await prisma_client.get_data(user_id=id)
|
||||
elif (
|
||||
custom_db_client is not None and id != litellm_proxy_budget_name
|
||||
):
|
||||
existing_spend_obj = await custom_db_client.get_data(
|
||||
key=id, table_name="user"
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
f"Updating existing_spend_obj: {existing_spend_obj}"
|
||||
if prisma_client is not None: # update
|
||||
user_ids = [user_id, litellm_proxy_budget_name]
|
||||
await prisma_client.db.litellm_usertable.update(
|
||||
where={"user_id": {"in": user_ids}},
|
||||
data={"spend": {"increment": response_cost}},
|
||||
)
|
||||
if existing_spend_obj is None:
|
||||
# if user does not exist in LiteLLM_UserTable, create a new user
|
||||
existing_spend = 0
|
||||
max_user_budget = None
|
||||
if litellm.max_user_budget is not None:
|
||||
max_user_budget = litellm.max_user_budget
|
||||
existing_spend_obj = LiteLLM_UserTable(
|
||||
user_id=id,
|
||||
spend=0,
|
||||
max_budget=max_user_budget,
|
||||
user_email=None,
|
||||
elif custom_db_client is not None:
|
||||
for id in user_ids:
|
||||
if id is None:
|
||||
continue
|
||||
if (
|
||||
custom_db_client is not None
|
||||
and id != litellm_proxy_budget_name
|
||||
):
|
||||
existing_spend_obj = await custom_db_client.get_data(
|
||||
key=id, table_name="user"
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
f"Updating existing_spend_obj: {existing_spend_obj}"
|
||||
)
|
||||
else:
|
||||
existing_spend = existing_spend_obj.spend
|
||||
|
||||
# Calculate the new cost by adding the existing cost and response_cost
|
||||
existing_spend_obj.spend = existing_spend + response_cost
|
||||
|
||||
# track cost per model, for the given user
|
||||
spend_per_model = existing_spend_obj.model_spend or {}
|
||||
current_model = kwargs.get("model")
|
||||
|
||||
if current_model is not None and spend_per_model is not None:
|
||||
if spend_per_model.get(current_model) is None:
|
||||
spend_per_model[current_model] = response_cost
|
||||
if existing_spend_obj is None:
|
||||
# if user does not exist in LiteLLM_UserTable, create a new user
|
||||
existing_spend = 0
|
||||
max_user_budget = None
|
||||
if litellm.max_user_budget is not None:
|
||||
max_user_budget = litellm.max_user_budget
|
||||
existing_spend_obj = LiteLLM_UserTable(
|
||||
user_id=id,
|
||||
spend=0,
|
||||
max_budget=max_user_budget,
|
||||
user_email=None,
|
||||
)
|
||||
else:
|
||||
spend_per_model[current_model] += response_cost
|
||||
existing_spend_obj.model_spend = spend_per_model
|
||||
existing_spend = existing_spend_obj.spend
|
||||
|
||||
valid_token = user_api_key_cache.get_cache(key=id)
|
||||
if valid_token is not None and isinstance(valid_token, dict):
|
||||
user_api_key_cache.set_cache(
|
||||
key=id, value=existing_spend_obj.json()
|
||||
# Calculate the new cost by adding the existing cost and response_cost
|
||||
existing_spend_obj.spend = existing_spend + response_cost
|
||||
|
||||
# track cost per model, for the given user
|
||||
spend_per_model = existing_spend_obj.model_spend or {}
|
||||
current_model = kwargs.get("model")
|
||||
|
||||
if current_model is not None and spend_per_model is not None:
|
||||
if spend_per_model.get(current_model) is None:
|
||||
spend_per_model[current_model] = response_cost
|
||||
else:
|
||||
spend_per_model[current_model] += response_cost
|
||||
existing_spend_obj.model_spend = spend_per_model
|
||||
|
||||
valid_token = user_api_key_cache.get_cache(key=id)
|
||||
if valid_token is not None and isinstance(valid_token, dict):
|
||||
user_api_key_cache.set_cache(
|
||||
key=id, value=existing_spend_obj.json()
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f"user - new cost: {existing_spend_obj.spend}, user_id: {id}"
|
||||
)
|
||||
data_list.append(existing_spend_obj)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f"user - new cost: {existing_spend_obj.spend}, user_id: {id}"
|
||||
)
|
||||
data_list.append(existing_spend_obj)
|
||||
|
||||
if custom_db_client is not None and user_id is not None:
|
||||
new_spend = data_list[0].spend
|
||||
await custom_db_client.update_data(
|
||||
key=user_id, value={"spend": new_spend}, table_name="user"
|
||||
)
|
||||
# Update the cost column for the given user id
|
||||
if prisma_client is not None:
|
||||
await prisma_client.update_data(
|
||||
data_list=data_list,
|
||||
query_type="update_many",
|
||||
table_name="user",
|
||||
)
|
||||
if custom_db_client is not None and user_id is not None:
|
||||
new_spend = data_list[0].spend
|
||||
await custom_db_client.update_data(
|
||||
key=user_id,
|
||||
value={"spend": new_spend},
|
||||
table_name="user",
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.info(
|
||||
f"Update User DB call failed to execute {str(e)}"
|
||||
|
@ -1140,82 +1140,10 @@ async def update_database(
|
|||
f"adding spend to key db. Response cost: {response_cost}. Token: {token}."
|
||||
)
|
||||
if prisma_client is not None:
|
||||
# Fetch the existing cost for the given token
|
||||
existing_spend_obj = await prisma_client.get_data(token=token)
|
||||
verbose_proxy_logger.debug(
|
||||
f"_update_key_db: existing spend: {existing_spend_obj}"
|
||||
await prisma_client.db.litellm_verificationtoken.update(
|
||||
where={"token": token},
|
||||
data={"spend": {"increment": response_cost}},
|
||||
)
|
||||
if existing_spend_obj is None:
|
||||
existing_spend = 0
|
||||
else:
|
||||
existing_spend = existing_spend_obj.spend
|
||||
# Calculate the new cost by adding the existing cost and response_cost
|
||||
new_spend = existing_spend + response_cost
|
||||
|
||||
## CHECK IF USER PROJECTED SPEND > SOFT LIMIT
|
||||
soft_budget_cooldown = existing_spend_obj.soft_budget_cooldown
|
||||
if (
|
||||
existing_spend_obj.soft_budget_cooldown == False
|
||||
and existing_spend_obj.litellm_budget_table is not None
|
||||
and (
|
||||
_is_projected_spend_over_limit(
|
||||
current_spend=new_spend,
|
||||
soft_budget_limit=existing_spend_obj.litellm_budget_table.soft_budget,
|
||||
)
|
||||
== True
|
||||
)
|
||||
):
|
||||
key_alias = existing_spend_obj.key_alias
|
||||
projected_spend, projected_exceeded_date = (
|
||||
_get_projected_spend_over_limit(
|
||||
current_spend=new_spend,
|
||||
soft_budget_limit=existing_spend_obj.litellm_budget_table.soft_budget,
|
||||
)
|
||||
)
|
||||
soft_limit = existing_spend_obj.litellm_budget_table.soft_budget
|
||||
user_info = {
|
||||
"key_alias": key_alias,
|
||||
"projected_spend": projected_spend,
|
||||
"projected_exceeded_date": projected_exceeded_date,
|
||||
}
|
||||
# alert user
|
||||
asyncio.create_task(
|
||||
proxy_logging_obj.budget_alerts(
|
||||
type="projected_limit_exceeded",
|
||||
user_info=user_info,
|
||||
user_max_budget=soft_limit,
|
||||
user_current_spend=new_spend,
|
||||
)
|
||||
)
|
||||
# set cooldown on alert
|
||||
soft_budget_cooldown = True
|
||||
# track cost per model, for the given key
|
||||
spend_per_model = existing_spend_obj.model_spend or {}
|
||||
current_model = kwargs.get("model")
|
||||
if current_model is not None and spend_per_model is not None:
|
||||
if spend_per_model.get(current_model) is None:
|
||||
spend_per_model[current_model] = response_cost
|
||||
else:
|
||||
spend_per_model[current_model] += response_cost
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f"new cost: {new_spend}, new spend per model: {spend_per_model}"
|
||||
)
|
||||
# Update the cost column for the given token
|
||||
await prisma_client.update_data(
|
||||
token=token,
|
||||
data={
|
||||
"spend": new_spend,
|
||||
"model_spend": spend_per_model,
|
||||
"soft_budget_cooldown": soft_budget_cooldown,
|
||||
},
|
||||
)
|
||||
|
||||
valid_token = user_api_key_cache.get_cache(key=token)
|
||||
if valid_token is not None:
|
||||
valid_token.spend = new_spend
|
||||
valid_token.model_spend = spend_per_model
|
||||
user_api_key_cache.set_cache(key=token, value=valid_token)
|
||||
elif custom_db_client is not None:
|
||||
# Fetch the existing cost for the given token
|
||||
existing_spend_obj = await custom_db_client.get_data(
|
||||
|
@ -1246,6 +1174,7 @@ async def update_database(
|
|||
verbose_proxy_logger.info(
|
||||
f"Update Key DB Call failed to execute - {str(e)}"
|
||||
)
|
||||
raise e
|
||||
|
||||
### UPDATE SPEND LOGS ###
|
||||
async def _insert_spend_log_to_db():
|
||||
|
@ -1269,6 +1198,7 @@ async def update_database(
|
|||
verbose_proxy_logger.info(
|
||||
f"Update Spend Logs DB failed to execute - {str(e)}"
|
||||
)
|
||||
raise e
|
||||
|
||||
### UPDATE KEY SPEND ###
|
||||
async def _update_team_db():
|
||||
|
@ -1282,41 +1212,10 @@ async def update_database(
|
|||
)
|
||||
return
|
||||
if prisma_client is not None:
|
||||
# Fetch the existing cost for the given token
|
||||
existing_spend_obj = await prisma_client.get_data(
|
||||
team_id=team_id, table_name="team"
|
||||
await prisma_client.db.litellm_teamtable.update(
|
||||
where={"team_id": team_id},
|
||||
data={"spend": {"increment": response_cost}},
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
f"_update_team_db: existing spend: {existing_spend_obj}"
|
||||
)
|
||||
if existing_spend_obj is None:
|
||||
# the team does not exist in the db - return
|
||||
verbose_proxy_logger.debug(
|
||||
"team_id does not exist in db, not tracking spend for team"
|
||||
)
|
||||
return
|
||||
else:
|
||||
existing_spend = existing_spend_obj.spend
|
||||
# Calculate the new cost by adding the existing cost and response_cost
|
||||
new_spend = existing_spend + response_cost
|
||||
spend_per_model = getattr(existing_spend_obj, "model_spend", {})
|
||||
# track cost per model, for the given team
|
||||
spend_per_model = existing_spend_obj.model_spend or {}
|
||||
current_model = kwargs.get("model")
|
||||
if current_model is not None and spend_per_model is not None:
|
||||
if spend_per_model.get(current_model) is None:
|
||||
spend_per_model[current_model] = response_cost
|
||||
else:
|
||||
spend_per_model[current_model] += response_cost
|
||||
|
||||
verbose_proxy_logger.debug(f"new cost: {new_spend}")
|
||||
# Update the cost column for the given token
|
||||
await prisma_client.update_data(
|
||||
team_id=team_id,
|
||||
data={"spend": new_spend, "model_spend": spend_per_model},
|
||||
table_name="team",
|
||||
)
|
||||
|
||||
elif custom_db_client is not None:
|
||||
# Fetch the existing cost for the given token
|
||||
existing_spend_obj = await custom_db_client.get_data(
|
||||
|
@ -1346,17 +1245,88 @@ async def update_database(
|
|||
verbose_proxy_logger.info(
|
||||
f"Update Team DB failed to execute - {str(e)}"
|
||||
)
|
||||
raise e
|
||||
|
||||
asyncio.create_task(_update_user_db())
|
||||
asyncio.create_task(_update_key_db())
|
||||
asyncio.create_task(_update_team_db())
|
||||
asyncio.create_task(_insert_spend_log_to_db())
|
||||
tasks = []
|
||||
tasks.append(_update_user_db())
|
||||
tasks.append(_update_key_db())
|
||||
tasks.append(_update_team_db())
|
||||
tasks.append(_insert_spend_log_to_db())
|
||||
|
||||
await asyncio.gather(*tasks)
|
||||
verbose_proxy_logger.info("Successfully updated spend in all 3 tables")
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.debug(
|
||||
f"Error updating Prisma database: {traceback.format_exc()}"
|
||||
)
|
||||
pass
|
||||
|
||||
|
||||
async def update_cache(
|
||||
token,
|
||||
response_cost,
|
||||
):
|
||||
"""
|
||||
Use this to update the cache with new user spend.
|
||||
|
||||
Put any alerting logic in here.
|
||||
"""
|
||||
### UPDATE KEY SPEND ###
|
||||
# Fetch the existing cost for the given token
|
||||
existing_spend_obj = await user_api_key_cache.async_get_cache(key=token)
|
||||
verbose_proxy_logger.debug(f"_update_key_db: existing spend: {existing_spend_obj}")
|
||||
if existing_spend_obj is None:
|
||||
existing_spend = 0
|
||||
else:
|
||||
existing_spend = existing_spend_obj.spend
|
||||
# Calculate the new cost by adding the existing cost and response_cost
|
||||
new_spend = existing_spend + response_cost
|
||||
|
||||
## CHECK IF USER PROJECTED SPEND > SOFT LIMIT
|
||||
soft_budget_cooldown = existing_spend_obj.soft_budget_cooldown
|
||||
if (
|
||||
existing_spend_obj.soft_budget_cooldown == False
|
||||
and existing_spend_obj.litellm_budget_table is not None
|
||||
and (
|
||||
_is_projected_spend_over_limit(
|
||||
current_spend=new_spend,
|
||||
soft_budget_limit=existing_spend_obj.litellm_budget_table.soft_budget,
|
||||
)
|
||||
== True
|
||||
)
|
||||
):
|
||||
key_alias = existing_spend_obj.key_alias
|
||||
projected_spend, projected_exceeded_date = _get_projected_spend_over_limit(
|
||||
current_spend=new_spend,
|
||||
soft_budget_limit=existing_spend_obj.litellm_budget_table.soft_budget,
|
||||
)
|
||||
soft_limit = existing_spend_obj.litellm_budget_table.soft_budget
|
||||
user_info = {
|
||||
"key_alias": key_alias,
|
||||
"projected_spend": projected_spend,
|
||||
"projected_exceeded_date": projected_exceeded_date,
|
||||
}
|
||||
# alert user
|
||||
asyncio.create_task(
|
||||
proxy_logging_obj.budget_alerts(
|
||||
type="projected_limit_exceeded",
|
||||
user_info=user_info,
|
||||
user_max_budget=soft_limit,
|
||||
user_current_spend=new_spend,
|
||||
)
|
||||
)
|
||||
# set cooldown on alert
|
||||
soft_budget_cooldown = True
|
||||
|
||||
if existing_spend_obj is None:
|
||||
existing_team_spend = 0
|
||||
else:
|
||||
existing_team_spend = existing_spend_obj.team_spend
|
||||
# Calculate the new cost by adding the existing cost and response_cost
|
||||
existing_spend_obj.team_spend = existing_team_spend + response_cost
|
||||
|
||||
# Update the cost column for the given token
|
||||
existing_spend_obj.spend = new_spend
|
||||
user_api_key_cache.set_cache(key=token, value=existing_spend_obj)
|
||||
|
||||
|
||||
def run_ollama_serve():
|
||||
|
@ -7238,6 +7208,55 @@ async def get_routes():
|
|||
return {"routes": routes}
|
||||
|
||||
|
||||
## TEST ENDPOINT
|
||||
# @router.post("/update_database", dependencies=[Depends(user_api_key_auth)])
|
||||
# async def update_database_endpoint(
|
||||
# user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
# ):
|
||||
# """
|
||||
# Test endpoint. DO NOT MERGE IN PROD.
|
||||
|
||||
# Used for isolating and testing our prisma db update logic in high-traffic.
|
||||
# """
|
||||
# try:
|
||||
# request_id = f"chatcmpl-e41836bb-bb8b-4df2-8e70-8f3e160155ac{time.time()}"
|
||||
# resp = litellm.ModelResponse(
|
||||
# id=request_id,
|
||||
# choices=[
|
||||
# litellm.Choices(
|
||||
# finish_reason=None,
|
||||
# index=0,
|
||||
# message=litellm.Message(
|
||||
# content=" Sure! Here is a short poem about the sky:\n\nA canvas of blue, a",
|
||||
# role="assistant",
|
||||
# ),
|
||||
# )
|
||||
# ],
|
||||
# model="gpt-35-turbo", # azure always has model written like this
|
||||
# usage=litellm.Usage(
|
||||
# prompt_tokens=210, completion_tokens=200, total_tokens=410
|
||||
# ),
|
||||
# )
|
||||
# await _PROXY_track_cost_callback(
|
||||
# kwargs={
|
||||
# "model": "chatgpt-v-2",
|
||||
# "stream": False,
|
||||
# "litellm_params": {
|
||||
# "metadata": {
|
||||
# "user_api_key": user_api_key_dict.token,
|
||||
# "user_api_key_user_id": user_api_key_dict.user_id,
|
||||
# }
|
||||
# },
|
||||
# "response_cost": 0.00002,
|
||||
# },
|
||||
# completion_response=resp,
|
||||
# start_time=datetime.now(),
|
||||
# end_time=datetime.now(),
|
||||
# )
|
||||
# except Exception as e:
|
||||
# raise e
|
||||
|
||||
|
||||
def _has_user_setup_sso():
|
||||
"""
|
||||
Check if the user has set up single sign-on (SSO) by verifying the presence of Microsoft client ID, Google client ID, and UI username environment variables.
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue