fix(dynamo_db.py): add cost tracking support for key + user

This commit is contained in:
Krrish Dholakia 2024-01-11 23:56:41 +05:30
parent 9b3d78c4f3
commit f94a37a836
4 changed files with 163 additions and 107 deletions

View file

@ -1,5 +1,7 @@
from typing import Any, Literal, List from typing import Any, Literal, List
class CustomDB:
class CustomDB:
""" """
Implements a base class that we expect any custom db implementation (e.g. DynamoDB) to follow Implements a base class that we expect any custom db implementation (e.g. DynamoDB) to follow
""" """
@ -7,37 +9,45 @@ class CustomDB:
def __init__(self) -> None: def __init__(self) -> None:
pass pass
def get_data(self, key: str, value: str, table_name: Literal["user", "key", "config"]): def get_data(self, key: str, table_name: Literal["user", "key", "config"]):
""" """
Check if key valid Check if key valid
""" """
pass pass
def insert_data(self, value: Any, table_name: Literal["user", "key", "config"]): def insert_data(self, value: Any, table_name: Literal["user", "key", "config"]):
""" """
For new key / user logic For new key / user logic
""" """
pass pass
def update_data(self, key: str, value: Any, table_name: Literal["user", "key", "config"]): def update_data(
self, key: str, value: Any, table_name: Literal["user", "key", "config"]
):
""" """
For cost tracking logic For cost tracking logic
""" """
pass pass
def delete_data(self, keys: List[str], table_name: Literal["user", "key", "config"]): def delete_data(
self, keys: List[str], table_name: Literal["user", "key", "config"]
):
""" """
For /key/delete endpoint s For /key/delete endpoint s
""" """
def connect(self, ): def connect(
self,
):
"""
For connecting to db and creating / updating any tables
""" """
For connecting to db and creating / updating any tables
"""
pass pass
def disconnect(self, ): def disconnect(
self,
):
""" """
For closing connection on server shutdown For closing connection on server shutdown
""" """
pass pass

View file

@ -104,20 +104,22 @@ class DynamoDBWrapper(CustomDB):
await table.put_item(item=value) await table.put_item(item=value)
async def get_data( async def get_data(self, key: str, table_name: Literal["user", "key", "config"]):
self, key: str, value: str, table_name: Literal["user", "key", "config"]
):
async with ClientSession() as session: async with ClientSession() as session:
client = Client(AIOHTTP(session), Credentials.auto(), self.region_name) client = Client(AIOHTTP(session), Credentials.auto(), self.region_name)
table = None table = None
key_name = None
if table_name == DBTableNames.user.name: if table_name == DBTableNames.user.name:
table = client.table(DBTableNames.user.value) table = client.table(DBTableNames.user.value)
key_name = "user_id"
elif table_name == DBTableNames.key.name: elif table_name == DBTableNames.key.name:
table = client.table(DBTableNames.key.value) table = client.table(DBTableNames.key.value)
key_name = "token"
elif table_name == DBTableNames.config.name: elif table_name == DBTableNames.config.name:
table = client.table(DBTableNames.config.value) table = client.table(DBTableNames.config.value)
key_name = "param_name"
response = await table.get_item({key: value}) response = await table.get_item({key_name: key})
new_response: Any = None new_response: Any = None
if table_name == DBTableNames.user.name: if table_name == DBTableNames.user.name:
@ -139,46 +141,41 @@ class DynamoDBWrapper(CustomDB):
return new_response return new_response
async def update_data( async def update_data(
self, key: str, value: Any, table_name: Literal["user", "key", "config"] self, key: str, value: dict, table_name: Literal["user", "key", "config"]
): ):
async with ClientSession() as session: async with ClientSession() as session:
client = Client(AIOHTTP(session), Credentials.auto(), self.region_name) client = Client(AIOHTTP(session), Credentials.auto(), self.region_name)
table = None table = None
key_name = None key_name = None
data_obj: Optional[ try:
Union[LiteLLM_Config, LiteLLM_UserTable, LiteLLM_VerificationToken] if table_name == DBTableNames.user.name:
] = None table = client.table(DBTableNames.user.value)
if table_name == DBTableNames.user.name: key_name = "user_id"
table = client.table(DBTableNames.user.value)
key_name = "user_id"
data_obj = LiteLLM_UserTable(user_id=key, **value)
elif table_name == DBTableNames.key.name: elif table_name == DBTableNames.key.name:
table = client.table(DBTableNames.key.value) table = client.table(DBTableNames.key.value)
key_name = "token" key_name = "token"
data_obj = LiteLLM_VerificationToken(token=key, **value)
elif table_name == DBTableNames.config.name: elif table_name == DBTableNames.config.name:
table = client.table(DBTableNames.config.value) table = client.table(DBTableNames.config.value)
key_name = "param_name" key_name = "param_name"
data_obj = LiteLLM_Config(param_name=key, **value) else:
raise Exception(
f"Invalid table name. Needs to be one of - {DBTableNames.user.name}, {DBTableNames.key.name}, {DBTableNames.config.name}"
)
except Exception as e:
raise Exception(f"Error connecting to table - {str(e)}")
if data_obj is None:
raise Exception(
f"invalid table name passed in - {table_name}. Unable to load valid data object - {data_obj}."
)
# Initialize an empty UpdateExpression # Initialize an empty UpdateExpression
actions: List = [] actions: List = []
for field in data_obj.fields_set(): for k, v in value.items():
field_value = getattr(data_obj, field)
# Convert datetime object to ISO8601 string # Convert datetime object to ISO8601 string
if isinstance(field_value, datetime): if isinstance(v, datetime):
field_value = field_value.isoformat() v = v.isoformat()
# Accumulate updates # Accumulate updates
actions.append((F(field), Value(value=field_value))) actions.append((F(k), Value(value=v)))
update_expression = UpdateExpression(set_updates=actions) update_expression = UpdateExpression(set_updates=actions)
# Perform the update in DynamoDB # Perform the update in DynamoDB

View file

@ -244,12 +244,15 @@ async def user_api_key_auth(
if valid_token is None: if valid_token is None:
## check db ## check db
verbose_proxy_logger.debug(f"api key: {api_key}") verbose_proxy_logger.debug(f"api key: {api_key}")
if prisma_client is not None: if prisma_client is not None:
valid_token = await prisma_client.get_data( valid_token = await prisma_client.get_data(
token=api_key, expires=datetime.utcnow().replace(tzinfo=timezone.utc) token=api_key,
expires=datetime.utcnow().replace(tzinfo=timezone.utc),
)
elif custom_db_client is not None:
valid_token = await custom_db_client.get_data(
key=api_key, table_name="key"
) )
elif custom_db_client is not None:
valid_token = await custom_db_client.get_data(key="token", value=api_key, table_name="key")
# Token exists, now check expiration. # Token exists, now check expiration.
if valid_token.expires is not None: if valid_token.expires is not None:
expiry_time = datetime.fromisoformat(valid_token.expires) expiry_time = datetime.fromisoformat(valid_token.expires)
@ -297,7 +300,7 @@ async def user_api_key_auth(
This makes the user row data accessible to pre-api call hooks. This makes the user row data accessible to pre-api call hooks.
""" """
if prisma_client is not None: if prisma_client is not None:
asyncio.create_task( asyncio.create_task(
_cache_user_row( _cache_user_row(
user_id=valid_token.user_id, user_id=valid_token.user_id,
@ -386,8 +389,8 @@ def load_from_azure_key_vault(use_azure_key_vault: bool = False):
def cost_tracking(): def cost_tracking():
global prisma_client global prisma_client, custom_db_client
if prisma_client is not None: if prisma_client is not None or custom_db_client is not None:
if isinstance(litellm.success_callback, list): if isinstance(litellm.success_callback, list):
verbose_proxy_logger.debug("setting litellm success callback to track cost") verbose_proxy_logger.debug("setting litellm success callback to track cost")
if (track_cost_callback) not in litellm.success_callback: # type: ignore if (track_cost_callback) not in litellm.success_callback: # type: ignore
@ -400,7 +403,7 @@ async def track_cost_callback(
start_time=None, start_time=None,
end_time=None, # start/end time for completion end_time=None, # start/end time for completion
): ):
global prisma_client global prisma_client, custom_db_client
try: try:
# check if it has collected an entire stream response # check if it has collected an entire stream response
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
@ -419,10 +422,10 @@ async def track_cost_callback(
user_id = kwargs["litellm_params"]["metadata"].get( user_id = kwargs["litellm_params"]["metadata"].get(
"user_api_key_user_id", None "user_api_key_user_id", None
) )
if user_api_key and prisma_client: if user_api_key and (
await update_prisma_database( prisma_client is not None or custom_db_client is not None
token=user_api_key, response_cost=response_cost ):
) await update_database(token=user_api_key, response_cost=response_cost)
elif kwargs["stream"] == False: # for non streaming responses elif kwargs["stream"] == False: # for non streaming responses
response_cost = litellm.completion_cost( response_cost = litellm.completion_cost(
completion_response=completion_response completion_response=completion_response
@ -433,15 +436,17 @@ async def track_cost_callback(
user_id = kwargs["litellm_params"]["metadata"].get( user_id = kwargs["litellm_params"]["metadata"].get(
"user_api_key_user_id", None "user_api_key_user_id", None
) )
if user_api_key and prisma_client: if user_api_key and (
await update_prisma_database( prisma_client is not None or custom_db_client is not None
):
await update_database(
token=user_api_key, response_cost=response_cost, user_id=user_id token=user_api_key, response_cost=response_cost, user_id=user_id
) )
except Exception as e: except Exception as e:
verbose_proxy_logger.debug(f"error in tracking cost callback - {str(e)}") verbose_proxy_logger.debug(f"error in tracking cost callback - {str(e)}")
async def update_prisma_database(token, response_cost, user_id=None): async def update_database(token, response_cost, user_id=None):
try: try:
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
f"Enters prisma db call, token: {token}; user_id: {user_id}" f"Enters prisma db call, token: {token}; user_id: {user_id}"
@ -451,7 +456,12 @@ async def update_prisma_database(token, response_cost, user_id=None):
async def _update_user_db(): async def _update_user_db():
if user_id is None: if user_id is None:
return return
existing_spend_obj = await prisma_client.get_data(user_id=user_id) if prisma_client is not None:
existing_spend_obj = await prisma_client.get_data(user_id=user_id)
elif custom_db_client is not None:
existing_spend_obj = await custom_db_client.get_data(
key=user_id, table_name="user"
)
if existing_spend_obj is None: if existing_spend_obj is None:
existing_spend = 0 existing_spend = 0
else: else:
@ -462,23 +472,49 @@ async def update_prisma_database(token, response_cost, user_id=None):
verbose_proxy_logger.debug(f"new cost: {new_spend}") verbose_proxy_logger.debug(f"new cost: {new_spend}")
# Update the cost column for the given user id # Update the cost column for the given user id
await prisma_client.update_data(user_id=user_id, data={"spend": new_spend}) if prisma_client is not None:
await prisma_client.update_data(
user_id=user_id, data={"spend": new_spend}
)
elif custom_db_client is not None:
await custom_db_client.update_data(
key=user_id, value={"spend": new_spend}, table_name="user"
)
### UPDATE KEY SPEND ### ### UPDATE KEY SPEND ###
async def _update_key_db(): async def _update_key_db():
# Fetch the existing cost for the given token if prisma_client is not None:
existing_spend_obj = await prisma_client.get_data(token=token) # Fetch the existing cost for the given token
verbose_proxy_logger.debug(f"existing spend: {existing_spend_obj}") existing_spend_obj = await prisma_client.get_data(token=token)
if existing_spend_obj is None: verbose_proxy_logger.debug(f"existing spend: {existing_spend_obj}")
existing_spend = 0 if existing_spend_obj is None:
else: existing_spend = 0
existing_spend = existing_spend_obj.spend else:
# Calculate the new cost by adding the existing cost and response_cost existing_spend = existing_spend_obj.spend
new_spend = existing_spend + response_cost # Calculate the new cost by adding the existing cost and response_cost
new_spend = existing_spend + response_cost
verbose_proxy_logger.debug(f"new cost: {new_spend}") verbose_proxy_logger.debug(f"new cost: {new_spend}")
# Update the cost column for the given token # Update the cost column for the given token
await prisma_client.update_data(token=token, data={"spend": new_spend}) await prisma_client.update_data(token=token, data={"spend": new_spend})
elif custom_db_client is not None:
# Fetch the existing cost for the given token
existing_spend_obj = await custom_db_client.get_data(
key=token, table_name="key"
)
verbose_proxy_logger.debug(f"existing spend: {existing_spend_obj}")
if existing_spend_obj is None:
existing_spend = 0
else:
existing_spend = existing_spend_obj.spend
# Calculate the new cost by adding the existing cost and response_cost
new_spend = existing_spend + response_cost
verbose_proxy_logger.debug(f"new cost: {new_spend}")
# Update the cost column for the given token
await custom_db_client.update_data(
key=token, value={"spend": new_spend}, table_name="key"
)
tasks = [] tasks = []
tasks.append(_update_user_db()) tasks.append(_update_user_db())
@ -802,8 +838,6 @@ class ProxyConfig:
verbose_proxy_logger.debug(f"GOING INTO LITELLM.GET_SECRET!") verbose_proxy_logger.debug(f"GOING INTO LITELLM.GET_SECRET!")
database_url = litellm.get_secret(database_url) database_url = litellm.get_secret(database_url)
verbose_proxy_logger.debug(f"RETRIEVED DB URL: {database_url}") verbose_proxy_logger.debug(f"RETRIEVED DB URL: {database_url}")
## COST TRACKING ##
cost_tracking()
### MASTER KEY ### ### MASTER KEY ###
master_key = general_settings.get( master_key = general_settings.get(
"master_key", litellm.get_secret("LITELLM_MASTER_KEY", None) "master_key", litellm.get_secret("LITELLM_MASTER_KEY", None)
@ -821,7 +855,11 @@ class ProxyConfig:
database_type = general_settings.get("database_type", None) database_type = general_settings.get("database_type", None)
if database_type is not None and database_type == "dynamo_db": if database_type is not None and database_type == "dynamo_db":
database_args = general_settings.get("database_args", None) database_args = general_settings.get("database_args", None)
custom_db_client = DBClient(custom_db_args=database_args, custom_db_type=database_type) custom_db_client = DBClient(
custom_db_args=database_args, custom_db_type=database_type
)
## COST TRACKING ##
cost_tracking()
### BACKGROUND HEALTH CHECKS ### ### BACKGROUND HEALTH CHECKS ###
# Enable background health checks # Enable background health checks
use_background_health_checks = general_settings.get( use_background_health_checks = general_settings.get(
@ -930,9 +968,10 @@ async def generate_key_helper_fn(
try: try:
# Create a new verification token (you may want to enhance this logic based on your needs) # Create a new verification token (you may want to enhance this logic based on your needs)
user_data = { user_data = {
"max_budget": max_budget, "max_budget": max_budget,
"user_email": user_email, "user_email": user_email,
"user_id": user_id "user_id": user_id,
"spend": spend,
} }
key_data = { key_data = {
"token": token, "token": token,
@ -945,17 +984,15 @@ async def generate_key_helper_fn(
"max_parallel_requests": max_parallel_requests, "max_parallel_requests": max_parallel_requests,
"metadata": metadata_json, "metadata": metadata_json,
} }
if prisma_client is not None: if prisma_client is not None:
verification_token_data = dict(key_data) verification_token_data = dict(key_data)
verification_token_data.update(user_data) verification_token_data.update(user_data)
verbose_proxy_logger.debug("PrismaClient: Before Insert Data") verbose_proxy_logger.debug("PrismaClient: Before Insert Data")
await prisma_client.insert_data( await prisma_client.insert_data(data=verification_token_data)
data=verification_token_data elif custom_db_client is not None:
)
elif custom_db_client is not None:
## CREATE USER (If necessary) ## CREATE USER (If necessary)
await custom_db_client.insert_data(value=user_data, table_name="user") await custom_db_client.insert_data(value=user_data, table_name="user")
## CREATE KEY ## CREATE KEY
await custom_db_client.insert_data(value=key_data, table_name="key") await custom_db_client.insert_data(value=key_data, table_name="key")
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
@ -1230,26 +1267,23 @@ async def startup_event():
verbose_proxy_logger.debug(f"prisma client - {prisma_client}") verbose_proxy_logger.debug(f"prisma client - {prisma_client}")
if prisma_client is not None: if prisma_client is not None:
await prisma_client.connect() await prisma_client.connect()
if custom_db_client is not None: if custom_db_client is not None:
await custom_db_client.connect() await custom_db_client.connect()
if prisma_client is not None and master_key is not None: if prisma_client is not None and master_key is not None:
# add master key to db # add master key to db
print(f'prisma_client: {prisma_client}')
await generate_key_helper_fn( await generate_key_helper_fn(
duration=None, models=[], aliases={}, config={}, spend=0, token=master_key duration=None, models=[], aliases={}, config={}, spend=0, token=master_key
) )
if custom_db_client is not None and master_key is not None: if custom_db_client is not None and master_key is not None:
# add master key to db # add master key to db
print(f'custom_db_client: {custom_db_client}')
await generate_key_helper_fn( await generate_key_helper_fn(
duration=None, models=[], aliases={}, config={}, spend=0, token=master_key duration=None, models=[], aliases={}, config={}, spend=0, token=master_key
) )
#### API ENDPOINTS #### #### API ENDPOINTS ####
@router.get( @router.get(
"/v1/models", dependencies=[Depends(user_api_key_auth)], tags=["model management"] "/v1/models", dependencies=[Depends(user_api_key_auth)], tags=["model management"]

View file

@ -255,10 +255,10 @@ class PrismaClient:
print_verbose( print_verbose(
"LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'" "LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'"
) )
## init logging object ## init logging object
self.proxy_logging_obj = proxy_logging_obj self.proxy_logging_obj = proxy_logging_obj
if os.getenv("DATABASE_URL", None) is None: # setup hasn't taken place if os.getenv("DATABASE_URL", None) is None: # setup hasn't taken place
os.environ["DATABASE_URL"] = database_url os.environ["DATABASE_URL"] = database_url
# Save the current working directory # Save the current working directory
original_dir = os.getcwd() original_dir = os.getcwd()
@ -334,7 +334,7 @@ class PrismaClient:
self.proxy_logging_obj.failure_handler(original_exception=e) self.proxy_logging_obj.failure_handler(original_exception=e)
) )
raise e raise e
@backoff.on_exception( @backoff.on_exception(
backoff.expo, backoff.expo,
Exception, # base exception to catch for the backoff Exception, # base exception to catch for the backoff
@ -582,47 +582,60 @@ class PrismaClient:
) )
raise e raise e
class DBClient: class DBClient:
""" """
Routes requests for CustomAuth Routes requests for CustomAuth
[TODO] route b/w customauth and prisma [TODO] route b/w customauth and prisma
""" """
def __init__(self, custom_db_type: Literal["dynamo_db"], custom_db_args: dict) -> None:
def __init__(
self, custom_db_type: Literal["dynamo_db"], custom_db_args: dict
) -> None:
if custom_db_type == "dynamo_db": if custom_db_type == "dynamo_db":
self.db = DynamoDBWrapper(database_arguments=DynamoDBArgs(**custom_db_args)) self.db = DynamoDBWrapper(database_arguments=DynamoDBArgs(**custom_db_args))
async def get_data(self, key: str, value: Any, table_name: Literal["user", "key", "config"]): async def get_data(self, key: str, table_name: Literal["user", "key", "config"]):
""" """
Check if key valid Check if key valid
""" """
return await self.db.get_data(key=key, value=value, table_name=table_name) return await self.db.get_data(key=key, table_name=table_name)
async def insert_data(self, value: Any, table_name: Literal["user", "key", "config"]): async def insert_data(
self, value: Any, table_name: Literal["user", "key", "config"]
):
""" """
For new key / user logic For new key / user logic
""" """
return await self.db.insert_data(value=value, table_name=table_name) return await self.db.insert_data(value=value, table_name=table_name)
async def update_data(self, key: str, value: Any, table_name: Literal["user", "key", "config"]): async def update_data(
self, key: str, value: Any, table_name: Literal["user", "key", "config"]
):
""" """
For cost tracking logic For cost tracking logic
key - hash_key value \n
value - dict with updated values
""" """
return await self.db.update_data(key=key, value=value, table_name=table_name) return await self.db.update_data(key=key, value=value, table_name=table_name)
async def delete_data(self, keys: List[str], table_name: Literal["user", "key", "config"]): async def delete_data(
self, keys: List[str], table_name: Literal["user", "key", "config"]
):
""" """
For /key/delete endpoints For /key/delete endpoints
""" """
return await self.db.delete_data(keys=keys, table_name=table_name) return await self.db.delete_data(keys=keys, table_name=table_name)
async def connect(self): async def connect(self):
""" """
For connecting to db and creating / updating any tables For connecting to db and creating / updating any tables
""" """
return await self.db.connect() return await self.db.connect()
async def disconnect(self): async def disconnect(self):
""" """
For closing connection on server shutdown For closing connection on server shutdown
""" """
@ -669,7 +682,9 @@ def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any:
### HELPER FUNCTIONS ### ### HELPER FUNCTIONS ###
async def _cache_user_row(user_id: str, cache: DualCache, db: Union[PrismaClient, DBClient]): async def _cache_user_row(
user_id: str, cache: DualCache, db: Union[PrismaClient, DBClient]
):
""" """
Check if a user_id exists in cache, Check if a user_id exists in cache,
if not retrieve it. if not retrieve it.
@ -681,7 +696,7 @@ async def _cache_user_row(user_id: str, cache: DualCache, db: Union[PrismaClient
if isinstance(db, PrismaClient): if isinstance(db, PrismaClient):
user_row = await db.get_data(user_id=user_id) user_row = await db.get_data(user_id=user_id)
elif isinstance(db, DBClient): elif isinstance(db, DBClient):
user_row = await db.get_data(key="user_id", value=user_id, table_name="user") user_row = await db.get_data(key=user_id, table_name="user")
if user_row is not None: if user_row is not None:
print_verbose(f"User Row: {user_row}, type = {type(user_row)}") print_verbose(f"User Row: {user_row}, type = {type(user_row)}")
if hasattr(user_row, "model_dump_json") and callable( if hasattr(user_row, "model_dump_json") and callable(