fix(dynamo_db.py): add cost tracking support for key + user

This commit is contained in:
Krrish Dholakia 2024-01-11 23:56:41 +05:30
parent 9b3d78c4f3
commit f94a37a836
4 changed files with 163 additions and 107 deletions

View file

@ -244,12 +244,15 @@ async def user_api_key_auth(
if valid_token is None:
## check db
verbose_proxy_logger.debug(f"api key: {api_key}")
if prisma_client is not None:
if prisma_client is not None:
valid_token = await prisma_client.get_data(
token=api_key, expires=datetime.utcnow().replace(tzinfo=timezone.utc)
token=api_key,
expires=datetime.utcnow().replace(tzinfo=timezone.utc),
)
elif custom_db_client is not None:
valid_token = await custom_db_client.get_data(
key=api_key, table_name="key"
)
elif custom_db_client is not None:
valid_token = await custom_db_client.get_data(key="token", value=api_key, table_name="key")
# Token exists, now check expiration.
if valid_token.expires is not None:
expiry_time = datetime.fromisoformat(valid_token.expires)
@ -297,7 +300,7 @@ async def user_api_key_auth(
This makes the user row data accessible to pre-api call hooks.
"""
if prisma_client is not None:
if prisma_client is not None:
asyncio.create_task(
_cache_user_row(
user_id=valid_token.user_id,
@ -386,8 +389,8 @@ def load_from_azure_key_vault(use_azure_key_vault: bool = False):
def cost_tracking():
global prisma_client
if prisma_client is not None:
global prisma_client, custom_db_client
if prisma_client is not None or custom_db_client is not None:
if isinstance(litellm.success_callback, list):
verbose_proxy_logger.debug("setting litellm success callback to track cost")
if (track_cost_callback) not in litellm.success_callback: # type: ignore
@ -400,7 +403,7 @@ async def track_cost_callback(
start_time=None,
end_time=None, # start/end time for completion
):
global prisma_client
global prisma_client, custom_db_client
try:
# check if it has collected an entire stream response
verbose_proxy_logger.debug(
@ -419,10 +422,10 @@ async def track_cost_callback(
user_id = kwargs["litellm_params"]["metadata"].get(
"user_api_key_user_id", None
)
if user_api_key and prisma_client:
await update_prisma_database(
token=user_api_key, response_cost=response_cost
)
if user_api_key and (
prisma_client is not None or custom_db_client is not None
):
await update_database(token=user_api_key, response_cost=response_cost)
elif kwargs["stream"] == False: # for non streaming responses
response_cost = litellm.completion_cost(
completion_response=completion_response
@ -433,15 +436,17 @@ async def track_cost_callback(
user_id = kwargs["litellm_params"]["metadata"].get(
"user_api_key_user_id", None
)
if user_api_key and prisma_client:
await update_prisma_database(
if user_api_key and (
prisma_client is not None or custom_db_client is not None
):
await update_database(
token=user_api_key, response_cost=response_cost, user_id=user_id
)
except Exception as e:
verbose_proxy_logger.debug(f"error in tracking cost callback - {str(e)}")
async def update_prisma_database(token, response_cost, user_id=None):
async def update_database(token, response_cost, user_id=None):
try:
verbose_proxy_logger.debug(
f"Enters prisma db call, token: {token}; user_id: {user_id}"
@ -451,7 +456,12 @@ async def update_prisma_database(token, response_cost, user_id=None):
async def _update_user_db():
if user_id is None:
return
existing_spend_obj = await prisma_client.get_data(user_id=user_id)
if prisma_client is not None:
existing_spend_obj = await prisma_client.get_data(user_id=user_id)
elif custom_db_client is not None:
existing_spend_obj = await custom_db_client.get_data(
key=user_id, table_name="user"
)
if existing_spend_obj is None:
existing_spend = 0
else:
@ -462,23 +472,49 @@ async def update_prisma_database(token, response_cost, user_id=None):
verbose_proxy_logger.debug(f"new cost: {new_spend}")
# Update the cost column for the given user id
await prisma_client.update_data(user_id=user_id, data={"spend": new_spend})
if prisma_client is not None:
await prisma_client.update_data(
user_id=user_id, data={"spend": new_spend}
)
elif custom_db_client is not None:
await custom_db_client.update_data(
key=user_id, value={"spend": new_spend}, table_name="user"
)
### UPDATE KEY SPEND ###
async def _update_key_db():
# Fetch the existing cost for the given token
existing_spend_obj = await prisma_client.get_data(token=token)
verbose_proxy_logger.debug(f"existing spend: {existing_spend_obj}")
if existing_spend_obj is None:
existing_spend = 0
else:
existing_spend = existing_spend_obj.spend
# Calculate the new cost by adding the existing cost and response_cost
new_spend = existing_spend + response_cost
if prisma_client is not None:
# Fetch the existing cost for the given token
existing_spend_obj = await prisma_client.get_data(token=token)
verbose_proxy_logger.debug(f"existing spend: {existing_spend_obj}")
if existing_spend_obj is None:
existing_spend = 0
else:
existing_spend = existing_spend_obj.spend
# Calculate the new cost by adding the existing cost and response_cost
new_spend = existing_spend + response_cost
verbose_proxy_logger.debug(f"new cost: {new_spend}")
# Update the cost column for the given token
await prisma_client.update_data(token=token, data={"spend": new_spend})
verbose_proxy_logger.debug(f"new cost: {new_spend}")
# Update the cost column for the given token
await prisma_client.update_data(token=token, data={"spend": new_spend})
elif custom_db_client is not None:
# Fetch the existing cost for the given token
existing_spend_obj = await custom_db_client.get_data(
key=token, table_name="key"
)
verbose_proxy_logger.debug(f"existing spend: {existing_spend_obj}")
if existing_spend_obj is None:
existing_spend = 0
else:
existing_spend = existing_spend_obj.spend
# Calculate the new cost by adding the existing cost and response_cost
new_spend = existing_spend + response_cost
verbose_proxy_logger.debug(f"new cost: {new_spend}")
# Update the cost column for the given token
await custom_db_client.update_data(
key=token, value={"spend": new_spend}, table_name="key"
)
tasks = []
tasks.append(_update_user_db())
@ -802,8 +838,6 @@ class ProxyConfig:
verbose_proxy_logger.debug(f"GOING INTO LITELLM.GET_SECRET!")
database_url = litellm.get_secret(database_url)
verbose_proxy_logger.debug(f"RETRIEVED DB URL: {database_url}")
## COST TRACKING ##
cost_tracking()
### MASTER KEY ###
master_key = general_settings.get(
"master_key", litellm.get_secret("LITELLM_MASTER_KEY", None)
@ -821,7 +855,11 @@ class ProxyConfig:
database_type = general_settings.get("database_type", None)
if database_type is not None and database_type == "dynamo_db":
database_args = general_settings.get("database_args", None)
custom_db_client = DBClient(custom_db_args=database_args, custom_db_type=database_type)
custom_db_client = DBClient(
custom_db_args=database_args, custom_db_type=database_type
)
## COST TRACKING ##
cost_tracking()
### BACKGROUND HEALTH CHECKS ###
# Enable background health checks
use_background_health_checks = general_settings.get(
@ -930,9 +968,10 @@ async def generate_key_helper_fn(
try:
# Create a new verification token (you may want to enhance this logic based on your needs)
user_data = {
"max_budget": max_budget,
"user_email": user_email,
"user_id": user_id
"max_budget": max_budget,
"user_email": user_email,
"user_id": user_id,
"spend": spend,
}
key_data = {
"token": token,
@ -945,17 +984,15 @@ async def generate_key_helper_fn(
"max_parallel_requests": max_parallel_requests,
"metadata": metadata_json,
}
if prisma_client is not None:
verification_token_data = dict(key_data)
if prisma_client is not None:
verification_token_data = dict(key_data)
verification_token_data.update(user_data)
verbose_proxy_logger.debug("PrismaClient: Before Insert Data")
await prisma_client.insert_data(
data=verification_token_data
)
elif custom_db_client is not None:
await prisma_client.insert_data(data=verification_token_data)
elif custom_db_client is not None:
## CREATE USER (If necessary)
await custom_db_client.insert_data(value=user_data, table_name="user")
## CREATE KEY
## CREATE KEY
await custom_db_client.insert_data(value=key_data, table_name="key")
except Exception as e:
traceback.print_exc()
@ -1230,26 +1267,23 @@ async def startup_event():
verbose_proxy_logger.debug(f"prisma client - {prisma_client}")
if prisma_client is not None:
await prisma_client.connect()
if custom_db_client is not None:
await custom_db_client.connect()
if custom_db_client is not None:
await custom_db_client.connect()
if prisma_client is not None and master_key is not None:
# add master key to db
print(f'prisma_client: {prisma_client}')
await generate_key_helper_fn(
duration=None, models=[], aliases={}, config={}, spend=0, token=master_key
)
if custom_db_client is not None and master_key is not None:
# add master key to db
print(f'custom_db_client: {custom_db_client}')
await generate_key_helper_fn(
duration=None, models=[], aliases={}, config={}, spend=0, token=master_key
)
#### API ENDPOINTS ####
@router.get(
"/v1/models", dependencies=[Depends(user_api_key_auth)], tags=["model management"]