forked from phoenix/litellm-mirror
fix(key_management_endpoints.py): fix metadata field update logic
This commit is contained in:
parent
b2abc61cc9
commit
ec0f2abae2
2 changed files with 19 additions and 3 deletions
|
@ -550,7 +550,17 @@ async def user_update(
|
||||||
)
|
)
|
||||||
non_default_values["budget_reset_at"] = user_reset_at
|
non_default_values["budget_reset_at"] = user_reset_at
|
||||||
|
|
||||||
non_default_values = prepare_metadata_fields(data, non_default_values)
|
existing_user_row = await prisma_client.get_data(
|
||||||
|
user_id=data.user_id, table_name="user", query_type="find_unique"
|
||||||
|
)
|
||||||
|
|
||||||
|
existing_metadata = existing_user_row.metadata if existing_user_row else {}
|
||||||
|
|
||||||
|
non_default_values = prepare_metadata_fields(
|
||||||
|
data=data,
|
||||||
|
non_default_values=non_default_values,
|
||||||
|
existing_metadata=existing_metadata or {},
|
||||||
|
)
|
||||||
|
|
||||||
## ADD USER, IF NEW ##
|
## ADD USER, IF NEW ##
|
||||||
verbose_proxy_logger.debug("/user/update: Received data = %s", data)
|
verbose_proxy_logger.debug("/user/update: Received data = %s", data)
|
||||||
|
|
|
@ -452,12 +452,16 @@ async def generate_key_fn( # noqa: PLR0915
|
||||||
raise handle_exception_on_proxy(e)
|
raise handle_exception_on_proxy(e)
|
||||||
|
|
||||||
|
|
||||||
def prepare_metadata_fields(data: BaseModel, non_default_values: dict) -> dict:
|
def prepare_metadata_fields(
|
||||||
|
data: BaseModel, non_default_values: dict, existing_metadata: dict
|
||||||
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
Check LiteLLM_ManagementEndpoint_MetadataFields (proxy/_types.py) for fields that are allowed to be updated
|
Check LiteLLM_ManagementEndpoint_MetadataFields (proxy/_types.py) for fields that are allowed to be updated
|
||||||
"""
|
"""
|
||||||
non_default_values.setdefault("metadata", {})
|
non_default_values.setdefault("metadata", {})
|
||||||
|
non_default_values["metadata"].update(existing_metadata)
|
||||||
data_json = data.model_dump(exclude_unset=True)
|
data_json = data.model_dump(exclude_unset=True)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for k, v in data_json.items():
|
for k, v in data_json.items():
|
||||||
if k == "model_tpm_limit" or k == "model_rpm_limit":
|
if k == "model_tpm_limit" or k == "model_rpm_limit":
|
||||||
|
@ -510,8 +514,10 @@ def prepare_key_update_data(
|
||||||
non_default_values["budget_reset_at"] = key_reset_at
|
non_default_values["budget_reset_at"] = key_reset_at
|
||||||
non_default_values["budget_duration"] = budget_duration
|
non_default_values["budget_duration"] = budget_duration
|
||||||
|
|
||||||
|
_metadata = existing_key_row.metadata or {}
|
||||||
|
|
||||||
non_default_values = prepare_metadata_fields(
|
non_default_values = prepare_metadata_fields(
|
||||||
data=data, non_default_values=non_default_values
|
data=data, non_default_values=non_default_values, existing_metadata=_metadata
|
||||||
)
|
)
|
||||||
|
|
||||||
return non_default_values
|
return non_default_values
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue