forked from phoenix/litellm-mirror
fix(proxy/utils.py): batch writing updates to db
This commit is contained in:
parent
077b9c6234
commit
8fefe625d9
3 changed files with 141 additions and 217 deletions
|
@ -996,10 +996,8 @@ async def _PROXY_track_cost_callback(
|
|||
)
|
||||
litellm_params = kwargs.get("litellm_params", {}) or {}
|
||||
proxy_server_request = litellm_params.get("proxy_server_request") or {}
|
||||
user_id = proxy_server_request.get("body", {}).get("user", None)
|
||||
user_id = user_id or kwargs["litellm_params"]["metadata"].get(
|
||||
"user_api_key_user_id", None
|
||||
)
|
||||
end_user_id = proxy_server_request.get("body", {}).get("user", None)
|
||||
user_id = kwargs["litellm_params"]["metadata"].get("user_api_key_user_id", None)
|
||||
team_id = kwargs["litellm_params"]["metadata"].get("user_api_key_team_id", None)
|
||||
if kwargs.get("response_cost", None) is not None:
|
||||
response_cost = kwargs["response_cost"]
|
||||
|
@ -1013,9 +1011,6 @@ async def _PROXY_track_cost_callback(
|
|||
f"Cache Hit: response_cost {response_cost}, for user_id {user_id}"
|
||||
)
|
||||
|
||||
verbose_proxy_logger.info(
|
||||
f"response_cost {response_cost}, for user_id {user_id}"
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
f"user_api_key {user_api_key}, prisma_client: {prisma_client}, custom_db_client: {custom_db_client}"
|
||||
)
|
||||
|
@ -1025,6 +1020,7 @@ async def _PROXY_track_cost_callback(
|
|||
token=user_api_key,
|
||||
response_cost=response_cost,
|
||||
user_id=user_id,
|
||||
end_user_id=end_user_id,
|
||||
team_id=team_id,
|
||||
kwargs=kwargs,
|
||||
completion_response=completion_response,
|
||||
|
@ -1066,6 +1062,7 @@ async def update_database(
|
|||
token,
|
||||
response_cost,
|
||||
user_id=None,
|
||||
end_user_id=None,
|
||||
team_id=None,
|
||||
kwargs=None,
|
||||
completion_response=None,
|
||||
|
@ -1076,6 +1073,10 @@ async def update_database(
|
|||
verbose_proxy_logger.info(
|
||||
f"Enters prisma db call, response_cost: {response_cost}, token: {token}; user_id: {user_id}; team_id: {team_id}"
|
||||
)
|
||||
if isinstance(token, str) and token.startswith("sk-"):
|
||||
hashed_token = hash_token(token=token)
|
||||
else:
|
||||
hashed_token = token
|
||||
|
||||
### UPDATE USER SPEND ###
|
||||
async def _update_user_db():
|
||||
|
@ -1084,11 +1085,6 @@ async def update_database(
|
|||
- Update litellm-proxy-budget row (global proxy spend)
|
||||
"""
|
||||
## if an end-user is passed in, do an upsert - we can't guarantee they already exist in db
|
||||
end_user_id = None
|
||||
if isinstance(token, str) and token.startswith("sk-"):
|
||||
hashed_token = hash_token(token=token)
|
||||
else:
|
||||
hashed_token = token
|
||||
existing_token_obj = await user_api_key_cache.async_get_cache(
|
||||
key=hashed_token
|
||||
)
|
||||
|
@ -1097,119 +1093,24 @@ async def update_database(
|
|||
existing_user_obj = await user_api_key_cache.async_get_cache(key=user_id)
|
||||
if existing_user_obj is not None and isinstance(existing_user_obj, dict):
|
||||
existing_user_obj = LiteLLM_UserTable(**existing_user_obj)
|
||||
if existing_token_obj.user_id != user_id: # an end-user id was passed in
|
||||
end_user_id = user_id
|
||||
user_ids = [existing_token_obj.user_id, litellm_proxy_budget_name]
|
||||
data_list = []
|
||||
try:
|
||||
if prisma_client is not None: # update
|
||||
user_ids = [user_id, litellm_proxy_budget_name]
|
||||
user_ids = [user_id]
|
||||
if (
|
||||
litellm.max_budget > 0
|
||||
): # track global proxy budget, if user set max budget
|
||||
user_ids.append(litellm_proxy_budget_name)
|
||||
### KEY CHANGE ###
|
||||
for _id in user_ids:
|
||||
prisma_client.user_list_transactons.append((_id, response_cost))
|
||||
######
|
||||
## do a group update for the user-id of the key + global proxy budget
|
||||
# await prisma_client.db.litellm_usertable.update_many(
|
||||
# where={"user_id": {"in": user_ids}},
|
||||
# data={"spend": {"increment": response_cost}},
|
||||
# )
|
||||
# if end_user_id is not None:
|
||||
# if existing_user_obj is None:
|
||||
# # if user does not exist in LiteLLM_UserTable, create a new user
|
||||
# existing_spend = 0
|
||||
# max_user_budget = None
|
||||
# if litellm.max_user_budget is not None:
|
||||
# max_user_budget = litellm.max_user_budget
|
||||
# existing_user_obj = LiteLLM_UserTable(
|
||||
# user_id=end_user_id,
|
||||
# spend=0,
|
||||
# max_budget=max_user_budget,
|
||||
# user_email=None,
|
||||
# )
|
||||
|
||||
# else:
|
||||
# existing_user_obj.spend = (
|
||||
# existing_user_obj.spend + response_cost
|
||||
# )
|
||||
|
||||
# user_object_json = {**existing_user_obj.json(exclude_none=True)}
|
||||
|
||||
# user_object_json["model_max_budget"] = json.dumps(
|
||||
# user_object_json["model_max_budget"]
|
||||
# )
|
||||
# user_object_json["model_spend"] = json.dumps(
|
||||
# user_object_json["model_spend"]
|
||||
# )
|
||||
|
||||
# await prisma_client.db.litellm_usertable.upsert(
|
||||
# where={"user_id": end_user_id},
|
||||
# data={
|
||||
# "create": user_object_json,
|
||||
# "update": {"spend": {"increment": response_cost}},
|
||||
# },
|
||||
# )
|
||||
|
||||
# elif custom_db_client is not None:
|
||||
# for id in user_ids:
|
||||
# if id is None:
|
||||
# continue
|
||||
# if (
|
||||
# custom_db_client is not None
|
||||
# and id != litellm_proxy_budget_name
|
||||
# ):
|
||||
# existing_spend_obj = await custom_db_client.get_data(
|
||||
# key=id, table_name="user"
|
||||
# )
|
||||
# verbose_proxy_logger.debug(
|
||||
# f"Updating existing_spend_obj: {existing_spend_obj}"
|
||||
# )
|
||||
# if existing_spend_obj is None:
|
||||
# # if user does not exist in LiteLLM_UserTable, create a new user
|
||||
# existing_spend = 0
|
||||
# max_user_budget = None
|
||||
# if litellm.max_user_budget is not None:
|
||||
# max_user_budget = litellm.max_user_budget
|
||||
# existing_spend_obj = LiteLLM_UserTable(
|
||||
# user_id=id,
|
||||
# spend=0,
|
||||
# max_budget=max_user_budget,
|
||||
# user_email=None,
|
||||
# )
|
||||
# else:
|
||||
# existing_spend = existing_spend_obj.spend
|
||||
|
||||
# # Calculate the new cost by adding the existing cost and response_cost
|
||||
# existing_spend_obj.spend = existing_spend + response_cost
|
||||
|
||||
# # track cost per model, for the given user
|
||||
# spend_per_model = existing_spend_obj.model_spend or {}
|
||||
# current_model = kwargs.get("model")
|
||||
|
||||
# if current_model is not None and spend_per_model is not None:
|
||||
# if spend_per_model.get(current_model) is None:
|
||||
# spend_per_model[current_model] = response_cost
|
||||
# else:
|
||||
# spend_per_model[current_model] += response_cost
|
||||
# existing_spend_obj.model_spend = spend_per_model
|
||||
|
||||
# valid_token = user_api_key_cache.get_cache(key=id)
|
||||
# if valid_token is not None and isinstance(valid_token, dict):
|
||||
# user_api_key_cache.set_cache(
|
||||
# key=id, value=existing_spend_obj.json()
|
||||
# )
|
||||
|
||||
# verbose_proxy_logger.debug(
|
||||
# f"user - new cost: {existing_spend_obj.spend}, user_id: {id}"
|
||||
# )
|
||||
# data_list.append(existing_spend_obj)
|
||||
|
||||
# if custom_db_client is not None and user_id is not None:
|
||||
# new_spend = data_list[0].spend
|
||||
# await custom_db_client.update_data(
|
||||
# key=user_id,
|
||||
# value={"spend": new_spend},
|
||||
# table_name="user",
|
||||
# )
|
||||
prisma_client.user_list_transactons[_id] = (
|
||||
response_cost
|
||||
+ prisma_client.user_list_transactons.get(_id, 0)
|
||||
)
|
||||
if end_user_id is not None:
|
||||
prisma_client.end_user_list_transactons[end_user_id] = (
|
||||
response_cost
|
||||
+ prisma_client.user_list_transactons.get(end_user_id, 0)
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.info(
|
||||
"\033[91m"
|
||||
|
@ -1220,38 +1121,13 @@ async def update_database(
|
|||
async def _update_key_db():
|
||||
try:
|
||||
verbose_proxy_logger.debug(
|
||||
f"adding spend to key db. Response cost: {response_cost}. Token: {token}."
|
||||
f"adding spend to key db. Response cost: {response_cost}. Token: {hashed_token}."
|
||||
)
|
||||
if prisma_client is not None:
|
||||
await prisma_client.db.litellm_verificationtoken.update(
|
||||
where={"token": token},
|
||||
data={"spend": {"increment": response_cost}},
|
||||
prisma_client.key_list_transactons[hashed_token] = (
|
||||
response_cost
|
||||
+ prisma_client.key_list_transactons.get(hashed_token, 0)
|
||||
)
|
||||
elif custom_db_client is not None:
|
||||
# Fetch the existing cost for the given token
|
||||
existing_spend_obj = await custom_db_client.get_data(
|
||||
key=token, table_name="key"
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
f"_update_key_db existing spend: {existing_spend_obj}"
|
||||
)
|
||||
if existing_spend_obj is None:
|
||||
existing_spend = 0
|
||||
else:
|
||||
existing_spend = existing_spend_obj.spend
|
||||
# Calculate the new cost by adding the existing cost and response_cost
|
||||
new_spend = existing_spend + response_cost
|
||||
|
||||
verbose_proxy_logger.debug(f"new cost: {new_spend}")
|
||||
# Update the cost column for the given token
|
||||
await custom_db_client.update_data(
|
||||
key=token, value={"spend": new_spend}, table_name="key"
|
||||
)
|
||||
|
||||
valid_token = user_api_key_cache.get_cache(key=token)
|
||||
if valid_token is not None:
|
||||
valid_token.spend = new_spend
|
||||
user_api_key_cache.set_cache(key=token, value=valid_token)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.info(
|
||||
f"Update Key DB Call failed to execute - {str(e)}\n{traceback.format_exc()}"
|
||||
|
@ -1273,16 +1149,13 @@ async def update_database(
|
|||
payload["spend"] = response_cost
|
||||
if prisma_client is not None:
|
||||
await prisma_client.insert_data(data=payload, table_name="spend")
|
||||
elif custom_db_client is not None:
|
||||
await custom_db_client.insert_data(payload, table_name="spend")
|
||||
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.info(
|
||||
f"Update Spend Logs DB failed to execute - {str(e)}\n{traceback.format_exc()}"
|
||||
)
|
||||
raise e
|
||||
|
||||
### UPDATE KEY SPEND ###
|
||||
### UPDATE TEAM SPEND ###
|
||||
async def _update_team_db():
|
||||
try:
|
||||
verbose_proxy_logger.debug(
|
||||
|
@ -1294,46 +1167,19 @@ async def update_database(
|
|||
)
|
||||
return
|
||||
if prisma_client is not None:
|
||||
await prisma_client.db.litellm_teamtable.update(
|
||||
where={"team_id": team_id},
|
||||
data={"spend": {"increment": response_cost}},
|
||||
prisma_client.team_list_transactons[team_id] = (
|
||||
response_cost
|
||||
+ prisma_client.team_list_transactons.get(team_id, 0)
|
||||
)
|
||||
elif custom_db_client is not None:
|
||||
# Fetch the existing cost for the given token
|
||||
existing_spend_obj = await custom_db_client.get_data(
|
||||
key=token, table_name="key"
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
f"_update_key_db existing spend: {existing_spend_obj}"
|
||||
)
|
||||
if existing_spend_obj is None:
|
||||
existing_spend = 0
|
||||
else:
|
||||
existing_spend = existing_spend_obj.spend
|
||||
# Calculate the new cost by adding the existing cost and response_cost
|
||||
new_spend = existing_spend + response_cost
|
||||
|
||||
verbose_proxy_logger.debug(f"new cost: {new_spend}")
|
||||
# Update the cost column for the given token
|
||||
await custom_db_client.update_data(
|
||||
key=token, value={"spend": new_spend}, table_name="key"
|
||||
)
|
||||
|
||||
valid_token = user_api_key_cache.get_cache(key=token)
|
||||
if valid_token is not None:
|
||||
valid_token.spend = new_spend
|
||||
user_api_key_cache.set_cache(key=token, value=valid_token)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.info(
|
||||
f"Update Team DB failed to execute - {str(e)}\n{traceback.format_exc()}"
|
||||
)
|
||||
raise e
|
||||
|
||||
# asyncio.create_task(_update_user_db())
|
||||
await _update_user_db()
|
||||
|
||||
# asyncio.create_task(_update_key_db())
|
||||
# asyncio.create_task(_update_team_db())
|
||||
asyncio.create_task(_update_user_db())
|
||||
asyncio.create_task(_update_key_db())
|
||||
asyncio.create_task(_update_team_db())
|
||||
# asyncio.create_task(_insert_spend_log_to_db())
|
||||
|
||||
verbose_proxy_logger.debug("Runs spend update on all tables")
|
||||
|
@ -2237,7 +2083,6 @@ async def generate_key_helper_fn(
|
|||
saved_token["expires"] = saved_token["expires"].isoformat()
|
||||
if prisma_client is not None:
|
||||
## CREATE USER (If necessary)
|
||||
verbose_proxy_logger.debug(f"prisma_client: Creating User={user_data}")
|
||||
if query_type == "insert_data":
|
||||
user_row = await prisma_client.insert_data(
|
||||
data=user_data, table_name="user"
|
||||
|
@ -2575,7 +2420,6 @@ async def startup_event():
|
|||
# add master key to db
|
||||
if os.getenv("PROXY_ADMIN_ID", None) is not None:
|
||||
litellm_proxy_admin_name = os.getenv("PROXY_ADMIN_ID")
|
||||
|
||||
asyncio.create_task(
|
||||
generate_key_helper_fn(
|
||||
duration=None,
|
||||
|
@ -2640,10 +2484,7 @@ async def startup_event():
|
|||
scheduler.add_job(
|
||||
reset_budget, "interval", seconds=interval, args=[prisma_client]
|
||||
)
|
||||
# scheduler.add_job(
|
||||
# monitor_spend_list, "interval", seconds=10, args=[prisma_client]
|
||||
# )
|
||||
scheduler.add_job(update_spend, "interval", seconds=60, args=[prisma_client])
|
||||
scheduler.add_job(update_spend, "interval", seconds=10, args=[prisma_client])
|
||||
scheduler.start()
|
||||
|
||||
|
||||
|
|
|
@ -7,6 +7,7 @@ from dotenv import load_dotenv
|
|||
|
||||
litellm_client = AsyncOpenAI(base_url="http://0.0.0.0:4000", api_key="sk-1234")
|
||||
|
||||
|
||||
async def litellm_completion():
|
||||
# Your existing code for litellm_completion goes here
|
||||
try:
|
||||
|
@ -18,6 +19,7 @@ async def litellm_completion():
|
|||
"content": f"{text}. Who was alexander the great? {uuid.uuid4()}",
|
||||
}
|
||||
],
|
||||
user="my-new-end-user-1",
|
||||
)
|
||||
return response
|
||||
|
||||
|
@ -29,9 +31,9 @@ async def litellm_completion():
|
|||
|
||||
|
||||
async def main():
|
||||
for i in range(6):
|
||||
for i in range(3):
|
||||
start = time.time()
|
||||
n = 20 # Number of concurrent tasks
|
||||
n = 10 # Number of concurrent tasks
|
||||
tasks = [litellm_completion() for _ in range(n)]
|
||||
|
||||
chat_completions = await asyncio.gather(*tasks)
|
||||
|
|
|
@ -7,6 +7,10 @@ from litellm.proxy._types import (
|
|||
LiteLLM_VerificationToken,
|
||||
LiteLLM_VerificationTokenView,
|
||||
LiteLLM_SpendLogs,
|
||||
LiteLLM_UserTable,
|
||||
LiteLLM_EndUserTable,
|
||||
LiteLLM_TeamTable,
|
||||
Member,
|
||||
)
|
||||
from litellm.caching import DualCache
|
||||
from litellm.proxy.hooks.parallel_request_limiter import (
|
||||
|
@ -472,9 +476,10 @@ def on_backoff(details):
|
|||
|
||||
|
||||
class PrismaClient:
|
||||
user_list_transactons: List = []
|
||||
key_list_transactons: List = []
|
||||
team_list_transactons: List = []
|
||||
user_list_transactons: dict = {}
|
||||
end_user_list_transactons: dict = {}
|
||||
key_list_transactons: dict = {}
|
||||
team_list_transactons: dict = {}
|
||||
spend_log_transactons: List = []
|
||||
|
||||
def __init__(self, database_url: str, proxy_logging_obj: ProxyLogging):
|
||||
|
@ -1855,34 +1860,62 @@ async def update_spend(
|
|||
Triggered every minute.
|
||||
|
||||
Requires:
|
||||
user_id_list: list,
|
||||
user_id_list: dict,
|
||||
keys_list: list,
|
||||
team_list: list,
|
||||
spend_logs: list,
|
||||
"""
|
||||
verbose_proxy_logger.debug(
|
||||
f"ENTERS UPDATE SPEND - len(prisma_client.user_list_transactons.keys()): {len(prisma_client.user_list_transactons.keys())}"
|
||||
)
|
||||
n_retry_times = 3
|
||||
### UPDATE USER TABLE ###
|
||||
if len(prisma_client.user_list_transactons) > 0:
|
||||
if len(prisma_client.user_list_transactons.keys()) > 0:
|
||||
for i in range(n_retry_times + 1):
|
||||
try:
|
||||
remaining_transactions = list(prisma_client.user_list_transactons)
|
||||
while remaining_transactions:
|
||||
batch_size = min(5000, len(remaining_transactions))
|
||||
batch_transactions = remaining_transactions[:batch_size]
|
||||
async with prisma_client.db.tx(timeout=60000) as transaction:
|
||||
async with transaction.batch_() as batcher:
|
||||
for user_id_tuple in batch_transactions:
|
||||
user_id, response_cost = user_id_tuple
|
||||
if user_id != "litellm-proxy-budget":
|
||||
batcher.litellm_usertable.update(
|
||||
where={"user_id": user_id},
|
||||
data={"spend": {"increment": response_cost}},
|
||||
)
|
||||
|
||||
remaining_transactions = remaining_transactions[batch_size:]
|
||||
|
||||
async with prisma_client.db.tx(timeout=6000) as transaction:
|
||||
async with transaction.batch_() as batcher:
|
||||
for (
|
||||
user_id,
|
||||
response_cost,
|
||||
) in prisma_client.user_list_transactons.items():
|
||||
batcher.litellm_usertable.update_many(
|
||||
where={"user_id": user_id},
|
||||
data={"spend": {"increment": response_cost}},
|
||||
)
|
||||
prisma_client.user_list_transactons = (
|
||||
[]
|
||||
{}
|
||||
) # Clear the remaining transactions after processing all batches in the loop.
|
||||
except httpx.ReadTimeout:
|
||||
if i >= n_retry_times: # If we've reached the maximum number of retries
|
||||
raise # Re-raise the last exception
|
||||
# Optionally, sleep for a bit before retrying
|
||||
await asyncio.sleep(2**i) # Exponential backoff
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
### UPDATE END-USER TABLE ###
|
||||
if len(prisma_client.end_user_list_transactons.keys()) > 0:
|
||||
for i in range(n_retry_times + 1):
|
||||
try:
|
||||
async with prisma_client.db.tx(timeout=6000) as transaction:
|
||||
async with transaction.batch_() as batcher:
|
||||
for (
|
||||
end_user_id,
|
||||
response_cost,
|
||||
) in prisma_client.end_user_list_transactons.items():
|
||||
max_user_budget = None
|
||||
if litellm.max_user_budget is not None:
|
||||
max_user_budget = litellm.max_user_budget
|
||||
new_user_obj = LiteLLM_EndUserTable(
|
||||
user_id=end_user_id, spend=response_cost, blocked=False
|
||||
)
|
||||
batcher.litellm_endusertable.update_many(
|
||||
where={"user_id": end_user_id},
|
||||
data={"spend": {"increment": response_cost}},
|
||||
)
|
||||
prisma_client.end_user_list_transactons = (
|
||||
{}
|
||||
) # Clear the remaining transactions after processing all batches in the loop.
|
||||
except httpx.ReadTimeout:
|
||||
if i >= n_retry_times: # If we've reached the maximum number of retries
|
||||
|
@ -1893,7 +1926,55 @@ async def update_spend(
|
|||
raise e
|
||||
|
||||
### UPDATE KEY TABLE ###
|
||||
if len(prisma_client.key_list_transactons.keys()) > 0:
|
||||
for i in range(n_retry_times + 1):
|
||||
try:
|
||||
async with prisma_client.db.tx(timeout=6000) as transaction:
|
||||
async with transaction.batch_() as batcher:
|
||||
for (
|
||||
token,
|
||||
response_cost,
|
||||
) in prisma_client.key_list_transactons.items():
|
||||
batcher.litellm_verificationtoken.update_many( # 'update_many' prevents error from being raised if no row exists
|
||||
where={"token": token},
|
||||
data={"spend": {"increment": response_cost}},
|
||||
)
|
||||
prisma_client.key_list_transactons = (
|
||||
{}
|
||||
) # Clear the remaining transactions after processing all batches in the loop.
|
||||
except httpx.ReadTimeout:
|
||||
if i >= n_retry_times: # If we've reached the maximum number of retries
|
||||
raise # Re-raise the last exception
|
||||
# Optionally, sleep for a bit before retrying
|
||||
await asyncio.sleep(2**i) # Exponential backoff
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
### UPDATE TEAM TABLE ###
|
||||
if len(prisma_client.team_list_transactons.keys()) > 0:
|
||||
for i in range(n_retry_times + 1):
|
||||
try:
|
||||
async with prisma_client.db.tx(timeout=6000) as transaction:
|
||||
async with transaction.batch_() as batcher:
|
||||
for (
|
||||
team_id,
|
||||
response_cost,
|
||||
) in prisma_client.team_list_transactons.items():
|
||||
batcher.litellm_teamtable.update_many( # 'update_many' prevents error from being raised if no row exists
|
||||
where={"team_id": team_id},
|
||||
data={"spend": {"increment": response_cost}},
|
||||
)
|
||||
prisma_client.team_list_transactons = (
|
||||
{}
|
||||
) # Clear the remaining transactions after processing all batches in the loop.
|
||||
except httpx.ReadTimeout:
|
||||
if i >= n_retry_times: # If we've reached the maximum number of retries
|
||||
raise # Re-raise the last exception
|
||||
# Optionally, sleep for a bit before retrying
|
||||
await asyncio.sleep(2**i) # Exponential backoff
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
### UPDATE SPEND LOGS TABLE ###
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue