refactor(team_endpoints.py): refactor auth checks for team member endpoints to ui team admin to manage it

This commit is contained in:
Krrish Dholakia 2024-08-20 16:57:18 -07:00
parent 19083a4d31
commit a61f3e7656
5 changed files with 72 additions and 39 deletions

View file

@ -342,10 +342,10 @@ class LiteLLMRoutes(enum.Enum):
+ sso_only_routes + sso_only_routes
) )
team_admin_routes: List = [ self_managed_routes: List = [
"/team/member_add", "/team/member_add",
"/team/member_delete", "/team/member_delete",
] + internal_user_routes ] # routes that manage their own allowed/disallowed logic
# class LiteLLMAllowedRoutes(LiteLLMBase): # class LiteLLMAllowedRoutes(LiteLLMBase):

View file

@ -975,8 +975,6 @@ async def user_api_key_auth(
if not _is_user_proxy_admin(user_obj=user_obj): # if non-admin if not _is_user_proxy_admin(user_obj=user_obj): # if non-admin
if is_llm_api_route(route=route): if is_llm_api_route(route=route):
pass pass
elif is_llm_api_route(route=route):
pass
elif ( elif (
route in LiteLLMRoutes.info_routes.value route in LiteLLMRoutes.info_routes.value
): # check if user allowed to call an info route ): # check if user allowed to call an info route
@ -1053,9 +1051,8 @@ async def user_api_key_auth(
): ):
pass pass
elif ( elif (
_is_user_team_admin(user_api_key_dict=valid_token) route in LiteLLMRoutes.self_managed_routes.value
and route in LiteLLMRoutes.team_admin_routes.value ): # routes that manage their own allowed/disallowed logic
):
pass pass
else: else:
user_role = "unknown" user_role = "unknown"
@ -1332,13 +1329,3 @@ def get_api_key_from_custom_header(
f"No LiteLLM Virtual Key pass. Please set header={custom_litellm_key_header_name}: Bearer <api_key>" f"No LiteLLM Virtual Key pass. Please set header={custom_litellm_key_header_name}: Bearer <api_key>"
) )
return api_key return api_key
def _is_user_team_admin(user_api_key_dict: UserAPIKeyAuth) -> bool:
if user_api_key_dict.team_member is None:
return False
if user_api_key_dict.team_member.role == LiteLLMTeamRoles.TEAM_ADMIN.value:
return True
return False

View file

@ -849,7 +849,7 @@ async def generate_key_helper_fn(
} }
if ( if (
litellm.get_secret("DISABLE_KEY_NAME", False) == True litellm.get_secret("DISABLE_KEY_NAME", False) is True
): # allow user to disable storing abbreviated key name (shown in UI, to help figure out which key spent how much) ): # allow user to disable storing abbreviated key name (shown in UI, to help figure out which key spent how much)
pass pass
else: else:

View file

@ -30,7 +30,7 @@ from litellm.proxy._types import (
UpdateTeamRequest, UpdateTeamRequest,
UserAPIKeyAuth, UserAPIKeyAuth,
) )
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth from litellm.proxy.auth.user_api_key_auth import _is_user_proxy_admin, user_api_key_auth
from litellm.proxy.management_helpers.utils import ( from litellm.proxy.management_helpers.utils import (
add_new_member, add_new_member,
management_endpoint_wrapper, management_endpoint_wrapper,
@ -39,6 +39,16 @@ from litellm.proxy.management_helpers.utils import (
router = APIRouter() router = APIRouter()
def _is_user_team_admin(
user_api_key_dict: UserAPIKeyAuth, team_obj: LiteLLM_TeamTable
) -> bool:
for member in team_obj.members_with_roles:
if member.user_id is not None and member.user_id == user_api_key_dict.user_id:
return True
return False
#### TEAM MANAGEMENT #### #### TEAM MANAGEMENT ####
@router.post( @router.post(
"/team/new", "/team/new",
@ -466,6 +476,23 @@ async def team_member_add(
complete_team_data = LiteLLM_TeamTable(**existing_team_row.model_dump()) complete_team_data = LiteLLM_TeamTable(**existing_team_row.model_dump())
## CHECK IF USER IS PROXY ADMIN OR TEAM ADMIN
if (
user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN.value
and not _is_user_team_admin(
user_api_key_dict=user_api_key_dict, team_obj=complete_team_data
)
):
raise HTTPException(
status_code=403,
detail={
"error": "Call not allowed. User not proxy admin OR team admin. route={}, team_id={}".format(
"/team/member_add", complete_team_data.team_id
)
},
)
if isinstance(data.member, Member): if isinstance(data.member, Member):
# add to team db # add to team db
new_member = data.member new_member = data.member
@ -570,6 +597,23 @@ async def team_member_delete(
) )
existing_team_row = LiteLLM_TeamTable(**_existing_team_row.model_dump()) existing_team_row = LiteLLM_TeamTable(**_existing_team_row.model_dump())
## CHECK IF USER IS PROXY ADMIN OR TEAM ADMIN
if (
user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN.value
and not _is_user_team_admin(
user_api_key_dict=user_api_key_dict, team_obj=existing_team_row
)
):
raise HTTPException(
status_code=403,
detail={
"error": "Call not allowed. User not proxy admin OR team admin. route={}, team_id={}".format(
"/team/member_delete", existing_team_row.team_id
)
},
)
## DELETE MEMBER FROM TEAM ## DELETE MEMBER FROM TEAM
new_team_members: List[Member] = [] new_team_members: List[Member] = []
for m in existing_team_row.members_with_roles: for m in existing_team_row.members_with_roles:

View file

@ -989,22 +989,16 @@ async def test_create_team_member_add_team_admin_user_api_key_auth(
request._body = json_bytes request._body = json_bytes
try: ## ALLOWED BY USER_API_KEY_AUTH
await user_api_key_auth(request=request, api_key="Bearer " + user_key) await user_api_key_auth(request=request, api_key="Bearer " + user_key)
if team_member_role == "user":
pytest.fail(
"Expected this call to fail. User not allowed to access this route."
)
except ProxyException:
if team_member_role == "admin":
pytest.fail(
"Expected this call to succeed. Team admin allowed to access /team/member_add"
)
@pytest.mark.parametrize("new_member_method", ["user_id", "user_email"]) @pytest.mark.parametrize("new_member_method", ["user_id", "user_email"])
@pytest.mark.parametrize("user_role", ["admin", "user"])
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_team_member_add_team_admin(prisma_client, new_member_method): async def test_create_team_member_add_team_admin(
prisma_client, new_member_method, user_role
):
""" """
Relevant issue - https://github.com/BerriAI/litellm/issues/5300 Relevant issue - https://github.com/BerriAI/litellm/issues/5300
@ -1018,6 +1012,7 @@ async def test_create_team_member_add_team_admin(prisma_client, new_member_metho
from litellm.proxy._types import LiteLLM_TeamTableCachedObj, Member from litellm.proxy._types import LiteLLM_TeamTableCachedObj, Member
from litellm.proxy.proxy_server import ( from litellm.proxy.proxy_server import (
HTTPException,
ProxyException, ProxyException,
hash_token, hash_token,
user_api_key_auth, user_api_key_auth,
@ -1035,8 +1030,8 @@ async def test_create_team_member_add_team_admin(prisma_client, new_member_metho
valid_token = UserAPIKeyAuth( valid_token = UserAPIKeyAuth(
team_id=_team_id, team_id=_team_id,
user_id=user,
token=hash_token(user_key), token=hash_token(user_key),
team_member=Member(role="admin", user_id=user),
last_refreshed_at=time.time(), last_refreshed_at=time.time(),
) )
user_api_key_cache.set_cache(key=hash_token(user_key), value=valid_token) user_api_key_cache.set_cache(key=hash_token(user_key), value=valid_token)
@ -1045,6 +1040,7 @@ async def test_create_team_member_add_team_admin(prisma_client, new_member_metho
team_id=_team_id, team_id=_team_id,
blocked=False, blocked=False,
last_refreshed_at=time.time(), last_refreshed_at=time.time(),
members_with_roles=[Member(role=user_role, user_id=user)],
metadata={"guardrails": {"modify_guardrails": False}}, metadata={"guardrails": {"modify_guardrails": False}},
) )
@ -1071,13 +1067,19 @@ async def test_create_team_member_add_team_admin(prisma_client, new_member_metho
mock_litellm_usertable.upsert = mock_client mock_litellm_usertable.upsert = mock_client
mock_litellm_usertable.find_many = AsyncMock(return_value=None) mock_litellm_usertable.find_many = AsyncMock(return_value=None)
await team_member_add( try:
data=team_member_add_request, await team_member_add(
user_api_key_dict=valid_token, data=team_member_add_request,
http_request=Request( user_api_key_dict=valid_token,
scope={"type": "http", "path": "/user/new"}, http_request=Request(
), scope={"type": "http", "path": "/user/new"},
) ),
)
except HTTPException as e:
if user_role == "user":
assert e.status_code == 403
else:
raise e
mock_client.assert_called() mock_client.assert_called()