diff --git a/litellm/proxy/management_endpoints/team_endpoints.py b/litellm/proxy/management_endpoints/team_endpoints.py index ea07a90ad7..6ebfed0eef 100644 --- a/litellm/proxy/management_endpoints/team_endpoints.py +++ b/litellm/proxy/management_endpoints/team_endpoints.py @@ -461,7 +461,7 @@ async def update_team( detail={"error": f"Team not found, passed team_id={data.team_id}"}, ) - updated_kv = data.json(exclude_none=True) + updated_kv = data.json(exclude_unset=True) # Check budget_duration and budget_reset_at if data.budget_duration is not None: diff --git a/tests/otel_tests/test_e2e_budgeting.py b/tests/otel_tests/test_e2e_budgeting.py index 2d852a4e1c..227d049f82 100644 --- a/tests/otel_tests/test_e2e_budgeting.py +++ b/tests/otel_tests/test_e2e_budgeting.py @@ -2,6 +2,7 @@ import pytest import asyncio import aiohttp import json +from httpx import AsyncClient async def make_calls_until_budget_exceeded(session, key: str, call_function, **kwargs): @@ -206,3 +207,74 @@ async def test_chat_completion_budget_update(): pytest.fail( f"Request should succeed after budget update but got error: {e}" ) + + +@pytest.mark.parametrize( + "field", + [ + "max_budget", + "rpm_limit", + "tpm_limit", + ], +) +@pytest.mark.asyncio +async def test_key_limit_modifications(field): + # Create initial key + client = AsyncClient(base_url="http://0.0.0.0:4000") + key_data = {"max_budget": None, "rpm_limit": None, "tpm_limit": None} + headers = {"Authorization": "Bearer sk-1234"} + response = await client.post("/key/generate", json=key_data, headers=headers) + assert response.status_code == 200 + generate_key_response = response.json() + print("generate_key_response: ", json.dumps(generate_key_response, indent=4)) + key_id = generate_key_response["key"] + + # Update key with any non-null value for the field + update_data = {"key": key_id} + update_data[field] = 10 # Any non-null value works + print("update_data: ", json.dumps(update_data, indent=4)) + response = await client.post(f"/key/update", json=update_data, headers=headers) + assert response.status_code == 200 + assert response.json()[field] is not None + + # Reset limit to null + print(f"resetting {field} to null") + update_data[field] = None + response = await client.post(f"/key/update", json=update_data, headers=headers) + print("response: ", json.dumps(response.json(), indent=4)) + assert response.status_code == 200 + assert response.json()[field] is None + + +@pytest.mark.parametrize( + "field", + [ + "max_budget", + ], +) +@pytest.mark.asyncio +async def test_team_limit_modifications(field): + # Create initial team + client = AsyncClient(base_url="http://0.0.0.0:4000") + team_data = {"max_budget": None, "rpm_limit": None, "tpm_limit": None} + headers = {"Authorization": "Bearer sk-1234"} + response = await client.post("/team/new", json=team_data, headers=headers) + print("response: ", json.dumps(response.json(), indent=4)) + assert response.status_code == 200 + team_id = response.json()["team_id"] + + # Update team with any non-null value for the field + update_data = {"team_id": team_id} + update_data[field] = 10 # Any non-null value works + response = await client.post(f"/team/update", json=update_data, headers=headers) + print("response: ", json.dumps(response.json(), indent=4)) + assert response.status_code == 200 + assert response.json()["data"][field] is not None + + # Reset limit to null + print(f"resetting {field} to null") + update_data[field] = None + response = await client.post(f"/team/update", json=update_data, headers=headers) + print("response: ", json.dumps(response.json(), indent=4)) + assert response.status_code == 200 + assert response.json()["data"][field] is None