From ec0f2abae2e658a4a7d552f1401ae86335e65a1b Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 30 Nov 2024 13:06:05 -0800 Subject: [PATCH] fix(key_management_endpoints.py): fix metadata field update logic --- .../management_endpoints/internal_user_endpoints.py | 12 +++++++++++- .../management_endpoints/key_management_endpoints.py | 10 ++++++++-- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/litellm/proxy/management_endpoints/internal_user_endpoints.py b/litellm/proxy/management_endpoints/internal_user_endpoints.py index 64b754ca7..c4d58e76a 100644 --- a/litellm/proxy/management_endpoints/internal_user_endpoints.py +++ b/litellm/proxy/management_endpoints/internal_user_endpoints.py @@ -550,7 +550,17 @@ async def user_update( ) non_default_values["budget_reset_at"] = user_reset_at - non_default_values = prepare_metadata_fields(data, non_default_values) + existing_user_row = await prisma_client.get_data( + user_id=data.user_id, table_name="user", query_type="find_unique" + ) + + existing_metadata = existing_user_row.metadata if existing_user_row else {} + + non_default_values = prepare_metadata_fields( + data=data, + non_default_values=non_default_values, + existing_metadata=existing_metadata or {}, + ) ## ADD USER, IF NEW ## verbose_proxy_logger.debug("/user/update: Received data = %s", data) diff --git a/litellm/proxy/management_endpoints/key_management_endpoints.py b/litellm/proxy/management_endpoints/key_management_endpoints.py index 1aa0cd042..cec7d3d91 100644 --- a/litellm/proxy/management_endpoints/key_management_endpoints.py +++ b/litellm/proxy/management_endpoints/key_management_endpoints.py @@ -452,12 +452,16 @@ async def generate_key_fn( # noqa: PLR0915 raise handle_exception_on_proxy(e) -def prepare_metadata_fields(data: BaseModel, non_default_values: dict) -> dict: +def prepare_metadata_fields( + data: BaseModel, non_default_values: dict, existing_metadata: dict +) -> dict: """ Check LiteLLM_ManagementEndpoint_MetadataFields (proxy/_types.py) for fields that are allowed to be updated """ non_default_values.setdefault("metadata", {}) + non_default_values["metadata"].update(existing_metadata) data_json = data.model_dump(exclude_unset=True) + try: for k, v in data_json.items(): if k == "model_tpm_limit" or k == "model_rpm_limit": @@ -510,8 +514,10 @@ def prepare_key_update_data( non_default_values["budget_reset_at"] = key_reset_at non_default_values["budget_duration"] = budget_duration + _metadata = existing_key_row.metadata or {} + non_default_values = prepare_metadata_fields( - data=data, non_default_values=non_default_values + data=data, non_default_values=non_default_values, existing_metadata=_metadata ) return non_default_values