From 65ad44aebdd7991b989d4d38676d203ba649b570 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 30 Nov 2024 17:52:36 -0800 Subject: [PATCH] fix: fix key management errors --- .../key_management_endpoints.py | 27 +++++------- .../test_key_management.py | 44 +++++++++++++++++++ tests/test_keys.py | 1 + 3 files changed, 57 insertions(+), 15 deletions(-) diff --git a/litellm/proxy/management_endpoints/key_management_endpoints.py b/litellm/proxy/management_endpoints/key_management_endpoints.py index f6a0456eb..c18ea3b7e 100644 --- a/litellm/proxy/management_endpoints/key_management_endpoints.py +++ b/litellm/proxy/management_endpoints/key_management_endpoints.py @@ -394,7 +394,8 @@ async def generate_key_fn( # noqa: PLR0915 } ) _budget_id = getattr(_budget, "budget_id", None) - data_json = data.json() # type: ignore + data_json = data.model_dump(exclude_unset=True, exclude_none=True) # type: ignore + # if we get max_budget passed to /key/generate, then use it as key_max_budget. Since generate_key_helper_fn is used to make new users if "max_budget" in data_json: data_json["key_max_budget"] = data_json.pop("max_budget", None) @@ -464,24 +465,20 @@ def prepare_metadata_fields( # Handle None cases for metadata if non_default_values["metadata"] is None: non_default_values["metadata"] = existing_metadata.copy() - else: - # Create a copy to avoid modifying the original - non_default_values["metadata"] = non_default_values["metadata"].copy() - non_default_values["metadata"].update(existing_metadata) casted_metadata = cast(dict, non_default_values["metadata"]) - data_json = data.model_dump(exclude_unset=True) + data_json = data.model_dump(exclude_unset=True, exclude_none=True) try: for k, v in data_json.items(): if k == "model_tpm_limit" or k == "model_rpm_limit": - if k not in casted_metadata: + if k not in casted_metadata or casted_metadata[k] is None: casted_metadata[k] = {} casted_metadata[k].update(v) if k == "tags" or k == "guardrails": - if k not in casted_metadata: + if k not in casted_metadata or casted_metadata[k] is None: casted_metadata[k] = [] seen = set(casted_metadata[k]) casted_metadata[k].extend( @@ -959,11 +956,11 @@ async def generate_key_helper_fn( # noqa: PLR0915 request_type: Literal[ "user", "key" ], # identifies if this request is from /user/new or /key/generate - duration: Optional[str], - models: list, - aliases: dict, - config: dict, - spend: float, + duration: Optional[str] = None, + models: list = [], + aliases: dict = {}, + config: dict = {}, + spend: float = 0.0, key_max_budget: Optional[float] = None, # key_max_budget is used to Budget Per key key_budget_duration: Optional[str] = None, budget_id: Optional[float] = None, # budget id <-> LiteLLM_BudgetTable @@ -992,8 +989,8 @@ async def generate_key_helper_fn( # noqa: PLR0915 allowed_cache_controls: Optional[list] = [], permissions: Optional[dict] = {}, model_max_budget: Optional[dict] = {}, - model_rpm_limit: Optional[dict] = {}, - model_tpm_limit: Optional[dict] = {}, + model_rpm_limit: Optional[dict] = None, + model_tpm_limit: Optional[dict] = None, guardrails: Optional[list] = None, teams: Optional[list] = None, organization_id: Optional[str] = None, diff --git a/tests/proxy_admin_ui_tests/test_key_management.py b/tests/proxy_admin_ui_tests/test_key_management.py index d0b1ab294..7a2764e3f 100644 --- a/tests/proxy_admin_ui_tests/test_key_management.py +++ b/tests/proxy_admin_ui_tests/test_key_management.py @@ -693,3 +693,47 @@ def test_personal_key_generation_check(): ), data=GenerateKeyRequest(), ) + + +def test_prepare_metadata_fields(): + from litellm.proxy.management_endpoints.key_management_endpoints import ( + prepare_metadata_fields, + ) + + new_metadata = {"test": "new"} + old_metadata = {"test": "test"} + + args = { + "data": UpdateKeyRequest( + key_alias=None, + duration=None, + models=[], + spend=None, + max_budget=None, + user_id=None, + team_id=None, + max_parallel_requests=None, + metadata=new_metadata, + tpm_limit=None, + rpm_limit=None, + budget_duration=None, + allowed_cache_controls=[], + soft_budget=None, + config={}, + permissions={}, + model_max_budget={}, + send_invite_email=None, + model_rpm_limit=None, + model_tpm_limit=None, + guardrails=None, + blocked=None, + aliases={}, + key="sk-1qGQUJJTcljeaPfzgWRrXQ", + tags=None, + ), + "non_default_values": {"metadata": new_metadata}, + "existing_metadata": {"tags": None, **old_metadata}, + } + + non_default_values = prepare_metadata_fields(**args) + assert non_default_values == {"metadata": new_metadata} diff --git a/tests/test_keys.py b/tests/test_keys.py index a569634bc..eaf9369d8 100644 --- a/tests/test_keys.py +++ b/tests/test_keys.py @@ -300,6 +300,7 @@ async def test_key_update(metadata): get_key=key, metadata=metadata, ) + print(f"updated_key['metadata']: {updated_key['metadata']}") assert updated_key["metadata"] == metadata await update_proxy_budget(session=session) # resets proxy spend await chat_completion(session=session, key=key)