forked from phoenix/litellm-mirror
fix(internal_user_endpoints.py): support adding guardrails on /user/update
Fixes https://github.com/BerriAI/litellm/issues/6942
This commit is contained in:
parent
aa1621757c
commit
a67dfa367e
3 changed files with 54 additions and 30 deletions
|
@ -2183,3 +2183,11 @@ PassThroughEndpointLoggingResultValues = Union[
|
|||
class PassThroughEndpointLoggingTypedDict(TypedDict):
|
||||
result: Optional[PassThroughEndpointLoggingResultValues]
|
||||
kwargs: dict
|
||||
|
||||
|
||||
LiteLLM_ManagementEndpoint_MetadataFields = [
|
||||
"model_rpm_limit",
|
||||
"model_tpm_limit",
|
||||
"guardrails",
|
||||
"tags",
|
||||
]
|
||||
|
|
|
@ -32,6 +32,7 @@ from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
|||
from litellm.proxy.management_endpoints.key_management_endpoints import (
|
||||
duration_in_seconds,
|
||||
generate_key_helper_fn,
|
||||
prepare_metadata_fields,
|
||||
)
|
||||
from litellm.proxy.management_helpers.utils import (
|
||||
add_new_member,
|
||||
|
@ -459,7 +460,8 @@ async def user_update(
|
|||
"user_id": "test-litellm-user-4",
|
||||
"user_role": "proxy_admin_viewer"
|
||||
}'
|
||||
|
||||
```
|
||||
|
||||
Parameters:
|
||||
- user_id: Optional[str] - Specify a user id. If not set, a unique id will be generated.
|
||||
- user_email: Optional[str] - Specify a user email.
|
||||
|
@ -491,7 +493,7 @@ async def user_update(
|
|||
- duration: Optional[str] - [NOT IMPLEMENTED].
|
||||
- key_alias: Optional[str] - [NOT IMPLEMENTED].
|
||||
|
||||
```
|
||||
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
|
@ -504,10 +506,15 @@ async def user_update(
|
|||
# get non default values for key
|
||||
non_default_values = {}
|
||||
for k, v in data_json.items():
|
||||
if v is not None and v not in (
|
||||
[],
|
||||
{},
|
||||
0,
|
||||
if (
|
||||
v is not None
|
||||
and v
|
||||
not in (
|
||||
[],
|
||||
{},
|
||||
0,
|
||||
)
|
||||
and k not in LiteLLM_ManagementEndpoint_MetadataFields
|
||||
): # models default to [], spend defaults to 0, we should not reset these values
|
||||
non_default_values[k] = v
|
||||
|
||||
|
@ -543,6 +550,8 @@ async def user_update(
|
|||
)
|
||||
non_default_values["budget_reset_at"] = user_reset_at
|
||||
|
||||
non_default_values = prepare_metadata_fields(data, non_default_values)
|
||||
|
||||
## ADD USER, IF NEW ##
|
||||
verbose_proxy_logger.debug("/user/update: Received data = %s", data)
|
||||
response: Optional[Any] = None
|
||||
|
|
|
@ -452,12 +452,39 @@ async def generate_key_fn( # noqa: PLR0915
|
|||
raise handle_exception_on_proxy(e)
|
||||
|
||||
|
||||
def prepare_metadata_fields(data: BaseModel, non_default_values: dict) -> dict:
|
||||
"""
|
||||
Check LiteLLM_ManagementEndpoint_MetadataFields (proxy/_types.py) for fields that are allowed to be updated
|
||||
"""
|
||||
non_default_values.setdefault("metadata", {})
|
||||
data_json = data.model_dump(exclude_unset=True)
|
||||
try:
|
||||
for k, v in data_json.items():
|
||||
if k == "model_tpm_limit" or k == "model_rpm_limit":
|
||||
if k not in non_default_values["metadata"]:
|
||||
non_default_values["metadata"][k] = {}
|
||||
non_default_values["metadata"][k].update(v)
|
||||
|
||||
if k == "tags" or k == "guardrails":
|
||||
if k not in non_default_values["metadata"]:
|
||||
non_default_values["metadata"][k] = []
|
||||
non_default_values["metadata"][k].extend(v)
|
||||
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
"litellm.proxy.proxy_server.prepare_metadata_fields(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
return non_default_values
|
||||
|
||||
|
||||
def prepare_key_update_data(
|
||||
data: Union[UpdateKeyRequest, RegenerateKeyRequest], existing_key_row
|
||||
):
|
||||
data_json: dict = data.model_dump(exclude_unset=True)
|
||||
data_json.pop("key", None)
|
||||
_metadata_fields = ["model_rpm_limit", "model_tpm_limit", "guardrails"]
|
||||
_metadata_fields = ["model_rpm_limit", "model_tpm_limit", "guardrails", "tags"]
|
||||
non_default_values = {}
|
||||
for k, v in data_json.items():
|
||||
if k in _metadata_fields:
|
||||
|
@ -483,29 +510,9 @@ def prepare_key_update_data(
|
|||
non_default_values["budget_reset_at"] = key_reset_at
|
||||
non_default_values["budget_duration"] = budget_duration
|
||||
|
||||
_metadata = existing_key_row.metadata or {}
|
||||
|
||||
if data.model_tpm_limit:
|
||||
if "model_tpm_limit" not in _metadata:
|
||||
_metadata["model_tpm_limit"] = {}
|
||||
_metadata["model_tpm_limit"].update(data.model_tpm_limit)
|
||||
non_default_values["metadata"] = _metadata
|
||||
|
||||
if data.model_rpm_limit:
|
||||
if "model_rpm_limit" not in _metadata:
|
||||
_metadata["model_rpm_limit"] = {}
|
||||
_metadata["model_rpm_limit"].update(data.model_rpm_limit)
|
||||
non_default_values["metadata"] = _metadata
|
||||
|
||||
if data.tags:
|
||||
if "tags" not in _metadata:
|
||||
_metadata["tags"] = []
|
||||
_metadata["tags"].extend(data.tags)
|
||||
non_default_values["metadata"] = _metadata
|
||||
|
||||
if data.guardrails:
|
||||
_metadata["guardrails"] = data.guardrails
|
||||
non_default_values["metadata"] = _metadata
|
||||
non_default_values = prepare_metadata_fields(
|
||||
data=data, non_default_values=non_default_values
|
||||
)
|
||||
|
||||
return non_default_values
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue