fix(internal_user_endpoints.py): support adding guardrails on /user/update

Fixes https://github.com/BerriAI/litellm/issues/6942
This commit is contained in:
Krrish Dholakia 2024-11-29 16:20:25 -08:00
parent aa1621757c
commit a67dfa367e
3 changed files with 54 additions and 30 deletions

View file

@ -2183,3 +2183,11 @@ PassThroughEndpointLoggingResultValues = Union[
class PassThroughEndpointLoggingTypedDict(TypedDict): class PassThroughEndpointLoggingTypedDict(TypedDict):
result: Optional[PassThroughEndpointLoggingResultValues] result: Optional[PassThroughEndpointLoggingResultValues]
kwargs: dict kwargs: dict
LiteLLM_ManagementEndpoint_MetadataFields = [
"model_rpm_limit",
"model_tpm_limit",
"guardrails",
"tags",
]

View file

@ -32,6 +32,7 @@ from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.management_endpoints.key_management_endpoints import ( from litellm.proxy.management_endpoints.key_management_endpoints import (
duration_in_seconds, duration_in_seconds,
generate_key_helper_fn, generate_key_helper_fn,
prepare_metadata_fields,
) )
from litellm.proxy.management_helpers.utils import ( from litellm.proxy.management_helpers.utils import (
add_new_member, add_new_member,
@ -459,6 +460,7 @@ async def user_update(
"user_id": "test-litellm-user-4", "user_id": "test-litellm-user-4",
"user_role": "proxy_admin_viewer" "user_role": "proxy_admin_viewer"
}' }'
```
Parameters: Parameters:
- user_id: Optional[str] - Specify a user id. If not set, a unique id will be generated. - user_id: Optional[str] - Specify a user id. If not set, a unique id will be generated.
@ -491,7 +493,7 @@ async def user_update(
- duration: Optional[str] - [NOT IMPLEMENTED]. - duration: Optional[str] - [NOT IMPLEMENTED].
- key_alias: Optional[str] - [NOT IMPLEMENTED]. - key_alias: Optional[str] - [NOT IMPLEMENTED].
```
""" """
from litellm.proxy.proxy_server import prisma_client from litellm.proxy.proxy_server import prisma_client
@ -504,10 +506,15 @@ async def user_update(
# get non default values for key # get non default values for key
non_default_values = {} non_default_values = {}
for k, v in data_json.items(): for k, v in data_json.items():
if v is not None and v not in ( if (
v is not None
and v
not in (
[], [],
{}, {},
0, 0,
)
and k not in LiteLLM_ManagementEndpoint_MetadataFields
): # models default to [], spend defaults to 0, we should not reset these values ): # models default to [], spend defaults to 0, we should not reset these values
non_default_values[k] = v non_default_values[k] = v
@ -543,6 +550,8 @@ async def user_update(
) )
non_default_values["budget_reset_at"] = user_reset_at non_default_values["budget_reset_at"] = user_reset_at
non_default_values = prepare_metadata_fields(data, non_default_values)
## ADD USER, IF NEW ## ## ADD USER, IF NEW ##
verbose_proxy_logger.debug("/user/update: Received data = %s", data) verbose_proxy_logger.debug("/user/update: Received data = %s", data)
response: Optional[Any] = None response: Optional[Any] = None

View file

@ -452,12 +452,39 @@ async def generate_key_fn( # noqa: PLR0915
raise handle_exception_on_proxy(e) raise handle_exception_on_proxy(e)
def prepare_metadata_fields(data: BaseModel, non_default_values: dict) -> dict:
"""
Check LiteLLM_ManagementEndpoint_MetadataFields (proxy/_types.py) for fields that are allowed to be updated
"""
non_default_values.setdefault("metadata", {})
data_json = data.model_dump(exclude_unset=True)
try:
for k, v in data_json.items():
if k == "model_tpm_limit" or k == "model_rpm_limit":
if k not in non_default_values["metadata"]:
non_default_values["metadata"][k] = {}
non_default_values["metadata"][k].update(v)
if k == "tags" or k == "guardrails":
if k not in non_default_values["metadata"]:
non_default_values["metadata"][k] = []
non_default_values["metadata"][k].extend(v)
except Exception as e:
verbose_proxy_logger.exception(
"litellm.proxy.proxy_server.prepare_metadata_fields(): Exception occured - {}".format(
str(e)
)
)
return non_default_values
def prepare_key_update_data( def prepare_key_update_data(
data: Union[UpdateKeyRequest, RegenerateKeyRequest], existing_key_row data: Union[UpdateKeyRequest, RegenerateKeyRequest], existing_key_row
): ):
data_json: dict = data.model_dump(exclude_unset=True) data_json: dict = data.model_dump(exclude_unset=True)
data_json.pop("key", None) data_json.pop("key", None)
_metadata_fields = ["model_rpm_limit", "model_tpm_limit", "guardrails"] _metadata_fields = ["model_rpm_limit", "model_tpm_limit", "guardrails", "tags"]
non_default_values = {} non_default_values = {}
for k, v in data_json.items(): for k, v in data_json.items():
if k in _metadata_fields: if k in _metadata_fields:
@ -483,29 +510,9 @@ def prepare_key_update_data(
non_default_values["budget_reset_at"] = key_reset_at non_default_values["budget_reset_at"] = key_reset_at
non_default_values["budget_duration"] = budget_duration non_default_values["budget_duration"] = budget_duration
_metadata = existing_key_row.metadata or {} non_default_values = prepare_metadata_fields(
data=data, non_default_values=non_default_values
if data.model_tpm_limit: )
if "model_tpm_limit" not in _metadata:
_metadata["model_tpm_limit"] = {}
_metadata["model_tpm_limit"].update(data.model_tpm_limit)
non_default_values["metadata"] = _metadata
if data.model_rpm_limit:
if "model_rpm_limit" not in _metadata:
_metadata["model_rpm_limit"] = {}
_metadata["model_rpm_limit"].update(data.model_rpm_limit)
non_default_values["metadata"] = _metadata
if data.tags:
if "tags" not in _metadata:
_metadata["tags"] = []
_metadata["tags"].extend(data.tags)
non_default_values["metadata"] = _metadata
if data.guardrails:
_metadata["guardrails"] = data.guardrails
non_default_values["metadata"] = _metadata
return non_default_values return non_default_values