diff --git a/litellm/proxy/example_config_yaml/custom_auth.py b/litellm/proxy/example_config_yaml/custom_auth.py index 416b66682..a764a647a 100644 --- a/litellm/proxy/example_config_yaml/custom_auth.py +++ b/litellm/proxy/example_config_yaml/custom_auth.py @@ -1,4 +1,4 @@ -from litellm.proxy._types import UserAPIKeyAuth +from litellm.proxy._types import UserAPIKeyAuth, GenerateKeyRequest from fastapi import Request from dotenv import load_dotenv import os @@ -14,3 +14,40 @@ async def user_api_key_auth(request: Request, api_key: str) -> UserAPIKeyAuth: raise Exception except: raise Exception + + +async def generate_key_fn(data: GenerateKeyRequest): + """ + Asynchronously decides if a key should be generated or not based on the provided data. + + Args: + data (GenerateKeyRequest): The data to be used for decision making. + + Returns: + bool: True if a key should be generated, False otherwise. + """ + # decide if a key should be generated or not + data_json = data.json() # type: ignore + + # Unpacking variables + team_id = data_json.get("team_id") + duration = data_json.get("duration") + models = data_json.get("models") + aliases = data_json.get("aliases") + config = data_json.get("config") + spend = data_json.get("spend") + user_id = data_json.get("user_id") + max_parallel_requests = data_json.get("max_parallel_requests") + metadata = data_json.get("metadata") + tpm_limit = data_json.get("tpm_limit") + rpm_limit = data_json.get("rpm_limit") + + if team_id is not None and len(team_id) > 0: + return { + "decision": True, + } + else: + return { + "decision": True, + "message": "This violates LiteLLM Proxy Rules. No team id provided.", + } diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml index 417b4c6f1..29aa3cf4f 100644 --- a/litellm/proxy/proxy_config.yaml +++ b/litellm/proxy/proxy_config.yaml @@ -62,8 +62,9 @@ litellm_settings: # setting callback class # callbacks: custom_callbacks.proxy_handler_instance # sets litellm.callbacks = [proxy_handler_instance] -# general_settings: -# master_key: sk-1234 +general_settings: + master_key: sk-1234 + custom_key_generate: custom_auth.generate_key_fn # database_type: "dynamo_db" # database_args: { # 👈 all args - https://github.com/BerriAI/litellm/blob/befbcbb7ac8f59835ce47415c128decf37aac328/litellm/proxy/_types.py#L190 # "billing_mode": "PAY_PER_REQUEST", diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index fafc41457..f032c5ec2 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -187,6 +187,7 @@ prisma_client: Optional[PrismaClient] = None custom_db_client: Optional[DBClient] = None user_api_key_cache = DualCache() user_custom_auth = None +user_custom_key_generate = None use_background_health_checks = None use_queue = False health_check_interval = None @@ -874,7 +875,7 @@ class ProxyConfig: """ Load config values into proxy global state """ - global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, use_background_health_checks, health_check_interval, use_queue, custom_db_client + global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, user_custom_key_generate, use_background_health_checks, health_check_interval, use_queue, custom_db_client # Load existing config config = await self.get_config(config_file_path=config_file_path) @@ -1052,6 +1053,12 @@ class ProxyConfig: user_custom_auth = get_instance_fn( value=custom_auth, config_file_path=config_file_path ) + + custom_key_generate = general_settings.get("custom_key_generate", None) + if custom_key_generate is not None: + user_custom_key_generate = get_instance_fn( + value=custom_key_generate, config_file_path=config_file_path + ) ## dynamodb database_type = general_settings.get("database_type", None) if database_type is not None and ( @@ -2156,7 +2163,16 @@ async def generate_key_fn( - expires: (datetime) Datetime object for when key expires. - user_id: (str) Unique user id - used for tracking spend across multiple keys for same user id. """ + global user_custom_key_generate verbose_proxy_logger.debug("entered /key/generate") + + if user_custom_key_generate is not None: + result = await user_custom_key_generate(data) + decision = result.get("decision", True) + message = result.get("message", "Authentication Failed - Custom Auth Rule") + if not decision: + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=message) + data_json = data.json() # type: ignore response = await generate_key_helper_fn(**data_json) return GenerateKeyResponse( @@ -2924,7 +2940,7 @@ async def get_routes(): @router.on_event("shutdown") async def shutdown_event(): - global prisma_client, master_key, user_custom_auth + global prisma_client, master_key, user_custom_auth, user_custom_key_generate if prisma_client: verbose_proxy_logger.debug("Disconnecting from Prisma") await prisma_client.disconnect() @@ -2934,7 +2950,7 @@ async def shutdown_event(): def cleanup_router_config_variables(): - global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, use_background_health_checks, health_check_interval, prisma_client, custom_db_client + global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, user_custom_key_generate, use_background_health_checks, health_check_interval, prisma_client, custom_db_client # Set all variables to None master_key = None @@ -2942,6 +2958,7 @@ def cleanup_router_config_variables(): otel_logging = None user_custom_auth = None user_custom_auth_path = None + user_custom_key_generate = None use_background_health_checks = None health_check_interval = None prisma_client = None