diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 680ec54940..103532f39d 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -561,7 +561,11 @@ class TeamBase(LiteLLMBase): metadata: Optional[dict] = None tpm_limit: Optional[int] = None rpm_limit: Optional[int] = None + + # Budget fields max_budget: Optional[float] = None + budget_duration: Optional[str] = None + models: list = [] blocked: bool = False @@ -607,6 +611,7 @@ class UpdateTeamRequest(LiteLLMBase): max_budget: Optional[float] = None models: Optional[list] = None blocked: Optional[bool] = None + budget_duration: Optional[str] = None class DeleteTeamRequest(LiteLLMBase): diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 25b5db03a4..ed7e999744 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -8040,12 +8040,15 @@ async def new_team( ## ADD TO TEAM TABLE complete_team_data = LiteLLM_TeamTable( **data.json(), - max_parallel_requests=user_api_key_dict.max_parallel_requests, - budget_duration=user_api_key_dict.budget_duration, - budget_reset_at=user_api_key_dict.budget_reset_at, model_id=_model_id, ) + # If budget_duration is set, set `budget_reset_at` + if complete_team_data.budget_duration is not None: + duration_s = _duration_in_seconds(duration=complete_team_data.budget_duration) + reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s) + complete_team_data.budget_reset_at = reset_at + team_row = await prisma_client.insert_data( data=complete_team_data.json(exclude_none=True), table_name="team" ) @@ -8124,6 +8127,15 @@ async def update_team( ) updated_kv = data.json(exclude_none=True) + + # Check budget_duration and budget_reset_at + if data.budget_duration is not None: + duration_s = _duration_in_seconds(duration=data.budget_duration) + reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s) + + # set the budget_reset_at in DB + updated_kv["budget_reset_at"] = reset_at + team_row = await prisma_client.update_data( update_key_values=updated_kv, data=updated_kv, diff --git a/litellm/tests/test_key_generate_prisma.py b/litellm/tests/test_key_generate_prisma.py index c1521e856c..528410ca63 100644 --- a/litellm/tests/test_key_generate_prisma.py +++ b/litellm/tests/test_key_generate_prisma.py @@ -50,8 +50,10 @@ from litellm.proxy.proxy_server import ( spend_key_fn, view_spend_logs, user_info, + team_info, info_key_fn, new_team, + update_team, chat_completion, completion, embeddings, @@ -73,6 +75,7 @@ from litellm.proxy._types import ( UpdateKeyRequest, GenerateKeyRequest, NewTeamRequest, + UpdateTeamRequest, UserAPIKeyAuth, LiteLLM_UpperboundKeyGenerateParams, ) @@ -2137,3 +2140,96 @@ async def test_reset_spend_authentication(prisma_client): "Tried to access route=/global/spend/reset, which is only for MASTER KEY" in e.message ) + + +@pytest.mark.asyncio() +async def test_create_update_team(prisma_client): + """ + - Set max_budget, budget_duration, max_budget, tpm_limit, rpm_limit + - Assert response has correct values + + - Update max_budget, budget_duration, max_budget, tpm_limit, rpm_limit + - Assert response has correct values + + - Call team_info and assert response has correct values + """ + print("prisma client=", prisma_client) + + master_key = "sk-1234" + + setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) + setattr(litellm.proxy.proxy_server, "master_key", master_key) + import datetime + + await litellm.proxy.proxy_server.prisma_client.connect() + from litellm.proxy.proxy_server import user_api_key_cache + + _team_id = "test-team_{}".format(uuid.uuid4()) + response = await new_team( + NewTeamRequest( + team_id=_team_id, + max_budget=20, + budget_duration="30d", + tpm_limit=20, + rpm_limit=20, + ), + user_api_key_dict=UserAPIKeyAuth( + user_role="proxy_admin", api_key="sk-1234", user_id="1234" + ), + ) + + print("RESPONSE from new_team", response) + + assert response["team_id"] == _team_id + assert response["max_budget"] == 20 + assert response["tpm_limit"] == 20 + assert response["rpm_limit"] == 20 + assert response["budget_duration"] == "30d" + assert response["budget_reset_at"] is not None and isinstance( + response["budget_reset_at"], datetime.datetime + ) + + # updating team budget duration and reset at + + response = await update_team( + UpdateTeamRequest( + team_id=_team_id, + max_budget=30, + budget_duration="2d", + tpm_limit=30, + rpm_limit=30, + ), + user_api_key_dict=UserAPIKeyAuth( + user_role="proxy_admin", api_key="sk-1234", user_id="1234" + ), + ) + + print("RESPONSE from update_team", response) + _updated_info = response["data"] + _updated_info = dict(_updated_info) + + assert _updated_info["team_id"] == _team_id + assert _updated_info["max_budget"] == 30 + assert _updated_info["tpm_limit"] == 30 + assert _updated_info["rpm_limit"] == 30 + assert _updated_info["budget_duration"] == "2d" + assert _updated_info["budget_reset_at"] is not None and isinstance( + _updated_info["budget_reset_at"], datetime.datetime + ) + + # now hit team_info + response = await team_info(team_id=_team_id) + + print("RESPONSE from team_info", response) + + _team_info = response["team_info"] + _team_info = dict(_team_info) + + assert _team_info["team_id"] == _team_id + assert _team_info["max_budget"] == 30 + assert _team_info["tpm_limit"] == 30 + assert _team_info["rpm_limit"] == 30 + assert _team_info["budget_duration"] == "2d" + assert _team_info["budget_reset_at"] is not None and isinstance( + _team_info["budget_reset_at"], datetime.datetime + )