Support master key rotations (#9041)

* feat(key_management_endpoints.py): adding support for rotating master key

* feat(key_management_endpoints.py): support decryption-re-encryption of models in db, when master key rotated

* fix(user_api_key_auth.py): raise valid token is None error earlier

enables easier debugging with api key hash in error message

* feat(key_management_endpoints.py): rotate any env vars

* fix(key_management_endpoints.py): uncomment check

* fix: fix linting error
This commit is contained in:
Krish Dholakia 2025-03-06 23:13:30 -08:00 committed by GitHub
parent fcc57318f8
commit da13ec2b64
8 changed files with 214 additions and 33 deletions

View file

@ -664,6 +664,7 @@ class RegenerateKeyRequest(GenerateKeyRequest):
duration: Optional[str] = None
spend: Optional[float] = None
metadata: Optional[dict] = None
new_master_key: Optional[str] = None
class KeyRequest(LiteLLMPydanticObjectBase):
@ -688,6 +689,30 @@ class LiteLLM_ModelTable(LiteLLMPydanticObjectBase):
model_config = ConfigDict(protected_namespaces=())
class LiteLLM_ProxyModelTable(LiteLLMPydanticObjectBase):
model_id: str
model_name: str
litellm_params: dict
model_info: dict
created_by: str
updated_by: str
@model_validator(mode="before")
@classmethod
def check_potential_json_str(cls, values):
if isinstance(values.get("litellm_params"), str):
try:
values["litellm_params"] = json.loads(values["litellm_params"])
except json.JSONDecodeError:
pass
if isinstance(values.get("model_info"), str):
try:
values["model_info"] = json.loads(values["model_info"])
except json.JSONDecodeError:
pass
return values
class NewUserRequest(GenerateRequestBase):
max_budget: Optional[float] = None
user_email: Optional[str] = None

View file

@ -786,6 +786,13 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
)
valid_token = None
if valid_token is None:
raise Exception(
"Invalid proxy server token passed. Received API Key (hashed)={}. Unable to find token in cache or `LiteLLM_VerificationTokenTable`".format(
api_key
)
)
user_obj: Optional[LiteLLM_UserTable] = None
valid_token_dict: dict = {}
if valid_token is not None:

View file

@ -1,5 +1,6 @@
import base64
import os
from typing import Optional
from litellm._logging import verbose_proxy_logger
@ -19,9 +20,9 @@ def _get_salt_key():
return salt_key
def encrypt_value_helper(value: str):
def encrypt_value_helper(value: str, new_encryption_key: Optional[str] = None):
signing_key = _get_salt_key()
signing_key = new_encryption_key or _get_salt_key()
try:
if isinstance(value, str):

View file

@ -35,18 +35,23 @@ from litellm.proxy.auth.auth_checks import (
)
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.hooks.key_management_event_hooks import KeyManagementEventHooks
from litellm.proxy.management_endpoints.model_management_endpoints import (
_add_model_to_db,
from litellm.proxy.management_endpoints.common_utils import (
_is_user_team_admin,
_set_object_metadata_field,
)
from litellm.proxy.management_helpers.utils import management_endpoint_wrapper
from litellm.proxy.spend_tracking.spend_tracking_utils import _is_master_key
from litellm.proxy.utils import (
PrismaClient,
_hash_token_if_needed,
handle_exception_on_proxy,
jsonify_object,
)
from litellm.router import Router
from litellm.secret_managers.main import get_secret
from litellm.types.router import Deployment
from litellm.types.utils import (
BudgetConfig,
PersonalUIKeyGenerationConfig,
@ -1525,14 +1530,98 @@ async def delete_key_aliases(
)
async def _rotate_master_key(
prisma_client: PrismaClient,
user_api_key_dict: UserAPIKeyAuth,
current_master_key: str,
new_master_key: str,
) -> None:
"""
Rotate the master key
1. Get the values from the DB
- Get models from DB
- Get config from DB
2. Decrypt the values
- ModelTable
- [{"model_name": "str", "litellm_params": {}}]
- ConfigTable
3. Encrypt the values with the new master key
4. Update the values in the DB
"""
from litellm.proxy.proxy_server import proxy_config
try:
models: Optional[List] = (
await prisma_client.db.litellm_proxymodeltable.find_many()
)
except Exception:
models = None
# 2. process model table
if models:
decrypted_models = proxy_config.decrypt_model_list_from_db(new_models=models)
verbose_proxy_logger.info(
"ABLE TO DECRYPT MODELS - len(decrypted_models): %s", len(decrypted_models)
)
new_models = []
for model in decrypted_models:
new_model = await _add_model_to_db(
model_params=Deployment(**model),
user_api_key_dict=user_api_key_dict,
prisma_client=prisma_client,
new_encryption_key=new_master_key,
should_create_model_in_db=False,
)
if new_model:
new_models.append(jsonify_object(new_model.model_dump()))
verbose_proxy_logger.info("Resetting proxy model table")
await prisma_client.db.litellm_proxymodeltable.delete_many()
verbose_proxy_logger.info("Creating %s models", len(new_models))
await prisma_client.db.litellm_proxymodeltable.create_many(
data=new_models,
)
# 3. process config table
try:
config = await prisma_client.db.litellm_config.find_many()
except Exception:
config = None
if config:
"""If environment_variables is found, decrypt it and encrypt it with the new master key"""
environment_variables_dict = {}
for c in config:
if c.param_name == "environment_variables":
environment_variables_dict = c.param_value
if environment_variables_dict:
decrypted_env_vars = proxy_config._decrypt_and_set_db_env_variables(
environment_variables=environment_variables_dict
)
encrypted_env_vars = proxy_config._encrypt_env_variables(
environment_variables=decrypted_env_vars,
new_encryption_key=new_master_key,
)
if encrypted_env_vars:
await prisma_client.db.litellm_config.update(
where={"param_name": "environment_variables"},
data={"param_value": jsonify_object(encrypted_env_vars)},
)
@router.post(
"/key/{key:path}/regenerate",
tags=["key management"],
dependencies=[Depends(user_api_key_auth)],
)
@router.post(
"/key/regenerate",
tags=["key management"],
dependencies=[Depends(user_api_key_auth)],
)
@management_endpoint_wrapper
async def regenerate_key_fn(
key: str,
key: Optional[str] = None,
data: Optional[RegenerateKeyRequest] = None,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
litellm_changed_by: Optional[str] = Header(
@ -1590,6 +1679,7 @@ async def regenerate_key_fn(
from litellm.proxy.proxy_server import (
hash_token,
master_key,
premium_user,
prisma_client,
proxy_logging_obj,
@ -1602,7 +1692,9 @@ async def regenerate_key_fn(
)
# Check if key exists, raise exception if key is not in the DB
key = data.key if data and data.key else key
if not key:
raise HTTPException(status_code=400, detail={"error": "No key passed in."})
### 1. Create New copy that is duplicate of existing key
######################################################################
@ -1617,6 +1709,27 @@ async def regenerate_key_fn(
detail={"error": "DB not connected. prisma_client is None"},
)
_is_master_key_valid = _is_master_key(api_key=key, _master_key=master_key)
if master_key is not None and data and _is_master_key_valid:
if data.new_master_key is None:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"error": "New master key is required."},
)
await _rotate_master_key(
prisma_client=prisma_client,
user_api_key_dict=user_api_key_dict,
current_master_key=master_key,
new_master_key=data.new_master_key,
)
return GenerateKeyResponse(
key=data.new_master_key,
token=data.new_master_key,
key_name=data.new_master_key,
expires=None,
)
if "sk" not in key:
hashed_api_key = key
else:
@ -1698,6 +1811,7 @@ async def regenerate_key_fn(
return response
except Exception as e:
verbose_proxy_logger.exception("Error regenerating key: %s", e)
raise handle_exception_on_proxy(e)

View file

@ -21,6 +21,7 @@ from litellm._logging import verbose_proxy_logger
from litellm.constants import LITELLM_PROXY_ADMIN_NAME
from litellm.proxy._types import (
CommonProxyErrors,
LiteLLM_ProxyModelTable,
LitellmUserRoles,
PrismaCompatibleUpdateDBModel,
ProxyErrorTypes,
@ -227,12 +228,16 @@ async def _add_model_to_db(
model_params: Deployment,
user_api_key_dict: UserAPIKeyAuth,
prisma_client: PrismaClient,
):
new_encryption_key: Optional[str] = None,
should_create_model_in_db: bool = True,
) -> Optional[LiteLLM_ProxyModelTable]:
# encrypt litellm params #
_litellm_params_dict = model_params.litellm_params.dict(exclude_none=True)
_orignal_litellm_model_name = model_params.litellm_params.model
for k, v in _litellm_params_dict.items():
encrypted_value = encrypt_value_helper(value=v)
encrypted_value = encrypt_value_helper(
value=v, new_encryption_key=new_encryption_key
)
model_params.litellm_params[k] = encrypted_value
_data: dict = {
"model_id": model_params.model_info.id,
@ -246,9 +251,12 @@ async def _add_model_to_db(
}
if model_params.model_info.id is not None:
_data["model_id"] = model_params.model_info.id
if should_create_model_in_db:
model_response = await prisma_client.db.litellm_proxymodeltable.create(
data=_data # type: ignore
)
else:
model_response = LiteLLM_ProxyModelTable(**_data)
return model_response

View file

@ -2419,17 +2419,7 @@ class ProxyConfig:
added_models += 1
return added_models
async def _update_llm_router(
self,
new_models: list,
proxy_logging_obj: ProxyLogging,
):
global llm_router, llm_model_list, master_key, general_settings
try:
if llm_router is None and master_key is not None:
verbose_proxy_logger.debug(f"len new_models: {len(new_models)}")
def decrypt_model_list_from_db(self, new_models: list) -> list:
_model_list: list = []
for m in new_models:
_litellm_params = m.litellm_params
@ -2453,6 +2443,23 @@ class ProxyConfig:
model_info=_model_info,
).to_json(exclude_none=True)
)
return _model_list
async def _update_llm_router(
self,
new_models: list,
proxy_logging_obj: ProxyLogging,
):
global llm_router, llm_model_list, master_key, general_settings
try:
if llm_router is None and master_key is not None:
verbose_proxy_logger.debug(f"len new_models: {len(new_models)}")
_model_list: list = self.decrypt_model_list_from_db(
new_models=new_models
)
if len(_model_list) > 0:
verbose_proxy_logger.debug(f"_model_list: {_model_list}")
llm_router = litellm.Router(
@ -2541,7 +2548,21 @@ class ProxyConfig:
environment_variables = config_data.get("environment_variables", {})
self._decrypt_and_set_db_env_variables(environment_variables)
def _decrypt_and_set_db_env_variables(self, environment_variables: dict) -> None:
def _encrypt_env_variables(
self, environment_variables: dict, new_encryption_key: Optional[str] = None
) -> dict:
"""
Encrypts a dictionary of environment variables and returns them.
"""
encrypted_env_vars = {}
for k, v in environment_variables.items():
encrypted_value = encrypt_value_helper(
value=v, new_encryption_key=new_encryption_key
)
encrypted_env_vars[k] = encrypted_value
return encrypted_env_vars
def _decrypt_and_set_db_env_variables(self, environment_variables: dict) -> dict:
"""
Decrypts a dictionary of environment variables and then sets them in the environment
@ -2549,15 +2570,18 @@ class ProxyConfig:
environment_variables: dict - dictionary of environment variables to decrypt and set
eg. `{"LANGFUSE_PUBLIC_KEY": "kFiKa1VZukMmD8RB6WXB9F......."}`
"""
decrypted_env_vars = {}
for k, v in environment_variables.items():
try:
decrypted_value = decrypt_value_helper(value=v)
if decrypted_value is not None:
os.environ[k] = decrypted_value
decrypted_env_vars[k] = decrypted_value
except Exception as e:
verbose_proxy_logger.error(
"Error setting env variable: %s - %s", k, str(e)
)
return decrypted_env_vars
async def _add_router_settings_from_db_config(
self,

View file

@ -4462,6 +4462,7 @@ class Router:
"""
# check if deployment already exists
_deployment_model_id = deployment.model_info.id or ""
_deployment_on_router: Optional[Deployment] = self.get_deployment(
model_id=_deployment_model_id
)

View file

@ -162,6 +162,7 @@ async def test_regenerate_api_key(prisma_client):
print(result)
# regenerate the key
print("regenerating key: {}".format(generated_key))
new_key = await regenerate_key_fn(
key=generated_key,
user_api_key_dict=UserAPIKeyAuth(