forked from phoenix/litellm-mirror
Merge pull request #3660 from BerriAI/litellm_proxy_ui_general_settings
feat(proxy_server.py): Enabling Admin to control general settings on proxy ui
This commit is contained in:
commit
25e4b34574
7 changed files with 1099 additions and 283 deletions
|
@ -234,6 +234,7 @@ class SpecialModelNames(enum.Enum):
|
|||
class CommonProxyErrors(enum.Enum):
|
||||
db_not_connected_error = "DB not connected"
|
||||
no_llm_router = "No models configured on proxy"
|
||||
not_allowed_access = "Admin-only endpoint. Not allowed to access this."
|
||||
|
||||
|
||||
@app.exception_handler(ProxyException)
|
||||
|
@ -566,9 +567,9 @@ async def user_api_key_auth(
|
|||
#### ELSE ####
|
||||
if master_key is None:
|
||||
if isinstance(api_key, str):
|
||||
return UserAPIKeyAuth(api_key=api_key)
|
||||
return UserAPIKeyAuth(api_key=api_key, user_role="proxy_admin")
|
||||
else:
|
||||
return UserAPIKeyAuth()
|
||||
return UserAPIKeyAuth(user_role="proxy_admin")
|
||||
elif api_key is None: # only require api key if master key is set
|
||||
raise Exception("No api key passed in.")
|
||||
elif api_key == "":
|
||||
|
@ -659,6 +660,7 @@ async def user_api_key_auth(
|
|||
verbose_proxy_logger.debug("Token from db: %s", valid_token)
|
||||
elif valid_token is not None:
|
||||
verbose_proxy_logger.debug("API Key Cache Hit!")
|
||||
|
||||
user_id_information = None
|
||||
if valid_token:
|
||||
# Got Valid Token from Cache, DB
|
||||
|
@ -1187,7 +1189,18 @@ async def user_api_key_auth(
|
|||
# No token was found when looking up in the DB
|
||||
raise Exception("Invalid token passed")
|
||||
if valid_token_dict is not None:
|
||||
return UserAPIKeyAuth(api_key=api_key, **valid_token_dict)
|
||||
if user_id_information is not None and _is_user_proxy_admin(
|
||||
user_id_information
|
||||
):
|
||||
return UserAPIKeyAuth(
|
||||
api_key=api_key, user_role="proxy_admin", **valid_token_dict
|
||||
)
|
||||
elif _has_user_setup_sso() and route in LiteLLMRoutes.sso_only_routes.value:
|
||||
return UserAPIKeyAuth(
|
||||
api_key=api_key, user_role="app_owner", **valid_token_dict
|
||||
)
|
||||
else:
|
||||
return UserAPIKeyAuth(api_key=api_key, **valid_token_dict)
|
||||
else:
|
||||
raise Exception()
|
||||
except Exception as e:
|
||||
|
@ -2795,7 +2808,19 @@ class ProxyConfig:
|
|||
"Error setting env variable: %s - %s", k, str(e)
|
||||
)
|
||||
|
||||
# general_settings
|
||||
# router settings
|
||||
if llm_router is not None and prisma_client is not None:
|
||||
db_router_settings = await prisma_client.db.litellm_config.find_first(
|
||||
where={"param_name": "router_settings"}
|
||||
)
|
||||
if (
|
||||
db_router_settings is not None
|
||||
and db_router_settings.param_value is not None
|
||||
):
|
||||
_router_settings = db_router_settings.param_value
|
||||
llm_router.update_settings(**_router_settings)
|
||||
|
||||
## ALERTING ## [TODO] move this to the _update_general_settings() block
|
||||
_general_settings = config_data.get("general_settings", {})
|
||||
if "alerting" in _general_settings:
|
||||
general_settings["alerting"] = _general_settings["alerting"]
|
||||
|
@ -2819,17 +2844,24 @@ class ProxyConfig:
|
|||
alert_to_webhook_url=general_settings["alert_to_webhook_url"]
|
||||
)
|
||||
|
||||
# router settings
|
||||
if llm_router is not None and prisma_client is not None:
|
||||
db_router_settings = await prisma_client.db.litellm_config.find_first(
|
||||
where={"param_name": "router_settings"}
|
||||
)
|
||||
if (
|
||||
db_router_settings is not None
|
||||
and db_router_settings.param_value is not None
|
||||
):
|
||||
_router_settings = db_router_settings.param_value
|
||||
llm_router.update_settings(**_router_settings)
|
||||
async def _update_general_settings(self, db_general_settings: Optional[Json]):
|
||||
"""
|
||||
Pull from DB, read general settings value
|
||||
"""
|
||||
global general_settings
|
||||
if db_general_settings is None:
|
||||
return
|
||||
_general_settings = dict(db_general_settings)
|
||||
## MAX PARALLEL REQUESTS ##
|
||||
if "max_parallel_requests" in _general_settings:
|
||||
general_settings["max_parallel_requests"] = _general_settings[
|
||||
"max_parallel_requests"
|
||||
]
|
||||
|
||||
if "global_max_parallel_requests" in _general_settings:
|
||||
general_settings["global_max_parallel_requests"] = _general_settings[
|
||||
"global_max_parallel_requests"
|
||||
]
|
||||
|
||||
async def add_deployment(
|
||||
self,
|
||||
|
@ -2837,7 +2869,7 @@ class ProxyConfig:
|
|||
proxy_logging_obj: ProxyLogging,
|
||||
):
|
||||
"""
|
||||
- Check db for new models (last 10 most recently updated)
|
||||
- Check db for new models
|
||||
- Check if model id's in router already
|
||||
- If not, add to router
|
||||
"""
|
||||
|
@ -2850,9 +2882,21 @@ class ProxyConfig:
|
|||
)
|
||||
verbose_proxy_logger.debug(f"llm_router: {llm_router}")
|
||||
new_models = await prisma_client.db.litellm_proxymodeltable.find_many()
|
||||
# update llm router
|
||||
await self._update_llm_router(
|
||||
new_models=new_models, proxy_logging_obj=proxy_logging_obj
|
||||
)
|
||||
|
||||
db_general_settings = await prisma_client.db.litellm_config.find_first(
|
||||
where={"param_name": "general_settings"}
|
||||
)
|
||||
|
||||
# update general settings
|
||||
if db_general_settings is not None:
|
||||
await self._update_general_settings(
|
||||
db_general_settings=db_general_settings.param_value,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(
|
||||
"{}\nTraceback:{}".format(str(e), traceback.format_exc())
|
||||
|
@ -3052,27 +3096,6 @@ async def generate_key_helper_fn(
|
|||
data=key_data, table_name="key"
|
||||
)
|
||||
key_data["token_id"] = getattr(create_key_response, "token", None)
|
||||
elif custom_db_client is not None:
|
||||
if table_name is None or table_name == "user":
|
||||
## CREATE USER (If necessary)
|
||||
verbose_proxy_logger.debug(
|
||||
"CustomDBClient: Creating User= %s", user_data
|
||||
)
|
||||
user_row = await custom_db_client.insert_data(
|
||||
value=user_data, table_name="user"
|
||||
)
|
||||
if user_row is None:
|
||||
# GET USER ROW
|
||||
user_row = await custom_db_client.get_data(
|
||||
key=user_id, table_name="user" # type: ignore
|
||||
)
|
||||
|
||||
## use default user model list if no key-specific model list provided
|
||||
if len(user_row.models) > 0 and len(key_data["models"]) == 0: # type: ignore
|
||||
key_data["models"] = user_row.models
|
||||
## CREATE KEY
|
||||
verbose_proxy_logger.debug("CustomDBClient: Creating Key= %s", key_data)
|
||||
await custom_db_client.insert_data(value=key_data, table_name="key")
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
if isinstance(e, HTTPException):
|
||||
|
@ -3668,6 +3691,9 @@ async def chat_completion(
|
|||
data["metadata"]["user_api_key_alias"] = getattr(
|
||||
user_api_key_dict, "key_alias", None
|
||||
)
|
||||
data["metadata"]["global_max_parallel_requests"] = general_settings.get(
|
||||
"global_max_parallel_requests", None
|
||||
)
|
||||
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
|
||||
data["metadata"]["user_api_key_org_id"] = user_api_key_dict.org_id
|
||||
data["metadata"]["user_api_key_team_id"] = getattr(
|
||||
|
@ -3938,6 +3964,9 @@ async def completion(
|
|||
data["metadata"]["user_api_key_team_id"] = getattr(
|
||||
user_api_key_dict, "team_id", None
|
||||
)
|
||||
data["metadata"]["global_max_parallel_requests"] = general_settings.get(
|
||||
"global_max_parallel_requests", None
|
||||
)
|
||||
data["metadata"]["user_api_key_team_alias"] = getattr(
|
||||
user_api_key_dict, "team_alias", None
|
||||
)
|
||||
|
@ -4134,6 +4163,9 @@ async def embeddings(
|
|||
data["metadata"]["user_api_key_alias"] = getattr(
|
||||
user_api_key_dict, "key_alias", None
|
||||
)
|
||||
data["metadata"]["global_max_parallel_requests"] = general_settings.get(
|
||||
"global_max_parallel_requests", None
|
||||
)
|
||||
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
|
||||
data["metadata"]["user_api_key_team_id"] = getattr(
|
||||
user_api_key_dict, "team_id", None
|
||||
|
@ -4338,6 +4370,9 @@ async def image_generation(
|
|||
data["metadata"]["user_api_key_alias"] = getattr(
|
||||
user_api_key_dict, "key_alias", None
|
||||
)
|
||||
data["metadata"]["global_max_parallel_requests"] = general_settings.get(
|
||||
"global_max_parallel_requests", None
|
||||
)
|
||||
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
|
||||
data["metadata"]["user_api_key_team_id"] = getattr(
|
||||
user_api_key_dict, "team_id", None
|
||||
|
@ -4518,6 +4553,9 @@ async def audio_transcriptions(
|
|||
data["metadata"]["user_api_key_team_id"] = getattr(
|
||||
user_api_key_dict, "team_id", None
|
||||
)
|
||||
data["metadata"]["global_max_parallel_requests"] = general_settings.get(
|
||||
"global_max_parallel_requests", None
|
||||
)
|
||||
data["metadata"]["user_api_key_team_alias"] = getattr(
|
||||
user_api_key_dict, "team_alias", None
|
||||
)
|
||||
|
@ -4715,6 +4753,9 @@ async def moderations(
|
|||
"authorization", None
|
||||
) # do not store the original `sk-..` api key in the db
|
||||
data["metadata"]["headers"] = _headers
|
||||
data["metadata"]["global_max_parallel_requests"] = general_settings.get(
|
||||
"global_max_parallel_requests", None
|
||||
)
|
||||
data["metadata"]["user_api_key_alias"] = getattr(
|
||||
user_api_key_dict, "key_alias", None
|
||||
)
|
||||
|
@ -9405,7 +9446,7 @@ async def auth_callback(request: Request):
|
|||
return RedirectResponse(url=litellm_dashboard_ui)
|
||||
|
||||
|
||||
#### BASIC ENDPOINTS ####
|
||||
#### CONFIG MANAGEMENT ####
|
||||
@router.post(
|
||||
"/config/update",
|
||||
tags=["config.yaml"],
|
||||
|
@ -9541,6 +9582,299 @@ async def update_config(config_info: ConfigYAML):
|
|||
)
|
||||
|
||||
|
||||
### CONFIG GENERAL SETTINGS
|
||||
"""
|
||||
- Update config settings
|
||||
- Get config settings
|
||||
|
||||
Keep it more precise, to prevent overwrite other values unintentially
|
||||
"""
|
||||
|
||||
|
||||
@router.post(
|
||||
"/config/field/update",
|
||||
tags=["config.yaml"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def update_config_general_settings(
|
||||
data: ConfigFieldUpdate,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Update a specific field in litellm general settings
|
||||
"""
|
||||
global prisma_client
|
||||
## VALIDATION ##
|
||||
"""
|
||||
- Check if prisma_client is None
|
||||
- Check if user allowed to call this endpoint (admin-only)
|
||||
- Check if param in general settings
|
||||
- Check if config value is valid type
|
||||
"""
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"error": CommonProxyErrors.db_not_connected_error.value},
|
||||
)
|
||||
|
||||
if user_api_key_dict.user_role != "proxy_admin":
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"error": CommonProxyErrors.not_allowed_access.value},
|
||||
)
|
||||
|
||||
if data.field_name not in ConfigGeneralSettings.model_fields:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"error": "Invalid field={} passed in.".format(data.field_name)},
|
||||
)
|
||||
|
||||
try:
|
||||
cgs = ConfigGeneralSettings(**{data.field_name: data.field_value})
|
||||
except:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "Invalid type of field value={} passed in.".format(
|
||||
type(data.field_value),
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
## get general settings from db
|
||||
db_general_settings = await prisma_client.db.litellm_config.find_first(
|
||||
where={"param_name": "general_settings"}
|
||||
)
|
||||
### update value
|
||||
|
||||
if db_general_settings is None or db_general_settings.param_value is None:
|
||||
general_settings = {}
|
||||
else:
|
||||
general_settings = dict(db_general_settings.param_value)
|
||||
|
||||
## update db
|
||||
|
||||
general_settings[data.field_name] = data.field_value
|
||||
|
||||
response = await prisma_client.db.litellm_config.upsert(
|
||||
where={"param_name": "general_settings"},
|
||||
data={
|
||||
"create": {"param_name": "general_settings", "param_value": json.dumps(general_settings)}, # type: ignore
|
||||
"update": {"param_value": json.dumps(general_settings)}, # type: ignore
|
||||
},
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@router.get(
|
||||
"/config/field/info",
|
||||
tags=["config.yaml"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def get_config_general_settings(
|
||||
field_name: str,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
global prisma_client
|
||||
|
||||
## VALIDATION ##
|
||||
"""
|
||||
- Check if prisma_client is None
|
||||
- Check if user allowed to call this endpoint (admin-only)
|
||||
- Check if param in general settings
|
||||
"""
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"error": CommonProxyErrors.db_not_connected_error.value},
|
||||
)
|
||||
|
||||
if user_api_key_dict.user_role != "proxy_admin":
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"error": CommonProxyErrors.not_allowed_access.value},
|
||||
)
|
||||
|
||||
if field_name not in ConfigGeneralSettings.model_fields:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"error": "Invalid field={} passed in.".format(field_name)},
|
||||
)
|
||||
|
||||
## get general settings from db
|
||||
db_general_settings = await prisma_client.db.litellm_config.find_first(
|
||||
where={"param_name": "general_settings"}
|
||||
)
|
||||
### pop the value
|
||||
|
||||
if db_general_settings is None or db_general_settings.param_value is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"error": "Field name={} not in DB".format(field_name)},
|
||||
)
|
||||
else:
|
||||
general_settings = dict(db_general_settings.param_value)
|
||||
|
||||
if field_name in general_settings:
|
||||
return {
|
||||
"field_name": field_name,
|
||||
"field_value": general_settings[field_name],
|
||||
}
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"error": "Field name={} not in DB".format(field_name)},
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/config/list",
|
||||
tags=["config.yaml"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def get_config_list(
|
||||
config_type: Literal["general_settings"],
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
) -> List[ConfigList]:
|
||||
"""
|
||||
List the available fields + current values for a given type of setting (currently just 'general_settings'user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),)
|
||||
"""
|
||||
global prisma_client, general_settings
|
||||
|
||||
## VALIDATION ##
|
||||
"""
|
||||
- Check if prisma_client is None
|
||||
- Check if user allowed to call this endpoint (admin-only)
|
||||
- Check if param in general settings
|
||||
"""
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"error": CommonProxyErrors.db_not_connected_error.value},
|
||||
)
|
||||
|
||||
if user_api_key_dict.user_role != "proxy_admin":
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "{}, your role={}".format(
|
||||
CommonProxyErrors.not_allowed_access.value,
|
||||
user_api_key_dict.user_role,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
## get general settings from db
|
||||
db_general_settings = await prisma_client.db.litellm_config.find_first(
|
||||
where={"param_name": "general_settings"}
|
||||
)
|
||||
|
||||
if db_general_settings is not None and db_general_settings.param_value is not None:
|
||||
db_general_settings_dict = dict(db_general_settings.param_value)
|
||||
else:
|
||||
db_general_settings_dict = {}
|
||||
|
||||
allowed_args = {
|
||||
"max_parallel_requests": {"type": "Integer"},
|
||||
"global_max_parallel_requests": {"type": "Integer"},
|
||||
}
|
||||
|
||||
return_val = []
|
||||
|
||||
for field_name, field_info in ConfigGeneralSettings.model_fields.items():
|
||||
if field_name in allowed_args:
|
||||
|
||||
_stored_in_db = None
|
||||
if field_name in db_general_settings_dict:
|
||||
_stored_in_db = True
|
||||
elif field_name in general_settings:
|
||||
_stored_in_db = False
|
||||
|
||||
_response_obj = ConfigList(
|
||||
field_name=field_name,
|
||||
field_type=allowed_args[field_name]["type"],
|
||||
field_description=field_info.description or "",
|
||||
field_value=general_settings.get(field_name, None),
|
||||
stored_in_db=_stored_in_db,
|
||||
)
|
||||
return_val.append(_response_obj)
|
||||
|
||||
return return_val
|
||||
|
||||
|
||||
@router.post(
|
||||
"/config/field/delete",
|
||||
tags=["config.yaml"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def delete_config_general_settings(
|
||||
data: ConfigFieldDelete,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Delete the db value of this field in litellm general settings. Resets it to it's initial default value on litellm.
|
||||
"""
|
||||
global prisma_client
|
||||
## VALIDATION ##
|
||||
"""
|
||||
- Check if prisma_client is None
|
||||
- Check if user allowed to call this endpoint (admin-only)
|
||||
- Check if param in general settings
|
||||
"""
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"error": CommonProxyErrors.db_not_connected_error.value},
|
||||
)
|
||||
|
||||
if user_api_key_dict.user_role != "proxy_admin":
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "{}, your role={}".format(
|
||||
CommonProxyErrors.not_allowed_access.value,
|
||||
user_api_key_dict.user_role,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
if data.field_name not in ConfigGeneralSettings.model_fields:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"error": "Invalid field={} passed in.".format(data.field_name)},
|
||||
)
|
||||
|
||||
## get general settings from db
|
||||
db_general_settings = await prisma_client.db.litellm_config.find_first(
|
||||
where={"param_name": "general_settings"}
|
||||
)
|
||||
### pop the value
|
||||
|
||||
if db_general_settings is None or db_general_settings.param_value is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"error": "Field name={} not in config".format(data.field_name)},
|
||||
)
|
||||
else:
|
||||
general_settings = dict(db_general_settings.param_value)
|
||||
|
||||
## update db
|
||||
|
||||
general_settings.pop(data.field_name, None)
|
||||
|
||||
response = await prisma_client.db.litellm_config.upsert(
|
||||
where={"param_name": "general_settings"},
|
||||
data={
|
||||
"create": {"param_name": "general_settings", "param_value": json.dumps(general_settings)}, # type: ignore
|
||||
"update": {"param_value": json.dumps(general_settings)}, # type: ignore
|
||||
},
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@router.get(
|
||||
"/get/config/callbacks",
|
||||
tags=["config.yaml"],
|
||||
|
@ -9708,6 +10042,7 @@ async def config_yaml_endpoint(config_info: ConfigYAML):
|
|||
return {"hello": "world"}
|
||||
|
||||
|
||||
#### BASIC ENDPOINTS ####
|
||||
@router.get(
|
||||
"/test",
|
||||
tags=["health"],
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue