mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 03:34:10 +00:00
refactor location of proxy
This commit is contained in:
parent
baa5564f95
commit
ce58c53ff1
413 changed files with 2087 additions and 2088 deletions
|
@ -1,869 +0,0 @@
|
|||
"""
|
||||
Allow proxy admin to add/update/delete models in the db
|
||||
|
||||
Currently most endpoints are in `proxy_server.py`, but those should be moved here over time.
|
||||
|
||||
Endpoints here:
|
||||
|
||||
model/{model_id}/update - PATCH endpoint for model update.
|
||||
"""
|
||||
|
||||
#### MODEL MANAGEMENT ####
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import uuid
|
||||
from typing import Dict, List, Literal, Optional, Union, cast
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from pydantic import BaseModel
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.constants import LITELLM_PROXY_ADMIN_NAME
|
||||
from litellm.proxy._types import (
|
||||
CommonProxyErrors,
|
||||
LiteLLM_ProxyModelTable,
|
||||
LiteLLM_TeamTable,
|
||||
LitellmTableNames,
|
||||
LitellmUserRoles,
|
||||
ModelInfoDelete,
|
||||
PrismaCompatibleUpdateDBModel,
|
||||
ProxyErrorTypes,
|
||||
ProxyException,
|
||||
TeamModelAddRequest,
|
||||
UpdateTeamRequest,
|
||||
UserAPIKeyAuth,
|
||||
)
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.proxy.common_utils.encrypt_decrypt_utils import encrypt_value_helper
|
||||
from litellm.proxy.management_endpoints.common_utils import _is_user_team_admin
|
||||
from litellm.proxy.management_endpoints.team_endpoints import (
|
||||
team_model_add,
|
||||
update_team,
|
||||
)
|
||||
from litellm.proxy.management_helpers.audit_logs import create_object_audit_log
|
||||
from litellm.proxy.utils import PrismaClient
|
||||
from litellm.types.router import (
|
||||
Deployment,
|
||||
DeploymentTypedDict,
|
||||
LiteLLMParamsTypedDict,
|
||||
updateDeployment,
|
||||
)
|
||||
from litellm.utils import get_utc_datetime
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
async def get_db_model(
|
||||
model_id: str, prisma_client: PrismaClient
|
||||
) -> Optional[Deployment]:
|
||||
db_model = cast(
|
||||
Optional[BaseModel],
|
||||
await prisma_client.db.litellm_proxymodeltable.find_unique(
|
||||
where={"model_id": model_id}
|
||||
),
|
||||
)
|
||||
|
||||
if not db_model:
|
||||
return None
|
||||
|
||||
deployment_pydantic_obj = Deployment(**db_model.model_dump(exclude_none=True))
|
||||
return deployment_pydantic_obj
|
||||
|
||||
|
||||
def update_db_model(
|
||||
db_model: Deployment, updated_patch: updateDeployment
|
||||
) -> PrismaCompatibleUpdateDBModel:
|
||||
merged_deployment_dict = DeploymentTypedDict(
|
||||
model_name=db_model.model_name,
|
||||
litellm_params=LiteLLMParamsTypedDict(
|
||||
**db_model.litellm_params.model_dump(exclude_none=True) # type: ignore
|
||||
),
|
||||
)
|
||||
# update model name
|
||||
if updated_patch.model_name:
|
||||
merged_deployment_dict["model_name"] = updated_patch.model_name
|
||||
|
||||
# update litellm params
|
||||
if updated_patch.litellm_params:
|
||||
# Encrypt any sensitive values
|
||||
encrypted_params = {
|
||||
k: encrypt_value_helper(v)
|
||||
for k, v in updated_patch.litellm_params.model_dump(
|
||||
exclude_none=True
|
||||
).items()
|
||||
}
|
||||
|
||||
merged_deployment_dict["litellm_params"].update(encrypted_params) # type: ignore
|
||||
|
||||
# update model info
|
||||
if updated_patch.model_info:
|
||||
if "model_info" not in merged_deployment_dict:
|
||||
merged_deployment_dict["model_info"] = {}
|
||||
merged_deployment_dict["model_info"].update(
|
||||
updated_patch.model_info.model_dump(exclude_none=True)
|
||||
)
|
||||
|
||||
# convert to prisma compatible format
|
||||
|
||||
prisma_compatible_model_dict = PrismaCompatibleUpdateDBModel()
|
||||
if "model_name" in merged_deployment_dict:
|
||||
prisma_compatible_model_dict["model_name"] = merged_deployment_dict[
|
||||
"model_name"
|
||||
]
|
||||
|
||||
if "litellm_params" in merged_deployment_dict:
|
||||
prisma_compatible_model_dict["litellm_params"] = json.dumps(
|
||||
merged_deployment_dict["litellm_params"]
|
||||
)
|
||||
|
||||
if "model_info" in merged_deployment_dict:
|
||||
prisma_compatible_model_dict["model_info"] = json.dumps(
|
||||
merged_deployment_dict["model_info"]
|
||||
)
|
||||
return prisma_compatible_model_dict
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/model/{model_id}/update",
|
||||
tags=["model management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def patch_model(
|
||||
model_id: str, # Get model_id from path parameter
|
||||
patch_data: updateDeployment, # Create a specific schema for PATCH operations
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
PATCH Endpoint for partial model updates.
|
||||
|
||||
Only updates the fields specified in the request while preserving other existing values.
|
||||
Follows proper PATCH semantics by only modifying provided fields.
|
||||
|
||||
Args:
|
||||
model_id: The ID of the model to update
|
||||
patch_data: The fields to update and their new values
|
||||
user_api_key_dict: User authentication information
|
||||
|
||||
Returns:
|
||||
Updated model information
|
||||
|
||||
Raises:
|
||||
ProxyException: For various error conditions including authentication and database errors
|
||||
"""
|
||||
from litellm.proxy.proxy_server import (
|
||||
litellm_proxy_admin_name,
|
||||
llm_router,
|
||||
prisma_client,
|
||||
store_model_in_db,
|
||||
)
|
||||
|
||||
try:
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": CommonProxyErrors.db_not_connected_error.value},
|
||||
)
|
||||
|
||||
# Verify model exists and is stored in DB
|
||||
if not store_model_in_db:
|
||||
raise ProxyException(
|
||||
message="Model updates only supported for DB-stored models",
|
||||
type=ProxyErrorTypes.validation_error.value,
|
||||
code=status.HTTP_400_BAD_REQUEST,
|
||||
param=None,
|
||||
)
|
||||
|
||||
# Fetch existing model
|
||||
db_model = await get_db_model(model_id=model_id, prisma_client=prisma_client)
|
||||
|
||||
if db_model is None:
|
||||
# Check if model exists in config but not DB
|
||||
if llm_router and llm_router.get_deployment(model_id=model_id) is not None:
|
||||
raise ProxyException(
|
||||
message="Cannot edit config-based model. Store model in DB via /model/new first.",
|
||||
type=ProxyErrorTypes.validation_error.value,
|
||||
code=status.HTTP_400_BAD_REQUEST,
|
||||
param=None,
|
||||
)
|
||||
raise ProxyException(
|
||||
message=f"Model {model_id} not found on proxy.",
|
||||
type=ProxyErrorTypes.not_found_error,
|
||||
code=status.HTTP_404_NOT_FOUND,
|
||||
param=None,
|
||||
)
|
||||
|
||||
# Create update dictionary only for provided fields
|
||||
update_data = update_db_model(db_model=db_model, updated_patch=patch_data)
|
||||
|
||||
# Add metadata about update
|
||||
update_data["updated_by"] = (
|
||||
user_api_key_dict.user_id or litellm_proxy_admin_name
|
||||
)
|
||||
update_data["updated_at"] = cast(str, get_utc_datetime())
|
||||
|
||||
# Perform partial update
|
||||
updated_model = await prisma_client.db.litellm_proxymodeltable.update(
|
||||
where={"model_id": model_id},
|
||||
data=update_data,
|
||||
)
|
||||
|
||||
return updated_model
|
||||
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error in patch_model: {str(e)}")
|
||||
|
||||
if isinstance(e, (HTTPException, ProxyException)):
|
||||
raise e
|
||||
|
||||
raise ProxyException(
|
||||
message=f"Error updating model: {str(e)}",
|
||||
type=ProxyErrorTypes.internal_server_error,
|
||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
param=None,
|
||||
)
|
||||
|
||||
|
||||
################################# Helper Functions #################################
|
||||
####################################################################################
|
||||
####################################################################################
|
||||
####################################################################################
|
||||
|
||||
|
||||
async def _add_model_to_db(
|
||||
model_params: Deployment,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
prisma_client: PrismaClient,
|
||||
new_encryption_key: Optional[str] = None,
|
||||
should_create_model_in_db: bool = True,
|
||||
) -> Optional[LiteLLM_ProxyModelTable]:
|
||||
# encrypt litellm params #
|
||||
_litellm_params_dict = model_params.litellm_params.dict(exclude_none=True)
|
||||
_orignal_litellm_model_name = model_params.litellm_params.model
|
||||
for k, v in _litellm_params_dict.items():
|
||||
encrypted_value = encrypt_value_helper(
|
||||
value=v, new_encryption_key=new_encryption_key
|
||||
)
|
||||
model_params.litellm_params[k] = encrypted_value
|
||||
_data: dict = {
|
||||
"model_id": model_params.model_info.id,
|
||||
"model_name": model_params.model_name,
|
||||
"litellm_params": model_params.litellm_params.model_dump_json(exclude_none=True), # type: ignore
|
||||
"model_info": model_params.model_info.model_dump_json( # type: ignore
|
||||
exclude_none=True
|
||||
),
|
||||
"created_by": user_api_key_dict.user_id or LITELLM_PROXY_ADMIN_NAME,
|
||||
"updated_by": user_api_key_dict.user_id or LITELLM_PROXY_ADMIN_NAME,
|
||||
}
|
||||
if model_params.model_info.id is not None:
|
||||
_data["model_id"] = model_params.model_info.id
|
||||
if should_create_model_in_db:
|
||||
model_response = await prisma_client.db.litellm_proxymodeltable.create(
|
||||
data=_data # type: ignore
|
||||
)
|
||||
else:
|
||||
model_response = LiteLLM_ProxyModelTable(**_data)
|
||||
return model_response
|
||||
|
||||
|
||||
async def _add_team_model_to_db(
|
||||
model_params: Deployment,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
prisma_client: PrismaClient,
|
||||
) -> Optional[LiteLLM_ProxyModelTable]:
|
||||
"""
|
||||
If 'team_id' is provided,
|
||||
|
||||
- generate a unique 'model_name' for the model (e.g. 'model_name_{team_id}_{uuid})
|
||||
- store the model in the db with the unique 'model_name'
|
||||
- store a team model alias mapping {"model_name": "model_name_{team_id}_{uuid}"}
|
||||
"""
|
||||
_team_id = model_params.model_info.team_id
|
||||
if _team_id is None:
|
||||
return None
|
||||
original_model_name = model_params.model_name
|
||||
if original_model_name:
|
||||
model_params.model_info.team_public_model_name = original_model_name
|
||||
|
||||
unique_model_name = f"model_name_{_team_id}_{uuid.uuid4()}"
|
||||
|
||||
model_params.model_name = unique_model_name
|
||||
|
||||
## CREATE MODEL IN DB ##
|
||||
model_response = await _add_model_to_db(
|
||||
model_params=model_params,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
prisma_client=prisma_client,
|
||||
)
|
||||
|
||||
## CREATE MODEL ALIAS IN DB ##
|
||||
await update_team(
|
||||
data=UpdateTeamRequest(
|
||||
team_id=_team_id,
|
||||
model_aliases={original_model_name: unique_model_name},
|
||||
),
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
http_request=Request(scope={"type": "http"}),
|
||||
)
|
||||
|
||||
# add model to team object
|
||||
await team_model_add(
|
||||
data=TeamModelAddRequest(
|
||||
team_id=_team_id,
|
||||
models=[original_model_name],
|
||||
),
|
||||
http_request=Request(scope={"type": "http"}),
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
)
|
||||
|
||||
return model_response
|
||||
|
||||
|
||||
class ModelManagementAuthChecks:
|
||||
"""
|
||||
Common auth checks for model management endpoints
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def can_user_make_team_model_call(
|
||||
team_id: str,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
team_obj: Optional[LiteLLM_TeamTable] = None,
|
||||
premium_user: bool = False,
|
||||
) -> Literal[True]:
|
||||
if premium_user is False:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail={"error": CommonProxyErrors.not_premium_user.value},
|
||||
)
|
||||
if (
|
||||
user_api_key_dict.user_role
|
||||
and user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN
|
||||
):
|
||||
return True
|
||||
elif team_obj is None or not _is_user_team_admin(
|
||||
user_api_key_dict=user_api_key_dict, team_obj=team_obj
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail={
|
||||
"error": "Team ID={} does not match the API key's team ID={}, OR you are not the admin for this team. Check `/user/info` to verify your team admin status.".format(
|
||||
team_id, user_api_key_dict.team_id
|
||||
)
|
||||
},
|
||||
)
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
async def allow_team_model_action(
|
||||
model_params: Union[Deployment, updateDeployment],
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
prisma_client: PrismaClient,
|
||||
premium_user: bool,
|
||||
) -> Literal[True]:
|
||||
if model_params.model_info is None or model_params.model_info.team_id is None:
|
||||
return True
|
||||
if model_params.model_info.team_id is not None and premium_user is not True:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail={"error": CommonProxyErrors.not_premium_user.value},
|
||||
)
|
||||
|
||||
_existing_team_row = await prisma_client.db.litellm_teamtable.find_unique(
|
||||
where={"team_id": model_params.model_info.team_id}
|
||||
)
|
||||
|
||||
if _existing_team_row is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "Team id={} does not exist in db".format(
|
||||
model_params.model_info.team_id
|
||||
)
|
||||
},
|
||||
)
|
||||
existing_team_row = LiteLLM_TeamTable(**_existing_team_row.model_dump())
|
||||
|
||||
ModelManagementAuthChecks.can_user_make_team_model_call(
|
||||
team_id=model_params.model_info.team_id,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
team_obj=existing_team_row,
|
||||
premium_user=premium_user,
|
||||
)
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
async def can_user_make_model_call(
|
||||
model_params: Deployment,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
prisma_client: PrismaClient,
|
||||
premium_user: bool,
|
||||
) -> Literal[True]:
|
||||
## Check team model auth
|
||||
if (
|
||||
model_params.model_info is not None
|
||||
and model_params.model_info.team_id is not None
|
||||
):
|
||||
team_obj_row = await prisma_client.db.litellm_teamtable.find_unique(
|
||||
where={"team_id": model_params.model_info.team_id}
|
||||
)
|
||||
if team_obj_row is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "Team id={} does not exist in db".format(
|
||||
model_params.model_info.team_id
|
||||
)
|
||||
},
|
||||
)
|
||||
team_obj = LiteLLM_TeamTable(**team_obj_row.model_dump())
|
||||
|
||||
return ModelManagementAuthChecks.can_user_make_team_model_call(
|
||||
team_id=model_params.model_info.team_id,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
team_obj=team_obj,
|
||||
premium_user=premium_user,
|
||||
)
|
||||
## Check non-team model auth
|
||||
elif user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail={
|
||||
"error": "User does not have permission to make this model call. Your role={}. You can only make model calls if you are a PROXY_ADMIN or if you are a team admin, by specifying a team_id in the model_info.".format(
|
||||
user_api_key_dict.user_role
|
||||
)
|
||||
},
|
||||
)
|
||||
else:
|
||||
return True
|
||||
|
||||
return True
|
||||
|
||||
|
||||
#### [BETA] - This is a beta endpoint, format might change based on user feedback. - https://github.com/BerriAI/litellm/issues/964
|
||||
@router.post(
|
||||
"/model/delete",
|
||||
description="Allows deleting models in the model list in the config.yaml",
|
||||
tags=["model management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def delete_model(
|
||||
model_info: ModelInfoDelete,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
from litellm.proxy.proxy_server import llm_router
|
||||
|
||||
try:
|
||||
"""
|
||||
[BETA] - This is a beta endpoint, format might change based on user feedback. - https://github.com/BerriAI/litellm/issues/964
|
||||
|
||||
- Check if id in db
|
||||
- Delete
|
||||
"""
|
||||
|
||||
from litellm.proxy.proxy_server import (
|
||||
llm_router,
|
||||
premium_user,
|
||||
prisma_client,
|
||||
store_model_in_db,
|
||||
)
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"error": "No DB Connected. Here's how to do it - https://docs.litellm.ai/docs/proxy/virtual_keys"
|
||||
},
|
||||
)
|
||||
|
||||
model_in_db = await prisma_client.db.litellm_proxymodeltable.find_unique(
|
||||
where={"model_id": model_info.id}
|
||||
)
|
||||
if model_in_db is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"error": f"Model with id={model_info.id} not found in db"},
|
||||
)
|
||||
|
||||
model_params = Deployment(**model_in_db.model_dump())
|
||||
await ModelManagementAuthChecks.can_user_make_model_call(
|
||||
model_params=model_params,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
prisma_client=prisma_client,
|
||||
premium_user=premium_user,
|
||||
)
|
||||
|
||||
# update DB
|
||||
if store_model_in_db is True:
|
||||
"""
|
||||
- store model_list in db
|
||||
- store keys separately
|
||||
"""
|
||||
# encrypt litellm params #
|
||||
result = await prisma_client.db.litellm_proxymodeltable.delete(
|
||||
where={"model_id": model_info.id}
|
||||
)
|
||||
|
||||
if result is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"error": f"Model with id={model_info.id} not found in db"},
|
||||
)
|
||||
|
||||
## DELETE FROM ROUTER ##
|
||||
if llm_router is not None:
|
||||
llm_router.delete_deployment(id=model_info.id)
|
||||
|
||||
## CREATE AUDIT LOG ##
|
||||
asyncio.create_task(
|
||||
create_object_audit_log(
|
||||
object_id=model_info.id,
|
||||
action="deleted",
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
table_name=LitellmTableNames.PROXY_MODEL_TABLE_NAME,
|
||||
before_value=result.model_dump_json(exclude_none=True),
|
||||
after_value=None,
|
||||
litellm_changed_by=user_api_key_dict.user_id,
|
||||
litellm_proxy_admin_name=LITELLM_PROXY_ADMIN_NAME,
|
||||
)
|
||||
)
|
||||
return {"message": f"Model: {result.model_id} deleted successfully"}
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"error": "Set `'STORE_MODEL_IN_DB='True'` in your env to enable this feature."
|
||||
},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
f"Failed to delete model. Due to error - {str(e)}"
|
||||
)
|
||||
if isinstance(e, HTTPException):
|
||||
raise ProxyException(
|
||||
message=getattr(e, "detail", f"Authentication Error({str(e)})"),
|
||||
type=ProxyErrorTypes.auth_error,
|
||||
param=getattr(e, "param", "None"),
|
||||
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
|
||||
)
|
||||
elif isinstance(e, ProxyException):
|
||||
raise e
|
||||
raise ProxyException(
|
||||
message="Authentication Error, " + str(e),
|
||||
type=ProxyErrorTypes.auth_error,
|
||||
param=getattr(e, "param", "None"),
|
||||
code=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
|
||||
#### [BETA] - This is a beta endpoint, format might change based on user feedback. - https://github.com/BerriAI/litellm/issues/964
|
||||
@router.post(
|
||||
"/model/new",
|
||||
description="Allows adding new models to the model list in the config.yaml",
|
||||
tags=["model management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def add_new_model(
|
||||
model_params: Deployment,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
from litellm.proxy.proxy_server import (
|
||||
general_settings,
|
||||
premium_user,
|
||||
prisma_client,
|
||||
proxy_config,
|
||||
proxy_logging_obj,
|
||||
store_model_in_db,
|
||||
)
|
||||
|
||||
try:
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"error": "No DB Connected. Here's how to do it - https://docs.litellm.ai/docs/proxy/virtual_keys"
|
||||
},
|
||||
)
|
||||
|
||||
## Auth check
|
||||
await ModelManagementAuthChecks.can_user_make_model_call(
|
||||
model_params=model_params,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
prisma_client=prisma_client,
|
||||
premium_user=premium_user,
|
||||
)
|
||||
|
||||
model_response: Optional[LiteLLM_ProxyModelTable] = None
|
||||
# update DB
|
||||
if store_model_in_db is True:
|
||||
"""
|
||||
- store model_list in db
|
||||
- store keys separately
|
||||
"""
|
||||
|
||||
try:
|
||||
_original_litellm_model_name = model_params.model_name
|
||||
if model_params.model_info.team_id is None:
|
||||
model_response = await _add_model_to_db(
|
||||
model_params=model_params,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
prisma_client=prisma_client,
|
||||
)
|
||||
else:
|
||||
model_response = await _add_team_model_to_db(
|
||||
model_params=model_params,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
prisma_client=prisma_client,
|
||||
)
|
||||
await proxy_config.add_deployment(
|
||||
prisma_client=prisma_client, proxy_logging_obj=proxy_logging_obj
|
||||
)
|
||||
# don't let failed slack alert block the /model/new response
|
||||
_alerting = general_settings.get("alerting", []) or []
|
||||
if "slack" in _alerting:
|
||||
# send notification - new model added
|
||||
await proxy_logging_obj.slack_alerting_instance.model_added_alert(
|
||||
model_name=model_params.model_name,
|
||||
litellm_model_name=_original_litellm_model_name,
|
||||
passed_model_info=model_params.model_info,
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Exception in add_new_model: {e}")
|
||||
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"error": "Set `'STORE_MODEL_IN_DB='True'` in your env to enable this feature."
|
||||
},
|
||||
)
|
||||
|
||||
if model_response is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"error": "Failed to add model to db. Check your server logs for more details."
|
||||
},
|
||||
)
|
||||
|
||||
## CREATE AUDIT LOG ##
|
||||
asyncio.create_task(
|
||||
create_object_audit_log(
|
||||
object_id=model_response.model_id,
|
||||
action="created",
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
table_name=LitellmTableNames.PROXY_MODEL_TABLE_NAME,
|
||||
before_value=None,
|
||||
after_value=(
|
||||
model_response.model_dump_json(exclude_none=True)
|
||||
if isinstance(model_response, BaseModel)
|
||||
else None
|
||||
),
|
||||
litellm_changed_by=user_api_key_dict.user_id,
|
||||
litellm_proxy_admin_name=LITELLM_PROXY_ADMIN_NAME,
|
||||
)
|
||||
)
|
||||
|
||||
return model_response
|
||||
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
"litellm.proxy.proxy_server.add_new_model(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
if isinstance(e, HTTPException):
|
||||
raise ProxyException(
|
||||
message=getattr(e, "detail", f"Authentication Error({str(e)})"),
|
||||
type=ProxyErrorTypes.auth_error,
|
||||
param=getattr(e, "param", "None"),
|
||||
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
|
||||
)
|
||||
elif isinstance(e, ProxyException):
|
||||
raise e
|
||||
raise ProxyException(
|
||||
message="Authentication Error, " + str(e),
|
||||
type=ProxyErrorTypes.auth_error,
|
||||
param=getattr(e, "param", "None"),
|
||||
code=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
|
||||
#### MODEL MANAGEMENT ####
|
||||
@router.post(
|
||||
"/model/update",
|
||||
description="Edit existing model params",
|
||||
tags=["model management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def update_model(
|
||||
model_params: updateDeployment,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Old endpoint for model update. Makes a PUT request.
|
||||
|
||||
Use `/model/{model_id}/update` to PATCH the stored model in db.
|
||||
"""
|
||||
from litellm.proxy.proxy_server import (
|
||||
LITELLM_PROXY_ADMIN_NAME,
|
||||
llm_router,
|
||||
premium_user,
|
||||
prisma_client,
|
||||
store_model_in_db,
|
||||
)
|
||||
|
||||
try:
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"error": "No DB Connected. Here's how to do it - https://docs.litellm.ai/docs/proxy/virtual_keys"
|
||||
},
|
||||
)
|
||||
|
||||
_model_id = None
|
||||
_model_info = getattr(model_params, "model_info", None)
|
||||
if _model_info is None:
|
||||
raise Exception("model_info not provided")
|
||||
|
||||
_model_id = _model_info.id
|
||||
if _model_id is None:
|
||||
raise Exception("model_info.id not provided")
|
||||
|
||||
_existing_litellm_params = (
|
||||
await prisma_client.db.litellm_proxymodeltable.find_unique(
|
||||
where={"model_id": _model_id}
|
||||
)
|
||||
)
|
||||
|
||||
if _existing_litellm_params is None:
|
||||
if (
|
||||
llm_router is not None
|
||||
and llm_router.get_deployment(model_id=_model_id) is not None
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "Can't edit model. Model in config. Store model in db via `/model/new`. to edit."
|
||||
},
|
||||
)
|
||||
else:
|
||||
raise Exception("model not found")
|
||||
deployment = Deployment(**_existing_litellm_params.model_dump())
|
||||
|
||||
await ModelManagementAuthChecks.can_user_make_model_call(
|
||||
model_params=deployment,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
prisma_client=prisma_client,
|
||||
premium_user=premium_user,
|
||||
)
|
||||
|
||||
# update DB
|
||||
if store_model_in_db is True:
|
||||
_existing_litellm_params_dict = dict(
|
||||
_existing_litellm_params.litellm_params
|
||||
)
|
||||
|
||||
if model_params.litellm_params is None:
|
||||
raise Exception("litellm_params not provided")
|
||||
|
||||
_new_litellm_params_dict = model_params.litellm_params.dict(
|
||||
exclude_none=True
|
||||
)
|
||||
|
||||
### ENCRYPT PARAMS ###
|
||||
for k, v in _new_litellm_params_dict.items():
|
||||
encrypted_value = encrypt_value_helper(value=v)
|
||||
model_params.litellm_params[k] = encrypted_value
|
||||
|
||||
### MERGE WITH EXISTING DATA ###
|
||||
merged_dictionary = {}
|
||||
_mp = model_params.litellm_params.dict()
|
||||
|
||||
for key, value in _mp.items():
|
||||
if value is not None:
|
||||
merged_dictionary[key] = value
|
||||
elif (
|
||||
key in _existing_litellm_params_dict
|
||||
and _existing_litellm_params_dict[key] is not None
|
||||
):
|
||||
merged_dictionary[key] = _existing_litellm_params_dict[key]
|
||||
else:
|
||||
pass
|
||||
|
||||
_data: dict = {
|
||||
"litellm_params": json.dumps(merged_dictionary), # type: ignore
|
||||
"updated_by": user_api_key_dict.user_id or LITELLM_PROXY_ADMIN_NAME,
|
||||
}
|
||||
model_response = await prisma_client.db.litellm_proxymodeltable.update(
|
||||
where={"model_id": _model_id},
|
||||
data=_data, # type: ignore
|
||||
)
|
||||
|
||||
## CREATE AUDIT LOG ##
|
||||
asyncio.create_task(
|
||||
create_object_audit_log(
|
||||
object_id=_model_id,
|
||||
action="updated",
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
table_name=LitellmTableNames.PROXY_MODEL_TABLE_NAME,
|
||||
before_value=(
|
||||
_existing_litellm_params.model_dump_json(exclude_none=True)
|
||||
if isinstance(_existing_litellm_params, BaseModel)
|
||||
else None
|
||||
),
|
||||
after_value=(
|
||||
model_response.model_dump_json(exclude_none=True)
|
||||
if isinstance(model_response, BaseModel)
|
||||
else None
|
||||
),
|
||||
litellm_changed_by=user_api_key_dict.user_id,
|
||||
litellm_proxy_admin_name=LITELLM_PROXY_ADMIN_NAME,
|
||||
)
|
||||
)
|
||||
|
||||
return model_response
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
"litellm.proxy.proxy_server.update_model(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
if isinstance(e, HTTPException):
|
||||
raise ProxyException(
|
||||
message=getattr(e, "detail", f"Authentication Error({str(e)})"),
|
||||
type=ProxyErrorTypes.auth_error,
|
||||
param=getattr(e, "param", "None"),
|
||||
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
|
||||
)
|
||||
elif isinstance(e, ProxyException):
|
||||
raise e
|
||||
raise ProxyException(
|
||||
message="Authentication Error, " + str(e),
|
||||
type=ProxyErrorTypes.auth_error,
|
||||
param=getattr(e, "param", "None"),
|
||||
code=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
|
||||
def _deduplicate_litellm_router_models(models: List[Dict]) -> List[Dict]:
|
||||
"""
|
||||
Deduplicate models based on their model_info.id field.
|
||||
Returns a list of unique models keeping only the first occurrence of each model ID.
|
||||
|
||||
Args:
|
||||
models: List of model dictionaries containing model_info
|
||||
|
||||
Returns:
|
||||
List of deduplicated model dictionaries
|
||||
"""
|
||||
seen_ids = set()
|
||||
unique_models = []
|
||||
for model in models:
|
||||
model_id = model.get("model_info", {}).get("id", None)
|
||||
if model_id is not None and model_id not in seen_ids:
|
||||
unique_models.append(model)
|
||||
seen_ids.add(model_id)
|
||||
return unique_models
|
Loading…
Add table
Add a link
Reference in a new issue