refactor(test_users.py): refactor test for user info to use mock endpoints

This commit is contained in:
Krrish Dholakia 2024-08-12 18:47:25 -07:00
parent 66c0d32b1d
commit d1d28487f7
3 changed files with 47 additions and 9 deletions

View file

@ -312,7 +312,7 @@ async def user_info(
try:
if prisma_client is None:
raise Exception(
f"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys"
"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys"
)
## GET USER ROW ##
if user_id is not None:
@ -365,7 +365,14 @@ async def user_info(
getattr(caller_user_info, "user_role", None)
== LitellmUserRoles.PROXY_ADMIN
):
teams_2 = await prisma_client.db.litellm_teamtable.find_many()
from litellm.proxy.management_endpoints.team_endpoints import list_team
teams_2 = await list_team(
http_request=Request(
scope={"type": "http", "path": "/user/info"},
),
user_api_key_dict=user_api_key_dict,
)
else:
teams_2 = await prisma_client.get_data(
team_id_list=caller_user_info.teams,

View file

@ -928,3 +928,41 @@ async def test_create_team_member_add(prisma_client, new_member_method):
mock_client.call_args.kwargs["data"]["create"]["budget_duration"]
== litellm.internal_user_budget_duration
)
@pytest.mark.asyncio
async def test_user_info_team_list(prisma_client):
"""Assert user_info for admin calls team_list function"""
from litellm.proxy._types import LiteLLM_UserTable
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
await litellm.proxy.proxy_server.prisma_client.connect()
from litellm.proxy.management_endpoints.internal_user_endpoints import user_info
with patch(
"litellm.proxy.management_endpoints.team_endpoints.list_team",
new_callable=AsyncMock,
) as mock_client:
prisma_client.get_data = AsyncMock(
return_value=LiteLLM_UserTable(
user_role="proxy_admin",
user_id="default_user_id",
max_budget=None,
user_email="",
)
)
try:
await user_info(
user_id=None,
user_api_key_dict=UserAPIKeyAuth(
api_key="sk-1234", user_id="default_user_id"
),
)
except Exception:
pass
mock_client.assert_called()

View file

@ -99,13 +99,6 @@ async def test_user_info():
)
assert status == 403
## check if returned teams as admin == all teams ##
admin_info = await get_user_info(
session=session, get_user="", call_user="sk-1234", view_all=True
)
all_teams = await list_teams(session=session, i=0)
assert len(admin_info["teams"]) == len(all_teams)
@pytest.mark.asyncio
async def test_user_update():