mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
Allow editing model api key + provider on UI (#8406)
* fix(parallel_request_limiter.py): add back parallel request information to max parallel request limiter Resolves https://github.com/BerriAI/litellm/issues/8392 * test: mark flaky test to handle time based tracking issues * feat(model_management_endpoints.py): expose new patch `/model/{model_id}/update` endpoint Allows updating specific values of a model in db - makes it easy for admin to know this by calling it a PA TCH * feat(edit_model_modal.tsx): allow user to update llm provider + api key on the ui * fix: fix linting error
This commit is contained in:
parent
0d2e723e95
commit
e4411e4815
7 changed files with 285 additions and 11 deletions
|
@ -1919,7 +1919,9 @@ class ProxyException(Exception):
|
||||||
|
|
||||||
|
|
||||||
class CommonProxyErrors(str, enum.Enum):
|
class CommonProxyErrors(str, enum.Enum):
|
||||||
db_not_connected_error = "DB not connected"
|
db_not_connected_error = (
|
||||||
|
"DB not connected. See https://docs.litellm.ai/docs/proxy/virtual_keys"
|
||||||
|
)
|
||||||
no_llm_router = "No models configured on proxy"
|
no_llm_router = "No models configured on proxy"
|
||||||
not_allowed_access = "Admin-only endpoint. Not allowed to access this."
|
not_allowed_access = "Admin-only endpoint. Not allowed to access this."
|
||||||
not_premium_user = "You must be a LiteLLM Enterprise user to use this feature. If you have a license please set `LITELLM_LICENSE` in your env. Get a 7 day trial key here: https://www.litellm.ai/#trial. \nPricing: https://www.litellm.ai/#pricing"
|
not_premium_user = "You must be a LiteLLM Enterprise user to use this feature. If you have a license please set `LITELLM_LICENSE` in your env. Get a 7 day trial key here: https://www.litellm.ai/#trial. \nPricing: https://www.litellm.ai/#pricing"
|
||||||
|
@ -1940,6 +1942,7 @@ class ProxyErrorTypes(str, enum.Enum):
|
||||||
internal_server_error = "internal_server_error"
|
internal_server_error = "internal_server_error"
|
||||||
bad_request_error = "bad_request_error"
|
bad_request_error = "bad_request_error"
|
||||||
not_found_error = "not_found_error"
|
not_found_error = "not_found_error"
|
||||||
|
validation_error = "bad_request_error"
|
||||||
|
|
||||||
|
|
||||||
DB_CONNECTION_ERROR_TYPES = (httpx.ConnectError, httpx.ReadError, httpx.ReadTimeout)
|
DB_CONNECTION_ERROR_TYPES = (httpx.ConnectError, httpx.ReadError, httpx.ReadTimeout)
|
||||||
|
@ -2407,3 +2410,11 @@ class LiteLLM_JWTAuth(LiteLLMPydanticObjectBase):
|
||||||
)
|
)
|
||||||
|
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class PrismaCompatibleUpdateDBModel(TypedDict, total=False):
|
||||||
|
model_name: str
|
||||||
|
litellm_params: str
|
||||||
|
model_info: str
|
||||||
|
updated_at: str
|
||||||
|
updated_by: str
|
||||||
|
|
|
@ -93,7 +93,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
||||||
else:
|
else:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=429,
|
status_code=429,
|
||||||
detail=f"LiteLLM Rate Limit Handler for rate limit type = {rate_limit_type}. Crossed TPM, RPM Limit. current rpm: {current['current_rpm']}, rpm limit: {rpm_limit}, current tpm: {current['current_tpm']}, tpm limit: {tpm_limit}",
|
detail=f"LiteLLM Rate Limit Handler for rate limit type = {rate_limit_type}. Crossed TPM / RPM / Max Parallel Request Limit. current rpm: {current['current_rpm']}, rpm limit: {rpm_limit}, current tpm: {current['current_tpm']}, tpm limit: {tpm_limit}, current max_parallel_requests: {current['current_requests']}, max_parallel_requests: {max_parallel_requests}",
|
||||||
headers={"retry-after": str(self.time_to_next_minute())},
|
headers={"retry-after": str(self.time_to_next_minute())},
|
||||||
)
|
)
|
||||||
return new_val
|
return new_val
|
||||||
|
|
208
litellm/proxy/management_endpoints/model_management_endpoints.py
Normal file
208
litellm/proxy/management_endpoints/model_management_endpoints.py
Normal file
|
@ -0,0 +1,208 @@
|
||||||
|
"""
|
||||||
|
Allow proxy admin to add/update/delete models in the db
|
||||||
|
|
||||||
|
Currently most endpoints are in `proxy_server.py`, but those should be moved here over time.
|
||||||
|
|
||||||
|
Endpoints here:
|
||||||
|
|
||||||
|
model/{model_id}/update - PATCH endpoint for model update.
|
||||||
|
"""
|
||||||
|
|
||||||
|
#### MODEL MANAGEMENT ####
|
||||||
|
|
||||||
|
import json
|
||||||
|
from typing import Optional, cast
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from litellm._logging import verbose_proxy_logger
|
||||||
|
from litellm.proxy._types import (
|
||||||
|
CommonProxyErrors,
|
||||||
|
PrismaCompatibleUpdateDBModel,
|
||||||
|
ProxyErrorTypes,
|
||||||
|
ProxyException,
|
||||||
|
UserAPIKeyAuth,
|
||||||
|
)
|
||||||
|
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||||
|
from litellm.proxy.common_utils.encrypt_decrypt_utils import encrypt_value_helper
|
||||||
|
from litellm.proxy.utils import PrismaClient
|
||||||
|
from litellm.types.router import (
|
||||||
|
Deployment,
|
||||||
|
DeploymentTypedDict,
|
||||||
|
LiteLLMParamsTypedDict,
|
||||||
|
updateDeployment,
|
||||||
|
)
|
||||||
|
from litellm.utils import get_utc_datetime
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
async def get_db_model(
|
||||||
|
model_id: str, prisma_client: PrismaClient
|
||||||
|
) -> Optional[Deployment]:
|
||||||
|
db_model = cast(
|
||||||
|
Optional[BaseModel],
|
||||||
|
await prisma_client.db.litellm_proxymodeltable.find_unique(
|
||||||
|
where={"model_id": model_id}
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
if not db_model:
|
||||||
|
return None
|
||||||
|
|
||||||
|
deployment_pydantic_obj = Deployment(**db_model.model_dump(exclude_none=True))
|
||||||
|
return deployment_pydantic_obj
|
||||||
|
|
||||||
|
|
||||||
|
def update_db_model(
|
||||||
|
db_model: Deployment, updated_patch: updateDeployment
|
||||||
|
) -> PrismaCompatibleUpdateDBModel:
|
||||||
|
merged_deployment_dict = DeploymentTypedDict(
|
||||||
|
model_name=db_model.model_name,
|
||||||
|
litellm_params=LiteLLMParamsTypedDict(
|
||||||
|
**db_model.litellm_params.model_dump(exclude_none=True) # type: ignore
|
||||||
|
),
|
||||||
|
)
|
||||||
|
# update model name
|
||||||
|
if updated_patch.model_name:
|
||||||
|
merged_deployment_dict["model_name"] = updated_patch.model_name
|
||||||
|
|
||||||
|
# update litellm params
|
||||||
|
if updated_patch.litellm_params:
|
||||||
|
# Encrypt any sensitive values
|
||||||
|
encrypted_params = {
|
||||||
|
k: encrypt_value_helper(v)
|
||||||
|
for k, v in updated_patch.litellm_params.model_dump(
|
||||||
|
exclude_none=True
|
||||||
|
).items()
|
||||||
|
}
|
||||||
|
|
||||||
|
merged_deployment_dict["litellm_params"].update(encrypted_params) # type: ignore
|
||||||
|
|
||||||
|
# update model info
|
||||||
|
if updated_patch.model_info:
|
||||||
|
if "model_info" not in merged_deployment_dict:
|
||||||
|
merged_deployment_dict["model_info"] = {}
|
||||||
|
merged_deployment_dict["model_info"].update(
|
||||||
|
updated_patch.model_info.model_dump(exclude_none=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
# convert to prisma compatible format
|
||||||
|
|
||||||
|
prisma_compatible_model_dict = PrismaCompatibleUpdateDBModel()
|
||||||
|
if "model_name" in merged_deployment_dict:
|
||||||
|
prisma_compatible_model_dict["model_name"] = merged_deployment_dict[
|
||||||
|
"model_name"
|
||||||
|
]
|
||||||
|
|
||||||
|
if "litellm_params" in merged_deployment_dict:
|
||||||
|
prisma_compatible_model_dict["litellm_params"] = json.dumps(
|
||||||
|
merged_deployment_dict["litellm_params"]
|
||||||
|
)
|
||||||
|
|
||||||
|
if "model_info" in merged_deployment_dict:
|
||||||
|
prisma_compatible_model_dict["model_info"] = json.dumps(
|
||||||
|
merged_deployment_dict["model_info"]
|
||||||
|
)
|
||||||
|
return prisma_compatible_model_dict
|
||||||
|
|
||||||
|
|
||||||
|
@router.patch(
|
||||||
|
"/model/{model_id}/update",
|
||||||
|
tags=["model management"],
|
||||||
|
dependencies=[Depends(user_api_key_auth)],
|
||||||
|
)
|
||||||
|
async def patch_model(
|
||||||
|
model_id: str, # Get model_id from path parameter
|
||||||
|
patch_data: updateDeployment, # Create a specific schema for PATCH operations
|
||||||
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
PATCH Endpoint for partial model updates.
|
||||||
|
|
||||||
|
Only updates the fields specified in the request while preserving other existing values.
|
||||||
|
Follows proper PATCH semantics by only modifying provided fields.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_id: The ID of the model to update
|
||||||
|
patch_data: The fields to update and their new values
|
||||||
|
user_api_key_dict: User authentication information
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated model information
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ProxyException: For various error conditions including authentication and database errors
|
||||||
|
"""
|
||||||
|
from litellm.proxy.proxy_server import (
|
||||||
|
litellm_proxy_admin_name,
|
||||||
|
llm_router,
|
||||||
|
prisma_client,
|
||||||
|
store_model_in_db,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if prisma_client is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail={"error": CommonProxyErrors.db_not_connected_error.value},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify model exists and is stored in DB
|
||||||
|
if not store_model_in_db:
|
||||||
|
raise ProxyException(
|
||||||
|
message="Model updates only supported for DB-stored models",
|
||||||
|
type=ProxyErrorTypes.validation_error.value,
|
||||||
|
code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
param=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Fetch existing model
|
||||||
|
db_model = await get_db_model(model_id=model_id, prisma_client=prisma_client)
|
||||||
|
|
||||||
|
if db_model is None:
|
||||||
|
# Check if model exists in config but not DB
|
||||||
|
if llm_router and llm_router.get_deployment(model_id=model_id) is not None:
|
||||||
|
raise ProxyException(
|
||||||
|
message="Cannot edit config-based model. Store model in DB via /model/new first.",
|
||||||
|
type=ProxyErrorTypes.validation_error.value,
|
||||||
|
code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
param=None,
|
||||||
|
)
|
||||||
|
raise ProxyException(
|
||||||
|
message=f"Model {model_id} not found on proxy.",
|
||||||
|
type=ProxyErrorTypes.not_found_error,
|
||||||
|
code=status.HTTP_404_NOT_FOUND,
|
||||||
|
param=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create update dictionary only for provided fields
|
||||||
|
update_data = update_db_model(db_model=db_model, updated_patch=patch_data)
|
||||||
|
|
||||||
|
# Add metadata about update
|
||||||
|
update_data["updated_by"] = (
|
||||||
|
user_api_key_dict.user_id or litellm_proxy_admin_name
|
||||||
|
)
|
||||||
|
update_data["updated_at"] = cast(str, get_utc_datetime())
|
||||||
|
|
||||||
|
# Perform partial update
|
||||||
|
updated_model = await prisma_client.db.litellm_proxymodeltable.update(
|
||||||
|
where={"model_id": model_id},
|
||||||
|
data=update_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
return updated_model
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
verbose_proxy_logger.exception(f"Error in patch_model: {str(e)}")
|
||||||
|
|
||||||
|
if isinstance(e, (HTTPException, ProxyException)):
|
||||||
|
raise e
|
||||||
|
|
||||||
|
raise ProxyException(
|
||||||
|
message=f"Error updating model: {str(e)}",
|
||||||
|
type=ProxyErrorTypes.internal_server_error,
|
||||||
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
param=None,
|
||||||
|
)
|
|
@ -196,6 +196,9 @@ from litellm.proxy.management_endpoints.key_management_endpoints import (
|
||||||
from litellm.proxy.management_endpoints.key_management_endpoints import (
|
from litellm.proxy.management_endpoints.key_management_endpoints import (
|
||||||
router as key_management_router,
|
router as key_management_router,
|
||||||
)
|
)
|
||||||
|
from litellm.proxy.management_endpoints.model_management_endpoints import (
|
||||||
|
router as model_management_router,
|
||||||
|
)
|
||||||
from litellm.proxy.management_endpoints.organization_endpoints import (
|
from litellm.proxy.management_endpoints.organization_endpoints import (
|
||||||
router as organization_router,
|
router as organization_router,
|
||||||
)
|
)
|
||||||
|
@ -6042,6 +6045,11 @@ async def update_model(
|
||||||
model_params: updateDeployment,
|
model_params: updateDeployment,
|
||||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Old endpoint for model update. Makes a PUT request.
|
||||||
|
|
||||||
|
Use `/model/{model_id}/update` to PATCH the stored model in db.
|
||||||
|
"""
|
||||||
global llm_router, llm_model_list, general_settings, user_config_file_path, proxy_config, prisma_client, master_key, store_model_in_db, proxy_logging_obj
|
global llm_router, llm_model_list, general_settings, user_config_file_path, proxy_config, prisma_client, master_key, store_model_in_db, proxy_logging_obj
|
||||||
try:
|
try:
|
||||||
import base64
|
import base64
|
||||||
|
@ -8924,3 +8932,4 @@ app.include_router(ui_crud_endpoints_router)
|
||||||
app.include_router(openai_files_router)
|
app.include_router(openai_files_router)
|
||||||
app.include_router(team_callback_router)
|
app.include_router(team_callback_router)
|
||||||
app.include_router(budget_management_router)
|
app.include_router(budget_management_router)
|
||||||
|
app.include_router(model_management_router)
|
||||||
|
|
|
@ -65,7 +65,40 @@ async def test_global_max_parallel_requests():
|
||||||
)
|
)
|
||||||
pytest.fail("Expected call to fail")
|
pytest.fail("Expected call to fail")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pass
|
print(e)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.flaky(retries=6, delay=1)
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_key_max_parallel_requests():
|
||||||
|
"""
|
||||||
|
Ensure the error str returned contains parallel request information.
|
||||||
|
|
||||||
|
Relevant Issue: https://github.com/BerriAI/litellm/issues/8392
|
||||||
|
"""
|
||||||
|
_api_key = "sk-12345"
|
||||||
|
_api_key = hash_token("sk-12345")
|
||||||
|
user_api_key_dict = UserAPIKeyAuth(api_key=_api_key, max_parallel_requests=1)
|
||||||
|
local_cache = DualCache()
|
||||||
|
parallel_request_handler = MaxParallelRequestsHandler(
|
||||||
|
internal_usage_cache=InternalUsageCache(dual_cache=local_cache)
|
||||||
|
)
|
||||||
|
|
||||||
|
parallel_limit_reached = False
|
||||||
|
for _ in range(3):
|
||||||
|
try:
|
||||||
|
await parallel_request_handler.async_pre_call_hook(
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
cache=local_cache,
|
||||||
|
data={},
|
||||||
|
call_type="",
|
||||||
|
)
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
except Exception as e:
|
||||||
|
if "current max_parallel_requests" in str(e):
|
||||||
|
parallel_limit_reached = True
|
||||||
|
|
||||||
|
assert parallel_limit_reached
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|
|
@ -54,7 +54,7 @@ const EditModelModal: React.FC<EditModelModalProps> = ({
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Modal
|
<Modal
|
||||||
title={"Edit Model " + model_name}
|
title={"Edit '" + model_name + "' LiteLLM Params"}
|
||||||
visible={visible}
|
visible={visible}
|
||||||
width={800}
|
width={800}
|
||||||
footer={null}
|
footer={null}
|
||||||
|
@ -88,6 +88,15 @@ const EditModelModal: React.FC<EditModelModalProps> = ({
|
||||||
<Form.Item className="mt-8" label="api_base" name="api_base">
|
<Form.Item className="mt-8" label="api_base" name="api_base">
|
||||||
<TextInput />
|
<TextInput />
|
||||||
</Form.Item>
|
</Form.Item>
|
||||||
|
<Form.Item className="mt-8" label="api_key" name="api_key">
|
||||||
|
<TextInput />
|
||||||
|
</Form.Item>
|
||||||
|
<Form.Item className="mt-8" label="custom_llm_provider" name="custom_llm_provider">
|
||||||
|
<TextInput />
|
||||||
|
</Form.Item>
|
||||||
|
<Form.Item className="mt-8" label="model" name="model">
|
||||||
|
<TextInput />
|
||||||
|
</Form.Item>
|
||||||
<Form.Item
|
<Form.Item
|
||||||
label="organization"
|
label="organization"
|
||||||
name="organization"
|
name="organization"
|
||||||
|
|
|
@ -418,17 +418,21 @@ const ModelDashboard: React.FC<ModelDashboardProps> = ({
|
||||||
|
|
||||||
for (const [key, value] of Object.entries(formValues)) {
|
for (const [key, value] of Object.entries(formValues)) {
|
||||||
if (key !== "model_id") {
|
if (key !== "model_id") {
|
||||||
newLiteLLMParams[key] = value;
|
// Empty string means user wants to null the value
|
||||||
|
newLiteLLMParams[key] = value === "" ? null : value;
|
||||||
} else {
|
} else {
|
||||||
model_info_model_id = value;
|
model_info_model_id = value === "" ? null : value;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let payload = {
|
let payload: {
|
||||||
litellm_params: newLiteLLMParams,
|
litellm_params: Record<string, any> | undefined;
|
||||||
model_info: {
|
model_info: { id: any } | undefined;
|
||||||
|
} = {
|
||||||
|
litellm_params: Object.keys(newLiteLLMParams).length > 0 ? newLiteLLMParams : undefined,
|
||||||
|
model_info: model_info_model_id !== undefined ? {
|
||||||
id: model_info_model_id,
|
id: model_info_model_id,
|
||||||
},
|
} : undefined,
|
||||||
};
|
};
|
||||||
|
|
||||||
console.log("handleEditSubmit payload:", payload);
|
console.log("handleEditSubmit payload:", payload);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue