diff --git a/tests/test_end_users.py b/tests/test_end_users.py index 83bffdc12..9c8b59753 100644 --- a/tests/test_end_users.py +++ b/tests/test_end_users.py @@ -99,7 +99,12 @@ async def generate_key( async def new_end_user( - session, i, user_id=str(uuid.uuid4()), model_region=None, default_model=None + session, + i, + user_id=str(uuid.uuid4()), + model_region=None, + default_model=None, + budget_id=None, ): url = "http://0.0.0.0:4000/end_user/new" headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"} @@ -109,6 +114,10 @@ async def new_end_user( "default_model": default_model, } + if budget_id is not None: + data["budget_id"] = budget_id + print("end user data: {}".format(data)) + async with session.post(url, headers=headers, json=data) as response: status = response.status response_text = await response.text() @@ -123,6 +132,23 @@ async def new_end_user( return await response.json() +async def new_budget(session, i, budget_id=None): + url = "http://0.0.0.0:4000/budget/new" + headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"} + data = { + "budget_id": budget_id, + "tpm_limit": 2, + } + + async with session.post(url, headers=headers, json=data) as response: + status = response.status + response_text = await response.text() + + print(f"Response {i} (Status code: {status}):") + print(response_text) + print() + + @pytest.mark.asyncio async def test_end_user_new(): """ @@ -170,3 +196,61 @@ async def test_end_user_specific_region(): ) assert result.headers.get("x-litellm-model-region") == "eu" + + +@pytest.mark.asyncio +async def test_end_tpm_limits(): + """ + 1. budget_id = Create Budget with tpm_limit = 10 + 2. create end_user with budget_id + 3. Make /chat/completions calls + 4. Sleep 1 second + 4. Make /chat/completions call -> expect this to fail because rate limit hit + """ + async with aiohttp.ClientSession() as session: + # create a budget with budget_id = "free-tier" + budget_id = f"free-tier-{uuid.uuid4()}" + await new_budget(session, 0, budget_id=budget_id) + await asyncio.sleep(2) + + end_user_id = str(uuid.uuid4()) + + await new_end_user( + session=session, i=0, user_id=end_user_id, budget_id=budget_id + ) + + ## MAKE CALL ## + key_gen = await generate_key(session=session, i=0, models=[]) + + key = key_gen["key"] + + # chat completion 1 + client = AsyncOpenAI(api_key=key, base_url="http://0.0.0.0:4000") + + result = await client.chat.completions.create( + model="fake-openai-endpoint", + messages=[{"role": "user", "content": "Hey!"}], + user=end_user_id, + ) + + print("\nchat completion result 1=", result) + + await asyncio.sleep(1) + + # chat completion 2 + try: + result = await client.chat.completions.create( + model="fake-openai-endpoint", + messages=[{"role": "user", "content": "Hey!"}], + user=end_user_id, + ) + pytest.fail( + "User crossed their limit - this should have failed. instead got result = {}".format( + result + ) + ) + except Exception as e: + print("got exception 2 =", e) + assert "Crossed TPM, RPM Limit" in str( + e + ), f"Expected 'Crossed TPM, RPM Limit' but got {str(e)}"