diff --git a/litellm/proxy/management_endpoints/internal_user_endpoints.py b/litellm/proxy/management_endpoints/internal_user_endpoints.py index c4d58e76a..857399034 100644 --- a/litellm/proxy/management_endpoints/internal_user_endpoints.py +++ b/litellm/proxy/management_endpoints/internal_user_endpoints.py @@ -43,7 +43,7 @@ from litellm.proxy.utils import handle_exception_on_proxy router = APIRouter() -def _update_internal_user_params(data_json: dict, data: NewUserRequest) -> dict: +def _update_internal_new_user_params(data_json: dict, data: NewUserRequest) -> dict: if "user_id" in data_json and data_json["user_id"] is None: data_json["user_id"] = str(uuid.uuid4()) auto_create_key = data_json.pop("auto_create_key", True) @@ -146,7 +146,7 @@ async def new_user( from litellm.proxy.proxy_server import general_settings, proxy_logging_obj data_json = data.json() # type: ignore - data_json = _update_internal_user_params(data_json, data) + data_json = _update_internal_new_user_params(data_json, data) response = await generate_key_helper_fn(request_type="user", **data_json) # Admin UI Logic @@ -439,6 +439,52 @@ async def user_info( # noqa: PLR0915 raise handle_exception_on_proxy(e) +def _update_internal_user_params(data_json: dict, data: UpdateUserRequest) -> dict: + non_default_values = {} + for k, v in data_json.items(): + if ( + v is not None + and v + not in ( + [], + {}, + 0, + ) + and k not in LiteLLM_ManagementEndpoint_MetadataFields + ): # models default to [], spend defaults to 0, we should not reset these values + non_default_values[k] = v + + is_internal_user = False + if data.user_role == LitellmUserRoles.INTERNAL_USER: + is_internal_user = True + + if "budget_duration" in non_default_values: + duration_s = duration_in_seconds(duration=non_default_values["budget_duration"]) + user_reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s) + non_default_values["budget_reset_at"] = user_reset_at + + if "max_budget" not in non_default_values: + if ( + is_internal_user and litellm.max_internal_user_budget is not None + ): # applies internal user limits, if user role updated + non_default_values["max_budget"] = litellm.max_internal_user_budget + + if ( + "budget_duration" not in non_default_values + ): # applies internal user limits, if user role updated + if is_internal_user and litellm.internal_user_budget_duration is not None: + non_default_values["budget_duration"] = ( + litellm.internal_user_budget_duration + ) + duration_s = duration_in_seconds( + duration=non_default_values["budget_duration"] + ) + user_reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s) + non_default_values["budget_reset_at"] = user_reset_at + + return non_default_values + + @router.post( "/user/update", tags=["Internal User management"], @@ -504,51 +550,9 @@ async def user_update( raise Exception("Not connected to DB!") # get non default values for key - non_default_values = {} - for k, v in data_json.items(): - if ( - v is not None - and v - not in ( - [], - {}, - 0, - ) - and k not in LiteLLM_ManagementEndpoint_MetadataFields - ): # models default to [], spend defaults to 0, we should not reset these values - non_default_values[k] = v - - is_internal_user = False - if data.user_role == LitellmUserRoles.INTERNAL_USER: - is_internal_user = True - - if "budget_duration" in non_default_values: - duration_s = duration_in_seconds( - duration=non_default_values["budget_duration"] - ) - user_reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s) - non_default_values["budget_reset_at"] = user_reset_at - - if "max_budget" not in non_default_values: - if ( - is_internal_user and litellm.max_internal_user_budget is not None - ): # applies internal user limits, if user role updated - non_default_values["max_budget"] = litellm.max_internal_user_budget - - if ( - "budget_duration" not in non_default_values - ): # applies internal user limits, if user role updated - if is_internal_user and litellm.internal_user_budget_duration is not None: - non_default_values["budget_duration"] = ( - litellm.internal_user_budget_duration - ) - duration_s = duration_in_seconds( - duration=non_default_values["budget_duration"] - ) - user_reset_at = datetime.now(timezone.utc) + timedelta( - seconds=duration_s - ) - non_default_values["budget_reset_at"] = user_reset_at + non_default_values = _update_internal_user_params( + data_json=data_json, data=data + ) existing_user_row = await prisma_client.get_data( user_id=data.user_id, table_name="user", query_type="find_unique" diff --git a/litellm/proxy/management_endpoints/key_management_endpoints.py b/litellm/proxy/management_endpoints/key_management_endpoints.py index ad90b22bb..f6a0456eb 100644 --- a/litellm/proxy/management_endpoints/key_management_endpoints.py +++ b/litellm/proxy/management_endpoints/key_management_endpoints.py @@ -17,7 +17,7 @@ import secrets import traceback import uuid from datetime import datetime, timedelta, timezone -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, cast import fastapi from fastapi import APIRouter, Depends, Header, HTTPException, Query, Request, status @@ -469,20 +469,22 @@ def prepare_metadata_fields( non_default_values["metadata"] = non_default_values["metadata"].copy() non_default_values["metadata"].update(existing_metadata) + casted_metadata = cast(dict, non_default_values["metadata"]) + data_json = data.model_dump(exclude_unset=True) try: for k, v in data_json.items(): if k == "model_tpm_limit" or k == "model_rpm_limit": - if k not in non_default_values["metadata"]: - non_default_values["metadata"][k] = {} - non_default_values["metadata"][k].update(v) + if k not in casted_metadata: + casted_metadata[k] = {} + casted_metadata[k].update(v) if k == "tags" or k == "guardrails": - if k not in non_default_values["metadata"]: - non_default_values["metadata"][k] = [] - seen = set(non_default_values["metadata"][k]) - non_default_values["metadata"][k].extend( + if k not in casted_metadata: + casted_metadata[k] = [] + seen = set(casted_metadata[k]) + casted_metadata[k].extend( x for x in v if x not in seen and not seen.add(x) # type: ignore ) # prevent duplicates from being added + maintain initial order @@ -492,6 +494,8 @@ def prepare_metadata_fields( str(e) ) ) + + non_default_values["metadata"] = casted_metadata return non_default_values