diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 3e43692158..65a6a1794d 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -1919,7 +1919,9 @@ class ProxyException(Exception): class CommonProxyErrors(str, enum.Enum): - db_not_connected_error = "DB not connected" + db_not_connected_error = ( + "DB not connected. See https://docs.litellm.ai/docs/proxy/virtual_keys" + ) no_llm_router = "No models configured on proxy" not_allowed_access = "Admin-only endpoint. Not allowed to access this." not_premium_user = "You must be a LiteLLM Enterprise user to use this feature. If you have a license please set `LITELLM_LICENSE` in your env. Get a 7 day trial key here: https://www.litellm.ai/#trial. \nPricing: https://www.litellm.ai/#pricing" @@ -1940,6 +1942,7 @@ class ProxyErrorTypes(str, enum.Enum): internal_server_error = "internal_server_error" bad_request_error = "bad_request_error" not_found_error = "not_found_error" + validation_error = "bad_request_error" DB_CONNECTION_ERROR_TYPES = (httpx.ConnectError, httpx.ReadError, httpx.ReadTimeout) @@ -2407,3 +2410,11 @@ class LiteLLM_JWTAuth(LiteLLMPydanticObjectBase): ) super().__init__(**kwargs) + + +class PrismaCompatibleUpdateDBModel(TypedDict, total=False): + model_name: str + litellm_params: str + model_info: str + updated_at: str + updated_by: str diff --git a/litellm/proxy/hooks/parallel_request_limiter.py b/litellm/proxy/hooks/parallel_request_limiter.py index c521896547..4cef9696fd 100644 --- a/litellm/proxy/hooks/parallel_request_limiter.py +++ b/litellm/proxy/hooks/parallel_request_limiter.py @@ -93,7 +93,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): else: raise HTTPException( status_code=429, - detail=f"LiteLLM Rate Limit Handler for rate limit type = {rate_limit_type}. Crossed TPM, RPM Limit. current rpm: {current['current_rpm']}, rpm limit: {rpm_limit}, current tpm: {current['current_tpm']}, tpm limit: {tpm_limit}", + detail=f"LiteLLM Rate Limit Handler for rate limit type = {rate_limit_type}. Crossed TPM / RPM / Max Parallel Request Limit. current rpm: {current['current_rpm']}, rpm limit: {rpm_limit}, current tpm: {current['current_tpm']}, tpm limit: {tpm_limit}, current max_parallel_requests: {current['current_requests']}, max_parallel_requests: {max_parallel_requests}", headers={"retry-after": str(self.time_to_next_minute())}, ) return new_val diff --git a/litellm/proxy/management_endpoints/model_management_endpoints.py b/litellm/proxy/management_endpoints/model_management_endpoints.py new file mode 100644 index 0000000000..9b4a9fa62c --- /dev/null +++ b/litellm/proxy/management_endpoints/model_management_endpoints.py @@ -0,0 +1,208 @@ +""" +Allow proxy admin to add/update/delete models in the db + +Currently most endpoints are in `proxy_server.py`, but those should be moved here over time. + +Endpoints here: + +model/{model_id}/update - PATCH endpoint for model update. +""" + +#### MODEL MANAGEMENT #### + +import json +from typing import Optional, cast + +from fastapi import APIRouter, Depends, HTTPException, status +from pydantic import BaseModel + +from litellm._logging import verbose_proxy_logger +from litellm.proxy._types import ( + CommonProxyErrors, + PrismaCompatibleUpdateDBModel, + ProxyErrorTypes, + ProxyException, + UserAPIKeyAuth, +) +from litellm.proxy.auth.user_api_key_auth import user_api_key_auth +from litellm.proxy.common_utils.encrypt_decrypt_utils import encrypt_value_helper +from litellm.proxy.utils import PrismaClient +from litellm.types.router import ( + Deployment, + DeploymentTypedDict, + LiteLLMParamsTypedDict, + updateDeployment, +) +from litellm.utils import get_utc_datetime + +router = APIRouter() + + +async def get_db_model( + model_id: str, prisma_client: PrismaClient +) -> Optional[Deployment]: + db_model = cast( + Optional[BaseModel], + await prisma_client.db.litellm_proxymodeltable.find_unique( + where={"model_id": model_id} + ), + ) + + if not db_model: + return None + + deployment_pydantic_obj = Deployment(**db_model.model_dump(exclude_none=True)) + return deployment_pydantic_obj + + +def update_db_model( + db_model: Deployment, updated_patch: updateDeployment +) -> PrismaCompatibleUpdateDBModel: + merged_deployment_dict = DeploymentTypedDict( + model_name=db_model.model_name, + litellm_params=LiteLLMParamsTypedDict( + **db_model.litellm_params.model_dump(exclude_none=True) # type: ignore + ), + ) + # update model name + if updated_patch.model_name: + merged_deployment_dict["model_name"] = updated_patch.model_name + + # update litellm params + if updated_patch.litellm_params: + # Encrypt any sensitive values + encrypted_params = { + k: encrypt_value_helper(v) + for k, v in updated_patch.litellm_params.model_dump( + exclude_none=True + ).items() + } + + merged_deployment_dict["litellm_params"].update(encrypted_params) # type: ignore + + # update model info + if updated_patch.model_info: + if "model_info" not in merged_deployment_dict: + merged_deployment_dict["model_info"] = {} + merged_deployment_dict["model_info"].update( + updated_patch.model_info.model_dump(exclude_none=True) + ) + + # convert to prisma compatible format + + prisma_compatible_model_dict = PrismaCompatibleUpdateDBModel() + if "model_name" in merged_deployment_dict: + prisma_compatible_model_dict["model_name"] = merged_deployment_dict[ + "model_name" + ] + + if "litellm_params" in merged_deployment_dict: + prisma_compatible_model_dict["litellm_params"] = json.dumps( + merged_deployment_dict["litellm_params"] + ) + + if "model_info" in merged_deployment_dict: + prisma_compatible_model_dict["model_info"] = json.dumps( + merged_deployment_dict["model_info"] + ) + return prisma_compatible_model_dict + + +@router.patch( + "/model/{model_id}/update", + tags=["model management"], + dependencies=[Depends(user_api_key_auth)], +) +async def patch_model( + model_id: str, # Get model_id from path parameter + patch_data: updateDeployment, # Create a specific schema for PATCH operations + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + PATCH Endpoint for partial model updates. + + Only updates the fields specified in the request while preserving other existing values. + Follows proper PATCH semantics by only modifying provided fields. + + Args: + model_id: The ID of the model to update + patch_data: The fields to update and their new values + user_api_key_dict: User authentication information + + Returns: + Updated model information + + Raises: + ProxyException: For various error conditions including authentication and database errors + """ + from litellm.proxy.proxy_server import ( + litellm_proxy_admin_name, + llm_router, + prisma_client, + store_model_in_db, + ) + + try: + if prisma_client is None: + raise HTTPException( + status_code=500, + detail={"error": CommonProxyErrors.db_not_connected_error.value}, + ) + + # Verify model exists and is stored in DB + if not store_model_in_db: + raise ProxyException( + message="Model updates only supported for DB-stored models", + type=ProxyErrorTypes.validation_error.value, + code=status.HTTP_400_BAD_REQUEST, + param=None, + ) + + # Fetch existing model + db_model = await get_db_model(model_id=model_id, prisma_client=prisma_client) + + if db_model is None: + # Check if model exists in config but not DB + if llm_router and llm_router.get_deployment(model_id=model_id) is not None: + raise ProxyException( + message="Cannot edit config-based model. Store model in DB via /model/new first.", + type=ProxyErrorTypes.validation_error.value, + code=status.HTTP_400_BAD_REQUEST, + param=None, + ) + raise ProxyException( + message=f"Model {model_id} not found on proxy.", + type=ProxyErrorTypes.not_found_error, + code=status.HTTP_404_NOT_FOUND, + param=None, + ) + + # Create update dictionary only for provided fields + update_data = update_db_model(db_model=db_model, updated_patch=patch_data) + + # Add metadata about update + update_data["updated_by"] = ( + user_api_key_dict.user_id or litellm_proxy_admin_name + ) + update_data["updated_at"] = cast(str, get_utc_datetime()) + + # Perform partial update + updated_model = await prisma_client.db.litellm_proxymodeltable.update( + where={"model_id": model_id}, + data=update_data, + ) + + return updated_model + + except Exception as e: + verbose_proxy_logger.exception(f"Error in patch_model: {str(e)}") + + if isinstance(e, (HTTPException, ProxyException)): + raise e + + raise ProxyException( + message=f"Error updating model: {str(e)}", + type=ProxyErrorTypes.internal_server_error, + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + param=None, + ) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 81008fa027..01c85c45c0 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -196,6 +196,9 @@ from litellm.proxy.management_endpoints.key_management_endpoints import ( from litellm.proxy.management_endpoints.key_management_endpoints import ( router as key_management_router, ) +from litellm.proxy.management_endpoints.model_management_endpoints import ( + router as model_management_router, +) from litellm.proxy.management_endpoints.organization_endpoints import ( router as organization_router, ) @@ -6042,6 +6045,11 @@ async def update_model( model_params: updateDeployment, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), ): + """ + Old endpoint for model update. Makes a PUT request. + + Use `/model/{model_id}/update` to PATCH the stored model in db. + """ global llm_router, llm_model_list, general_settings, user_config_file_path, proxy_config, prisma_client, master_key, store_model_in_db, proxy_logging_obj try: import base64 @@ -8924,3 +8932,4 @@ app.include_router(ui_crud_endpoints_router) app.include_router(openai_files_router) app.include_router(team_callback_router) app.include_router(budget_management_router) +app.include_router(model_management_router) diff --git a/tests/local_testing/test_parallel_request_limiter.py b/tests/local_testing/test_parallel_request_limiter.py index bd308ded3f..8ab4b5fa30 100644 --- a/tests/local_testing/test_parallel_request_limiter.py +++ b/tests/local_testing/test_parallel_request_limiter.py @@ -65,7 +65,40 @@ async def test_global_max_parallel_requests(): ) pytest.fail("Expected call to fail") except Exception as e: - pass + print(e) + + +@pytest.mark.flaky(retries=6, delay=1) +@pytest.mark.asyncio +async def test_key_max_parallel_requests(): + """ + Ensure the error str returned contains parallel request information. + + Relevant Issue: https://github.com/BerriAI/litellm/issues/8392 + """ + _api_key = "sk-12345" + _api_key = hash_token("sk-12345") + user_api_key_dict = UserAPIKeyAuth(api_key=_api_key, max_parallel_requests=1) + local_cache = DualCache() + parallel_request_handler = MaxParallelRequestsHandler( + internal_usage_cache=InternalUsageCache(dual_cache=local_cache) + ) + + parallel_limit_reached = False + for _ in range(3): + try: + await parallel_request_handler.async_pre_call_hook( + user_api_key_dict=user_api_key_dict, + cache=local_cache, + data={}, + call_type="", + ) + await asyncio.sleep(1) + except Exception as e: + if "current max_parallel_requests" in str(e): + parallel_limit_reached = True + + assert parallel_limit_reached @pytest.mark.asyncio diff --git a/ui/litellm-dashboard/src/components/edit_model/edit_model_modal.tsx b/ui/litellm-dashboard/src/components/edit_model/edit_model_modal.tsx index 052dad18f7..5b5ecf3d8a 100644 --- a/ui/litellm-dashboard/src/components/edit_model/edit_model_modal.tsx +++ b/ui/litellm-dashboard/src/components/edit_model/edit_model_modal.tsx @@ -54,7 +54,7 @@ const EditModelModal: React.FC = ({ return ( = ({ + + + + + + + + + = ({ for (const [key, value] of Object.entries(formValues)) { if (key !== "model_id") { - newLiteLLMParams[key] = value; + // Empty string means user wants to null the value + newLiteLLMParams[key] = value === "" ? null : value; } else { - model_info_model_id = value; + model_info_model_id = value === "" ? null : value; } } - - let payload = { - litellm_params: newLiteLLMParams, - model_info: { + + let payload: { + litellm_params: Record | undefined; + model_info: { id: any } | undefined; + } = { + litellm_params: Object.keys(newLiteLLMParams).length > 0 ? newLiteLLMParams : undefined, + model_info: model_info_model_id !== undefined ? { id: model_info_model_id, - }, + } : undefined, }; console.log("handleEditSubmit payload:", payload);