mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 19:24:27 +00:00
(Bug fix) - Allow setting null
for max_budget
, rpm_limit
, tpm_limit
when updating values on a team (#7912)
* fix update_team * fix test_key_limit_modifications
This commit is contained in:
parent
6b810bd815
commit
b056bc0ac3
2 changed files with 73 additions and 1 deletions
|
@ -461,7 +461,7 @@ async def update_team(
|
||||||
detail={"error": f"Team not found, passed team_id={data.team_id}"},
|
detail={"error": f"Team not found, passed team_id={data.team_id}"},
|
||||||
)
|
)
|
||||||
|
|
||||||
updated_kv = data.json(exclude_none=True)
|
updated_kv = data.json(exclude_unset=True)
|
||||||
|
|
||||||
# Check budget_duration and budget_reset_at
|
# Check budget_duration and budget_reset_at
|
||||||
if data.budget_duration is not None:
|
if data.budget_duration is not None:
|
||||||
|
|
|
@ -2,6 +2,7 @@ import pytest
|
||||||
import asyncio
|
import asyncio
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import json
|
import json
|
||||||
|
from httpx import AsyncClient
|
||||||
|
|
||||||
|
|
||||||
async def make_calls_until_budget_exceeded(session, key: str, call_function, **kwargs):
|
async def make_calls_until_budget_exceeded(session, key: str, call_function, **kwargs):
|
||||||
|
@ -206,3 +207,74 @@ async def test_chat_completion_budget_update():
|
||||||
pytest.fail(
|
pytest.fail(
|
||||||
f"Request should succeed after budget update but got error: {e}"
|
f"Request should succeed after budget update but got error: {e}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"field",
|
||||||
|
[
|
||||||
|
"max_budget",
|
||||||
|
"rpm_limit",
|
||||||
|
"tpm_limit",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_key_limit_modifications(field):
|
||||||
|
# Create initial key
|
||||||
|
client = AsyncClient(base_url="http://0.0.0.0:4000")
|
||||||
|
key_data = {"max_budget": None, "rpm_limit": None, "tpm_limit": None}
|
||||||
|
headers = {"Authorization": "Bearer sk-1234"}
|
||||||
|
response = await client.post("/key/generate", json=key_data, headers=headers)
|
||||||
|
assert response.status_code == 200
|
||||||
|
generate_key_response = response.json()
|
||||||
|
print("generate_key_response: ", json.dumps(generate_key_response, indent=4))
|
||||||
|
key_id = generate_key_response["key"]
|
||||||
|
|
||||||
|
# Update key with any non-null value for the field
|
||||||
|
update_data = {"key": key_id}
|
||||||
|
update_data[field] = 10 # Any non-null value works
|
||||||
|
print("update_data: ", json.dumps(update_data, indent=4))
|
||||||
|
response = await client.post(f"/key/update", json=update_data, headers=headers)
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()[field] is not None
|
||||||
|
|
||||||
|
# Reset limit to null
|
||||||
|
print(f"resetting {field} to null")
|
||||||
|
update_data[field] = None
|
||||||
|
response = await client.post(f"/key/update", json=update_data, headers=headers)
|
||||||
|
print("response: ", json.dumps(response.json(), indent=4))
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()[field] is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"field",
|
||||||
|
[
|
||||||
|
"max_budget",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_team_limit_modifications(field):
|
||||||
|
# Create initial team
|
||||||
|
client = AsyncClient(base_url="http://0.0.0.0:4000")
|
||||||
|
team_data = {"max_budget": None, "rpm_limit": None, "tpm_limit": None}
|
||||||
|
headers = {"Authorization": "Bearer sk-1234"}
|
||||||
|
response = await client.post("/team/new", json=team_data, headers=headers)
|
||||||
|
print("response: ", json.dumps(response.json(), indent=4))
|
||||||
|
assert response.status_code == 200
|
||||||
|
team_id = response.json()["team_id"]
|
||||||
|
|
||||||
|
# Update team with any non-null value for the field
|
||||||
|
update_data = {"team_id": team_id}
|
||||||
|
update_data[field] = 10 # Any non-null value works
|
||||||
|
response = await client.post(f"/team/update", json=update_data, headers=headers)
|
||||||
|
print("response: ", json.dumps(response.json(), indent=4))
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["data"][field] is not None
|
||||||
|
|
||||||
|
# Reset limit to null
|
||||||
|
print(f"resetting {field} to null")
|
||||||
|
update_data[field] = None
|
||||||
|
response = await client.post(f"/team/update", json=update_data, headers=headers)
|
||||||
|
print("response: ", json.dumps(response.json(), indent=4))
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["data"][field] is None
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue