diff --git a/litellm/proxy/management_endpoints/key_management_endpoints.py b/litellm/proxy/management_endpoints/key_management_endpoints.py index ba6df9a2e..f7d1658e1 100644 --- a/litellm/proxy/management_endpoints/key_management_endpoints.py +++ b/litellm/proxy/management_endpoints/key_management_endpoints.py @@ -476,6 +476,7 @@ def prepare_key_update_data( duration_s = duration_in_seconds(duration=budget_duration) key_reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s) non_default_values["budget_reset_at"] = key_reset_at + non_default_values["budget_duration"] = budget_duration _metadata = existing_key_row.metadata or {} diff --git a/tests/proxy_unit_tests/test_key_generate_prisma.py b/tests/proxy_unit_tests/test_key_generate_prisma.py index e6f8ca541..e34ed2b3e 100644 --- a/tests/proxy_unit_tests/test_key_generate_prisma.py +++ b/tests/proxy_unit_tests/test_key_generate_prisma.py @@ -23,7 +23,7 @@ import os import sys import traceback import uuid -from datetime import datetime +from datetime import datetime, timezone from dotenv import load_dotenv from fastapi import Request @@ -1305,6 +1305,8 @@ def test_generate_and_update_key(prisma_client): data=UpdateKeyRequest( key=generated_key, models=["ada", "babbage", "curie", "davinci"], + budget_duration="1mo", + max_budget=100, ), ) @@ -1333,6 +1335,27 @@ def test_generate_and_update_key(prisma_client): } assert result["info"]["models"] == ["ada", "babbage", "curie", "davinci"] assert result["info"]["team_id"] == _team_2 + assert result["info"]["budget_duration"] == "1mo" + assert result["info"]["max_budget"] == 100 + + # budget_reset_at should be 30 days from now + assert result["info"]["budget_reset_at"] is not None + budget_reset_at = result["info"]["budget_reset_at"].replace( + tzinfo=timezone.utc + ) + current_time = datetime.now(timezone.utc) + + print( + "days between now and budget_reset_at", + (budget_reset_at - current_time).days, + ) + # assert budget_reset_at is 30 days from now + assert ( + abs( + (budget_reset_at - current_time).total_seconds() - 30 * 24 * 60 * 60 + ) + <= 10 + ) # cleanup - delete key delete_key_request = KeyRequest(keys=[generated_key]) @@ -2613,6 +2636,15 @@ async def test_create_update_team(prisma_client): _updated_info["budget_reset_at"], datetime.datetime ) + # budget_reset_at should be 2 days from now + budget_reset_at = _updated_info["budget_reset_at"].replace(tzinfo=timezone.utc) + current_time = datetime.datetime.now(timezone.utc) + + # assert budget_reset_at is 2 days from now + assert ( + abs((budget_reset_at - current_time).total_seconds() - 2 * 24 * 60 * 60) <= 10 + ) + # now hit team_info try: response = await team_info( @@ -2756,6 +2788,56 @@ async def test_update_user_role(prisma_client): print("result from user auth with new key", result) +@pytest.mark.asyncio() +async def test_update_user_unit_test(prisma_client): + """ + Unit test for /user/update + + Ensure that params are updated for UpdateUserRequest + """ + setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) + setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") + await litellm.proxy.proxy_server.prisma_client.connect() + key = await new_user( + data=NewUserRequest( + user_email="test@test.com", + ) + ) + + print(key) + + user_info = await user_update( + data=UpdateUserRequest( + user_id=key.user_id, + team_id="1234", + max_budget=100, + budget_duration="10d", + tpm_limit=100, + rpm_limit=100, + metadata={"very-new-metadata": "something"}, + ) + ) + + print("user_info", user_info) + assert user_info is not None + _user_info = user_info["data"].model_dump() + + assert _user_info["user_id"] == key.user_id + assert _user_info["team_id"] == "1234" + assert _user_info["max_budget"] == 100 + assert _user_info["budget_duration"] == "10d" + assert _user_info["tpm_limit"] == 100 + assert _user_info["rpm_limit"] == 100 + assert _user_info["metadata"] == {"very-new-metadata": "something"} + + # budget reset at should be 10 days from now + budget_reset_at = _user_info["budget_reset_at"].replace(tzinfo=timezone.utc) + current_time = datetime.now(timezone.utc) + assert ( + abs((budget_reset_at - current_time).total_seconds() - 10 * 24 * 60 * 60) <= 10 + ) + + @pytest.mark.asyncio() async def test_custom_api_key_header_name(prisma_client): """ """