diff --git a/litellm/proxy/management_endpoints/key_management_endpoints.py b/litellm/proxy/management_endpoints/key_management_endpoints.py index 2c240a17f..917456fad 100644 --- a/litellm/proxy/management_endpoints/key_management_endpoints.py +++ b/litellm/proxy/management_endpoints/key_management_endpoints.py @@ -379,8 +379,8 @@ async def update_key_fn( ) try: - data_json: dict = data.json() - key = data_json.pop("key") + non_default_values: dict = data.model_dump(exclude_unset=True) + key = non_default_values.pop("key") # get the row from db if prisma_client is None: raise Exception("Not connected to DB!") @@ -395,10 +395,6 @@ async def update_key_fn( detail={"error": f"Team not found, passed team_id={data.team_id}"}, ) - non_default_values = await prepare_key_update_data( - data=data, existing_key_row=existing_key_row - ) - response = await prisma_client.update_data( token=key, data={**non_default_values, "token": key} ) @@ -413,7 +409,7 @@ async def update_key_fn( # Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True if litellm.store_audit_logs is True: - _updated_values = json.dumps(data_json, default=str) + _updated_values = json.dumps(non_default_values, default=str) _before_value = existing_key_row.json(exclude_none=True) _before_value = json.dumps(_before_value, default=str) diff --git a/tests/test_keys.py b/tests/test_keys.py index 554a084c9..84344386d 100644 --- a/tests/test_keys.py +++ b/tests/test_keys.py @@ -66,6 +66,7 @@ async def generate_key( max_parallel_requests: Optional[int] = None, user_id: Optional[str] = None, team_id: Optional[str] = None, + metadata: Optional[dict] = None, calling_key="sk-1234", ): url = "http://0.0.0.0:4000/key/generate" @@ -82,6 +83,7 @@ async def generate_key( "max_parallel_requests": max_parallel_requests, "user_id": user_id, "team_id": team_id, + "metadata": metadata, } print(f"data: {data}") @@ -136,16 +138,21 @@ async def test_key_gen_bad_key(): pass -async def update_key(session, get_key): +async def update_key(session, get_key, metadata: Optional[dict] = None): """ Make sure only models user has access to are returned """ url = "http://0.0.0.0:4000/key/update" headers = { - "Authorization": f"Bearer sk-1234", + "Authorization": "Bearer sk-1234", "Content-Type": "application/json", } - data = {"key": get_key, "models": ["gpt-4"], "duration": "120s"} + data = {"key": get_key} + + if metadata is not None: + data["metadata"] = metadata + else: + data.update({"models": ["gpt-4"], "duration": "120s"}) async with session.post(url, headers=headers, json=data) as response: status = response.status @@ -276,20 +283,24 @@ async def chat_completion_streaming(session, key, model="gpt-4"): return prompt_tokens, completion_tokens +@pytest.mark.parametrize("metadata", [{"test": "test"}, None]) @pytest.mark.asyncio -async def test_key_update(): +async def test_key_update(metadata): """ Create key Update key with new model Test key w/ model """ async with aiohttp.ClientSession() as session: - key_gen = await generate_key(session=session, i=0) + key_gen = await generate_key(session=session, i=0, metadata={"test": "test"}) key = key_gen["key"] - await update_key( + assert key_gen["metadata"]["test"] == "test" + updated_key = await update_key( session=session, get_key=key, + metadata=metadata, ) + assert updated_key["metadata"] == metadata await update_proxy_budget(session=session) # resets proxy spend await chat_completion(session=session, key=key)