From da13ec2b6480c6aa223ba31f0f5e6e9c765dc427 Mon Sep 17 00:00:00 2001 From: Krish Dholakia Date: Thu, 6 Mar 2025 23:13:30 -0800 Subject: [PATCH] Support master key rotations (#9041) * feat(key_management_endpoints.py): adding support for rotating master key * feat(key_management_endpoints.py): support decryption-re-encryption of models in db, when master key rotated * fix(user_api_key_auth.py): raise valid token is None error earlier enables easier debugging with api key hash in error message * feat(key_management_endpoints.py): rotate any env vars * fix(key_management_endpoints.py): uncomment check * fix: fix linting error --- litellm/proxy/_types.py | 25 ++++ litellm/proxy/auth/user_api_key_auth.py | 7 ++ .../common_utils/encrypt_decrypt_utils.py | 5 +- .../key_management_endpoints.py | 118 +++++++++++++++++- .../model_management_endpoints.py | 18 ++- litellm/proxy/proxy_server.py | 72 +++++++---- litellm/router.py | 1 + .../test_key_management.py | 1 + 8 files changed, 214 insertions(+), 33 deletions(-) diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 7ac2b17c15..1e23acff86 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -664,6 +664,7 @@ class RegenerateKeyRequest(GenerateKeyRequest): duration: Optional[str] = None spend: Optional[float] = None metadata: Optional[dict] = None + new_master_key: Optional[str] = None class KeyRequest(LiteLLMPydanticObjectBase): @@ -688,6 +689,30 @@ class LiteLLM_ModelTable(LiteLLMPydanticObjectBase): model_config = ConfigDict(protected_namespaces=()) +class LiteLLM_ProxyModelTable(LiteLLMPydanticObjectBase): + model_id: str + model_name: str + litellm_params: dict + model_info: dict + created_by: str + updated_by: str + + @model_validator(mode="before") + @classmethod + def check_potential_json_str(cls, values): + if isinstance(values.get("litellm_params"), str): + try: + values["litellm_params"] = json.loads(values["litellm_params"]) + except json.JSONDecodeError: + pass + if isinstance(values.get("model_info"), str): + try: + values["model_info"] = json.loads(values["model_info"]) + except json.JSONDecodeError: + pass + return values + + class NewUserRequest(GenerateRequestBase): max_budget: Optional[float] = None user_email: Optional[str] = None diff --git a/litellm/proxy/auth/user_api_key_auth.py b/litellm/proxy/auth/user_api_key_auth.py index ecefc64d67..7ce097e0d7 100644 --- a/litellm/proxy/auth/user_api_key_auth.py +++ b/litellm/proxy/auth/user_api_key_auth.py @@ -786,6 +786,13 @@ async def _user_api_key_auth_builder( # noqa: PLR0915 ) valid_token = None + if valid_token is None: + raise Exception( + "Invalid proxy server token passed. Received API Key (hashed)={}. Unable to find token in cache or `LiteLLM_VerificationTokenTable`".format( + api_key + ) + ) + user_obj: Optional[LiteLLM_UserTable] = None valid_token_dict: dict = {} if valid_token is not None: diff --git a/litellm/proxy/common_utils/encrypt_decrypt_utils.py b/litellm/proxy/common_utils/encrypt_decrypt_utils.py index ac2caa9a01..7be60d9dc6 100644 --- a/litellm/proxy/common_utils/encrypt_decrypt_utils.py +++ b/litellm/proxy/common_utils/encrypt_decrypt_utils.py @@ -1,5 +1,6 @@ import base64 import os +from typing import Optional from litellm._logging import verbose_proxy_logger @@ -19,9 +20,9 @@ def _get_salt_key(): return salt_key -def encrypt_value_helper(value: str): +def encrypt_value_helper(value: str, new_encryption_key: Optional[str] = None): - signing_key = _get_salt_key() + signing_key = new_encryption_key or _get_salt_key() try: if isinstance(value, str): diff --git a/litellm/proxy/management_endpoints/key_management_endpoints.py b/litellm/proxy/management_endpoints/key_management_endpoints.py index bef954be83..363f6f3d67 100644 --- a/litellm/proxy/management_endpoints/key_management_endpoints.py +++ b/litellm/proxy/management_endpoints/key_management_endpoints.py @@ -35,18 +35,23 @@ from litellm.proxy.auth.auth_checks import ( ) from litellm.proxy.auth.user_api_key_auth import user_api_key_auth from litellm.proxy.hooks.key_management_event_hooks import KeyManagementEventHooks +from litellm.proxy.management_endpoints.model_management_endpoints import ( + _add_model_to_db, from litellm.proxy.management_endpoints.common_utils import ( _is_user_team_admin, _set_object_metadata_field, ) from litellm.proxy.management_helpers.utils import management_endpoint_wrapper +from litellm.proxy.spend_tracking.spend_tracking_utils import _is_master_key from litellm.proxy.utils import ( PrismaClient, _hash_token_if_needed, handle_exception_on_proxy, + jsonify_object, ) from litellm.router import Router from litellm.secret_managers.main import get_secret +from litellm.types.router import Deployment from litellm.types.utils import ( BudgetConfig, PersonalUIKeyGenerationConfig, @@ -1525,14 +1530,98 @@ async def delete_key_aliases( ) +async def _rotate_master_key( + prisma_client: PrismaClient, + user_api_key_dict: UserAPIKeyAuth, + current_master_key: str, + new_master_key: str, +) -> None: + """ + Rotate the master key + + 1. Get the values from the DB + - Get models from DB + - Get config from DB + 2. Decrypt the values + - ModelTable + - [{"model_name": "str", "litellm_params": {}}] + - ConfigTable + 3. Encrypt the values with the new master key + 4. Update the values in the DB + """ + from litellm.proxy.proxy_server import proxy_config + + try: + models: Optional[List] = ( + await prisma_client.db.litellm_proxymodeltable.find_many() + ) + except Exception: + models = None + # 2. process model table + if models: + decrypted_models = proxy_config.decrypt_model_list_from_db(new_models=models) + verbose_proxy_logger.info( + "ABLE TO DECRYPT MODELS - len(decrypted_models): %s", len(decrypted_models) + ) + new_models = [] + for model in decrypted_models: + new_model = await _add_model_to_db( + model_params=Deployment(**model), + user_api_key_dict=user_api_key_dict, + prisma_client=prisma_client, + new_encryption_key=new_master_key, + should_create_model_in_db=False, + ) + if new_model: + new_models.append(jsonify_object(new_model.model_dump())) + verbose_proxy_logger.info("Resetting proxy model table") + await prisma_client.db.litellm_proxymodeltable.delete_many() + verbose_proxy_logger.info("Creating %s models", len(new_models)) + await prisma_client.db.litellm_proxymodeltable.create_many( + data=new_models, + ) + # 3. process config table + try: + config = await prisma_client.db.litellm_config.find_many() + except Exception: + config = None + + if config: + """If environment_variables is found, decrypt it and encrypt it with the new master key""" + environment_variables_dict = {} + for c in config: + if c.param_name == "environment_variables": + environment_variables_dict = c.param_value + + if environment_variables_dict: + decrypted_env_vars = proxy_config._decrypt_and_set_db_env_variables( + environment_variables=environment_variables_dict + ) + encrypted_env_vars = proxy_config._encrypt_env_variables( + environment_variables=decrypted_env_vars, + new_encryption_key=new_master_key, + ) + + if encrypted_env_vars: + await prisma_client.db.litellm_config.update( + where={"param_name": "environment_variables"}, + data={"param_value": jsonify_object(encrypted_env_vars)}, + ) + + @router.post( "/key/{key:path}/regenerate", tags=["key management"], dependencies=[Depends(user_api_key_auth)], ) +@router.post( + "/key/regenerate", + tags=["key management"], + dependencies=[Depends(user_api_key_auth)], +) @management_endpoint_wrapper async def regenerate_key_fn( - key: str, + key: Optional[str] = None, data: Optional[RegenerateKeyRequest] = None, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), litellm_changed_by: Optional[str] = Header( @@ -1590,6 +1679,7 @@ async def regenerate_key_fn( from litellm.proxy.proxy_server import ( hash_token, + master_key, premium_user, prisma_client, proxy_logging_obj, @@ -1602,7 +1692,9 @@ async def regenerate_key_fn( ) # Check if key exists, raise exception if key is not in the DB - + key = data.key if data and data.key else key + if not key: + raise HTTPException(status_code=400, detail={"error": "No key passed in."}) ### 1. Create New copy that is duplicate of existing key ###################################################################### @@ -1617,6 +1709,27 @@ async def regenerate_key_fn( detail={"error": "DB not connected. prisma_client is None"}, ) + _is_master_key_valid = _is_master_key(api_key=key, _master_key=master_key) + + if master_key is not None and data and _is_master_key_valid: + if data.new_master_key is None: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail={"error": "New master key is required."}, + ) + await _rotate_master_key( + prisma_client=prisma_client, + user_api_key_dict=user_api_key_dict, + current_master_key=master_key, + new_master_key=data.new_master_key, + ) + return GenerateKeyResponse( + key=data.new_master_key, + token=data.new_master_key, + key_name=data.new_master_key, + expires=None, + ) + if "sk" not in key: hashed_api_key = key else: @@ -1698,6 +1811,7 @@ async def regenerate_key_fn( return response except Exception as e: + verbose_proxy_logger.exception("Error regenerating key: %s", e) raise handle_exception_on_proxy(e) diff --git a/litellm/proxy/management_endpoints/model_management_endpoints.py b/litellm/proxy/management_endpoints/model_management_endpoints.py index e91937c0e6..42ee2a5a40 100644 --- a/litellm/proxy/management_endpoints/model_management_endpoints.py +++ b/litellm/proxy/management_endpoints/model_management_endpoints.py @@ -21,6 +21,7 @@ from litellm._logging import verbose_proxy_logger from litellm.constants import LITELLM_PROXY_ADMIN_NAME from litellm.proxy._types import ( CommonProxyErrors, + LiteLLM_ProxyModelTable, LitellmUserRoles, PrismaCompatibleUpdateDBModel, ProxyErrorTypes, @@ -227,12 +228,16 @@ async def _add_model_to_db( model_params: Deployment, user_api_key_dict: UserAPIKeyAuth, prisma_client: PrismaClient, -): + new_encryption_key: Optional[str] = None, + should_create_model_in_db: bool = True, +) -> Optional[LiteLLM_ProxyModelTable]: # encrypt litellm params # _litellm_params_dict = model_params.litellm_params.dict(exclude_none=True) _orignal_litellm_model_name = model_params.litellm_params.model for k, v in _litellm_params_dict.items(): - encrypted_value = encrypt_value_helper(value=v) + encrypted_value = encrypt_value_helper( + value=v, new_encryption_key=new_encryption_key + ) model_params.litellm_params[k] = encrypted_value _data: dict = { "model_id": model_params.model_info.id, @@ -246,9 +251,12 @@ async def _add_model_to_db( } if model_params.model_info.id is not None: _data["model_id"] = model_params.model_info.id - model_response = await prisma_client.db.litellm_proxymodeltable.create( - data=_data # type: ignore - ) + if should_create_model_in_db: + model_response = await prisma_client.db.litellm_proxymodeltable.create( + data=_data # type: ignore + ) + else: + model_response = LiteLLM_ProxyModelTable(**_data) return model_response diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 0fda92b878..1eb5996657 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -2419,6 +2419,33 @@ class ProxyConfig: added_models += 1 return added_models + def decrypt_model_list_from_db(self, new_models: list) -> list: + _model_list: list = [] + for m in new_models: + _litellm_params = m.litellm_params + if isinstance(_litellm_params, dict): + # decrypt values + for k, v in _litellm_params.items(): + decrypted_value = decrypt_value_helper(value=v) + _litellm_params[k] = decrypted_value + _litellm_params = LiteLLM_Params(**_litellm_params) + else: + verbose_proxy_logger.error( + f"Invalid model added to proxy db. Invalid litellm params. litellm_params={_litellm_params}" + ) + continue # skip to next model + + _model_info = self.get_model_info_with_id(model=m) + _model_list.append( + Deployment( + model_name=m.model_name, + litellm_params=_litellm_params, + model_info=_model_info, + ).to_json(exclude_none=True) + ) + + return _model_list + async def _update_llm_router( self, new_models: list, @@ -2430,29 +2457,9 @@ class ProxyConfig: if llm_router is None and master_key is not None: verbose_proxy_logger.debug(f"len new_models: {len(new_models)}") - _model_list: list = [] - for m in new_models: - _litellm_params = m.litellm_params - if isinstance(_litellm_params, dict): - # decrypt values - for k, v in _litellm_params.items(): - decrypted_value = decrypt_value_helper(value=v) - _litellm_params[k] = decrypted_value - _litellm_params = LiteLLM_Params(**_litellm_params) - else: - verbose_proxy_logger.error( - f"Invalid model added to proxy db. Invalid litellm params. litellm_params={_litellm_params}" - ) - continue # skip to next model - - _model_info = self.get_model_info_with_id(model=m) - _model_list.append( - Deployment( - model_name=m.model_name, - litellm_params=_litellm_params, - model_info=_model_info, - ).to_json(exclude_none=True) - ) + _model_list: list = self.decrypt_model_list_from_db( + new_models=new_models + ) if len(_model_list) > 0: verbose_proxy_logger.debug(f"_model_list: {_model_list}") llm_router = litellm.Router( @@ -2541,7 +2548,21 @@ class ProxyConfig: environment_variables = config_data.get("environment_variables", {}) self._decrypt_and_set_db_env_variables(environment_variables) - def _decrypt_and_set_db_env_variables(self, environment_variables: dict) -> None: + def _encrypt_env_variables( + self, environment_variables: dict, new_encryption_key: Optional[str] = None + ) -> dict: + """ + Encrypts a dictionary of environment variables and returns them. + """ + encrypted_env_vars = {} + for k, v in environment_variables.items(): + encrypted_value = encrypt_value_helper( + value=v, new_encryption_key=new_encryption_key + ) + encrypted_env_vars[k] = encrypted_value + return encrypted_env_vars + + def _decrypt_and_set_db_env_variables(self, environment_variables: dict) -> dict: """ Decrypts a dictionary of environment variables and then sets them in the environment @@ -2549,15 +2570,18 @@ class ProxyConfig: environment_variables: dict - dictionary of environment variables to decrypt and set eg. `{"LANGFUSE_PUBLIC_KEY": "kFiKa1VZukMmD8RB6WXB9F......."}` """ + decrypted_env_vars = {} for k, v in environment_variables.items(): try: decrypted_value = decrypt_value_helper(value=v) if decrypted_value is not None: os.environ[k] = decrypted_value + decrypted_env_vars[k] = decrypted_value except Exception as e: verbose_proxy_logger.error( "Error setting env variable: %s - %s", k, str(e) ) + return decrypted_env_vars async def _add_router_settings_from_db_config( self, diff --git a/litellm/router.py b/litellm/router.py index c2382098b0..aba9e16104 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -4462,6 +4462,7 @@ class Router: """ # check if deployment already exists _deployment_model_id = deployment.model_info.id or "" + _deployment_on_router: Optional[Deployment] = self.get_deployment( model_id=_deployment_model_id ) diff --git a/tests/proxy_admin_ui_tests/test_key_management.py b/tests/proxy_admin_ui_tests/test_key_management.py index 9d6c24db0e..46fdc0b2cf 100644 --- a/tests/proxy_admin_ui_tests/test_key_management.py +++ b/tests/proxy_admin_ui_tests/test_key_management.py @@ -162,6 +162,7 @@ async def test_regenerate_api_key(prisma_client): print(result) # regenerate the key + print("regenerating key: {}".format(generated_key)) new_key = await regenerate_key_fn( key=generated_key, user_api_key_dict=UserAPIKeyAuth(