build(ui/general_settings.tsx): support updating global max parallel requests on the ui

This commit is contained in:
Krrish Dholakia 2024-05-15 19:26:57 -07:00
parent 6a357b4275
commit 153ce0d085
4 changed files with 491 additions and 110 deletions

View file

@ -567,9 +567,9 @@ async def user_api_key_auth(
#### ELSE ####
if master_key is None:
if isinstance(api_key, str):
return UserAPIKeyAuth(api_key=api_key)
return UserAPIKeyAuth(api_key=api_key, user_role="proxy_admin")
else:
return UserAPIKeyAuth()
return UserAPIKeyAuth(user_role="proxy_admin")
elif api_key is None: # only require api key if master key is set
raise Exception("No api key passed in.")
elif api_key == "":
@ -660,6 +660,7 @@ async def user_api_key_auth(
verbose_proxy_logger.debug("Token from db: %s", valid_token)
elif valid_token is not None:
verbose_proxy_logger.debug("API Key Cache Hit!")
user_id_information = None
if valid_token:
# Got Valid Token from Cache, DB
@ -1188,7 +1189,18 @@ async def user_api_key_auth(
# No token was found when looking up in the DB
raise Exception("Invalid token passed")
if valid_token_dict is not None:
return UserAPIKeyAuth(api_key=api_key, **valid_token_dict)
if user_id_information is not None and _is_user_proxy_admin(
user_id_information
):
return UserAPIKeyAuth(
api_key=api_key, user_role="proxy_admin", **valid_token_dict
)
elif _has_user_setup_sso() and route in LiteLLMRoutes.sso_only_routes.value:
return UserAPIKeyAuth(
api_key=api_key, user_role="app_owner", **valid_token_dict
)
else:
return UserAPIKeyAuth(api_key=api_key, **valid_token_dict)
else:
raise Exception()
except Exception as e:
@ -2796,7 +2808,19 @@ class ProxyConfig:
"Error setting env variable: %s - %s", k, str(e)
)
# general_settings
# router settings
if llm_router is not None and prisma_client is not None:
db_router_settings = await prisma_client.db.litellm_config.find_first(
where={"param_name": "router_settings"}
)
if (
db_router_settings is not None
and db_router_settings.param_value is not None
):
_router_settings = db_router_settings.param_value
llm_router.update_settings(**_router_settings)
## ALERTING ## [TODO] move this to the _update_general_settings() block
_general_settings = config_data.get("general_settings", {})
if "alerting" in _general_settings:
general_settings["alerting"] = _general_settings["alerting"]
@ -2820,17 +2844,23 @@ class ProxyConfig:
alert_to_webhook_url=general_settings["alert_to_webhook_url"]
)
# router settings
if llm_router is not None and prisma_client is not None:
db_router_settings = await prisma_client.db.litellm_config.find_first(
where={"param_name": "router_settings"}
)
if (
db_router_settings is not None
and db_router_settings.param_value is not None
):
_router_settings = db_router_settings.param_value
llm_router.update_settings(**_router_settings)
async def _update_general_settings(self, db_general_settings: Optional[Json]):
"""
Pull from DB, read general settings value
"""
if db_general_settings is None:
return
_general_settings = dict(db_general_settings)
## MAX PARALLEL REQUESTS ##
if "max_parallel_requests" in _general_settings:
general_settings["max_parallel_requests"] = _general_settings[
"max_parallel_requests"
]
if "global_max_parallel_requests" in _general_settings:
general_settings["global_max_parallel_requests"] = _general_settings[
"global_max_parallel_requests"
]
async def add_deployment(
self,
@ -2838,7 +2868,7 @@ class ProxyConfig:
proxy_logging_obj: ProxyLogging,
):
"""
- Check db for new models (last 10 most recently updated)
- Check db for new models
- Check if model id's in router already
- If not, add to router
"""
@ -2851,9 +2881,21 @@ class ProxyConfig:
)
verbose_proxy_logger.debug(f"llm_router: {llm_router}")
new_models = await prisma_client.db.litellm_proxymodeltable.find_many()
# update llm router
await self._update_llm_router(
new_models=new_models, proxy_logging_obj=proxy_logging_obj
)
db_general_settings = await prisma_client.db.litellm_config.find_first(
where={"param_name": "general_settings"}
)
# update general settings
if db_general_settings is not None:
await self._update_general_settings(
db_general_settings=db_general_settings.param_value,
)
except Exception as e:
verbose_proxy_logger.error(
"{}\nTraceback:{}".format(str(e), traceback.format_exc())
@ -3053,27 +3095,6 @@ async def generate_key_helper_fn(
data=key_data, table_name="key"
)
key_data["token_id"] = getattr(create_key_response, "token", None)
elif custom_db_client is not None:
if table_name is None or table_name == "user":
## CREATE USER (If necessary)
verbose_proxy_logger.debug(
"CustomDBClient: Creating User= %s", user_data
)
user_row = await custom_db_client.insert_data(
value=user_data, table_name="user"
)
if user_row is None:
# GET USER ROW
user_row = await custom_db_client.get_data(
key=user_id, table_name="user" # type: ignore
)
## use default user model list if no key-specific model list provided
if len(user_row.models) > 0 and len(key_data["models"]) == 0: # type: ignore
key_data["models"] = user_row.models
## CREATE KEY
verbose_proxy_logger.debug("CustomDBClient: Creating Key= %s", key_data)
await custom_db_client.insert_data(value=key_data, table_name="key")
except Exception as e:
traceback.print_exc()
if isinstance(e, HTTPException):
@ -9673,13 +9694,88 @@ async def get_config_general_settings(
)
@router.get(
"/config/list",
tags=["config.yaml"],
dependencies=[Depends(user_api_key_auth)],
)
async def get_config_list(
config_type: Literal["general_settings"],
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
) -> List[ConfigList]:
"""
List the available fields + current values for a given type of setting (currently just 'general_settings'user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),)
"""
global prisma_client, general_settings
## VALIDATION ##
"""
- Check if prisma_client is None
- Check if user allowed to call this endpoint (admin-only)
- Check if param in general settings
"""
if prisma_client is None:
raise HTTPException(
status_code=400,
detail={"error": CommonProxyErrors.db_not_connected_error.value},
)
if user_api_key_dict.user_role != "proxy_admin":
raise HTTPException(
status_code=400,
detail={
"error": "{}, your role={}".format(
CommonProxyErrors.not_allowed_access.value,
user_api_key_dict.user_role,
)
},
)
## get general settings from db
db_general_settings = await prisma_client.db.litellm_config.find_first(
where={"param_name": "general_settings"}
)
if db_general_settings is not None and db_general_settings.param_value is not None:
db_general_settings_dict = dict(db_general_settings.param_value)
else:
db_general_settings_dict = {}
allowed_args = {
"max_parallel_requests": {"type": "Integer"},
"global_max_parallel_requests": {"type": "Integer"},
}
return_val = []
for field_name, field_info in ConfigGeneralSettings.model_fields.items():
if field_name in allowed_args:
_stored_in_db = None
if field_name in db_general_settings_dict:
_stored_in_db = True
elif field_name in general_settings:
_stored_in_db = False
_response_obj = ConfigList(
field_name=field_name,
field_type=allowed_args[field_name]["type"],
field_description=field_info.description or "",
field_value=general_settings.get(field_name, None),
stored_in_db=_stored_in_db,
)
return_val.append(_response_obj)
return return_val
@router.post(
"/config/field/delete",
tags=["config.yaml"],
dependencies=[Depends(user_api_key_auth)],
)
async def delete_config_general_settings(
field_name: str,
data: ConfigFieldDelete,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
@ -9701,13 +9797,18 @@ async def delete_config_general_settings(
if user_api_key_dict.user_role != "proxy_admin":
raise HTTPException(
status_code=400,
detail={"error": CommonProxyErrors.not_allowed_access.value},
detail={
"error": "{}, your role={}".format(
CommonProxyErrors.not_allowed_access.value,
user_api_key_dict.user_role,
)
},
)
if field_name not in ConfigGeneralSettings.model_fields:
if data.field_name not in ConfigGeneralSettings.model_fields:
raise HTTPException(
status_code=400,
detail={"error": "Invalid field={} passed in.".format(field_name)},
detail={"error": "Invalid field={} passed in.".format(data.field_name)},
)
## get general settings from db
@ -9719,14 +9820,14 @@ async def delete_config_general_settings(
if db_general_settings is None or db_general_settings.param_value is None:
raise HTTPException(
status_code=400,
detail={"error": "Field name={} not in config".format(field_name)},
detail={"error": "Field name={} not in config".format(data.field_name)},
)
else:
general_settings = dict(db_general_settings.param_value)
## update db
general_settings.pop(field_name)
general_settings.pop(data.field_name, None)
response = await prisma_client.db.litellm_config.upsert(
where={"param_name": "general_settings"},