fix(key_management_endpoints.py): fix /key/update with metadata update

This commit is contained in:
Krrish Dholakia 2024-11-12 22:58:18 +05:30
parent 75446852f8
commit e86a47b0ed
2 changed files with 20 additions and 13 deletions

View file

@ -379,8 +379,8 @@ async def update_key_fn(
) )
try: try:
data_json: dict = data.json() non_default_values: dict = data.model_dump(exclude_unset=True)
key = data_json.pop("key") key = non_default_values.pop("key")
# get the row from db # get the row from db
if prisma_client is None: if prisma_client is None:
raise Exception("Not connected to DB!") raise Exception("Not connected to DB!")
@ -395,10 +395,6 @@ async def update_key_fn(
detail={"error": f"Team not found, passed team_id={data.team_id}"}, detail={"error": f"Team not found, passed team_id={data.team_id}"},
) )
non_default_values = await prepare_key_update_data(
data=data, existing_key_row=existing_key_row
)
response = await prisma_client.update_data( response = await prisma_client.update_data(
token=key, data={**non_default_values, "token": key} token=key, data={**non_default_values, "token": key}
) )
@ -413,7 +409,7 @@ async def update_key_fn(
# Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True # Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True
if litellm.store_audit_logs is True: if litellm.store_audit_logs is True:
_updated_values = json.dumps(data_json, default=str) _updated_values = json.dumps(non_default_values, default=str)
_before_value = existing_key_row.json(exclude_none=True) _before_value = existing_key_row.json(exclude_none=True)
_before_value = json.dumps(_before_value, default=str) _before_value = json.dumps(_before_value, default=str)

View file

@ -66,6 +66,7 @@ async def generate_key(
max_parallel_requests: Optional[int] = None, max_parallel_requests: Optional[int] = None,
user_id: Optional[str] = None, user_id: Optional[str] = None,
team_id: Optional[str] = None, team_id: Optional[str] = None,
metadata: Optional[dict] = None,
calling_key="sk-1234", calling_key="sk-1234",
): ):
url = "http://0.0.0.0:4000/key/generate" url = "http://0.0.0.0:4000/key/generate"
@ -82,6 +83,7 @@ async def generate_key(
"max_parallel_requests": max_parallel_requests, "max_parallel_requests": max_parallel_requests,
"user_id": user_id, "user_id": user_id,
"team_id": team_id, "team_id": team_id,
"metadata": metadata,
} }
print(f"data: {data}") print(f"data: {data}")
@ -136,16 +138,21 @@ async def test_key_gen_bad_key():
pass pass
async def update_key(session, get_key): async def update_key(session, get_key, metadata: Optional[dict] = None):
""" """
Make sure only models user has access to are returned Make sure only models user has access to are returned
""" """
url = "http://0.0.0.0:4000/key/update" url = "http://0.0.0.0:4000/key/update"
headers = { headers = {
"Authorization": f"Bearer sk-1234", "Authorization": "Bearer sk-1234",
"Content-Type": "application/json", "Content-Type": "application/json",
} }
data = {"key": get_key, "models": ["gpt-4"], "duration": "120s"} data = {"key": get_key}
if metadata is not None:
data["metadata"] = metadata
else:
data.update({"models": ["gpt-4"], "duration": "120s"})
async with session.post(url, headers=headers, json=data) as response: async with session.post(url, headers=headers, json=data) as response:
status = response.status status = response.status
@ -276,20 +283,24 @@ async def chat_completion_streaming(session, key, model="gpt-4"):
return prompt_tokens, completion_tokens return prompt_tokens, completion_tokens
@pytest.mark.parametrize("metadata", [{"test": "test"}, None])
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_key_update(): async def test_key_update(metadata):
""" """
Create key Create key
Update key with new model Update key with new model
Test key w/ model Test key w/ model
""" """
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
key_gen = await generate_key(session=session, i=0) key_gen = await generate_key(session=session, i=0, metadata={"test": "test"})
key = key_gen["key"] key = key_gen["key"]
await update_key( assert key_gen["metadata"]["test"] == "test"
updated_key = await update_key(
session=session, session=session,
get_key=key, get_key=key,
metadata=metadata,
) )
assert updated_key["metadata"] == metadata
await update_proxy_budget(session=session) # resets proxy spend await update_proxy_budget(session=session) # resets proxy spend
await chat_completion(session=session, key=key) await chat_completion(session=session, key=key)