forked from phoenix/litellm-mirror
v0 using custom_key_generate
This commit is contained in:
parent
ab60f071ff
commit
13eb40e7bd
3 changed files with 61 additions and 6 deletions
|
@ -1,4 +1,4 @@
|
||||||
from litellm.proxy._types import UserAPIKeyAuth
|
from litellm.proxy._types import UserAPIKeyAuth, GenerateKeyRequest
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
import os
|
import os
|
||||||
|
@ -14,3 +14,40 @@ async def user_api_key_auth(request: Request, api_key: str) -> UserAPIKeyAuth:
|
||||||
raise Exception
|
raise Exception
|
||||||
except:
|
except:
|
||||||
raise Exception
|
raise Exception
|
||||||
|
|
||||||
|
|
||||||
|
async def generate_key_fn(data: GenerateKeyRequest):
|
||||||
|
"""
|
||||||
|
Asynchronously decides if a key should be generated or not based on the provided data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data (GenerateKeyRequest): The data to be used for decision making.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if a key should be generated, False otherwise.
|
||||||
|
"""
|
||||||
|
# decide if a key should be generated or not
|
||||||
|
data_json = data.json() # type: ignore
|
||||||
|
|
||||||
|
# Unpacking variables
|
||||||
|
team_id = data_json.get("team_id")
|
||||||
|
duration = data_json.get("duration")
|
||||||
|
models = data_json.get("models")
|
||||||
|
aliases = data_json.get("aliases")
|
||||||
|
config = data_json.get("config")
|
||||||
|
spend = data_json.get("spend")
|
||||||
|
user_id = data_json.get("user_id")
|
||||||
|
max_parallel_requests = data_json.get("max_parallel_requests")
|
||||||
|
metadata = data_json.get("metadata")
|
||||||
|
tpm_limit = data_json.get("tpm_limit")
|
||||||
|
rpm_limit = data_json.get("rpm_limit")
|
||||||
|
|
||||||
|
if team_id is not None and len(team_id) > 0:
|
||||||
|
return {
|
||||||
|
"decision": True,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
return {
|
||||||
|
"decision": True,
|
||||||
|
"message": "This violates LiteLLM Proxy Rules. No team id provided.",
|
||||||
|
}
|
||||||
|
|
|
@ -62,8 +62,9 @@ litellm_settings:
|
||||||
# setting callback class
|
# setting callback class
|
||||||
# callbacks: custom_callbacks.proxy_handler_instance # sets litellm.callbacks = [proxy_handler_instance]
|
# callbacks: custom_callbacks.proxy_handler_instance # sets litellm.callbacks = [proxy_handler_instance]
|
||||||
|
|
||||||
# general_settings:
|
general_settings:
|
||||||
# master_key: sk-1234
|
master_key: sk-1234
|
||||||
|
custom_key_generate: custom_auth.generate_key_fn
|
||||||
# database_type: "dynamo_db"
|
# database_type: "dynamo_db"
|
||||||
# database_args: { # 👈 all args - https://github.com/BerriAI/litellm/blob/befbcbb7ac8f59835ce47415c128decf37aac328/litellm/proxy/_types.py#L190
|
# database_args: { # 👈 all args - https://github.com/BerriAI/litellm/blob/befbcbb7ac8f59835ce47415c128decf37aac328/litellm/proxy/_types.py#L190
|
||||||
# "billing_mode": "PAY_PER_REQUEST",
|
# "billing_mode": "PAY_PER_REQUEST",
|
||||||
|
|
|
@ -187,6 +187,7 @@ prisma_client: Optional[PrismaClient] = None
|
||||||
custom_db_client: Optional[DBClient] = None
|
custom_db_client: Optional[DBClient] = None
|
||||||
user_api_key_cache = DualCache()
|
user_api_key_cache = DualCache()
|
||||||
user_custom_auth = None
|
user_custom_auth = None
|
||||||
|
user_custom_key_generate = None
|
||||||
use_background_health_checks = None
|
use_background_health_checks = None
|
||||||
use_queue = False
|
use_queue = False
|
||||||
health_check_interval = None
|
health_check_interval = None
|
||||||
|
@ -874,7 +875,7 @@ class ProxyConfig:
|
||||||
"""
|
"""
|
||||||
Load config values into proxy global state
|
Load config values into proxy global state
|
||||||
"""
|
"""
|
||||||
global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, use_background_health_checks, health_check_interval, use_queue, custom_db_client
|
global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, user_custom_key_generate, use_background_health_checks, health_check_interval, use_queue, custom_db_client
|
||||||
|
|
||||||
# Load existing config
|
# Load existing config
|
||||||
config = await self.get_config(config_file_path=config_file_path)
|
config = await self.get_config(config_file_path=config_file_path)
|
||||||
|
@ -1052,6 +1053,12 @@ class ProxyConfig:
|
||||||
user_custom_auth = get_instance_fn(
|
user_custom_auth = get_instance_fn(
|
||||||
value=custom_auth, config_file_path=config_file_path
|
value=custom_auth, config_file_path=config_file_path
|
||||||
)
|
)
|
||||||
|
|
||||||
|
custom_key_generate = general_settings.get("custom_key_generate", None)
|
||||||
|
if custom_key_generate is not None:
|
||||||
|
user_custom_key_generate = get_instance_fn(
|
||||||
|
value=custom_key_generate, config_file_path=config_file_path
|
||||||
|
)
|
||||||
## dynamodb
|
## dynamodb
|
||||||
database_type = general_settings.get("database_type", None)
|
database_type = general_settings.get("database_type", None)
|
||||||
if database_type is not None and (
|
if database_type is not None and (
|
||||||
|
@ -2156,7 +2163,16 @@ async def generate_key_fn(
|
||||||
- expires: (datetime) Datetime object for when key expires.
|
- expires: (datetime) Datetime object for when key expires.
|
||||||
- user_id: (str) Unique user id - used for tracking spend across multiple keys for same user id.
|
- user_id: (str) Unique user id - used for tracking spend across multiple keys for same user id.
|
||||||
"""
|
"""
|
||||||
|
global user_custom_key_generate
|
||||||
verbose_proxy_logger.debug("entered /key/generate")
|
verbose_proxy_logger.debug("entered /key/generate")
|
||||||
|
|
||||||
|
if user_custom_key_generate is not None:
|
||||||
|
result = await user_custom_key_generate(data)
|
||||||
|
decision = result.get("decision", True)
|
||||||
|
message = result.get("message", "Authentication Failed - Custom Auth Rule")
|
||||||
|
if not decision:
|
||||||
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=message)
|
||||||
|
|
||||||
data_json = data.json() # type: ignore
|
data_json = data.json() # type: ignore
|
||||||
response = await generate_key_helper_fn(**data_json)
|
response = await generate_key_helper_fn(**data_json)
|
||||||
return GenerateKeyResponse(
|
return GenerateKeyResponse(
|
||||||
|
@ -2924,7 +2940,7 @@ async def get_routes():
|
||||||
|
|
||||||
@router.on_event("shutdown")
|
@router.on_event("shutdown")
|
||||||
async def shutdown_event():
|
async def shutdown_event():
|
||||||
global prisma_client, master_key, user_custom_auth
|
global prisma_client, master_key, user_custom_auth, user_custom_key_generate
|
||||||
if prisma_client:
|
if prisma_client:
|
||||||
verbose_proxy_logger.debug("Disconnecting from Prisma")
|
verbose_proxy_logger.debug("Disconnecting from Prisma")
|
||||||
await prisma_client.disconnect()
|
await prisma_client.disconnect()
|
||||||
|
@ -2934,7 +2950,7 @@ async def shutdown_event():
|
||||||
|
|
||||||
|
|
||||||
def cleanup_router_config_variables():
|
def cleanup_router_config_variables():
|
||||||
global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, use_background_health_checks, health_check_interval, prisma_client, custom_db_client
|
global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, user_custom_key_generate, use_background_health_checks, health_check_interval, prisma_client, custom_db_client
|
||||||
|
|
||||||
# Set all variables to None
|
# Set all variables to None
|
||||||
master_key = None
|
master_key = None
|
||||||
|
@ -2942,6 +2958,7 @@ def cleanup_router_config_variables():
|
||||||
otel_logging = None
|
otel_logging = None
|
||||||
user_custom_auth = None
|
user_custom_auth = None
|
||||||
user_custom_auth_path = None
|
user_custom_auth_path = None
|
||||||
|
user_custom_key_generate = None
|
||||||
use_background_health_checks = None
|
use_background_health_checks = None
|
||||||
health_check_interval = None
|
health_check_interval = None
|
||||||
prisma_client = None
|
prisma_client = None
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue