From 575afa8029b1c475874a5ea52f7938537960fc35 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 10 Aug 2024 16:36:43 -0700 Subject: [PATCH] fix(internal_user_endpoints.py): return all teams if user is admin --- .../internal_user_endpoints.py | 8 ++------ tests/test_team.py | 15 +++++++++++++++ tests/test_users.py | 16 ++++++++++++++-- 3 files changed, 31 insertions(+), 8 deletions(-) diff --git a/litellm/proxy/management_endpoints/internal_user_endpoints.py b/litellm/proxy/management_endpoints/internal_user_endpoints.py index 22faec3be..bced1851e 100644 --- a/litellm/proxy/management_endpoints/internal_user_endpoints.py +++ b/litellm/proxy/management_endpoints/internal_user_endpoints.py @@ -316,7 +316,7 @@ async def user_info( ## GET USER ROW ## if user_id is not None: user_info = await prisma_client.get_data(user_id=user_id) - elif view_all == True: + elif view_all is True: if page is None: page = 0 if page_size is None: @@ -364,11 +364,7 @@ async def user_info( getattr(caller_user_info, "user_role", None) == LitellmUserRoles.PROXY_ADMIN ): - teams_2 = await prisma_client.get_data( - table_name="team", - query_type="find_all", - team_id_list=None, - ) + teams_2 = await prisma_client.db.litellm_teamtable.find_many() else: teams_2 = await prisma_client.get_data( team_id_list=caller_user_info.teams, diff --git a/tests/test_team.py b/tests/test_team.py index 544273c2e..45bd7138f 100644 --- a/tests/test_team.py +++ b/tests/test_team.py @@ -251,6 +251,21 @@ async def delete_team( return await response.json() +async def list_teams( + session, + i, +): + url = "http://0.0.0.0:4000/team/list" + headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"} + + async with session.get(url, headers=headers) as response: + status = response.status + if status != 200: + raise Exception(f"Request {i} did not return a 200 status code: {status}") + + return await response.json() + + @pytest.mark.asyncio async def test_team_new(): """ diff --git a/tests/test_users.py b/tests/test_users.py index a73d7163e..632dd8f36 100644 --- a/tests/test_users.py +++ b/tests/test_users.py @@ -5,6 +5,8 @@ import asyncio import aiohttp import time from openai import AsyncOpenAI +from test_team import list_teams +from typing import Optional async def new_user(session, i, user_id=None, budget=None, budget_duration=None): @@ -45,11 +47,14 @@ async def test_user_new(): await asyncio.gather(*tasks) -async def get_user_info(session, get_user, call_user): +async def get_user_info(session, get_user, call_user, view_all: Optional[bool] = None): """ Make sure only models user has access to are returned """ - url = f"http://0.0.0.0:4000/user/info?user_id={get_user}" + if view_all is True: + url = "http://0.0.0.0:4000/user/info" + else: + url = f"http://0.0.0.0:4000/user/info?user_id={get_user}" headers = { "Authorization": f"Bearer {call_user}", "Content-Type": "application/json", @@ -94,6 +99,13 @@ async def test_user_info(): ) assert status == 403 + ## check if returned teams as admin == all teams ## + admin_info = await get_user_info( + session=session, get_user="", call_user="sk-1234", view_all=True + ) + all_teams = await list_teams(session=session, i=0) + assert len(admin_info["teams"]) == len(all_teams) + @pytest.mark.asyncio async def test_user_update():