forked from phoenix/litellm-mirror
fix(internal_user_endpoints.py): support adding guardrails on /user/update
Fixes https://github.com/BerriAI/litellm/issues/6942
This commit is contained in:
parent
aa1621757c
commit
a67dfa367e
3 changed files with 54 additions and 30 deletions
|
@ -2183,3 +2183,11 @@ PassThroughEndpointLoggingResultValues = Union[
|
||||||
class PassThroughEndpointLoggingTypedDict(TypedDict):
|
class PassThroughEndpointLoggingTypedDict(TypedDict):
|
||||||
result: Optional[PassThroughEndpointLoggingResultValues]
|
result: Optional[PassThroughEndpointLoggingResultValues]
|
||||||
kwargs: dict
|
kwargs: dict
|
||||||
|
|
||||||
|
|
||||||
|
LiteLLM_ManagementEndpoint_MetadataFields = [
|
||||||
|
"model_rpm_limit",
|
||||||
|
"model_tpm_limit",
|
||||||
|
"guardrails",
|
||||||
|
"tags",
|
||||||
|
]
|
||||||
|
|
|
@ -32,6 +32,7 @@ from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||||
from litellm.proxy.management_endpoints.key_management_endpoints import (
|
from litellm.proxy.management_endpoints.key_management_endpoints import (
|
||||||
duration_in_seconds,
|
duration_in_seconds,
|
||||||
generate_key_helper_fn,
|
generate_key_helper_fn,
|
||||||
|
prepare_metadata_fields,
|
||||||
)
|
)
|
||||||
from litellm.proxy.management_helpers.utils import (
|
from litellm.proxy.management_helpers.utils import (
|
||||||
add_new_member,
|
add_new_member,
|
||||||
|
@ -459,6 +460,7 @@ async def user_update(
|
||||||
"user_id": "test-litellm-user-4",
|
"user_id": "test-litellm-user-4",
|
||||||
"user_role": "proxy_admin_viewer"
|
"user_role": "proxy_admin_viewer"
|
||||||
}'
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
- user_id: Optional[str] - Specify a user id. If not set, a unique id will be generated.
|
- user_id: Optional[str] - Specify a user id. If not set, a unique id will be generated.
|
||||||
|
@ -491,7 +493,7 @@ async def user_update(
|
||||||
- duration: Optional[str] - [NOT IMPLEMENTED].
|
- duration: Optional[str] - [NOT IMPLEMENTED].
|
||||||
- key_alias: Optional[str] - [NOT IMPLEMENTED].
|
- key_alias: Optional[str] - [NOT IMPLEMENTED].
|
||||||
|
|
||||||
```
|
|
||||||
"""
|
"""
|
||||||
from litellm.proxy.proxy_server import prisma_client
|
from litellm.proxy.proxy_server import prisma_client
|
||||||
|
|
||||||
|
@ -504,10 +506,15 @@ async def user_update(
|
||||||
# get non default values for key
|
# get non default values for key
|
||||||
non_default_values = {}
|
non_default_values = {}
|
||||||
for k, v in data_json.items():
|
for k, v in data_json.items():
|
||||||
if v is not None and v not in (
|
if (
|
||||||
[],
|
v is not None
|
||||||
{},
|
and v
|
||||||
0,
|
not in (
|
||||||
|
[],
|
||||||
|
{},
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
and k not in LiteLLM_ManagementEndpoint_MetadataFields
|
||||||
): # models default to [], spend defaults to 0, we should not reset these values
|
): # models default to [], spend defaults to 0, we should not reset these values
|
||||||
non_default_values[k] = v
|
non_default_values[k] = v
|
||||||
|
|
||||||
|
@ -543,6 +550,8 @@ async def user_update(
|
||||||
)
|
)
|
||||||
non_default_values["budget_reset_at"] = user_reset_at
|
non_default_values["budget_reset_at"] = user_reset_at
|
||||||
|
|
||||||
|
non_default_values = prepare_metadata_fields(data, non_default_values)
|
||||||
|
|
||||||
## ADD USER, IF NEW ##
|
## ADD USER, IF NEW ##
|
||||||
verbose_proxy_logger.debug("/user/update: Received data = %s", data)
|
verbose_proxy_logger.debug("/user/update: Received data = %s", data)
|
||||||
response: Optional[Any] = None
|
response: Optional[Any] = None
|
||||||
|
|
|
@ -452,12 +452,39 @@ async def generate_key_fn( # noqa: PLR0915
|
||||||
raise handle_exception_on_proxy(e)
|
raise handle_exception_on_proxy(e)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_metadata_fields(data: BaseModel, non_default_values: dict) -> dict:
|
||||||
|
"""
|
||||||
|
Check LiteLLM_ManagementEndpoint_MetadataFields (proxy/_types.py) for fields that are allowed to be updated
|
||||||
|
"""
|
||||||
|
non_default_values.setdefault("metadata", {})
|
||||||
|
data_json = data.model_dump(exclude_unset=True)
|
||||||
|
try:
|
||||||
|
for k, v in data_json.items():
|
||||||
|
if k == "model_tpm_limit" or k == "model_rpm_limit":
|
||||||
|
if k not in non_default_values["metadata"]:
|
||||||
|
non_default_values["metadata"][k] = {}
|
||||||
|
non_default_values["metadata"][k].update(v)
|
||||||
|
|
||||||
|
if k == "tags" or k == "guardrails":
|
||||||
|
if k not in non_default_values["metadata"]:
|
||||||
|
non_default_values["metadata"][k] = []
|
||||||
|
non_default_values["metadata"][k].extend(v)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
verbose_proxy_logger.exception(
|
||||||
|
"litellm.proxy.proxy_server.prepare_metadata_fields(): Exception occured - {}".format(
|
||||||
|
str(e)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return non_default_values
|
||||||
|
|
||||||
|
|
||||||
def prepare_key_update_data(
|
def prepare_key_update_data(
|
||||||
data: Union[UpdateKeyRequest, RegenerateKeyRequest], existing_key_row
|
data: Union[UpdateKeyRequest, RegenerateKeyRequest], existing_key_row
|
||||||
):
|
):
|
||||||
data_json: dict = data.model_dump(exclude_unset=True)
|
data_json: dict = data.model_dump(exclude_unset=True)
|
||||||
data_json.pop("key", None)
|
data_json.pop("key", None)
|
||||||
_metadata_fields = ["model_rpm_limit", "model_tpm_limit", "guardrails"]
|
_metadata_fields = ["model_rpm_limit", "model_tpm_limit", "guardrails", "tags"]
|
||||||
non_default_values = {}
|
non_default_values = {}
|
||||||
for k, v in data_json.items():
|
for k, v in data_json.items():
|
||||||
if k in _metadata_fields:
|
if k in _metadata_fields:
|
||||||
|
@ -483,29 +510,9 @@ def prepare_key_update_data(
|
||||||
non_default_values["budget_reset_at"] = key_reset_at
|
non_default_values["budget_reset_at"] = key_reset_at
|
||||||
non_default_values["budget_duration"] = budget_duration
|
non_default_values["budget_duration"] = budget_duration
|
||||||
|
|
||||||
_metadata = existing_key_row.metadata or {}
|
non_default_values = prepare_metadata_fields(
|
||||||
|
data=data, non_default_values=non_default_values
|
||||||
if data.model_tpm_limit:
|
)
|
||||||
if "model_tpm_limit" not in _metadata:
|
|
||||||
_metadata["model_tpm_limit"] = {}
|
|
||||||
_metadata["model_tpm_limit"].update(data.model_tpm_limit)
|
|
||||||
non_default_values["metadata"] = _metadata
|
|
||||||
|
|
||||||
if data.model_rpm_limit:
|
|
||||||
if "model_rpm_limit" not in _metadata:
|
|
||||||
_metadata["model_rpm_limit"] = {}
|
|
||||||
_metadata["model_rpm_limit"].update(data.model_rpm_limit)
|
|
||||||
non_default_values["metadata"] = _metadata
|
|
||||||
|
|
||||||
if data.tags:
|
|
||||||
if "tags" not in _metadata:
|
|
||||||
_metadata["tags"] = []
|
|
||||||
_metadata["tags"].extend(data.tags)
|
|
||||||
non_default_values["metadata"] = _metadata
|
|
||||||
|
|
||||||
if data.guardrails:
|
|
||||||
_metadata["guardrails"] = data.guardrails
|
|
||||||
non_default_values["metadata"] = _metadata
|
|
||||||
|
|
||||||
return non_default_values
|
return non_default_values
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue