mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
Support master key rotations (#9041)
* feat(key_management_endpoints.py): adding support for rotating master key * feat(key_management_endpoints.py): support decryption-re-encryption of models in db, when master key rotated * fix(user_api_key_auth.py): raise valid token is None error earlier enables easier debugging with api key hash in error message * feat(key_management_endpoints.py): rotate any env vars * fix(key_management_endpoints.py): uncomment check * fix: fix linting error
This commit is contained in:
parent
fcc57318f8
commit
da13ec2b64
8 changed files with 214 additions and 33 deletions
|
@ -664,6 +664,7 @@ class RegenerateKeyRequest(GenerateKeyRequest):
|
||||||
duration: Optional[str] = None
|
duration: Optional[str] = None
|
||||||
spend: Optional[float] = None
|
spend: Optional[float] = None
|
||||||
metadata: Optional[dict] = None
|
metadata: Optional[dict] = None
|
||||||
|
new_master_key: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class KeyRequest(LiteLLMPydanticObjectBase):
|
class KeyRequest(LiteLLMPydanticObjectBase):
|
||||||
|
@ -688,6 +689,30 @@ class LiteLLM_ModelTable(LiteLLMPydanticObjectBase):
|
||||||
model_config = ConfigDict(protected_namespaces=())
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
|
|
||||||
|
|
||||||
|
class LiteLLM_ProxyModelTable(LiteLLMPydanticObjectBase):
|
||||||
|
model_id: str
|
||||||
|
model_name: str
|
||||||
|
litellm_params: dict
|
||||||
|
model_info: dict
|
||||||
|
created_by: str
|
||||||
|
updated_by: str
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_potential_json_str(cls, values):
|
||||||
|
if isinstance(values.get("litellm_params"), str):
|
||||||
|
try:
|
||||||
|
values["litellm_params"] = json.loads(values["litellm_params"])
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
if isinstance(values.get("model_info"), str):
|
||||||
|
try:
|
||||||
|
values["model_info"] = json.loads(values["model_info"])
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
return values
|
||||||
|
|
||||||
|
|
||||||
class NewUserRequest(GenerateRequestBase):
|
class NewUserRequest(GenerateRequestBase):
|
||||||
max_budget: Optional[float] = None
|
max_budget: Optional[float] = None
|
||||||
user_email: Optional[str] = None
|
user_email: Optional[str] = None
|
||||||
|
|
|
@ -786,6 +786,13 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
|
||||||
)
|
)
|
||||||
valid_token = None
|
valid_token = None
|
||||||
|
|
||||||
|
if valid_token is None:
|
||||||
|
raise Exception(
|
||||||
|
"Invalid proxy server token passed. Received API Key (hashed)={}. Unable to find token in cache or `LiteLLM_VerificationTokenTable`".format(
|
||||||
|
api_key
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
user_obj: Optional[LiteLLM_UserTable] = None
|
user_obj: Optional[LiteLLM_UserTable] = None
|
||||||
valid_token_dict: dict = {}
|
valid_token_dict: dict = {}
|
||||||
if valid_token is not None:
|
if valid_token is not None:
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
import base64
|
import base64
|
||||||
import os
|
import os
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from litellm._logging import verbose_proxy_logger
|
from litellm._logging import verbose_proxy_logger
|
||||||
|
|
||||||
|
@ -19,9 +20,9 @@ def _get_salt_key():
|
||||||
return salt_key
|
return salt_key
|
||||||
|
|
||||||
|
|
||||||
def encrypt_value_helper(value: str):
|
def encrypt_value_helper(value: str, new_encryption_key: Optional[str] = None):
|
||||||
|
|
||||||
signing_key = _get_salt_key()
|
signing_key = new_encryption_key or _get_salt_key()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if isinstance(value, str):
|
if isinstance(value, str):
|
||||||
|
|
|
@ -35,18 +35,23 @@ from litellm.proxy.auth.auth_checks import (
|
||||||
)
|
)
|
||||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||||
from litellm.proxy.hooks.key_management_event_hooks import KeyManagementEventHooks
|
from litellm.proxy.hooks.key_management_event_hooks import KeyManagementEventHooks
|
||||||
|
from litellm.proxy.management_endpoints.model_management_endpoints import (
|
||||||
|
_add_model_to_db,
|
||||||
from litellm.proxy.management_endpoints.common_utils import (
|
from litellm.proxy.management_endpoints.common_utils import (
|
||||||
_is_user_team_admin,
|
_is_user_team_admin,
|
||||||
_set_object_metadata_field,
|
_set_object_metadata_field,
|
||||||
)
|
)
|
||||||
from litellm.proxy.management_helpers.utils import management_endpoint_wrapper
|
from litellm.proxy.management_helpers.utils import management_endpoint_wrapper
|
||||||
|
from litellm.proxy.spend_tracking.spend_tracking_utils import _is_master_key
|
||||||
from litellm.proxy.utils import (
|
from litellm.proxy.utils import (
|
||||||
PrismaClient,
|
PrismaClient,
|
||||||
_hash_token_if_needed,
|
_hash_token_if_needed,
|
||||||
handle_exception_on_proxy,
|
handle_exception_on_proxy,
|
||||||
|
jsonify_object,
|
||||||
)
|
)
|
||||||
from litellm.router import Router
|
from litellm.router import Router
|
||||||
from litellm.secret_managers.main import get_secret
|
from litellm.secret_managers.main import get_secret
|
||||||
|
from litellm.types.router import Deployment
|
||||||
from litellm.types.utils import (
|
from litellm.types.utils import (
|
||||||
BudgetConfig,
|
BudgetConfig,
|
||||||
PersonalUIKeyGenerationConfig,
|
PersonalUIKeyGenerationConfig,
|
||||||
|
@ -1525,14 +1530,98 @@ async def delete_key_aliases(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _rotate_master_key(
|
||||||
|
prisma_client: PrismaClient,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
|
current_master_key: str,
|
||||||
|
new_master_key: str,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Rotate the master key
|
||||||
|
|
||||||
|
1. Get the values from the DB
|
||||||
|
- Get models from DB
|
||||||
|
- Get config from DB
|
||||||
|
2. Decrypt the values
|
||||||
|
- ModelTable
|
||||||
|
- [{"model_name": "str", "litellm_params": {}}]
|
||||||
|
- ConfigTable
|
||||||
|
3. Encrypt the values with the new master key
|
||||||
|
4. Update the values in the DB
|
||||||
|
"""
|
||||||
|
from litellm.proxy.proxy_server import proxy_config
|
||||||
|
|
||||||
|
try:
|
||||||
|
models: Optional[List] = (
|
||||||
|
await prisma_client.db.litellm_proxymodeltable.find_many()
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
models = None
|
||||||
|
# 2. process model table
|
||||||
|
if models:
|
||||||
|
decrypted_models = proxy_config.decrypt_model_list_from_db(new_models=models)
|
||||||
|
verbose_proxy_logger.info(
|
||||||
|
"ABLE TO DECRYPT MODELS - len(decrypted_models): %s", len(decrypted_models)
|
||||||
|
)
|
||||||
|
new_models = []
|
||||||
|
for model in decrypted_models:
|
||||||
|
new_model = await _add_model_to_db(
|
||||||
|
model_params=Deployment(**model),
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
prisma_client=prisma_client,
|
||||||
|
new_encryption_key=new_master_key,
|
||||||
|
should_create_model_in_db=False,
|
||||||
|
)
|
||||||
|
if new_model:
|
||||||
|
new_models.append(jsonify_object(new_model.model_dump()))
|
||||||
|
verbose_proxy_logger.info("Resetting proxy model table")
|
||||||
|
await prisma_client.db.litellm_proxymodeltable.delete_many()
|
||||||
|
verbose_proxy_logger.info("Creating %s models", len(new_models))
|
||||||
|
await prisma_client.db.litellm_proxymodeltable.create_many(
|
||||||
|
data=new_models,
|
||||||
|
)
|
||||||
|
# 3. process config table
|
||||||
|
try:
|
||||||
|
config = await prisma_client.db.litellm_config.find_many()
|
||||||
|
except Exception:
|
||||||
|
config = None
|
||||||
|
|
||||||
|
if config:
|
||||||
|
"""If environment_variables is found, decrypt it and encrypt it with the new master key"""
|
||||||
|
environment_variables_dict = {}
|
||||||
|
for c in config:
|
||||||
|
if c.param_name == "environment_variables":
|
||||||
|
environment_variables_dict = c.param_value
|
||||||
|
|
||||||
|
if environment_variables_dict:
|
||||||
|
decrypted_env_vars = proxy_config._decrypt_and_set_db_env_variables(
|
||||||
|
environment_variables=environment_variables_dict
|
||||||
|
)
|
||||||
|
encrypted_env_vars = proxy_config._encrypt_env_variables(
|
||||||
|
environment_variables=decrypted_env_vars,
|
||||||
|
new_encryption_key=new_master_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
if encrypted_env_vars:
|
||||||
|
await prisma_client.db.litellm_config.update(
|
||||||
|
where={"param_name": "environment_variables"},
|
||||||
|
data={"param_value": jsonify_object(encrypted_env_vars)},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
"/key/{key:path}/regenerate",
|
"/key/{key:path}/regenerate",
|
||||||
tags=["key management"],
|
tags=["key management"],
|
||||||
dependencies=[Depends(user_api_key_auth)],
|
dependencies=[Depends(user_api_key_auth)],
|
||||||
)
|
)
|
||||||
|
@router.post(
|
||||||
|
"/key/regenerate",
|
||||||
|
tags=["key management"],
|
||||||
|
dependencies=[Depends(user_api_key_auth)],
|
||||||
|
)
|
||||||
@management_endpoint_wrapper
|
@management_endpoint_wrapper
|
||||||
async def regenerate_key_fn(
|
async def regenerate_key_fn(
|
||||||
key: str,
|
key: Optional[str] = None,
|
||||||
data: Optional[RegenerateKeyRequest] = None,
|
data: Optional[RegenerateKeyRequest] = None,
|
||||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||||
litellm_changed_by: Optional[str] = Header(
|
litellm_changed_by: Optional[str] = Header(
|
||||||
|
@ -1590,6 +1679,7 @@ async def regenerate_key_fn(
|
||||||
|
|
||||||
from litellm.proxy.proxy_server import (
|
from litellm.proxy.proxy_server import (
|
||||||
hash_token,
|
hash_token,
|
||||||
|
master_key,
|
||||||
premium_user,
|
premium_user,
|
||||||
prisma_client,
|
prisma_client,
|
||||||
proxy_logging_obj,
|
proxy_logging_obj,
|
||||||
|
@ -1602,7 +1692,9 @@ async def regenerate_key_fn(
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if key exists, raise exception if key is not in the DB
|
# Check if key exists, raise exception if key is not in the DB
|
||||||
|
key = data.key if data and data.key else key
|
||||||
|
if not key:
|
||||||
|
raise HTTPException(status_code=400, detail={"error": "No key passed in."})
|
||||||
### 1. Create New copy that is duplicate of existing key
|
### 1. Create New copy that is duplicate of existing key
|
||||||
######################################################################
|
######################################################################
|
||||||
|
|
||||||
|
@ -1617,6 +1709,27 @@ async def regenerate_key_fn(
|
||||||
detail={"error": "DB not connected. prisma_client is None"},
|
detail={"error": "DB not connected. prisma_client is None"},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
_is_master_key_valid = _is_master_key(api_key=key, _master_key=master_key)
|
||||||
|
|
||||||
|
if master_key is not None and data and _is_master_key_valid:
|
||||||
|
if data.new_master_key is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail={"error": "New master key is required."},
|
||||||
|
)
|
||||||
|
await _rotate_master_key(
|
||||||
|
prisma_client=prisma_client,
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
current_master_key=master_key,
|
||||||
|
new_master_key=data.new_master_key,
|
||||||
|
)
|
||||||
|
return GenerateKeyResponse(
|
||||||
|
key=data.new_master_key,
|
||||||
|
token=data.new_master_key,
|
||||||
|
key_name=data.new_master_key,
|
||||||
|
expires=None,
|
||||||
|
)
|
||||||
|
|
||||||
if "sk" not in key:
|
if "sk" not in key:
|
||||||
hashed_api_key = key
|
hashed_api_key = key
|
||||||
else:
|
else:
|
||||||
|
@ -1698,6 +1811,7 @@ async def regenerate_key_fn(
|
||||||
|
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
verbose_proxy_logger.exception("Error regenerating key: %s", e)
|
||||||
raise handle_exception_on_proxy(e)
|
raise handle_exception_on_proxy(e)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -21,6 +21,7 @@ from litellm._logging import verbose_proxy_logger
|
||||||
from litellm.constants import LITELLM_PROXY_ADMIN_NAME
|
from litellm.constants import LITELLM_PROXY_ADMIN_NAME
|
||||||
from litellm.proxy._types import (
|
from litellm.proxy._types import (
|
||||||
CommonProxyErrors,
|
CommonProxyErrors,
|
||||||
|
LiteLLM_ProxyModelTable,
|
||||||
LitellmUserRoles,
|
LitellmUserRoles,
|
||||||
PrismaCompatibleUpdateDBModel,
|
PrismaCompatibleUpdateDBModel,
|
||||||
ProxyErrorTypes,
|
ProxyErrorTypes,
|
||||||
|
@ -227,12 +228,16 @@ async def _add_model_to_db(
|
||||||
model_params: Deployment,
|
model_params: Deployment,
|
||||||
user_api_key_dict: UserAPIKeyAuth,
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
prisma_client: PrismaClient,
|
prisma_client: PrismaClient,
|
||||||
):
|
new_encryption_key: Optional[str] = None,
|
||||||
|
should_create_model_in_db: bool = True,
|
||||||
|
) -> Optional[LiteLLM_ProxyModelTable]:
|
||||||
# encrypt litellm params #
|
# encrypt litellm params #
|
||||||
_litellm_params_dict = model_params.litellm_params.dict(exclude_none=True)
|
_litellm_params_dict = model_params.litellm_params.dict(exclude_none=True)
|
||||||
_orignal_litellm_model_name = model_params.litellm_params.model
|
_orignal_litellm_model_name = model_params.litellm_params.model
|
||||||
for k, v in _litellm_params_dict.items():
|
for k, v in _litellm_params_dict.items():
|
||||||
encrypted_value = encrypt_value_helper(value=v)
|
encrypted_value = encrypt_value_helper(
|
||||||
|
value=v, new_encryption_key=new_encryption_key
|
||||||
|
)
|
||||||
model_params.litellm_params[k] = encrypted_value
|
model_params.litellm_params[k] = encrypted_value
|
||||||
_data: dict = {
|
_data: dict = {
|
||||||
"model_id": model_params.model_info.id,
|
"model_id": model_params.model_info.id,
|
||||||
|
@ -246,9 +251,12 @@ async def _add_model_to_db(
|
||||||
}
|
}
|
||||||
if model_params.model_info.id is not None:
|
if model_params.model_info.id is not None:
|
||||||
_data["model_id"] = model_params.model_info.id
|
_data["model_id"] = model_params.model_info.id
|
||||||
|
if should_create_model_in_db:
|
||||||
model_response = await prisma_client.db.litellm_proxymodeltable.create(
|
model_response = await prisma_client.db.litellm_proxymodeltable.create(
|
||||||
data=_data # type: ignore
|
data=_data # type: ignore
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
model_response = LiteLLM_ProxyModelTable(**_data)
|
||||||
return model_response
|
return model_response
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -2419,17 +2419,7 @@ class ProxyConfig:
|
||||||
added_models += 1
|
added_models += 1
|
||||||
return added_models
|
return added_models
|
||||||
|
|
||||||
async def _update_llm_router(
|
def decrypt_model_list_from_db(self, new_models: list) -> list:
|
||||||
self,
|
|
||||||
new_models: list,
|
|
||||||
proxy_logging_obj: ProxyLogging,
|
|
||||||
):
|
|
||||||
global llm_router, llm_model_list, master_key, general_settings
|
|
||||||
|
|
||||||
try:
|
|
||||||
if llm_router is None and master_key is not None:
|
|
||||||
verbose_proxy_logger.debug(f"len new_models: {len(new_models)}")
|
|
||||||
|
|
||||||
_model_list: list = []
|
_model_list: list = []
|
||||||
for m in new_models:
|
for m in new_models:
|
||||||
_litellm_params = m.litellm_params
|
_litellm_params = m.litellm_params
|
||||||
|
@ -2453,6 +2443,23 @@ class ProxyConfig:
|
||||||
model_info=_model_info,
|
model_info=_model_info,
|
||||||
).to_json(exclude_none=True)
|
).to_json(exclude_none=True)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return _model_list
|
||||||
|
|
||||||
|
async def _update_llm_router(
|
||||||
|
self,
|
||||||
|
new_models: list,
|
||||||
|
proxy_logging_obj: ProxyLogging,
|
||||||
|
):
|
||||||
|
global llm_router, llm_model_list, master_key, general_settings
|
||||||
|
|
||||||
|
try:
|
||||||
|
if llm_router is None and master_key is not None:
|
||||||
|
verbose_proxy_logger.debug(f"len new_models: {len(new_models)}")
|
||||||
|
|
||||||
|
_model_list: list = self.decrypt_model_list_from_db(
|
||||||
|
new_models=new_models
|
||||||
|
)
|
||||||
if len(_model_list) > 0:
|
if len(_model_list) > 0:
|
||||||
verbose_proxy_logger.debug(f"_model_list: {_model_list}")
|
verbose_proxy_logger.debug(f"_model_list: {_model_list}")
|
||||||
llm_router = litellm.Router(
|
llm_router = litellm.Router(
|
||||||
|
@ -2541,7 +2548,21 @@ class ProxyConfig:
|
||||||
environment_variables = config_data.get("environment_variables", {})
|
environment_variables = config_data.get("environment_variables", {})
|
||||||
self._decrypt_and_set_db_env_variables(environment_variables)
|
self._decrypt_and_set_db_env_variables(environment_variables)
|
||||||
|
|
||||||
def _decrypt_and_set_db_env_variables(self, environment_variables: dict) -> None:
|
def _encrypt_env_variables(
|
||||||
|
self, environment_variables: dict, new_encryption_key: Optional[str] = None
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Encrypts a dictionary of environment variables and returns them.
|
||||||
|
"""
|
||||||
|
encrypted_env_vars = {}
|
||||||
|
for k, v in environment_variables.items():
|
||||||
|
encrypted_value = encrypt_value_helper(
|
||||||
|
value=v, new_encryption_key=new_encryption_key
|
||||||
|
)
|
||||||
|
encrypted_env_vars[k] = encrypted_value
|
||||||
|
return encrypted_env_vars
|
||||||
|
|
||||||
|
def _decrypt_and_set_db_env_variables(self, environment_variables: dict) -> dict:
|
||||||
"""
|
"""
|
||||||
Decrypts a dictionary of environment variables and then sets them in the environment
|
Decrypts a dictionary of environment variables and then sets them in the environment
|
||||||
|
|
||||||
|
@ -2549,15 +2570,18 @@ class ProxyConfig:
|
||||||
environment_variables: dict - dictionary of environment variables to decrypt and set
|
environment_variables: dict - dictionary of environment variables to decrypt and set
|
||||||
eg. `{"LANGFUSE_PUBLIC_KEY": "kFiKa1VZukMmD8RB6WXB9F......."}`
|
eg. `{"LANGFUSE_PUBLIC_KEY": "kFiKa1VZukMmD8RB6WXB9F......."}`
|
||||||
"""
|
"""
|
||||||
|
decrypted_env_vars = {}
|
||||||
for k, v in environment_variables.items():
|
for k, v in environment_variables.items():
|
||||||
try:
|
try:
|
||||||
decrypted_value = decrypt_value_helper(value=v)
|
decrypted_value = decrypt_value_helper(value=v)
|
||||||
if decrypted_value is not None:
|
if decrypted_value is not None:
|
||||||
os.environ[k] = decrypted_value
|
os.environ[k] = decrypted_value
|
||||||
|
decrypted_env_vars[k] = decrypted_value
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
verbose_proxy_logger.error(
|
verbose_proxy_logger.error(
|
||||||
"Error setting env variable: %s - %s", k, str(e)
|
"Error setting env variable: %s - %s", k, str(e)
|
||||||
)
|
)
|
||||||
|
return decrypted_env_vars
|
||||||
|
|
||||||
async def _add_router_settings_from_db_config(
|
async def _add_router_settings_from_db_config(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -4462,6 +4462,7 @@ class Router:
|
||||||
"""
|
"""
|
||||||
# check if deployment already exists
|
# check if deployment already exists
|
||||||
_deployment_model_id = deployment.model_info.id or ""
|
_deployment_model_id = deployment.model_info.id or ""
|
||||||
|
|
||||||
_deployment_on_router: Optional[Deployment] = self.get_deployment(
|
_deployment_on_router: Optional[Deployment] = self.get_deployment(
|
||||||
model_id=_deployment_model_id
|
model_id=_deployment_model_id
|
||||||
)
|
)
|
||||||
|
|
|
@ -162,6 +162,7 @@ async def test_regenerate_api_key(prisma_client):
|
||||||
print(result)
|
print(result)
|
||||||
|
|
||||||
# regenerate the key
|
# regenerate the key
|
||||||
|
print("regenerating key: {}".format(generated_key))
|
||||||
new_key = await regenerate_key_fn(
|
new_key = await regenerate_key_fn(
|
||||||
key=generated_key,
|
key=generated_key,
|
||||||
user_api_key_dict=UserAPIKeyAuth(
|
user_api_key_dict=UserAPIKeyAuth(
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue