Merge branch 'main' into litellm_budget_per_key

This commit is contained in:
Ishaan Jaff 2024-01-22 15:49:57 -08:00 committed by GitHub
commit db68774d60
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
20 changed files with 731 additions and 183 deletions

View file

@ -187,6 +187,7 @@ prisma_client: Optional[PrismaClient] = None
custom_db_client: Optional[DBClient] = None
user_api_key_cache = DualCache()
user_custom_auth = None
user_custom_key_generate = None
use_background_health_checks = None
use_queue = False
health_check_interval = None
@ -584,7 +585,7 @@ async def track_cost_callback(
"user_api_key_user_id", None
)
verbose_proxy_logger.debug(
verbose_proxy_logger.info(
f"streaming response_cost {response_cost}, for user_id {user_id}"
)
if user_api_key and (
@ -609,7 +610,7 @@ async def track_cost_callback(
user_id = user_id or kwargs["litellm_params"]["metadata"].get(
"user_api_key_user_id", None
)
verbose_proxy_logger.debug(
verbose_proxy_logger.info(
f"response_cost {response_cost}, for user_id {user_id}"
)
if user_api_key and (
@ -896,7 +897,7 @@ class ProxyConfig:
"""
Load config values into proxy global state
"""
global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, use_background_health_checks, health_check_interval, use_queue, custom_db_client
global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, user_custom_key_generate, use_background_health_checks, health_check_interval, use_queue, custom_db_client
# Load existing config
config = await self.get_config(config_file_path=config_file_path)
@ -1074,6 +1075,12 @@ class ProxyConfig:
user_custom_auth = get_instance_fn(
value=custom_auth, config_file_path=config_file_path
)
custom_key_generate = general_settings.get("custom_key_generate", None)
if custom_key_generate is not None:
user_custom_key_generate = get_instance_fn(
value=custom_key_generate, config_file_path=config_file_path
)
## dynamodb
database_type = general_settings.get("database_type", None)
if database_type is not None and (
@ -2189,7 +2196,16 @@ async def generate_key_fn(
- expires: (datetime) Datetime object for when key expires.
- user_id: (str) Unique user id - used for tracking spend across multiple keys for same user id.
"""
global user_custom_key_generate
verbose_proxy_logger.debug("entered /key/generate")
if user_custom_key_generate is not None:
result = await user_custom_key_generate(data)
decision = result.get("decision", True)
message = result.get("message", "Authentication Failed - Custom Auth Rule")
if not decision:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=message)
data_json = data.json() # type: ignore
# if we get max_budget passed to /key/generate, then use it as key_max_budget. Since generate_key_helper_fn is used to make new users
@ -2978,7 +2994,7 @@ async def get_routes():
@router.on_event("shutdown")
async def shutdown_event():
global prisma_client, master_key, user_custom_auth
global prisma_client, master_key, user_custom_auth, user_custom_key_generate
if prisma_client:
verbose_proxy_logger.debug("Disconnecting from Prisma")
await prisma_client.disconnect()
@ -2988,7 +3004,7 @@ async def shutdown_event():
def cleanup_router_config_variables():
global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, use_background_health_checks, health_check_interval, prisma_client, custom_db_client
global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, user_custom_key_generate, use_background_health_checks, health_check_interval, prisma_client, custom_db_client
# Set all variables to None
master_key = None
@ -2996,6 +3012,7 @@ def cleanup_router_config_variables():
otel_logging = None
user_custom_auth = None
user_custom_auth_path = None
user_custom_key_generate = None
use_background_health_checks = None
health_check_interval = None
prisma_client = None