forked from phoenix/litellm-mirror
fix(key_management_endpoints.py): fix /key/update with metadata update
This commit is contained in:
parent
75446852f8
commit
e86a47b0ed
2 changed files with 20 additions and 13 deletions
|
@ -379,8 +379,8 @@ async def update_key_fn(
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
data_json: dict = data.json()
|
non_default_values: dict = data.model_dump(exclude_unset=True)
|
||||||
key = data_json.pop("key")
|
key = non_default_values.pop("key")
|
||||||
# get the row from db
|
# get the row from db
|
||||||
if prisma_client is None:
|
if prisma_client is None:
|
||||||
raise Exception("Not connected to DB!")
|
raise Exception("Not connected to DB!")
|
||||||
|
@ -395,10 +395,6 @@ async def update_key_fn(
|
||||||
detail={"error": f"Team not found, passed team_id={data.team_id}"},
|
detail={"error": f"Team not found, passed team_id={data.team_id}"},
|
||||||
)
|
)
|
||||||
|
|
||||||
non_default_values = await prepare_key_update_data(
|
|
||||||
data=data, existing_key_row=existing_key_row
|
|
||||||
)
|
|
||||||
|
|
||||||
response = await prisma_client.update_data(
|
response = await prisma_client.update_data(
|
||||||
token=key, data={**non_default_values, "token": key}
|
token=key, data={**non_default_values, "token": key}
|
||||||
)
|
)
|
||||||
|
@ -413,7 +409,7 @@ async def update_key_fn(
|
||||||
|
|
||||||
# Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True
|
# Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True
|
||||||
if litellm.store_audit_logs is True:
|
if litellm.store_audit_logs is True:
|
||||||
_updated_values = json.dumps(data_json, default=str)
|
_updated_values = json.dumps(non_default_values, default=str)
|
||||||
|
|
||||||
_before_value = existing_key_row.json(exclude_none=True)
|
_before_value = existing_key_row.json(exclude_none=True)
|
||||||
_before_value = json.dumps(_before_value, default=str)
|
_before_value = json.dumps(_before_value, default=str)
|
||||||
|
|
|
@ -66,6 +66,7 @@ async def generate_key(
|
||||||
max_parallel_requests: Optional[int] = None,
|
max_parallel_requests: Optional[int] = None,
|
||||||
user_id: Optional[str] = None,
|
user_id: Optional[str] = None,
|
||||||
team_id: Optional[str] = None,
|
team_id: Optional[str] = None,
|
||||||
|
metadata: Optional[dict] = None,
|
||||||
calling_key="sk-1234",
|
calling_key="sk-1234",
|
||||||
):
|
):
|
||||||
url = "http://0.0.0.0:4000/key/generate"
|
url = "http://0.0.0.0:4000/key/generate"
|
||||||
|
@ -82,6 +83,7 @@ async def generate_key(
|
||||||
"max_parallel_requests": max_parallel_requests,
|
"max_parallel_requests": max_parallel_requests,
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
"team_id": team_id,
|
"team_id": team_id,
|
||||||
|
"metadata": metadata,
|
||||||
}
|
}
|
||||||
|
|
||||||
print(f"data: {data}")
|
print(f"data: {data}")
|
||||||
|
@ -136,16 +138,21 @@ async def test_key_gen_bad_key():
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
async def update_key(session, get_key):
|
async def update_key(session, get_key, metadata: Optional[dict] = None):
|
||||||
"""
|
"""
|
||||||
Make sure only models user has access to are returned
|
Make sure only models user has access to are returned
|
||||||
"""
|
"""
|
||||||
url = "http://0.0.0.0:4000/key/update"
|
url = "http://0.0.0.0:4000/key/update"
|
||||||
headers = {
|
headers = {
|
||||||
"Authorization": f"Bearer sk-1234",
|
"Authorization": "Bearer sk-1234",
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
}
|
}
|
||||||
data = {"key": get_key, "models": ["gpt-4"], "duration": "120s"}
|
data = {"key": get_key}
|
||||||
|
|
||||||
|
if metadata is not None:
|
||||||
|
data["metadata"] = metadata
|
||||||
|
else:
|
||||||
|
data.update({"models": ["gpt-4"], "duration": "120s"})
|
||||||
|
|
||||||
async with session.post(url, headers=headers, json=data) as response:
|
async with session.post(url, headers=headers, json=data) as response:
|
||||||
status = response.status
|
status = response.status
|
||||||
|
@ -276,20 +283,24 @@ async def chat_completion_streaming(session, key, model="gpt-4"):
|
||||||
return prompt_tokens, completion_tokens
|
return prompt_tokens, completion_tokens
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("metadata", [{"test": "test"}, None])
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_key_update():
|
async def test_key_update(metadata):
|
||||||
"""
|
"""
|
||||||
Create key
|
Create key
|
||||||
Update key with new model
|
Update key with new model
|
||||||
Test key w/ model
|
Test key w/ model
|
||||||
"""
|
"""
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
key_gen = await generate_key(session=session, i=0)
|
key_gen = await generate_key(session=session, i=0, metadata={"test": "test"})
|
||||||
key = key_gen["key"]
|
key = key_gen["key"]
|
||||||
await update_key(
|
assert key_gen["metadata"]["test"] == "test"
|
||||||
|
updated_key = await update_key(
|
||||||
session=session,
|
session=session,
|
||||||
get_key=key,
|
get_key=key,
|
||||||
|
metadata=metadata,
|
||||||
)
|
)
|
||||||
|
assert updated_key["metadata"] == metadata
|
||||||
await update_proxy_budget(session=session) # resets proxy spend
|
await update_proxy_budget(session=session) # resets proxy spend
|
||||||
await chat_completion(session=session, key=key)
|
await chat_completion(session=session, key=key)
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue