test(test_proxy_server.py): unit testing to make sure internal user params don't impact admin

This commit is contained in:
Krrish Dholakia 2024-08-08 17:59:30 -07:00
parent c919c15c4a
commit 856ede4a05
2 changed files with 28 additions and 8 deletions

View file

@ -87,12 +87,16 @@ async def new_user(
"user" # only create a user, don't create key if 'auto_create_key' set to False "user" # only create a user, don't create key if 'auto_create_key' set to False
) )
is_internal_user = False
if data.user_role == LitellmUserRoles.INTERNAL_USER:
is_internal_user = True
if "max_budget" in data_json and data_json["max_budget"] is None: if "max_budget" in data_json and data_json["max_budget"] is None:
if litellm.max_internal_user_budget is not None: if is_internal_user and litellm.max_internal_user_budget is not None:
data_json["max_budget"] = litellm.max_internal_user_budget data_json["max_budget"] = litellm.max_internal_user_budget
if "budget_duration" in data_json and data_json["budget_duration"] is None: if "budget_duration" in data_json and data_json["budget_duration"] is None:
if litellm.internal_user_budget_duration is not None: if is_internal_user and litellm.internal_user_budget_duration is not None:
data_json["budget_duration"] = litellm.internal_user_budget_duration data_json["budget_duration"] = litellm.internal_user_budget_duration
response = await generate_key_helper_fn(request_type="user", **data_json) response = await generate_key_helper_fn(request_type="user", **data_json)

View file

@ -811,15 +811,22 @@ from litellm.proxy.management_endpoints.internal_user_endpoints import new_user
from litellm.tests.test_key_generate_prisma import prisma_client from litellm.tests.test_key_generate_prisma import prisma_client
@pytest.mark.parametrize(
"user_role",
[LitellmUserRoles.INTERNAL_USER.value, LitellmUserRoles.PROXY_ADMIN.value],
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_user_default_budget(prisma_client): async def test_create_user_default_budget(prisma_client, user_role):
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
setattr(litellm, "max_internal_user_budget", 10) setattr(litellm, "max_internal_user_budget", 10)
setattr(litellm, "internal_user_budget_duration", "5m")
await litellm.proxy.proxy_server.prisma_client.connect() await litellm.proxy.proxy_server.prisma_client.connect()
user = f"ishaan {uuid.uuid4().hex}" user = f"ishaan {uuid.uuid4().hex}"
request = NewUserRequest(user_id=user) # create a key with no budget request = NewUserRequest(
user_id=user, user_role=user_role
) # create a key with no budget
with patch.object( with patch.object(
litellm.proxy.proxy_server.prisma_client, "insert_data", new=AsyncMock() litellm.proxy.proxy_server.prisma_client, "insert_data", new=AsyncMock()
) as mock_client: ) as mock_client:
@ -832,7 +839,16 @@ async def test_create_user_default_budget(prisma_client):
print(f"mock_client.call_args: {mock_client.call_args}") print(f"mock_client.call_args: {mock_client.call_args}")
print("mock_client.call_args.kwargs: {}".format(mock_client.call_args.kwargs)) print("mock_client.call_args.kwargs: {}".format(mock_client.call_args.kwargs))
assert ( if user_role == LitellmUserRoles.INTERNAL_USER.value:
mock_client.call_args.kwargs["data"]["max_budget"] assert (
== litellm.max_internal_user_budget mock_client.call_args.kwargs["data"]["max_budget"]
) == litellm.max_internal_user_budget
)
assert (
mock_client.call_args.kwargs["data"]["budget_duration"]
== litellm.internal_user_budget_duration
)
else:
assert mock_client.call_args.kwargs["data"]["max_budget"] is None
assert mock_client.call_args.kwargs["data"]["budget_duration"] is None