fix(management_helpers/utils.py): use user_default max_budget, budget duration on new user upsert during team member add

Fixes https://github.com/BerriAI/litellm/issues/5106
This commit is contained in:
Krrish Dholakia 2024-08-08 19:14:43 -07:00
parent 6af9d9d2b3
commit 1d39c0fb7d
3 changed files with 109 additions and 8 deletions

View file

@ -804,10 +804,16 @@ async def test_get_team_redis(client_no_auth):
import random
import uuid
from unittest.mock import AsyncMock, MagicMock, patch
from unittest.mock import AsyncMock, MagicMock, PropertyMock, patch
from litellm.proxy._types import LitellmUserRoles, NewUserRequest, UserAPIKeyAuth
from litellm.proxy._types import (
LitellmUserRoles,
NewUserRequest,
TeamMemberAddRequest,
UserAPIKeyAuth,
)
from litellm.proxy.management_endpoints.internal_user_endpoints import new_user
from litellm.proxy.management_endpoints.team_endpoints import team_member_add
from litellm.tests.test_key_generate_prisma import prisma_client
@ -852,3 +858,67 @@ async def test_create_user_default_budget(prisma_client, user_role):
else:
assert mock_client.call_args.kwargs["data"]["max_budget"] is None
assert mock_client.call_args.kwargs["data"]["budget_duration"] is None
@pytest.mark.parametrize("new_member_method", ["user_id", "user_email"])
@pytest.mark.asyncio
async def test_create_team_member_add(prisma_client, new_member_method):
import time
from litellm.proxy._types import LiteLLM_TeamTableCachedObj
from litellm.proxy.proxy_server import hash_token, user_api_key_cache
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
setattr(litellm, "max_internal_user_budget", 10)
setattr(litellm, "internal_user_budget_duration", "5m")
await litellm.proxy.proxy_server.prisma_client.connect()
user = f"ishaan {uuid.uuid4().hex}"
_team_id = "litellm-test-client-id-new"
team_obj = LiteLLM_TeamTableCachedObj(
team_id=_team_id,
blocked=False,
last_refreshed_at=time.time(),
metadata={"guardrails": {"modify_guardrails": False}},
)
# user_api_key_cache.set_cache(key=hash_token(user_key), value=valid_token)
user_api_key_cache.set_cache(key="team_id:{}".format(_team_id), value=team_obj)
setattr(litellm.proxy.proxy_server, "user_api_key_cache", user_api_key_cache)
if new_member_method == "user_id":
data = {
"team_id": _team_id,
"member": [{"role": "user", "user_id": user}],
}
elif new_member_method == "user_email":
data = {
"team_id": _team_id,
"member": [{"role": "user", "user_email": user}],
}
team_member_add_request = TeamMemberAddRequest(**data)
with patch(
"litellm.proxy.proxy_server.prisma_client.db.litellm_usertable",
new_callable=AsyncMock,
) as mock_litellm_usertable:
mock_client = AsyncMock()
mock_litellm_usertable.upsert = mock_client
mock_litellm_usertable.find_many = AsyncMock(return_value=None)
await team_member_add(
data=team_member_add_request, user_api_key_dict=UserAPIKeyAuth()
)
mock_client.assert_called()
print(f"mock_client.call_args: {mock_client.call_args}")
print("mock_client.call_args.kwargs: {}".format(mock_client.call_args.kwargs))
assert (
mock_client.call_args.kwargs["data"]["create"]["max_budget"]
== litellm.max_internal_user_budget
)
assert (
mock_client.call_args.kwargs["data"]["create"]["budget_duration"]
== litellm.internal_user_budget_duration
)