Support master key rotations (#9041)

* feat(key_management_endpoints.py): adding support for rotating master key

* feat(key_management_endpoints.py): support decryption-re-encryption of models in db, when master key rotated

* fix(user_api_key_auth.py): raise valid token is None error earlier

enables easier debugging with api key hash in error message

* feat(key_management_endpoints.py): rotate any env vars

* fix(key_management_endpoints.py): uncomment check

* fix: fix linting error
This commit is contained in:
Krish Dholakia 2025-03-06 23:13:30 -08:00 committed by GitHub
parent fcc57318f8
commit da13ec2b64
8 changed files with 214 additions and 33 deletions

View file

@ -35,18 +35,23 @@ from litellm.proxy.auth.auth_checks import (
)
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.hooks.key_management_event_hooks import KeyManagementEventHooks
from litellm.proxy.management_endpoints.model_management_endpoints import (
_add_model_to_db,
from litellm.proxy.management_endpoints.common_utils import (
_is_user_team_admin,
_set_object_metadata_field,
)
from litellm.proxy.management_helpers.utils import management_endpoint_wrapper
from litellm.proxy.spend_tracking.spend_tracking_utils import _is_master_key
from litellm.proxy.utils import (
PrismaClient,
_hash_token_if_needed,
handle_exception_on_proxy,
jsonify_object,
)
from litellm.router import Router
from litellm.secret_managers.main import get_secret
from litellm.types.router import Deployment
from litellm.types.utils import (
BudgetConfig,
PersonalUIKeyGenerationConfig,
@ -1525,14 +1530,98 @@ async def delete_key_aliases(
)
async def _rotate_master_key(
prisma_client: PrismaClient,
user_api_key_dict: UserAPIKeyAuth,
current_master_key: str,
new_master_key: str,
) -> None:
"""
Rotate the master key
1. Get the values from the DB
- Get models from DB
- Get config from DB
2. Decrypt the values
- ModelTable
- [{"model_name": "str", "litellm_params": {}}]
- ConfigTable
3. Encrypt the values with the new master key
4. Update the values in the DB
"""
from litellm.proxy.proxy_server import proxy_config
try:
models: Optional[List] = (
await prisma_client.db.litellm_proxymodeltable.find_many()
)
except Exception:
models = None
# 2. process model table
if models:
decrypted_models = proxy_config.decrypt_model_list_from_db(new_models=models)
verbose_proxy_logger.info(
"ABLE TO DECRYPT MODELS - len(decrypted_models): %s", len(decrypted_models)
)
new_models = []
for model in decrypted_models:
new_model = await _add_model_to_db(
model_params=Deployment(**model),
user_api_key_dict=user_api_key_dict,
prisma_client=prisma_client,
new_encryption_key=new_master_key,
should_create_model_in_db=False,
)
if new_model:
new_models.append(jsonify_object(new_model.model_dump()))
verbose_proxy_logger.info("Resetting proxy model table")
await prisma_client.db.litellm_proxymodeltable.delete_many()
verbose_proxy_logger.info("Creating %s models", len(new_models))
await prisma_client.db.litellm_proxymodeltable.create_many(
data=new_models,
)
# 3. process config table
try:
config = await prisma_client.db.litellm_config.find_many()
except Exception:
config = None
if config:
"""If environment_variables is found, decrypt it and encrypt it with the new master key"""
environment_variables_dict = {}
for c in config:
if c.param_name == "environment_variables":
environment_variables_dict = c.param_value
if environment_variables_dict:
decrypted_env_vars = proxy_config._decrypt_and_set_db_env_variables(
environment_variables=environment_variables_dict
)
encrypted_env_vars = proxy_config._encrypt_env_variables(
environment_variables=decrypted_env_vars,
new_encryption_key=new_master_key,
)
if encrypted_env_vars:
await prisma_client.db.litellm_config.update(
where={"param_name": "environment_variables"},
data={"param_value": jsonify_object(encrypted_env_vars)},
)
@router.post(
"/key/{key:path}/regenerate",
tags=["key management"],
dependencies=[Depends(user_api_key_auth)],
)
@router.post(
"/key/regenerate",
tags=["key management"],
dependencies=[Depends(user_api_key_auth)],
)
@management_endpoint_wrapper
async def regenerate_key_fn(
key: str,
key: Optional[str] = None,
data: Optional[RegenerateKeyRequest] = None,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
litellm_changed_by: Optional[str] = Header(
@ -1590,6 +1679,7 @@ async def regenerate_key_fn(
from litellm.proxy.proxy_server import (
hash_token,
master_key,
premium_user,
prisma_client,
proxy_logging_obj,
@ -1602,7 +1692,9 @@ async def regenerate_key_fn(
)
# Check if key exists, raise exception if key is not in the DB
key = data.key if data and data.key else key
if not key:
raise HTTPException(status_code=400, detail={"error": "No key passed in."})
### 1. Create New copy that is duplicate of existing key
######################################################################
@ -1617,6 +1709,27 @@ async def regenerate_key_fn(
detail={"error": "DB not connected. prisma_client is None"},
)
_is_master_key_valid = _is_master_key(api_key=key, _master_key=master_key)
if master_key is not None and data and _is_master_key_valid:
if data.new_master_key is None:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"error": "New master key is required."},
)
await _rotate_master_key(
prisma_client=prisma_client,
user_api_key_dict=user_api_key_dict,
current_master_key=master_key,
new_master_key=data.new_master_key,
)
return GenerateKeyResponse(
key=data.new_master_key,
token=data.new_master_key,
key_name=data.new_master_key,
expires=None,
)
if "sk" not in key:
hashed_api_key = key
else:
@ -1698,6 +1811,7 @@ async def regenerate_key_fn(
return response
except Exception as e:
verbose_proxy_logger.exception("Error regenerating key: %s", e)
raise handle_exception_on_proxy(e)