fix: fix key management errors

This commit is contained in:
Krrish Dholakia 2024-11-30 17:52:36 -08:00
parent 84f3ac7d25
commit 65ad44aebd
3 changed files with 57 additions and 15 deletions

View file

@ -394,7 +394,8 @@ async def generate_key_fn( # noqa: PLR0915
} }
) )
_budget_id = getattr(_budget, "budget_id", None) _budget_id = getattr(_budget, "budget_id", None)
data_json = data.json() # type: ignore data_json = data.model_dump(exclude_unset=True, exclude_none=True) # type: ignore
# if we get max_budget passed to /key/generate, then use it as key_max_budget. Since generate_key_helper_fn is used to make new users # if we get max_budget passed to /key/generate, then use it as key_max_budget. Since generate_key_helper_fn is used to make new users
if "max_budget" in data_json: if "max_budget" in data_json:
data_json["key_max_budget"] = data_json.pop("max_budget", None) data_json["key_max_budget"] = data_json.pop("max_budget", None)
@ -464,24 +465,20 @@ def prepare_metadata_fields(
# Handle None cases for metadata # Handle None cases for metadata
if non_default_values["metadata"] is None: if non_default_values["metadata"] is None:
non_default_values["metadata"] = existing_metadata.copy() non_default_values["metadata"] = existing_metadata.copy()
else:
# Create a copy to avoid modifying the original
non_default_values["metadata"] = non_default_values["metadata"].copy()
non_default_values["metadata"].update(existing_metadata)
casted_metadata = cast(dict, non_default_values["metadata"]) casted_metadata = cast(dict, non_default_values["metadata"])
data_json = data.model_dump(exclude_unset=True) data_json = data.model_dump(exclude_unset=True, exclude_none=True)
try: try:
for k, v in data_json.items(): for k, v in data_json.items():
if k == "model_tpm_limit" or k == "model_rpm_limit": if k == "model_tpm_limit" or k == "model_rpm_limit":
if k not in casted_metadata: if k not in casted_metadata or casted_metadata[k] is None:
casted_metadata[k] = {} casted_metadata[k] = {}
casted_metadata[k].update(v) casted_metadata[k].update(v)
if k == "tags" or k == "guardrails": if k == "tags" or k == "guardrails":
if k not in casted_metadata: if k not in casted_metadata or casted_metadata[k] is None:
casted_metadata[k] = [] casted_metadata[k] = []
seen = set(casted_metadata[k]) seen = set(casted_metadata[k])
casted_metadata[k].extend( casted_metadata[k].extend(
@ -959,11 +956,11 @@ async def generate_key_helper_fn( # noqa: PLR0915
request_type: Literal[ request_type: Literal[
"user", "key" "user", "key"
], # identifies if this request is from /user/new or /key/generate ], # identifies if this request is from /user/new or /key/generate
duration: Optional[str], duration: Optional[str] = None,
models: list, models: list = [],
aliases: dict, aliases: dict = {},
config: dict, config: dict = {},
spend: float, spend: float = 0.0,
key_max_budget: Optional[float] = None, # key_max_budget is used to Budget Per key key_max_budget: Optional[float] = None, # key_max_budget is used to Budget Per key
key_budget_duration: Optional[str] = None, key_budget_duration: Optional[str] = None,
budget_id: Optional[float] = None, # budget id <-> LiteLLM_BudgetTable budget_id: Optional[float] = None, # budget id <-> LiteLLM_BudgetTable
@ -992,8 +989,8 @@ async def generate_key_helper_fn( # noqa: PLR0915
allowed_cache_controls: Optional[list] = [], allowed_cache_controls: Optional[list] = [],
permissions: Optional[dict] = {}, permissions: Optional[dict] = {},
model_max_budget: Optional[dict] = {}, model_max_budget: Optional[dict] = {},
model_rpm_limit: Optional[dict] = {}, model_rpm_limit: Optional[dict] = None,
model_tpm_limit: Optional[dict] = {}, model_tpm_limit: Optional[dict] = None,
guardrails: Optional[list] = None, guardrails: Optional[list] = None,
teams: Optional[list] = None, teams: Optional[list] = None,
organization_id: Optional[str] = None, organization_id: Optional[str] = None,

View file

@ -693,3 +693,47 @@ def test_personal_key_generation_check():
), ),
data=GenerateKeyRequest(), data=GenerateKeyRequest(),
) )
def test_prepare_metadata_fields():
from litellm.proxy.management_endpoints.key_management_endpoints import (
prepare_metadata_fields,
)
new_metadata = {"test": "new"}
old_metadata = {"test": "test"}
args = {
"data": UpdateKeyRequest(
key_alias=None,
duration=None,
models=[],
spend=None,
max_budget=None,
user_id=None,
team_id=None,
max_parallel_requests=None,
metadata=new_metadata,
tpm_limit=None,
rpm_limit=None,
budget_duration=None,
allowed_cache_controls=[],
soft_budget=None,
config={},
permissions={},
model_max_budget={},
send_invite_email=None,
model_rpm_limit=None,
model_tpm_limit=None,
guardrails=None,
blocked=None,
aliases={},
key="sk-1qGQUJJTcljeaPfzgWRrXQ",
tags=None,
),
"non_default_values": {"metadata": new_metadata},
"existing_metadata": {"tags": None, **old_metadata},
}
non_default_values = prepare_metadata_fields(**args)
assert non_default_values == {"metadata": new_metadata}

View file

@ -300,6 +300,7 @@ async def test_key_update(metadata):
get_key=key, get_key=key,
metadata=metadata, metadata=metadata,
) )
print(f"updated_key['metadata']: {updated_key['metadata']}")
assert updated_key["metadata"] == metadata assert updated_key["metadata"] == metadata
await update_proxy_budget(session=session) # resets proxy spend await update_proxy_budget(session=session) # resets proxy spend
await chat_completion(session=session, key=key) await chat_completion(session=session, key=key)