mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
* feat(key_management_endpoints.py): adding support for rotating master key * feat(key_management_endpoints.py): support decryption-re-encryption of models in db, when master key rotated * fix(user_api_key_auth.py): raise valid token is None error earlier enables easier debugging with api key hash in error message * feat(key_management_endpoints.py): rotate any env vars * fix(key_management_endpoints.py): uncomment check * fix: fix linting error
331 lines
11 KiB
Python
331 lines
11 KiB
Python
"""
|
|
Allow proxy admin to add/update/delete models in the db
|
|
|
|
Currently most endpoints are in `proxy_server.py`, but those should be moved here over time.
|
|
|
|
Endpoints here:
|
|
|
|
model/{model_id}/update - PATCH endpoint for model update.
|
|
"""
|
|
|
|
#### MODEL MANAGEMENT ####
|
|
|
|
import json
|
|
import uuid
|
|
from typing import Optional, cast
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
|
from pydantic import BaseModel
|
|
|
|
from litellm._logging import verbose_proxy_logger
|
|
from litellm.constants import LITELLM_PROXY_ADMIN_NAME
|
|
from litellm.proxy._types import (
|
|
CommonProxyErrors,
|
|
LiteLLM_ProxyModelTable,
|
|
LitellmUserRoles,
|
|
PrismaCompatibleUpdateDBModel,
|
|
ProxyErrorTypes,
|
|
ProxyException,
|
|
TeamModelAddRequest,
|
|
UpdateTeamRequest,
|
|
UserAPIKeyAuth,
|
|
)
|
|
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
|
from litellm.proxy.common_utils.encrypt_decrypt_utils import encrypt_value_helper
|
|
from litellm.proxy.management_endpoints.team_endpoints import (
|
|
team_model_add,
|
|
update_team,
|
|
)
|
|
from litellm.proxy.utils import PrismaClient
|
|
from litellm.types.router import (
|
|
Deployment,
|
|
DeploymentTypedDict,
|
|
LiteLLMParamsTypedDict,
|
|
updateDeployment,
|
|
)
|
|
from litellm.utils import get_utc_datetime
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
async def get_db_model(
|
|
model_id: str, prisma_client: PrismaClient
|
|
) -> Optional[Deployment]:
|
|
db_model = cast(
|
|
Optional[BaseModel],
|
|
await prisma_client.db.litellm_proxymodeltable.find_unique(
|
|
where={"model_id": model_id}
|
|
),
|
|
)
|
|
|
|
if not db_model:
|
|
return None
|
|
|
|
deployment_pydantic_obj = Deployment(**db_model.model_dump(exclude_none=True))
|
|
return deployment_pydantic_obj
|
|
|
|
|
|
def update_db_model(
|
|
db_model: Deployment, updated_patch: updateDeployment
|
|
) -> PrismaCompatibleUpdateDBModel:
|
|
merged_deployment_dict = DeploymentTypedDict(
|
|
model_name=db_model.model_name,
|
|
litellm_params=LiteLLMParamsTypedDict(
|
|
**db_model.litellm_params.model_dump(exclude_none=True) # type: ignore
|
|
),
|
|
)
|
|
# update model name
|
|
if updated_patch.model_name:
|
|
merged_deployment_dict["model_name"] = updated_patch.model_name
|
|
|
|
# update litellm params
|
|
if updated_patch.litellm_params:
|
|
# Encrypt any sensitive values
|
|
encrypted_params = {
|
|
k: encrypt_value_helper(v)
|
|
for k, v in updated_patch.litellm_params.model_dump(
|
|
exclude_none=True
|
|
).items()
|
|
}
|
|
|
|
merged_deployment_dict["litellm_params"].update(encrypted_params) # type: ignore
|
|
|
|
# update model info
|
|
if updated_patch.model_info:
|
|
if "model_info" not in merged_deployment_dict:
|
|
merged_deployment_dict["model_info"] = {}
|
|
merged_deployment_dict["model_info"].update(
|
|
updated_patch.model_info.model_dump(exclude_none=True)
|
|
)
|
|
|
|
# convert to prisma compatible format
|
|
|
|
prisma_compatible_model_dict = PrismaCompatibleUpdateDBModel()
|
|
if "model_name" in merged_deployment_dict:
|
|
prisma_compatible_model_dict["model_name"] = merged_deployment_dict[
|
|
"model_name"
|
|
]
|
|
|
|
if "litellm_params" in merged_deployment_dict:
|
|
prisma_compatible_model_dict["litellm_params"] = json.dumps(
|
|
merged_deployment_dict["litellm_params"]
|
|
)
|
|
|
|
if "model_info" in merged_deployment_dict:
|
|
prisma_compatible_model_dict["model_info"] = json.dumps(
|
|
merged_deployment_dict["model_info"]
|
|
)
|
|
return prisma_compatible_model_dict
|
|
|
|
|
|
@router.patch(
|
|
"/model/{model_id}/update",
|
|
tags=["model management"],
|
|
dependencies=[Depends(user_api_key_auth)],
|
|
)
|
|
async def patch_model(
|
|
model_id: str, # Get model_id from path parameter
|
|
patch_data: updateDeployment, # Create a specific schema for PATCH operations
|
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
|
):
|
|
"""
|
|
PATCH Endpoint for partial model updates.
|
|
|
|
Only updates the fields specified in the request while preserving other existing values.
|
|
Follows proper PATCH semantics by only modifying provided fields.
|
|
|
|
Args:
|
|
model_id: The ID of the model to update
|
|
patch_data: The fields to update and their new values
|
|
user_api_key_dict: User authentication information
|
|
|
|
Returns:
|
|
Updated model information
|
|
|
|
Raises:
|
|
ProxyException: For various error conditions including authentication and database errors
|
|
"""
|
|
from litellm.proxy.proxy_server import (
|
|
litellm_proxy_admin_name,
|
|
llm_router,
|
|
prisma_client,
|
|
store_model_in_db,
|
|
)
|
|
|
|
try:
|
|
if prisma_client is None:
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail={"error": CommonProxyErrors.db_not_connected_error.value},
|
|
)
|
|
|
|
# Verify model exists and is stored in DB
|
|
if not store_model_in_db:
|
|
raise ProxyException(
|
|
message="Model updates only supported for DB-stored models",
|
|
type=ProxyErrorTypes.validation_error.value,
|
|
code=status.HTTP_400_BAD_REQUEST,
|
|
param=None,
|
|
)
|
|
|
|
# Fetch existing model
|
|
db_model = await get_db_model(model_id=model_id, prisma_client=prisma_client)
|
|
|
|
if db_model is None:
|
|
# Check if model exists in config but not DB
|
|
if llm_router and llm_router.get_deployment(model_id=model_id) is not None:
|
|
raise ProxyException(
|
|
message="Cannot edit config-based model. Store model in DB via /model/new first.",
|
|
type=ProxyErrorTypes.validation_error.value,
|
|
code=status.HTTP_400_BAD_REQUEST,
|
|
param=None,
|
|
)
|
|
raise ProxyException(
|
|
message=f"Model {model_id} not found on proxy.",
|
|
type=ProxyErrorTypes.not_found_error,
|
|
code=status.HTTP_404_NOT_FOUND,
|
|
param=None,
|
|
)
|
|
|
|
# Create update dictionary only for provided fields
|
|
update_data = update_db_model(db_model=db_model, updated_patch=patch_data)
|
|
|
|
# Add metadata about update
|
|
update_data["updated_by"] = (
|
|
user_api_key_dict.user_id or litellm_proxy_admin_name
|
|
)
|
|
update_data["updated_at"] = cast(str, get_utc_datetime())
|
|
|
|
# Perform partial update
|
|
updated_model = await prisma_client.db.litellm_proxymodeltable.update(
|
|
where={"model_id": model_id},
|
|
data=update_data,
|
|
)
|
|
|
|
return updated_model
|
|
|
|
except Exception as e:
|
|
verbose_proxy_logger.exception(f"Error in patch_model: {str(e)}")
|
|
|
|
if isinstance(e, (HTTPException, ProxyException)):
|
|
raise e
|
|
|
|
raise ProxyException(
|
|
message=f"Error updating model: {str(e)}",
|
|
type=ProxyErrorTypes.internal_server_error,
|
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
param=None,
|
|
)
|
|
|
|
|
|
################################# Helper Functions #################################
|
|
####################################################################################
|
|
####################################################################################
|
|
####################################################################################
|
|
|
|
|
|
async def _add_model_to_db(
|
|
model_params: Deployment,
|
|
user_api_key_dict: UserAPIKeyAuth,
|
|
prisma_client: PrismaClient,
|
|
new_encryption_key: Optional[str] = None,
|
|
should_create_model_in_db: bool = True,
|
|
) -> Optional[LiteLLM_ProxyModelTable]:
|
|
# encrypt litellm params #
|
|
_litellm_params_dict = model_params.litellm_params.dict(exclude_none=True)
|
|
_orignal_litellm_model_name = model_params.litellm_params.model
|
|
for k, v in _litellm_params_dict.items():
|
|
encrypted_value = encrypt_value_helper(
|
|
value=v, new_encryption_key=new_encryption_key
|
|
)
|
|
model_params.litellm_params[k] = encrypted_value
|
|
_data: dict = {
|
|
"model_id": model_params.model_info.id,
|
|
"model_name": model_params.model_name,
|
|
"litellm_params": model_params.litellm_params.model_dump_json(exclude_none=True), # type: ignore
|
|
"model_info": model_params.model_info.model_dump_json( # type: ignore
|
|
exclude_none=True
|
|
),
|
|
"created_by": user_api_key_dict.user_id or LITELLM_PROXY_ADMIN_NAME,
|
|
"updated_by": user_api_key_dict.user_id or LITELLM_PROXY_ADMIN_NAME,
|
|
}
|
|
if model_params.model_info.id is not None:
|
|
_data["model_id"] = model_params.model_info.id
|
|
if should_create_model_in_db:
|
|
model_response = await prisma_client.db.litellm_proxymodeltable.create(
|
|
data=_data # type: ignore
|
|
)
|
|
else:
|
|
model_response = LiteLLM_ProxyModelTable(**_data)
|
|
return model_response
|
|
|
|
|
|
async def _add_team_model_to_db(
|
|
model_params: Deployment,
|
|
user_api_key_dict: UserAPIKeyAuth,
|
|
prisma_client: PrismaClient,
|
|
):
|
|
"""
|
|
If 'team_id' is provided,
|
|
|
|
- generate a unique 'model_name' for the model (e.g. 'model_name_{team_id}_{uuid})
|
|
- store the model in the db with the unique 'model_name'
|
|
- store a team model alias mapping {"model_name": "model_name_{team_id}_{uuid}"}
|
|
"""
|
|
_team_id = model_params.model_info.team_id
|
|
if _team_id is None:
|
|
return None
|
|
original_model_name = model_params.model_name
|
|
if original_model_name:
|
|
model_params.model_info.team_public_model_name = original_model_name
|
|
|
|
unique_model_name = f"model_name_{_team_id}_{uuid.uuid4()}"
|
|
|
|
model_params.model_name = unique_model_name
|
|
|
|
## CREATE MODEL IN DB ##
|
|
model_response = await _add_model_to_db(
|
|
model_params=model_params,
|
|
user_api_key_dict=user_api_key_dict,
|
|
prisma_client=prisma_client,
|
|
)
|
|
|
|
## CREATE MODEL ALIAS IN DB ##
|
|
await update_team(
|
|
data=UpdateTeamRequest(
|
|
team_id=_team_id,
|
|
model_aliases={original_model_name: unique_model_name},
|
|
),
|
|
user_api_key_dict=user_api_key_dict,
|
|
http_request=Request(scope={"type": "http"}),
|
|
)
|
|
|
|
# add model to team object
|
|
await team_model_add(
|
|
data=TeamModelAddRequest(
|
|
team_id=_team_id,
|
|
models=[original_model_name],
|
|
),
|
|
http_request=Request(scope={"type": "http"}),
|
|
user_api_key_dict=user_api_key_dict,
|
|
)
|
|
|
|
return model_response
|
|
|
|
|
|
def check_if_team_id_matches_key(
|
|
team_id: Optional[str], user_api_key_dict: UserAPIKeyAuth
|
|
) -> bool:
|
|
can_make_call = True
|
|
if (
|
|
user_api_key_dict.user_role
|
|
and user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN
|
|
):
|
|
return True
|
|
if team_id is None:
|
|
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
|
|
can_make_call = False
|
|
else:
|
|
if user_api_key_dict.team_id != team_id:
|
|
can_make_call = False
|
|
return can_make_call
|