mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
Support master key rotations (#9041)
* feat(key_management_endpoints.py): adding support for rotating master key * feat(key_management_endpoints.py): support decryption-re-encryption of models in db, when master key rotated * fix(user_api_key_auth.py): raise valid token is None error earlier enables easier debugging with api key hash in error message * feat(key_management_endpoints.py): rotate any env vars * fix(key_management_endpoints.py): uncomment check * fix: fix linting error
This commit is contained in:
parent
fcc57318f8
commit
da13ec2b64
8 changed files with 214 additions and 33 deletions
|
@ -664,6 +664,7 @@ class RegenerateKeyRequest(GenerateKeyRequest):
|
|||
duration: Optional[str] = None
|
||||
spend: Optional[float] = None
|
||||
metadata: Optional[dict] = None
|
||||
new_master_key: Optional[str] = None
|
||||
|
||||
|
||||
class KeyRequest(LiteLLMPydanticObjectBase):
|
||||
|
@ -688,6 +689,30 @@ class LiteLLM_ModelTable(LiteLLMPydanticObjectBase):
|
|||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
class LiteLLM_ProxyModelTable(LiteLLMPydanticObjectBase):
|
||||
model_id: str
|
||||
model_name: str
|
||||
litellm_params: dict
|
||||
model_info: dict
|
||||
created_by: str
|
||||
updated_by: str
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_potential_json_str(cls, values):
|
||||
if isinstance(values.get("litellm_params"), str):
|
||||
try:
|
||||
values["litellm_params"] = json.loads(values["litellm_params"])
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
if isinstance(values.get("model_info"), str):
|
||||
try:
|
||||
values["model_info"] = json.loads(values["model_info"])
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
return values
|
||||
|
||||
|
||||
class NewUserRequest(GenerateRequestBase):
|
||||
max_budget: Optional[float] = None
|
||||
user_email: Optional[str] = None
|
||||
|
|
|
@ -786,6 +786,13 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
|
|||
)
|
||||
valid_token = None
|
||||
|
||||
if valid_token is None:
|
||||
raise Exception(
|
||||
"Invalid proxy server token passed. Received API Key (hashed)={}. Unable to find token in cache or `LiteLLM_VerificationTokenTable`".format(
|
||||
api_key
|
||||
)
|
||||
)
|
||||
|
||||
user_obj: Optional[LiteLLM_UserTable] = None
|
||||
valid_token_dict: dict = {}
|
||||
if valid_token is not None:
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import base64
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
|
||||
|
@ -19,9 +20,9 @@ def _get_salt_key():
|
|||
return salt_key
|
||||
|
||||
|
||||
def encrypt_value_helper(value: str):
|
||||
def encrypt_value_helper(value: str, new_encryption_key: Optional[str] = None):
|
||||
|
||||
signing_key = _get_salt_key()
|
||||
signing_key = new_encryption_key or _get_salt_key()
|
||||
|
||||
try:
|
||||
if isinstance(value, str):
|
||||
|
|
|
@ -35,18 +35,23 @@ from litellm.proxy.auth.auth_checks import (
|
|||
)
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.proxy.hooks.key_management_event_hooks import KeyManagementEventHooks
|
||||
from litellm.proxy.management_endpoints.model_management_endpoints import (
|
||||
_add_model_to_db,
|
||||
from litellm.proxy.management_endpoints.common_utils import (
|
||||
_is_user_team_admin,
|
||||
_set_object_metadata_field,
|
||||
)
|
||||
from litellm.proxy.management_helpers.utils import management_endpoint_wrapper
|
||||
from litellm.proxy.spend_tracking.spend_tracking_utils import _is_master_key
|
||||
from litellm.proxy.utils import (
|
||||
PrismaClient,
|
||||
_hash_token_if_needed,
|
||||
handle_exception_on_proxy,
|
||||
jsonify_object,
|
||||
)
|
||||
from litellm.router import Router
|
||||
from litellm.secret_managers.main import get_secret
|
||||
from litellm.types.router import Deployment
|
||||
from litellm.types.utils import (
|
||||
BudgetConfig,
|
||||
PersonalUIKeyGenerationConfig,
|
||||
|
@ -1525,14 +1530,98 @@ async def delete_key_aliases(
|
|||
)
|
||||
|
||||
|
||||
async def _rotate_master_key(
|
||||
prisma_client: PrismaClient,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
current_master_key: str,
|
||||
new_master_key: str,
|
||||
) -> None:
|
||||
"""
|
||||
Rotate the master key
|
||||
|
||||
1. Get the values from the DB
|
||||
- Get models from DB
|
||||
- Get config from DB
|
||||
2. Decrypt the values
|
||||
- ModelTable
|
||||
- [{"model_name": "str", "litellm_params": {}}]
|
||||
- ConfigTable
|
||||
3. Encrypt the values with the new master key
|
||||
4. Update the values in the DB
|
||||
"""
|
||||
from litellm.proxy.proxy_server import proxy_config
|
||||
|
||||
try:
|
||||
models: Optional[List] = (
|
||||
await prisma_client.db.litellm_proxymodeltable.find_many()
|
||||
)
|
||||
except Exception:
|
||||
models = None
|
||||
# 2. process model table
|
||||
if models:
|
||||
decrypted_models = proxy_config.decrypt_model_list_from_db(new_models=models)
|
||||
verbose_proxy_logger.info(
|
||||
"ABLE TO DECRYPT MODELS - len(decrypted_models): %s", len(decrypted_models)
|
||||
)
|
||||
new_models = []
|
||||
for model in decrypted_models:
|
||||
new_model = await _add_model_to_db(
|
||||
model_params=Deployment(**model),
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
prisma_client=prisma_client,
|
||||
new_encryption_key=new_master_key,
|
||||
should_create_model_in_db=False,
|
||||
)
|
||||
if new_model:
|
||||
new_models.append(jsonify_object(new_model.model_dump()))
|
||||
verbose_proxy_logger.info("Resetting proxy model table")
|
||||
await prisma_client.db.litellm_proxymodeltable.delete_many()
|
||||
verbose_proxy_logger.info("Creating %s models", len(new_models))
|
||||
await prisma_client.db.litellm_proxymodeltable.create_many(
|
||||
data=new_models,
|
||||
)
|
||||
# 3. process config table
|
||||
try:
|
||||
config = await prisma_client.db.litellm_config.find_many()
|
||||
except Exception:
|
||||
config = None
|
||||
|
||||
if config:
|
||||
"""If environment_variables is found, decrypt it and encrypt it with the new master key"""
|
||||
environment_variables_dict = {}
|
||||
for c in config:
|
||||
if c.param_name == "environment_variables":
|
||||
environment_variables_dict = c.param_value
|
||||
|
||||
if environment_variables_dict:
|
||||
decrypted_env_vars = proxy_config._decrypt_and_set_db_env_variables(
|
||||
environment_variables=environment_variables_dict
|
||||
)
|
||||
encrypted_env_vars = proxy_config._encrypt_env_variables(
|
||||
environment_variables=decrypted_env_vars,
|
||||
new_encryption_key=new_master_key,
|
||||
)
|
||||
|
||||
if encrypted_env_vars:
|
||||
await prisma_client.db.litellm_config.update(
|
||||
where={"param_name": "environment_variables"},
|
||||
data={"param_value": jsonify_object(encrypted_env_vars)},
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/key/{key:path}/regenerate",
|
||||
tags=["key management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
@router.post(
|
||||
"/key/regenerate",
|
||||
tags=["key management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
@management_endpoint_wrapper
|
||||
async def regenerate_key_fn(
|
||||
key: str,
|
||||
key: Optional[str] = None,
|
||||
data: Optional[RegenerateKeyRequest] = None,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
litellm_changed_by: Optional[str] = Header(
|
||||
|
@ -1590,6 +1679,7 @@ async def regenerate_key_fn(
|
|||
|
||||
from litellm.proxy.proxy_server import (
|
||||
hash_token,
|
||||
master_key,
|
||||
premium_user,
|
||||
prisma_client,
|
||||
proxy_logging_obj,
|
||||
|
@ -1602,7 +1692,9 @@ async def regenerate_key_fn(
|
|||
)
|
||||
|
||||
# Check if key exists, raise exception if key is not in the DB
|
||||
|
||||
key = data.key if data and data.key else key
|
||||
if not key:
|
||||
raise HTTPException(status_code=400, detail={"error": "No key passed in."})
|
||||
### 1. Create New copy that is duplicate of existing key
|
||||
######################################################################
|
||||
|
||||
|
@ -1617,6 +1709,27 @@ async def regenerate_key_fn(
|
|||
detail={"error": "DB not connected. prisma_client is None"},
|
||||
)
|
||||
|
||||
_is_master_key_valid = _is_master_key(api_key=key, _master_key=master_key)
|
||||
|
||||
if master_key is not None and data and _is_master_key_valid:
|
||||
if data.new_master_key is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={"error": "New master key is required."},
|
||||
)
|
||||
await _rotate_master_key(
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
current_master_key=master_key,
|
||||
new_master_key=data.new_master_key,
|
||||
)
|
||||
return GenerateKeyResponse(
|
||||
key=data.new_master_key,
|
||||
token=data.new_master_key,
|
||||
key_name=data.new_master_key,
|
||||
expires=None,
|
||||
)
|
||||
|
||||
if "sk" not in key:
|
||||
hashed_api_key = key
|
||||
else:
|
||||
|
@ -1698,6 +1811,7 @@ async def regenerate_key_fn(
|
|||
|
||||
return response
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception("Error regenerating key: %s", e)
|
||||
raise handle_exception_on_proxy(e)
|
||||
|
||||
|
||||
|
|
|
@ -21,6 +21,7 @@ from litellm._logging import verbose_proxy_logger
|
|||
from litellm.constants import LITELLM_PROXY_ADMIN_NAME
|
||||
from litellm.proxy._types import (
|
||||
CommonProxyErrors,
|
||||
LiteLLM_ProxyModelTable,
|
||||
LitellmUserRoles,
|
||||
PrismaCompatibleUpdateDBModel,
|
||||
ProxyErrorTypes,
|
||||
|
@ -227,12 +228,16 @@ async def _add_model_to_db(
|
|||
model_params: Deployment,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
prisma_client: PrismaClient,
|
||||
):
|
||||
new_encryption_key: Optional[str] = None,
|
||||
should_create_model_in_db: bool = True,
|
||||
) -> Optional[LiteLLM_ProxyModelTable]:
|
||||
# encrypt litellm params #
|
||||
_litellm_params_dict = model_params.litellm_params.dict(exclude_none=True)
|
||||
_orignal_litellm_model_name = model_params.litellm_params.model
|
||||
for k, v in _litellm_params_dict.items():
|
||||
encrypted_value = encrypt_value_helper(value=v)
|
||||
encrypted_value = encrypt_value_helper(
|
||||
value=v, new_encryption_key=new_encryption_key
|
||||
)
|
||||
model_params.litellm_params[k] = encrypted_value
|
||||
_data: dict = {
|
||||
"model_id": model_params.model_info.id,
|
||||
|
@ -246,9 +251,12 @@ async def _add_model_to_db(
|
|||
}
|
||||
if model_params.model_info.id is not None:
|
||||
_data["model_id"] = model_params.model_info.id
|
||||
if should_create_model_in_db:
|
||||
model_response = await prisma_client.db.litellm_proxymodeltable.create(
|
||||
data=_data # type: ignore
|
||||
)
|
||||
else:
|
||||
model_response = LiteLLM_ProxyModelTable(**_data)
|
||||
return model_response
|
||||
|
||||
|
||||
|
|
|
@ -2419,17 +2419,7 @@ class ProxyConfig:
|
|||
added_models += 1
|
||||
return added_models
|
||||
|
||||
async def _update_llm_router(
|
||||
self,
|
||||
new_models: list,
|
||||
proxy_logging_obj: ProxyLogging,
|
||||
):
|
||||
global llm_router, llm_model_list, master_key, general_settings
|
||||
|
||||
try:
|
||||
if llm_router is None and master_key is not None:
|
||||
verbose_proxy_logger.debug(f"len new_models: {len(new_models)}")
|
||||
|
||||
def decrypt_model_list_from_db(self, new_models: list) -> list:
|
||||
_model_list: list = []
|
||||
for m in new_models:
|
||||
_litellm_params = m.litellm_params
|
||||
|
@ -2453,6 +2443,23 @@ class ProxyConfig:
|
|||
model_info=_model_info,
|
||||
).to_json(exclude_none=True)
|
||||
)
|
||||
|
||||
return _model_list
|
||||
|
||||
async def _update_llm_router(
|
||||
self,
|
||||
new_models: list,
|
||||
proxy_logging_obj: ProxyLogging,
|
||||
):
|
||||
global llm_router, llm_model_list, master_key, general_settings
|
||||
|
||||
try:
|
||||
if llm_router is None and master_key is not None:
|
||||
verbose_proxy_logger.debug(f"len new_models: {len(new_models)}")
|
||||
|
||||
_model_list: list = self.decrypt_model_list_from_db(
|
||||
new_models=new_models
|
||||
)
|
||||
if len(_model_list) > 0:
|
||||
verbose_proxy_logger.debug(f"_model_list: {_model_list}")
|
||||
llm_router = litellm.Router(
|
||||
|
@ -2541,7 +2548,21 @@ class ProxyConfig:
|
|||
environment_variables = config_data.get("environment_variables", {})
|
||||
self._decrypt_and_set_db_env_variables(environment_variables)
|
||||
|
||||
def _decrypt_and_set_db_env_variables(self, environment_variables: dict) -> None:
|
||||
def _encrypt_env_variables(
|
||||
self, environment_variables: dict, new_encryption_key: Optional[str] = None
|
||||
) -> dict:
|
||||
"""
|
||||
Encrypts a dictionary of environment variables and returns them.
|
||||
"""
|
||||
encrypted_env_vars = {}
|
||||
for k, v in environment_variables.items():
|
||||
encrypted_value = encrypt_value_helper(
|
||||
value=v, new_encryption_key=new_encryption_key
|
||||
)
|
||||
encrypted_env_vars[k] = encrypted_value
|
||||
return encrypted_env_vars
|
||||
|
||||
def _decrypt_and_set_db_env_variables(self, environment_variables: dict) -> dict:
|
||||
"""
|
||||
Decrypts a dictionary of environment variables and then sets them in the environment
|
||||
|
||||
|
@ -2549,15 +2570,18 @@ class ProxyConfig:
|
|||
environment_variables: dict - dictionary of environment variables to decrypt and set
|
||||
eg. `{"LANGFUSE_PUBLIC_KEY": "kFiKa1VZukMmD8RB6WXB9F......."}`
|
||||
"""
|
||||
decrypted_env_vars = {}
|
||||
for k, v in environment_variables.items():
|
||||
try:
|
||||
decrypted_value = decrypt_value_helper(value=v)
|
||||
if decrypted_value is not None:
|
||||
os.environ[k] = decrypted_value
|
||||
decrypted_env_vars[k] = decrypted_value
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(
|
||||
"Error setting env variable: %s - %s", k, str(e)
|
||||
)
|
||||
return decrypted_env_vars
|
||||
|
||||
async def _add_router_settings_from_db_config(
|
||||
self,
|
||||
|
|
|
@ -4462,6 +4462,7 @@ class Router:
|
|||
"""
|
||||
# check if deployment already exists
|
||||
_deployment_model_id = deployment.model_info.id or ""
|
||||
|
||||
_deployment_on_router: Optional[Deployment] = self.get_deployment(
|
||||
model_id=_deployment_model_id
|
||||
)
|
||||
|
|
|
@ -162,6 +162,7 @@ async def test_regenerate_api_key(prisma_client):
|
|||
print(result)
|
||||
|
||||
# regenerate the key
|
||||
print("regenerating key: {}".format(generated_key))
|
||||
new_key = await regenerate_key_fn(
|
||||
key=generated_key,
|
||||
user_api_key_dict=UserAPIKeyAuth(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue