From a67dfa367e77531c01edfb0f78c8335c434cec53 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Fri, 29 Nov 2024 16:20:25 -0800 Subject: [PATCH] fix(internal_user_endpoints.py): support adding guardrails on `/user/update` Fixes https://github.com/BerriAI/litellm/issues/6942 --- litellm/proxy/_types.py | 8 +++ .../internal_user_endpoints.py | 21 +++++-- .../key_management_endpoints.py | 55 +++++++++++-------- 3 files changed, 54 insertions(+), 30 deletions(-) diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 965b72642..d2b417c9d 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -2183,3 +2183,11 @@ PassThroughEndpointLoggingResultValues = Union[ class PassThroughEndpointLoggingTypedDict(TypedDict): result: Optional[PassThroughEndpointLoggingResultValues] kwargs: dict + + +LiteLLM_ManagementEndpoint_MetadataFields = [ + "model_rpm_limit", + "model_tpm_limit", + "guardrails", + "tags", +] diff --git a/litellm/proxy/management_endpoints/internal_user_endpoints.py b/litellm/proxy/management_endpoints/internal_user_endpoints.py index c41975f50..64b754ca7 100644 --- a/litellm/proxy/management_endpoints/internal_user_endpoints.py +++ b/litellm/proxy/management_endpoints/internal_user_endpoints.py @@ -32,6 +32,7 @@ from litellm.proxy.auth.user_api_key_auth import user_api_key_auth from litellm.proxy.management_endpoints.key_management_endpoints import ( duration_in_seconds, generate_key_helper_fn, + prepare_metadata_fields, ) from litellm.proxy.management_helpers.utils import ( add_new_member, @@ -459,7 +460,8 @@ async def user_update( "user_id": "test-litellm-user-4", "user_role": "proxy_admin_viewer" }' - + ``` + Parameters: - user_id: Optional[str] - Specify a user id. If not set, a unique id will be generated. - user_email: Optional[str] - Specify a user email. @@ -491,7 +493,7 @@ async def user_update( - duration: Optional[str] - [NOT IMPLEMENTED]. - key_alias: Optional[str] - [NOT IMPLEMENTED]. - ``` + """ from litellm.proxy.proxy_server import prisma_client @@ -504,10 +506,15 @@ async def user_update( # get non default values for key non_default_values = {} for k, v in data_json.items(): - if v is not None and v not in ( - [], - {}, - 0, + if ( + v is not None + and v + not in ( + [], + {}, + 0, + ) + and k not in LiteLLM_ManagementEndpoint_MetadataFields ): # models default to [], spend defaults to 0, we should not reset these values non_default_values[k] = v @@ -543,6 +550,8 @@ async def user_update( ) non_default_values["budget_reset_at"] = user_reset_at + non_default_values = prepare_metadata_fields(data, non_default_values) + ## ADD USER, IF NEW ## verbose_proxy_logger.debug("/user/update: Received data = %s", data) response: Optional[Any] = None diff --git a/litellm/proxy/management_endpoints/key_management_endpoints.py b/litellm/proxy/management_endpoints/key_management_endpoints.py index 27d1ec0a4..56cf6f383 100644 --- a/litellm/proxy/management_endpoints/key_management_endpoints.py +++ b/litellm/proxy/management_endpoints/key_management_endpoints.py @@ -452,12 +452,39 @@ async def generate_key_fn( # noqa: PLR0915 raise handle_exception_on_proxy(e) +def prepare_metadata_fields(data: BaseModel, non_default_values: dict) -> dict: + """ + Check LiteLLM_ManagementEndpoint_MetadataFields (proxy/_types.py) for fields that are allowed to be updated + """ + non_default_values.setdefault("metadata", {}) + data_json = data.model_dump(exclude_unset=True) + try: + for k, v in data_json.items(): + if k == "model_tpm_limit" or k == "model_rpm_limit": + if k not in non_default_values["metadata"]: + non_default_values["metadata"][k] = {} + non_default_values["metadata"][k].update(v) + + if k == "tags" or k == "guardrails": + if k not in non_default_values["metadata"]: + non_default_values["metadata"][k] = [] + non_default_values["metadata"][k].extend(v) + + except Exception as e: + verbose_proxy_logger.exception( + "litellm.proxy.proxy_server.prepare_metadata_fields(): Exception occured - {}".format( + str(e) + ) + ) + return non_default_values + + def prepare_key_update_data( data: Union[UpdateKeyRequest, RegenerateKeyRequest], existing_key_row ): data_json: dict = data.model_dump(exclude_unset=True) data_json.pop("key", None) - _metadata_fields = ["model_rpm_limit", "model_tpm_limit", "guardrails"] + _metadata_fields = ["model_rpm_limit", "model_tpm_limit", "guardrails", "tags"] non_default_values = {} for k, v in data_json.items(): if k in _metadata_fields: @@ -483,29 +510,9 @@ def prepare_key_update_data( non_default_values["budget_reset_at"] = key_reset_at non_default_values["budget_duration"] = budget_duration - _metadata = existing_key_row.metadata or {} - - if data.model_tpm_limit: - if "model_tpm_limit" not in _metadata: - _metadata["model_tpm_limit"] = {} - _metadata["model_tpm_limit"].update(data.model_tpm_limit) - non_default_values["metadata"] = _metadata - - if data.model_rpm_limit: - if "model_rpm_limit" not in _metadata: - _metadata["model_rpm_limit"] = {} - _metadata["model_rpm_limit"].update(data.model_rpm_limit) - non_default_values["metadata"] = _metadata - - if data.tags: - if "tags" not in _metadata: - _metadata["tags"] = [] - _metadata["tags"].extend(data.tags) - non_default_values["metadata"] = _metadata - - if data.guardrails: - _metadata["guardrails"] = data.guardrails - non_default_values["metadata"] = _metadata + non_default_values = prepare_metadata_fields( + data=data, non_default_values=non_default_values + ) return non_default_values