forked from phoenix/litellm-mirror
Merge pull request #2561 from BerriAI/litellm_batch_writing_db
fix(proxy/utils.py): move to batch writing db updates
This commit is contained in:
commit
c4dbd0407e
9 changed files with 352 additions and 101 deletions
|
@ -96,6 +96,8 @@ from litellm.proxy.utils import (
|
|||
_is_user_proxy_admin,
|
||||
_is_projected_spend_over_limit,
|
||||
_get_projected_spend_over_limit,
|
||||
update_spend,
|
||||
monitor_spend_list,
|
||||
)
|
||||
from litellm.proxy.secret_managers.google_kms import load_google_kms
|
||||
from litellm.proxy.secret_managers.aws_secret_manager import load_aws_secret_manager
|
||||
|
@ -277,6 +279,7 @@ litellm_proxy_admin_name = "default_user_id"
|
|||
ui_access_mode: Literal["admin", "all"] = "all"
|
||||
proxy_budget_rescheduler_min_time = 597
|
||||
proxy_budget_rescheduler_max_time = 605
|
||||
proxy_batch_write_at = 60 # in seconds
|
||||
litellm_master_key_hash = None
|
||||
### INITIALIZE GLOBAL LOGGING OBJECT ###
|
||||
proxy_logging_obj = ProxyLogging(user_api_key_cache=user_api_key_cache)
|
||||
|
@ -995,10 +998,8 @@ async def _PROXY_track_cost_callback(
|
|||
)
|
||||
litellm_params = kwargs.get("litellm_params", {}) or {}
|
||||
proxy_server_request = litellm_params.get("proxy_server_request") or {}
|
||||
user_id = proxy_server_request.get("body", {}).get("user", None)
|
||||
user_id = user_id or kwargs["litellm_params"]["metadata"].get(
|
||||
"user_api_key_user_id", None
|
||||
)
|
||||
end_user_id = proxy_server_request.get("body", {}).get("user", None)
|
||||
user_id = kwargs["litellm_params"]["metadata"].get("user_api_key_user_id", None)
|
||||
team_id = kwargs["litellm_params"]["metadata"].get("user_api_key_team_id", None)
|
||||
if kwargs.get("response_cost", None) is not None:
|
||||
response_cost = kwargs["response_cost"]
|
||||
|
@ -1012,9 +1013,6 @@ async def _PROXY_track_cost_callback(
|
|||
f"Cache Hit: response_cost {response_cost}, for user_id {user_id}"
|
||||
)
|
||||
|
||||
verbose_proxy_logger.info(
|
||||
f"response_cost {response_cost}, for user_id {user_id}"
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
f"user_api_key {user_api_key}, prisma_client: {prisma_client}, custom_db_client: {custom_db_client}"
|
||||
)
|
||||
|
@ -1024,6 +1022,7 @@ async def _PROXY_track_cost_callback(
|
|||
token=user_api_key,
|
||||
response_cost=response_cost,
|
||||
user_id=user_id,
|
||||
end_user_id=end_user_id,
|
||||
team_id=team_id,
|
||||
kwargs=kwargs,
|
||||
completion_response=completion_response,
|
||||
|
@ -1065,6 +1064,7 @@ async def update_database(
|
|||
token,
|
||||
response_cost,
|
||||
user_id=None,
|
||||
end_user_id=None,
|
||||
team_id=None,
|
||||
kwargs=None,
|
||||
completion_response=None,
|
||||
|
@ -1075,6 +1075,10 @@ async def update_database(
|
|||
verbose_proxy_logger.info(
|
||||
f"Enters prisma db call, response_cost: {response_cost}, token: {token}; user_id: {user_id}; team_id: {team_id}"
|
||||
)
|
||||
if isinstance(token, str) and token.startswith("sk-"):
|
||||
hashed_token = hash_token(token=token)
|
||||
else:
|
||||
hashed_token = token
|
||||
|
||||
### UPDATE USER SPEND ###
|
||||
async def _update_user_db():
|
||||
|
@ -1083,11 +1087,6 @@ async def update_database(
|
|||
- Update litellm-proxy-budget row (global proxy spend)
|
||||
"""
|
||||
## if an end-user is passed in, do an upsert - we can't guarantee they already exist in db
|
||||
end_user_id = None
|
||||
if isinstance(token, str) and token.startswith("sk-"):
|
||||
hashed_token = hash_token(token=token)
|
||||
else:
|
||||
hashed_token = token
|
||||
existing_token_obj = await user_api_key_cache.async_get_cache(
|
||||
key=hashed_token
|
||||
)
|
||||
|
@ -1096,54 +1095,25 @@ async def update_database(
|
|||
existing_user_obj = await user_api_key_cache.async_get_cache(key=user_id)
|
||||
if existing_user_obj is not None and isinstance(existing_user_obj, dict):
|
||||
existing_user_obj = LiteLLM_UserTable(**existing_user_obj)
|
||||
if existing_token_obj.user_id != user_id: # an end-user id was passed in
|
||||
end_user_id = user_id
|
||||
user_ids = [existing_token_obj.user_id, litellm_proxy_budget_name]
|
||||
data_list = []
|
||||
try:
|
||||
if prisma_client is not None: # update
|
||||
user_ids = [user_id, litellm_proxy_budget_name]
|
||||
## do a group update for the user-id of the key + global proxy budget
|
||||
await prisma_client.db.litellm_usertable.update_many(
|
||||
where={"user_id": {"in": user_ids}},
|
||||
data={"spend": {"increment": response_cost}},
|
||||
)
|
||||
user_ids = [user_id]
|
||||
if (
|
||||
litellm.max_budget > 0
|
||||
): # track global proxy budget, if user set max budget
|
||||
user_ids.append(litellm_proxy_budget_name)
|
||||
### KEY CHANGE ###
|
||||
for _id in user_ids:
|
||||
prisma_client.user_list_transactons[_id] = (
|
||||
response_cost
|
||||
+ prisma_client.user_list_transactons.get(_id, 0)
|
||||
)
|
||||
if end_user_id is not None:
|
||||
if existing_user_obj is None:
|
||||
# if user does not exist in LiteLLM_UserTable, create a new user
|
||||
existing_spend = 0
|
||||
max_user_budget = None
|
||||
if litellm.max_user_budget is not None:
|
||||
max_user_budget = litellm.max_user_budget
|
||||
existing_user_obj = LiteLLM_UserTable(
|
||||
user_id=end_user_id,
|
||||
spend=0,
|
||||
max_budget=max_user_budget,
|
||||
user_email=None,
|
||||
)
|
||||
|
||||
else:
|
||||
existing_user_obj.spend = (
|
||||
existing_user_obj.spend + response_cost
|
||||
)
|
||||
|
||||
user_object_json = {**existing_user_obj.json(exclude_none=True)}
|
||||
|
||||
user_object_json["model_max_budget"] = json.dumps(
|
||||
user_object_json["model_max_budget"]
|
||||
prisma_client.end_user_list_transactons[end_user_id] = (
|
||||
response_cost
|
||||
+ prisma_client.user_list_transactons.get(end_user_id, 0)
|
||||
)
|
||||
user_object_json["model_spend"] = json.dumps(
|
||||
user_object_json["model_spend"]
|
||||
)
|
||||
|
||||
await prisma_client.db.litellm_usertable.upsert(
|
||||
where={"user_id": end_user_id},
|
||||
data={
|
||||
"create": user_object_json,
|
||||
"update": {"spend": {"increment": response_cost}},
|
||||
},
|
||||
)
|
||||
|
||||
elif custom_db_client is not None:
|
||||
for id in user_ids:
|
||||
if id is None:
|
||||
|
@ -1205,6 +1175,7 @@ async def update_database(
|
|||
value={"spend": new_spend},
|
||||
table_name="user",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.info(
|
||||
"\033[91m"
|
||||
|
@ -1215,12 +1186,12 @@ async def update_database(
|
|||
async def _update_key_db():
|
||||
try:
|
||||
verbose_proxy_logger.debug(
|
||||
f"adding spend to key db. Response cost: {response_cost}. Token: {token}."
|
||||
f"adding spend to key db. Response cost: {response_cost}. Token: {hashed_token}."
|
||||
)
|
||||
if prisma_client is not None:
|
||||
await prisma_client.db.litellm_verificationtoken.update(
|
||||
where={"token": token},
|
||||
data={"spend": {"increment": response_cost}},
|
||||
prisma_client.key_list_transactons[hashed_token] = (
|
||||
response_cost
|
||||
+ prisma_client.key_list_transactons.get(hashed_token, 0)
|
||||
)
|
||||
elif custom_db_client is not None:
|
||||
# Fetch the existing cost for the given token
|
||||
|
@ -1257,7 +1228,6 @@ async def update_database(
|
|||
async def _insert_spend_log_to_db():
|
||||
try:
|
||||
# Helper to generate payload to log
|
||||
verbose_proxy_logger.debug("inserting spend log to db")
|
||||
payload = get_logging_payload(
|
||||
kwargs=kwargs,
|
||||
response_obj=completion_response,
|
||||
|
@ -1268,16 +1238,13 @@ async def update_database(
|
|||
payload["spend"] = response_cost
|
||||
if prisma_client is not None:
|
||||
await prisma_client.insert_data(data=payload, table_name="spend")
|
||||
elif custom_db_client is not None:
|
||||
await custom_db_client.insert_data(payload, table_name="spend")
|
||||
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.info(
|
||||
verbose_proxy_logger.debug(
|
||||
f"Update Spend Logs DB failed to execute - {str(e)}\n{traceback.format_exc()}"
|
||||
)
|
||||
raise e
|
||||
|
||||
### UPDATE KEY SPEND ###
|
||||
### UPDATE TEAM SPEND ###
|
||||
async def _update_team_db():
|
||||
try:
|
||||
verbose_proxy_logger.debug(
|
||||
|
@ -1289,9 +1256,9 @@ async def update_database(
|
|||
)
|
||||
return
|
||||
if prisma_client is not None:
|
||||
await prisma_client.db.litellm_teamtable.update(
|
||||
where={"team_id": team_id},
|
||||
data={"spend": {"increment": response_cost}},
|
||||
prisma_client.team_list_transactons[team_id] = (
|
||||
response_cost
|
||||
+ prisma_client.team_list_transactons.get(team_id, 0)
|
||||
)
|
||||
elif custom_db_client is not None:
|
||||
# Fetch the existing cost for the given token
|
||||
|
@ -1327,7 +1294,8 @@ async def update_database(
|
|||
asyncio.create_task(_update_user_db())
|
||||
asyncio.create_task(_update_key_db())
|
||||
asyncio.create_task(_update_team_db())
|
||||
asyncio.create_task(_insert_spend_log_to_db())
|
||||
# asyncio.create_task(_insert_spend_log_to_db())
|
||||
await _insert_spend_log_to_db()
|
||||
|
||||
verbose_proxy_logger.debug("Runs spend update on all tables")
|
||||
except Exception as e:
|
||||
|
@ -1646,7 +1614,7 @@ class ProxyConfig:
|
|||
"""
|
||||
Load config values into proxy global state
|
||||
"""
|
||||
global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, user_custom_key_generate, use_background_health_checks, health_check_interval, use_queue, custom_db_client, proxy_budget_rescheduler_max_time, proxy_budget_rescheduler_min_time, ui_access_mode, litellm_master_key_hash
|
||||
global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, user_custom_key_generate, use_background_health_checks, health_check_interval, use_queue, custom_db_client, proxy_budget_rescheduler_max_time, proxy_budget_rescheduler_min_time, ui_access_mode, litellm_master_key_hash, proxy_batch_write_at
|
||||
|
||||
# Load existing config
|
||||
config = await self.get_config(config_file_path=config_file_path)
|
||||
|
@ -2010,6 +1978,10 @@ class ProxyConfig:
|
|||
proxy_budget_rescheduler_max_time = general_settings.get(
|
||||
"proxy_budget_rescheduler_max_time", proxy_budget_rescheduler_max_time
|
||||
)
|
||||
## BATCH WRITER ##
|
||||
proxy_batch_write_at = general_settings.get(
|
||||
"proxy_batch_write_at", proxy_batch_write_at
|
||||
)
|
||||
### BACKGROUND HEALTH CHECKS ###
|
||||
# Enable background health checks
|
||||
use_background_health_checks = general_settings.get(
|
||||
|
@ -2238,7 +2210,6 @@ async def generate_key_helper_fn(
|
|||
saved_token["expires"] = saved_token["expires"].isoformat()
|
||||
if prisma_client is not None:
|
||||
## CREATE USER (If necessary)
|
||||
verbose_proxy_logger.debug(f"prisma_client: Creating User={user_data}")
|
||||
if query_type == "insert_data":
|
||||
user_row = await prisma_client.insert_data(
|
||||
data=user_data, table_name="user"
|
||||
|
@ -2576,7 +2547,6 @@ async def startup_event():
|
|||
# add master key to db
|
||||
if os.getenv("PROXY_ADMIN_ID", None) is not None:
|
||||
litellm_proxy_admin_name = os.getenv("PROXY_ADMIN_ID")
|
||||
|
||||
asyncio.create_task(
|
||||
generate_key_helper_fn(
|
||||
duration=None,
|
||||
|
@ -2638,9 +2608,18 @@ async def startup_event():
|
|||
interval = random.randint(
|
||||
proxy_budget_rescheduler_min_time, proxy_budget_rescheduler_max_time
|
||||
) # random interval, so multiple workers avoid resetting budget at the same time
|
||||
batch_writing_interval = random.randint(
|
||||
proxy_batch_write_at - 3, proxy_batch_write_at + 3
|
||||
) # random interval, so multiple workers avoid batch writing at the same time
|
||||
scheduler.add_job(
|
||||
reset_budget, "interval", seconds=interval, args=[prisma_client]
|
||||
)
|
||||
scheduler.add_job(
|
||||
update_spend,
|
||||
"interval",
|
||||
seconds=batch_writing_interval,
|
||||
args=[prisma_client],
|
||||
)
|
||||
scheduler.start()
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue