fix(key_management_endpoints.py): fix metadata field update logic

This commit is contained in:
Krrish Dholakia 2024-11-30 13:06:05 -08:00
parent b2abc61cc9
commit ec0f2abae2
2 changed files with 19 additions and 3 deletions

View file

@ -550,7 +550,17 @@ async def user_update(
)
non_default_values["budget_reset_at"] = user_reset_at
non_default_values = prepare_metadata_fields(data, non_default_values)
existing_user_row = await prisma_client.get_data(
user_id=data.user_id, table_name="user", query_type="find_unique"
)
existing_metadata = existing_user_row.metadata if existing_user_row else {}
non_default_values = prepare_metadata_fields(
data=data,
non_default_values=non_default_values,
existing_metadata=existing_metadata or {},
)
## ADD USER, IF NEW ##
verbose_proxy_logger.debug("/user/update: Received data = %s", data)

View file

@ -452,12 +452,16 @@ async def generate_key_fn( # noqa: PLR0915
raise handle_exception_on_proxy(e)
def prepare_metadata_fields(data: BaseModel, non_default_values: dict) -> dict:
def prepare_metadata_fields(
data: BaseModel, non_default_values: dict, existing_metadata: dict
) -> dict:
"""
Check LiteLLM_ManagementEndpoint_MetadataFields (proxy/_types.py) for fields that are allowed to be updated
"""
non_default_values.setdefault("metadata", {})
non_default_values["metadata"].update(existing_metadata)
data_json = data.model_dump(exclude_unset=True)
try:
for k, v in data_json.items():
if k == "model_tpm_limit" or k == "model_rpm_limit":
@ -510,8 +514,10 @@ def prepare_key_update_data(
non_default_values["budget_reset_at"] = key_reset_at
non_default_values["budget_duration"] = budget_duration
_metadata = existing_key_row.metadata or {}
non_default_values = prepare_metadata_fields(
data=data, non_default_values=non_default_values
data=data, non_default_values=non_default_values, existing_metadata=_metadata
)
return non_default_values