diff --git a/litellm/__init__.py b/litellm/__init__.py index d67ecb718..2e23191a1 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -62,6 +62,9 @@ cache: Optional[ model_alias_map: Dict[str, str] = {} model_group_alias_map: Dict[str, str] = {} max_budget: float = 0.0 # set the max budget across all providers +budget_duration: Optional[ + str +] = None # proxy only - resets budget after fixed duration. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d"). _openai_completion_params = [ "functions", "function_call", diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index d5dc841cb..8a059c507 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -306,6 +306,10 @@ class LiteLLM_VerificationToken(LiteLLMBase): user_id: Union[str, None] max_parallel_requests: Union[int, None] metadata: Dict[str, str] = {} + tpm_limit: Optional[int] = None + rpm_limit: Optional[int] = None + budget_duration: Optional[str] = None + budget_reset_at: Optional[datetime] = None class LiteLLM_Config(LiteLLMBase): diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 747d67fbb..4aa06bff7 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -1154,6 +1154,7 @@ async def generate_key_helper_fn( metadata: Optional[dict] = {}, tpm_limit: Optional[int] = None, rpm_limit: Optional[int] = None, + query_type: Literal["insert_data", "update_data"] = "insert_data", ): global prisma_client, custom_db_client @@ -1196,6 +1197,12 @@ async def generate_key_helper_fn( duration_s = _duration_in_seconds(duration=key_budget_duration) key_reset_at = datetime.utcnow() + timedelta(seconds=duration_s) + if budget_duration is None: # one-time budget + reset_at = None + else: + duration_s = _duration_in_seconds(duration=budget_duration) + reset_at = datetime.utcnow() + timedelta(seconds=duration_s) + aliases_json = json.dumps(aliases) config_json = json.dumps(config) metadata_json = json.dumps(metadata) @@ -1216,6 +1223,8 @@ async def generate_key_helper_fn( "max_parallel_requests": max_parallel_requests, "tpm_limit": tpm_limit, "rpm_limit": rpm_limit, + "budget_duration": budget_duration, + "budget_reset_at": reset_at, } key_data = { "token": token, @@ -1237,13 +1246,18 @@ async def generate_key_helper_fn( if prisma_client is not None: ## CREATE USER (If necessary) verbose_proxy_logger.debug(f"prisma_client: Creating User={user_data}") - user_row = await prisma_client.insert_data( - data=user_data, table_name="user" - ) + if query_type == "insert_data": + user_row = await prisma_client.insert_data( + data=user_data, table_name="user" + ) + ## use default user model list if no key-specific model list provided + if len(user_row.models) > 0 and len(key_data["models"]) == 0: # type: ignore + key_data["models"] = user_row.models + elif query_type == "update_data": + user_row = await prisma_client.update_data( + data=user_data, table_name="user" + ) - ## use default user model list if no key-specific model list provided - if len(user_row.models) > 0 and len(key_data["models"]) == 0: # type: ignore - key_data["models"] = user_row.models ## CREATE KEY verbose_proxy_logger.debug(f"prisma_client: Creating Key={key_data}") await prisma_client.insert_data(data=key_data, table_name="key") @@ -1551,6 +1565,25 @@ async def startup_event(): await generate_key_helper_fn( duration=None, models=[], aliases={}, config={}, spend=0, token=master_key ) + + if ( + prisma_client is not None + and litellm.max_budget > 0 + and litellm.budget_duration is not None + ): + # add proxy budget to db in the user table + await generate_key_helper_fn( + user_id="litellm-proxy-budget", + duration=None, + models=[], + aliases={}, + config={}, + spend=0, + max_budget=litellm.max_budget, + budget_duration=litellm.budget_duration, + query_type="update_data", + ) + verbose_proxy_logger.debug( f"custom_db_client client {custom_db_client}. Master_key: {master_key}" ) diff --git a/litellm/proxy/schema.prisma b/litellm/proxy/schema.prisma index cbc7da399..441c3515f 100644 --- a/litellm/proxy/schema.prisma +++ b/litellm/proxy/schema.prisma @@ -17,6 +17,8 @@ model LiteLLM_UserTable { max_parallel_requests Int? tpm_limit BigInt? rpm_limit BigInt? + budget_duration String? + budget_reset_at DateTime? } // required for token gen diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 1e27c0d82..06b1194b1 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -425,12 +425,21 @@ class PrismaClient: status_code=status.HTTP_401_UNAUTHORIZED, detail="Authentication Error: invalid user key - token does not exist", ) - elif user_id is not None: - response = await self.db.litellm_usertable.find_unique( # type: ignore - where={ - "user_id": user_id, - } - ) + elif user_id is not None or ( + table_name is not None and table_name == "user" + ): + if query_type == "find_unique": + response = await self.db.litellm_usertable.find_unique( # type: ignore + where={ + "user_id": user_id, # type: ignore + } + ) + elif query_type == "find_all" and reset_at is not None: + response = await self.db.litellm_usertable.find_many( + where={ # type:ignore + "budget_reset_at": {"lt": reset_at}, + } + ) return response elif table_name == "user" and query_type == "find_all": response = await self.db.litellm_usertable.find_many( # type: ignore @@ -597,10 +606,16 @@ class PrismaClient: + "\033[0m" ) return {"token": token, "data": db_data} - elif user_id is not None: + elif ( + user_id is not None + or (table_name is not None and table_name == "user") + and query_type == "update" + ): """ If data['spend'] + data['user'], update the user table with spend info as well """ + if user_id is None: + user_id = db_data["user_id"] update_user_row = await self.db.litellm_usertable.update( where={"user_id": user_id}, # type: ignore data={**db_data}, # type: ignore @@ -650,6 +665,30 @@ class PrismaClient: print_verbose( "\033[91m" + f"DB Token Table update succeeded" + "\033[0m" ) + elif ( + table_name is not None + and table_name == "user" + and query_type == "update_many" + and data_list is not None + and isinstance(data_list, list) + ): + """ + Batch write update queries + """ + batcher = self.db.batch_() + for idx, user in enumerate(data_list): + try: + data_json = self.jsonify_object(data=user.model_dump()) + except: + data_json = self.jsonify_object(data=user.dict()) + batcher.litellm_usertable.update( + where={"user_id": user.user_id}, # type: ignore + data={**data_json}, # type: ignore + ) + await batcher.commit() + print_verbose( + "\033[91m" + f"DB User Table update succeeded" + "\033[0m" + ) except Exception as e: asyncio.create_task( self.proxy_logging_obj.failure_handler(original_exception=e) @@ -1007,17 +1046,36 @@ async def reset_budget(prisma_client: PrismaClient): Updates db """ if prisma_client is not None: + ### RESET KEY BUDGET ### now = datetime.utcnow() keys_to_reset = await prisma_client.get_data( table_name="key", query_type="find_all", expires=now, reset_at=now ) - for key in keys_to_reset: - key.spend = 0.0 - duration_s = _duration_in_seconds(duration=key.budget_duration) - key.budget_reset_at = key.budget_reset_at + timedelta(seconds=duration_s) + if keys_to_reset is not None and len(keys_to_reset) > 0: + for key in keys_to_reset: + key.spend = 0.0 + duration_s = _duration_in_seconds(duration=key.budget_duration) + key.budget_reset_at = now + timedelta(seconds=duration_s) - if len(keys_to_reset) > 0: await prisma_client.update_data( query_type="update_many", data_list=keys_to_reset, table_name="key" ) + + ### RESET USER BUDGET ### + now = datetime.utcnow() + users_to_reset = await prisma_client.get_data( + table_name="user", query_type="find_all", reset_at=now + ) + + verbose_proxy_logger.debug(f"users_to_reset from get_data: {users_to_reset}") + + if users_to_reset is not None and len(users_to_reset) > 0: + for user in users_to_reset: + user.spend = 0.0 + duration_s = _duration_in_seconds(duration=user.budget_duration) + user.budget_reset_at = now + timedelta(seconds=duration_s) + + await prisma_client.update_data( + query_type="update_many", data_list=users_to_reset, table_name="user" + ) diff --git a/schema.prisma b/schema.prisma index cbc7da399..441c3515f 100644 --- a/schema.prisma +++ b/schema.prisma @@ -17,6 +17,8 @@ model LiteLLM_UserTable { max_parallel_requests Int? tpm_limit BigInt? rpm_limit BigInt? + budget_duration String? + budget_reset_at DateTime? } // required for token gen