forked from phoenix/litellm-mirror
Merge pull request #1538 from BerriAI/litellm_use_custom_key_gen
[Feat] Proxy Auth - Use custom_key_generate
This commit is contained in:
commit
5700d60e1a
5 changed files with 238 additions and 6 deletions
|
@ -187,6 +187,7 @@ prisma_client: Optional[PrismaClient] = None
|
|||
custom_db_client: Optional[DBClient] = None
|
||||
user_api_key_cache = DualCache()
|
||||
user_custom_auth = None
|
||||
user_custom_key_generate = None
|
||||
use_background_health_checks = None
|
||||
use_queue = False
|
||||
health_check_interval = None
|
||||
|
@ -874,7 +875,7 @@ class ProxyConfig:
|
|||
"""
|
||||
Load config values into proxy global state
|
||||
"""
|
||||
global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, use_background_health_checks, health_check_interval, use_queue, custom_db_client
|
||||
global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, user_custom_key_generate, use_background_health_checks, health_check_interval, use_queue, custom_db_client
|
||||
|
||||
# Load existing config
|
||||
config = await self.get_config(config_file_path=config_file_path)
|
||||
|
@ -1052,6 +1053,12 @@ class ProxyConfig:
|
|||
user_custom_auth = get_instance_fn(
|
||||
value=custom_auth, config_file_path=config_file_path
|
||||
)
|
||||
|
||||
custom_key_generate = general_settings.get("custom_key_generate", None)
|
||||
if custom_key_generate is not None:
|
||||
user_custom_key_generate = get_instance_fn(
|
||||
value=custom_key_generate, config_file_path=config_file_path
|
||||
)
|
||||
## dynamodb
|
||||
database_type = general_settings.get("database_type", None)
|
||||
if database_type is not None and (
|
||||
|
@ -2164,7 +2171,16 @@ async def generate_key_fn(
|
|||
- expires: (datetime) Datetime object for when key expires.
|
||||
- user_id: (str) Unique user id - used for tracking spend across multiple keys for same user id.
|
||||
"""
|
||||
global user_custom_key_generate
|
||||
verbose_proxy_logger.debug("entered /key/generate")
|
||||
|
||||
if user_custom_key_generate is not None:
|
||||
result = await user_custom_key_generate(data)
|
||||
decision = result.get("decision", True)
|
||||
message = result.get("message", "Authentication Failed - Custom Auth Rule")
|
||||
if not decision:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=message)
|
||||
|
||||
data_json = data.json() # type: ignore
|
||||
response = await generate_key_helper_fn(**data_json)
|
||||
return GenerateKeyResponse(
|
||||
|
@ -2948,7 +2964,7 @@ async def get_routes():
|
|||
|
||||
@router.on_event("shutdown")
|
||||
async def shutdown_event():
|
||||
global prisma_client, master_key, user_custom_auth
|
||||
global prisma_client, master_key, user_custom_auth, user_custom_key_generate
|
||||
if prisma_client:
|
||||
verbose_proxy_logger.debug("Disconnecting from Prisma")
|
||||
await prisma_client.disconnect()
|
||||
|
@ -2958,7 +2974,7 @@ async def shutdown_event():
|
|||
|
||||
|
||||
def cleanup_router_config_variables():
|
||||
global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, use_background_health_checks, health_check_interval, prisma_client, custom_db_client
|
||||
global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, user_custom_key_generate, use_background_health_checks, health_check_interval, prisma_client, custom_db_client
|
||||
|
||||
# Set all variables to None
|
||||
master_key = None
|
||||
|
@ -2966,6 +2982,7 @@ def cleanup_router_config_variables():
|
|||
otel_logging = None
|
||||
user_custom_auth = None
|
||||
user_custom_auth_path = None
|
||||
user_custom_key_generate = None
|
||||
use_background_health_checks = None
|
||||
health_check_interval = None
|
||||
prisma_client = None
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue