diff --git a/litellm/proxy/db/base_client.py b/litellm/proxy/db/base_client.py index 9eb7e5ddbf..07f0ecdc47 100644 --- a/litellm/proxy/db/base_client.py +++ b/litellm/proxy/db/base_client.py @@ -1,5 +1,7 @@ from typing import Any, Literal, List -class CustomDB: + + +class CustomDB: """ Implements a base class that we expect any custom db implementation (e.g. DynamoDB) to follow """ @@ -7,37 +9,45 @@ class CustomDB: def __init__(self) -> None: pass - def get_data(self, key: str, value: str, table_name: Literal["user", "key", "config"]): + def get_data(self, key: str, table_name: Literal["user", "key", "config"]): """ Check if key valid """ pass - - def insert_data(self, value: Any, table_name: Literal["user", "key", "config"]): + + def insert_data(self, value: Any, table_name: Literal["user", "key", "config"]): """ For new key / user logic """ pass - def update_data(self, key: str, value: Any, table_name: Literal["user", "key", "config"]): + def update_data( + self, key: str, value: Any, table_name: Literal["user", "key", "config"] + ): """ For cost tracking logic """ pass - def delete_data(self, keys: List[str], table_name: Literal["user", "key", "config"]): + def delete_data( + self, keys: List[str], table_name: Literal["user", "key", "config"] + ): """ For /key/delete endpoint s """ - - def connect(self, ): + + def connect( + self, + ): + """ + For connecting to db and creating / updating any tables """ - For connecting to db and creating / updating any tables - """ pass - def disconnect(self, ): + def disconnect( + self, + ): """ For closing connection on server shutdown """ - pass \ No newline at end of file + pass diff --git a/litellm/proxy/db/dynamo_db.py b/litellm/proxy/db/dynamo_db.py index 660eee9100..17e5bd3b8c 100644 --- a/litellm/proxy/db/dynamo_db.py +++ b/litellm/proxy/db/dynamo_db.py @@ -104,20 +104,22 @@ class DynamoDBWrapper(CustomDB): await table.put_item(item=value) - async def get_data( - self, key: str, value: str, table_name: Literal["user", "key", "config"] - ): + async def get_data(self, key: str, table_name: Literal["user", "key", "config"]): async with ClientSession() as session: client = Client(AIOHTTP(session), Credentials.auto(), self.region_name) table = None + key_name = None if table_name == DBTableNames.user.name: table = client.table(DBTableNames.user.value) + key_name = "user_id" elif table_name == DBTableNames.key.name: table = client.table(DBTableNames.key.value) + key_name = "token" elif table_name == DBTableNames.config.name: table = client.table(DBTableNames.config.value) + key_name = "param_name" - response = await table.get_item({key: value}) + response = await table.get_item({key_name: key}) new_response: Any = None if table_name == DBTableNames.user.name: @@ -139,46 +141,41 @@ class DynamoDBWrapper(CustomDB): return new_response async def update_data( - self, key: str, value: Any, table_name: Literal["user", "key", "config"] + self, key: str, value: dict, table_name: Literal["user", "key", "config"] ): async with ClientSession() as session: client = Client(AIOHTTP(session), Credentials.auto(), self.region_name) table = None key_name = None - data_obj: Optional[ - Union[LiteLLM_Config, LiteLLM_UserTable, LiteLLM_VerificationToken] - ] = None - if table_name == DBTableNames.user.name: - table = client.table(DBTableNames.user.value) - key_name = "user_id" - data_obj = LiteLLM_UserTable(user_id=key, **value) + try: + if table_name == DBTableNames.user.name: + table = client.table(DBTableNames.user.value) + key_name = "user_id" - elif table_name == DBTableNames.key.name: - table = client.table(DBTableNames.key.value) - key_name = "token" - data_obj = LiteLLM_VerificationToken(token=key, **value) + elif table_name == DBTableNames.key.name: + table = client.table(DBTableNames.key.value) + key_name = "token" - elif table_name == DBTableNames.config.name: - table = client.table(DBTableNames.config.value) - key_name = "param_name" - data_obj = LiteLLM_Config(param_name=key, **value) + elif table_name == DBTableNames.config.name: + table = client.table(DBTableNames.config.value) + key_name = "param_name" + else: + raise Exception( + f"Invalid table name. Needs to be one of - {DBTableNames.user.name}, {DBTableNames.key.name}, {DBTableNames.config.name}" + ) + except Exception as e: + raise Exception(f"Error connecting to table - {str(e)}") - if data_obj is None: - raise Exception( - f"invalid table name passed in - {table_name}. Unable to load valid data object - {data_obj}." - ) # Initialize an empty UpdateExpression actions: List = [] - for field in data_obj.fields_set(): - field_value = getattr(data_obj, field) - + for k, v in value.items(): # Convert datetime object to ISO8601 string - if isinstance(field_value, datetime): - field_value = field_value.isoformat() + if isinstance(v, datetime): + v = v.isoformat() # Accumulate updates - actions.append((F(field), Value(value=field_value))) + actions.append((F(k), Value(value=v))) update_expression = UpdateExpression(set_updates=actions) # Perform the update in DynamoDB diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 00116c1ff1..25dde7cf0d 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -244,12 +244,15 @@ async def user_api_key_auth( if valid_token is None: ## check db verbose_proxy_logger.debug(f"api key: {api_key}") - if prisma_client is not None: + if prisma_client is not None: valid_token = await prisma_client.get_data( - token=api_key, expires=datetime.utcnow().replace(tzinfo=timezone.utc) + token=api_key, + expires=datetime.utcnow().replace(tzinfo=timezone.utc), + ) + elif custom_db_client is not None: + valid_token = await custom_db_client.get_data( + key=api_key, table_name="key" ) - elif custom_db_client is not None: - valid_token = await custom_db_client.get_data(key="token", value=api_key, table_name="key") # Token exists, now check expiration. if valid_token.expires is not None: expiry_time = datetime.fromisoformat(valid_token.expires) @@ -297,7 +300,7 @@ async def user_api_key_auth( This makes the user row data accessible to pre-api call hooks. """ - if prisma_client is not None: + if prisma_client is not None: asyncio.create_task( _cache_user_row( user_id=valid_token.user_id, @@ -386,8 +389,8 @@ def load_from_azure_key_vault(use_azure_key_vault: bool = False): def cost_tracking(): - global prisma_client - if prisma_client is not None: + global prisma_client, custom_db_client + if prisma_client is not None or custom_db_client is not None: if isinstance(litellm.success_callback, list): verbose_proxy_logger.debug("setting litellm success callback to track cost") if (track_cost_callback) not in litellm.success_callback: # type: ignore @@ -400,7 +403,7 @@ async def track_cost_callback( start_time=None, end_time=None, # start/end time for completion ): - global prisma_client + global prisma_client, custom_db_client try: # check if it has collected an entire stream response verbose_proxy_logger.debug( @@ -419,10 +422,10 @@ async def track_cost_callback( user_id = kwargs["litellm_params"]["metadata"].get( "user_api_key_user_id", None ) - if user_api_key and prisma_client: - await update_prisma_database( - token=user_api_key, response_cost=response_cost - ) + if user_api_key and ( + prisma_client is not None or custom_db_client is not None + ): + await update_database(token=user_api_key, response_cost=response_cost) elif kwargs["stream"] == False: # for non streaming responses response_cost = litellm.completion_cost( completion_response=completion_response @@ -433,15 +436,17 @@ async def track_cost_callback( user_id = kwargs["litellm_params"]["metadata"].get( "user_api_key_user_id", None ) - if user_api_key and prisma_client: - await update_prisma_database( + if user_api_key and ( + prisma_client is not None or custom_db_client is not None + ): + await update_database( token=user_api_key, response_cost=response_cost, user_id=user_id ) except Exception as e: verbose_proxy_logger.debug(f"error in tracking cost callback - {str(e)}") -async def update_prisma_database(token, response_cost, user_id=None): +async def update_database(token, response_cost, user_id=None): try: verbose_proxy_logger.debug( f"Enters prisma db call, token: {token}; user_id: {user_id}" @@ -451,7 +456,12 @@ async def update_prisma_database(token, response_cost, user_id=None): async def _update_user_db(): if user_id is None: return - existing_spend_obj = await prisma_client.get_data(user_id=user_id) + if prisma_client is not None: + existing_spend_obj = await prisma_client.get_data(user_id=user_id) + elif custom_db_client is not None: + existing_spend_obj = await custom_db_client.get_data( + key=user_id, table_name="user" + ) if existing_spend_obj is None: existing_spend = 0 else: @@ -462,23 +472,49 @@ async def update_prisma_database(token, response_cost, user_id=None): verbose_proxy_logger.debug(f"new cost: {new_spend}") # Update the cost column for the given user id - await prisma_client.update_data(user_id=user_id, data={"spend": new_spend}) + if prisma_client is not None: + await prisma_client.update_data( + user_id=user_id, data={"spend": new_spend} + ) + elif custom_db_client is not None: + await custom_db_client.update_data( + key=user_id, value={"spend": new_spend}, table_name="user" + ) ### UPDATE KEY SPEND ### async def _update_key_db(): - # Fetch the existing cost for the given token - existing_spend_obj = await prisma_client.get_data(token=token) - verbose_proxy_logger.debug(f"existing spend: {existing_spend_obj}") - if existing_spend_obj is None: - existing_spend = 0 - else: - existing_spend = existing_spend_obj.spend - # Calculate the new cost by adding the existing cost and response_cost - new_spend = existing_spend + response_cost + if prisma_client is not None: + # Fetch the existing cost for the given token + existing_spend_obj = await prisma_client.get_data(token=token) + verbose_proxy_logger.debug(f"existing spend: {existing_spend_obj}") + if existing_spend_obj is None: + existing_spend = 0 + else: + existing_spend = existing_spend_obj.spend + # Calculate the new cost by adding the existing cost and response_cost + new_spend = existing_spend + response_cost - verbose_proxy_logger.debug(f"new cost: {new_spend}") - # Update the cost column for the given token - await prisma_client.update_data(token=token, data={"spend": new_spend}) + verbose_proxy_logger.debug(f"new cost: {new_spend}") + # Update the cost column for the given token + await prisma_client.update_data(token=token, data={"spend": new_spend}) + elif custom_db_client is not None: + # Fetch the existing cost for the given token + existing_spend_obj = await custom_db_client.get_data( + key=token, table_name="key" + ) + verbose_proxy_logger.debug(f"existing spend: {existing_spend_obj}") + if existing_spend_obj is None: + existing_spend = 0 + else: + existing_spend = existing_spend_obj.spend + # Calculate the new cost by adding the existing cost and response_cost + new_spend = existing_spend + response_cost + + verbose_proxy_logger.debug(f"new cost: {new_spend}") + # Update the cost column for the given token + await custom_db_client.update_data( + key=token, value={"spend": new_spend}, table_name="key" + ) tasks = [] tasks.append(_update_user_db()) @@ -802,8 +838,6 @@ class ProxyConfig: verbose_proxy_logger.debug(f"GOING INTO LITELLM.GET_SECRET!") database_url = litellm.get_secret(database_url) verbose_proxy_logger.debug(f"RETRIEVED DB URL: {database_url}") - ## COST TRACKING ## - cost_tracking() ### MASTER KEY ### master_key = general_settings.get( "master_key", litellm.get_secret("LITELLM_MASTER_KEY", None) @@ -821,7 +855,11 @@ class ProxyConfig: database_type = general_settings.get("database_type", None) if database_type is not None and database_type == "dynamo_db": database_args = general_settings.get("database_args", None) - custom_db_client = DBClient(custom_db_args=database_args, custom_db_type=database_type) + custom_db_client = DBClient( + custom_db_args=database_args, custom_db_type=database_type + ) + ## COST TRACKING ## + cost_tracking() ### BACKGROUND HEALTH CHECKS ### # Enable background health checks use_background_health_checks = general_settings.get( @@ -930,9 +968,10 @@ async def generate_key_helper_fn( try: # Create a new verification token (you may want to enhance this logic based on your needs) user_data = { - "max_budget": max_budget, - "user_email": user_email, - "user_id": user_id + "max_budget": max_budget, + "user_email": user_email, + "user_id": user_id, + "spend": spend, } key_data = { "token": token, @@ -945,17 +984,15 @@ async def generate_key_helper_fn( "max_parallel_requests": max_parallel_requests, "metadata": metadata_json, } - if prisma_client is not None: - verification_token_data = dict(key_data) + if prisma_client is not None: + verification_token_data = dict(key_data) verification_token_data.update(user_data) verbose_proxy_logger.debug("PrismaClient: Before Insert Data") - await prisma_client.insert_data( - data=verification_token_data - ) - elif custom_db_client is not None: + await prisma_client.insert_data(data=verification_token_data) + elif custom_db_client is not None: ## CREATE USER (If necessary) await custom_db_client.insert_data(value=user_data, table_name="user") - ## CREATE KEY + ## CREATE KEY await custom_db_client.insert_data(value=key_data, table_name="key") except Exception as e: traceback.print_exc() @@ -1230,26 +1267,23 @@ async def startup_event(): verbose_proxy_logger.debug(f"prisma client - {prisma_client}") if prisma_client is not None: await prisma_client.connect() - - if custom_db_client is not None: - await custom_db_client.connect() + + if custom_db_client is not None: + await custom_db_client.connect() if prisma_client is not None and master_key is not None: # add master key to db - print(f'prisma_client: {prisma_client}') await generate_key_helper_fn( duration=None, models=[], aliases={}, config={}, spend=0, token=master_key ) if custom_db_client is not None and master_key is not None: # add master key to db - print(f'custom_db_client: {custom_db_client}') await generate_key_helper_fn( duration=None, models=[], aliases={}, config={}, spend=0, token=master_key ) - #### API ENDPOINTS #### @router.get( "/v1/models", dependencies=[Depends(user_api_key_auth)], tags=["model management"] diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 4668340a2d..3e07b21486 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -255,10 +255,10 @@ class PrismaClient: print_verbose( "LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'" ) - ## init logging object + ## init logging object self.proxy_logging_obj = proxy_logging_obj - if os.getenv("DATABASE_URL", None) is None: # setup hasn't taken place + if os.getenv("DATABASE_URL", None) is None: # setup hasn't taken place os.environ["DATABASE_URL"] = database_url # Save the current working directory original_dir = os.getcwd() @@ -334,7 +334,7 @@ class PrismaClient: self.proxy_logging_obj.failure_handler(original_exception=e) ) raise e - + @backoff.on_exception( backoff.expo, Exception, # base exception to catch for the backoff @@ -582,47 +582,60 @@ class PrismaClient: ) raise e + class DBClient: """ - Routes requests for CustomAuth + Routes requests for CustomAuth [TODO] route b/w customauth and prisma """ - def __init__(self, custom_db_type: Literal["dynamo_db"], custom_db_args: dict) -> None: + + def __init__( + self, custom_db_type: Literal["dynamo_db"], custom_db_args: dict + ) -> None: if custom_db_type == "dynamo_db": self.db = DynamoDBWrapper(database_arguments=DynamoDBArgs(**custom_db_args)) - async def get_data(self, key: str, value: Any, table_name: Literal["user", "key", "config"]): + async def get_data(self, key: str, table_name: Literal["user", "key", "config"]): """ Check if key valid """ - return await self.db.get_data(key=key, value=value, table_name=table_name) - - async def insert_data(self, value: Any, table_name: Literal["user", "key", "config"]): + return await self.db.get_data(key=key, table_name=table_name) + + async def insert_data( + self, value: Any, table_name: Literal["user", "key", "config"] + ): """ For new key / user logic """ return await self.db.insert_data(value=value, table_name=table_name) - async def update_data(self, key: str, value: Any, table_name: Literal["user", "key", "config"]): + async def update_data( + self, key: str, value: Any, table_name: Literal["user", "key", "config"] + ): """ For cost tracking logic + + key - hash_key value \n + value - dict with updated values """ return await self.db.update_data(key=key, value=value, table_name=table_name) - async def delete_data(self, keys: List[str], table_name: Literal["user", "key", "config"]): + async def delete_data( + self, keys: List[str], table_name: Literal["user", "key", "config"] + ): """ For /key/delete endpoints """ return await self.db.delete_data(keys=keys, table_name=table_name) - + async def connect(self): """ - For connecting to db and creating / updating any tables - """ + For connecting to db and creating / updating any tables + """ return await self.db.connect() - async def disconnect(self): + async def disconnect(self): """ For closing connection on server shutdown """ @@ -669,7 +682,9 @@ def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any: ### HELPER FUNCTIONS ### -async def _cache_user_row(user_id: str, cache: DualCache, db: Union[PrismaClient, DBClient]): +async def _cache_user_row( + user_id: str, cache: DualCache, db: Union[PrismaClient, DBClient] +): """ Check if a user_id exists in cache, if not retrieve it. @@ -681,7 +696,7 @@ async def _cache_user_row(user_id: str, cache: DualCache, db: Union[PrismaClient if isinstance(db, PrismaClient): user_row = await db.get_data(user_id=user_id) elif isinstance(db, DBClient): - user_row = await db.get_data(key="user_id", value=user_id, table_name="user") + user_row = await db.get_data(key=user_id, table_name="user") if user_row is not None: print_verbose(f"User Row: {user_row}, type = {type(user_row)}") if hasattr(user_row, "model_dump_json") and callable(