Merge pull request #3842 from BerriAI/litellm_set_budget_dur

[Fix] Set budget_duration on `/team/new` and `/team/update`
This commit is contained in:
Ishaan Jaff 2024-05-25 16:31:39 -07:00 committed by GitHub
commit e4053b6732
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 116 additions and 3 deletions

View file

@ -561,7 +561,11 @@ class TeamBase(LiteLLMBase):
metadata: Optional[dict] = None
tpm_limit: Optional[int] = None
rpm_limit: Optional[int] = None
# Budget fields
max_budget: Optional[float] = None
budget_duration: Optional[str] = None
models: list = []
blocked: bool = False
@ -607,6 +611,7 @@ class UpdateTeamRequest(LiteLLMBase):
max_budget: Optional[float] = None
models: Optional[list] = None
blocked: Optional[bool] = None
budget_duration: Optional[str] = None
class DeleteTeamRequest(LiteLLMBase):

View file

@ -8040,12 +8040,15 @@ async def new_team(
## ADD TO TEAM TABLE
complete_team_data = LiteLLM_TeamTable(
**data.json(),
max_parallel_requests=user_api_key_dict.max_parallel_requests,
budget_duration=user_api_key_dict.budget_duration,
budget_reset_at=user_api_key_dict.budget_reset_at,
model_id=_model_id,
)
# If budget_duration is set, set `budget_reset_at`
if complete_team_data.budget_duration is not None:
duration_s = _duration_in_seconds(duration=complete_team_data.budget_duration)
reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s)
complete_team_data.budget_reset_at = reset_at
team_row = await prisma_client.insert_data(
data=complete_team_data.json(exclude_none=True), table_name="team"
)
@ -8124,6 +8127,15 @@ async def update_team(
)
updated_kv = data.json(exclude_none=True)
# Check budget_duration and budget_reset_at
if data.budget_duration is not None:
duration_s = _duration_in_seconds(duration=data.budget_duration)
reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s)
# set the budget_reset_at in DB
updated_kv["budget_reset_at"] = reset_at
team_row = await prisma_client.update_data(
update_key_values=updated_kv,
data=updated_kv,

View file

@ -50,8 +50,10 @@ from litellm.proxy.proxy_server import (
spend_key_fn,
view_spend_logs,
user_info,
team_info,
info_key_fn,
new_team,
update_team,
chat_completion,
completion,
embeddings,
@ -73,6 +75,7 @@ from litellm.proxy._types import (
UpdateKeyRequest,
GenerateKeyRequest,
NewTeamRequest,
UpdateTeamRequest,
UserAPIKeyAuth,
LiteLLM_UpperboundKeyGenerateParams,
)
@ -2137,3 +2140,96 @@ async def test_reset_spend_authentication(prisma_client):
"Tried to access route=/global/spend/reset, which is only for MASTER KEY"
in e.message
)
@pytest.mark.asyncio()
async def test_create_update_team(prisma_client):
"""
- Set max_budget, budget_duration, max_budget, tpm_limit, rpm_limit
- Assert response has correct values
- Update max_budget, budget_duration, max_budget, tpm_limit, rpm_limit
- Assert response has correct values
- Call team_info and assert response has correct values
"""
print("prisma client=", prisma_client)
master_key = "sk-1234"
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
setattr(litellm.proxy.proxy_server, "master_key", master_key)
import datetime
await litellm.proxy.proxy_server.prisma_client.connect()
from litellm.proxy.proxy_server import user_api_key_cache
_team_id = "test-team_{}".format(uuid.uuid4())
response = await new_team(
NewTeamRequest(
team_id=_team_id,
max_budget=20,
budget_duration="30d",
tpm_limit=20,
rpm_limit=20,
),
user_api_key_dict=UserAPIKeyAuth(
user_role="proxy_admin", api_key="sk-1234", user_id="1234"
),
)
print("RESPONSE from new_team", response)
assert response["team_id"] == _team_id
assert response["max_budget"] == 20
assert response["tpm_limit"] == 20
assert response["rpm_limit"] == 20
assert response["budget_duration"] == "30d"
assert response["budget_reset_at"] is not None and isinstance(
response["budget_reset_at"], datetime.datetime
)
# updating team budget duration and reset at
response = await update_team(
UpdateTeamRequest(
team_id=_team_id,
max_budget=30,
budget_duration="2d",
tpm_limit=30,
rpm_limit=30,
),
user_api_key_dict=UserAPIKeyAuth(
user_role="proxy_admin", api_key="sk-1234", user_id="1234"
),
)
print("RESPONSE from update_team", response)
_updated_info = response["data"]
_updated_info = dict(_updated_info)
assert _updated_info["team_id"] == _team_id
assert _updated_info["max_budget"] == 30
assert _updated_info["tpm_limit"] == 30
assert _updated_info["rpm_limit"] == 30
assert _updated_info["budget_duration"] == "2d"
assert _updated_info["budget_reset_at"] is not None and isinstance(
_updated_info["budget_reset_at"], datetime.datetime
)
# now hit team_info
response = await team_info(team_id=_team_id)
print("RESPONSE from team_info", response)
_team_info = response["team_info"]
_team_info = dict(_team_info)
assert _team_info["team_id"] == _team_id
assert _team_info["max_budget"] == 30
assert _team_info["tpm_limit"] == 30
assert _team_info["rpm_limit"] == 30
assert _team_info["budget_duration"] == "2d"
assert _team_info["budget_reset_at"] is not None and isinstance(
_team_info["budget_reset_at"], datetime.datetime
)