feat(proxy_server.py): Enabling Admin to control general settings on proxy ui

This commit is contained in:
Krrish Dholakia 2024-05-15 15:26:57 -07:00
parent d9ad7c6218
commit 6a357b4275
3 changed files with 515 additions and 174 deletions

View file

@ -234,6 +234,7 @@ class SpecialModelNames(enum.Enum):
class CommonProxyErrors(enum.Enum):
db_not_connected_error = "DB not connected"
no_llm_router = "No models configured on proxy"
not_allowed_access = "Admin-only endpoint. Not allowed to access this."
@app.exception_handler(ProxyException)
@ -9389,7 +9390,7 @@ async def auth_callback(request: Request):
return RedirectResponse(url=litellm_dashboard_ui)
#### BASIC ENDPOINTS ####
#### CONFIG MANAGEMENT ####
@router.post(
"/config/update",
tags=["config.yaml"],
@ -9525,6 +9526,219 @@ async def update_config(config_info: ConfigYAML):
)
### CONFIG GENERAL SETTINGS
"""
- Update config settings
- Get config settings
Keep it more precise, to prevent overwrite other values unintentially
"""
@router.post(
"/config/field/update",
tags=["config.yaml"],
dependencies=[Depends(user_api_key_auth)],
)
async def update_config_general_settings(
data: ConfigFieldUpdate,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Update a specific field in litellm general settings
"""
global prisma_client
## VALIDATION ##
"""
- Check if prisma_client is None
- Check if user allowed to call this endpoint (admin-only)
- Check if param in general settings
- Check if config value is valid type
"""
if prisma_client is None:
raise HTTPException(
status_code=400,
detail={"error": CommonProxyErrors.db_not_connected_error.value},
)
if user_api_key_dict.user_role != "proxy_admin":
raise HTTPException(
status_code=400,
detail={"error": CommonProxyErrors.not_allowed_access.value},
)
if data.field_name not in ConfigGeneralSettings.model_fields:
raise HTTPException(
status_code=400,
detail={"error": "Invalid field={} passed in.".format(data.field_name)},
)
try:
cgs = ConfigGeneralSettings(**{data.field_name: data.field_value})
except:
raise HTTPException(
status_code=400,
detail={
"error": "Invalid type of field value={} passed in.".format(
type(data.field_value),
)
},
)
## get general settings from db
db_general_settings = await prisma_client.db.litellm_config.find_first(
where={"param_name": "general_settings"}
)
### update value
if db_general_settings is None or db_general_settings.param_value is None:
general_settings = {}
else:
general_settings = dict(db_general_settings.param_value)
## update db
general_settings[data.field_name] = data.field_value
response = await prisma_client.db.litellm_config.upsert(
where={"param_name": "general_settings"},
data={
"create": {"param_name": "general_settings", "param_value": json.dumps(general_settings)}, # type: ignore
"update": {"param_value": json.dumps(general_settings)}, # type: ignore
},
)
return response
@router.get(
"/config/field/info",
tags=["config.yaml"],
dependencies=[Depends(user_api_key_auth)],
)
async def get_config_general_settings(
field_name: str,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
global prisma_client
## VALIDATION ##
"""
- Check if prisma_client is None
- Check if user allowed to call this endpoint (admin-only)
- Check if param in general settings
"""
if prisma_client is None:
raise HTTPException(
status_code=400,
detail={"error": CommonProxyErrors.db_not_connected_error.value},
)
if user_api_key_dict.user_role != "proxy_admin":
raise HTTPException(
status_code=400,
detail={"error": CommonProxyErrors.not_allowed_access.value},
)
if field_name not in ConfigGeneralSettings.model_fields:
raise HTTPException(
status_code=400,
detail={"error": "Invalid field={} passed in.".format(field_name)},
)
## get general settings from db
db_general_settings = await prisma_client.db.litellm_config.find_first(
where={"param_name": "general_settings"}
)
### pop the value
if db_general_settings is None or db_general_settings.param_value is None:
raise HTTPException(
status_code=400,
detail={"error": "Field name={} not in DB".format(field_name)},
)
else:
general_settings = dict(db_general_settings.param_value)
if field_name in general_settings:
return {
"field_name": field_name,
"field_value": general_settings[field_name],
}
else:
raise HTTPException(
status_code=400,
detail={"error": "Field name={} not in DB".format(field_name)},
)
@router.post(
"/config/field/delete",
tags=["config.yaml"],
dependencies=[Depends(user_api_key_auth)],
)
async def delete_config_general_settings(
field_name: str,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Delete the db value of this field in litellm general settings. Resets it to it's initial default value on litellm.
"""
global prisma_client
## VALIDATION ##
"""
- Check if prisma_client is None
- Check if user allowed to call this endpoint (admin-only)
- Check if param in general settings
"""
if prisma_client is None:
raise HTTPException(
status_code=400,
detail={"error": CommonProxyErrors.db_not_connected_error.value},
)
if user_api_key_dict.user_role != "proxy_admin":
raise HTTPException(
status_code=400,
detail={"error": CommonProxyErrors.not_allowed_access.value},
)
if field_name not in ConfigGeneralSettings.model_fields:
raise HTTPException(
status_code=400,
detail={"error": "Invalid field={} passed in.".format(field_name)},
)
## get general settings from db
db_general_settings = await prisma_client.db.litellm_config.find_first(
where={"param_name": "general_settings"}
)
### pop the value
if db_general_settings is None or db_general_settings.param_value is None:
raise HTTPException(
status_code=400,
detail={"error": "Field name={} not in config".format(field_name)},
)
else:
general_settings = dict(db_general_settings.param_value)
## update db
general_settings.pop(field_name)
response = await prisma_client.db.litellm_config.upsert(
where={"param_name": "general_settings"},
data={
"create": {"param_name": "general_settings", "param_value": json.dumps(general_settings)}, # type: ignore
"update": {"param_value": json.dumps(general_settings)}, # type: ignore
},
)
return response
@router.get(
"/get/config/callbacks",
tags=["config.yaml"],
@ -9692,6 +9906,7 @@ async def config_yaml_endpoint(config_info: ConfigYAML):
return {"hello": "world"}
#### BASIC ENDPOINTS ####
@router.get(
"/test",
tags=["health"],