fix(internal_user_endpoints.py): return all teams if user is admin

This commit is contained in:
Krrish Dholakia 2024-08-10 16:36:43 -07:00
parent e67a239520
commit 575afa8029
3 changed files with 31 additions and 8 deletions

View file

@ -316,7 +316,7 @@ async def user_info(
## GET USER ROW ##
if user_id is not None:
user_info = await prisma_client.get_data(user_id=user_id)
elif view_all == True:
elif view_all is True:
if page is None:
page = 0
if page_size is None:
@ -364,11 +364,7 @@ async def user_info(
getattr(caller_user_info, "user_role", None)
== LitellmUserRoles.PROXY_ADMIN
):
teams_2 = await prisma_client.get_data(
table_name="team",
query_type="find_all",
team_id_list=None,
)
teams_2 = await prisma_client.db.litellm_teamtable.find_many()
else:
teams_2 = await prisma_client.get_data(
team_id_list=caller_user_info.teams,

View file

@ -251,6 +251,21 @@ async def delete_team(
return await response.json()
async def list_teams(
session,
i,
):
url = "http://0.0.0.0:4000/team/list"
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
async with session.get(url, headers=headers) as response:
status = response.status
if status != 200:
raise Exception(f"Request {i} did not return a 200 status code: {status}")
return await response.json()
@pytest.mark.asyncio
async def test_team_new():
"""

View file

@ -5,6 +5,8 @@ import asyncio
import aiohttp
import time
from openai import AsyncOpenAI
from test_team import list_teams
from typing import Optional
async def new_user(session, i, user_id=None, budget=None, budget_duration=None):
@ -45,11 +47,14 @@ async def test_user_new():
await asyncio.gather(*tasks)
async def get_user_info(session, get_user, call_user):
async def get_user_info(session, get_user, call_user, view_all: Optional[bool] = None):
"""
Make sure only models user has access to are returned
"""
url = f"http://0.0.0.0:4000/user/info?user_id={get_user}"
if view_all is True:
url = "http://0.0.0.0:4000/user/info"
else:
url = f"http://0.0.0.0:4000/user/info?user_id={get_user}"
headers = {
"Authorization": f"Bearer {call_user}",
"Content-Type": "application/json",
@ -94,6 +99,13 @@ async def test_user_info():
)
assert status == 403
## check if returned teams as admin == all teams ##
admin_info = await get_user_info(
session=session, get_user="", call_user="sk-1234", view_all=True
)
all_teams = await list_teams(session=session, i=0)
assert len(admin_info["teams"]) == len(all_teams)
@pytest.mark.asyncio
async def test_user_update():