diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 25cf2f56d3..0177c21907 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -342,10 +342,10 @@ class LiteLLMRoutes(enum.Enum): + sso_only_routes ) - team_admin_routes: List = [ + self_managed_routes: List = [ "/team/member_add", "/team/member_delete", - ] + internal_user_routes + ] # routes that manage their own allowed/disallowed logic # class LiteLLMAllowedRoutes(LiteLLMBase): diff --git a/litellm/proxy/auth/user_api_key_auth.py b/litellm/proxy/auth/user_api_key_auth.py index d20ab54bd8..378dd84525 100644 --- a/litellm/proxy/auth/user_api_key_auth.py +++ b/litellm/proxy/auth/user_api_key_auth.py @@ -975,8 +975,6 @@ async def user_api_key_auth( if not _is_user_proxy_admin(user_obj=user_obj): # if non-admin if is_llm_api_route(route=route): pass - elif is_llm_api_route(route=route): - pass elif ( route in LiteLLMRoutes.info_routes.value ): # check if user allowed to call an info route @@ -1053,9 +1051,8 @@ async def user_api_key_auth( ): pass elif ( - _is_user_team_admin(user_api_key_dict=valid_token) - and route in LiteLLMRoutes.team_admin_routes.value - ): + route in LiteLLMRoutes.self_managed_routes.value + ): # routes that manage their own allowed/disallowed logic pass else: user_role = "unknown" @@ -1332,13 +1329,3 @@ def get_api_key_from_custom_header( f"No LiteLLM Virtual Key pass. Please set header={custom_litellm_key_header_name}: Bearer " ) return api_key - - -def _is_user_team_admin(user_api_key_dict: UserAPIKeyAuth) -> bool: - if user_api_key_dict.team_member is None: - return False - - if user_api_key_dict.team_member.role == LiteLLMTeamRoles.TEAM_ADMIN.value: - return True - - return False diff --git a/litellm/proxy/management_endpoints/key_management_endpoints.py b/litellm/proxy/management_endpoints/key_management_endpoints.py index 1758b416dd..2e16b533c8 100644 --- a/litellm/proxy/management_endpoints/key_management_endpoints.py +++ b/litellm/proxy/management_endpoints/key_management_endpoints.py @@ -849,7 +849,7 @@ async def generate_key_helper_fn( } if ( - litellm.get_secret("DISABLE_KEY_NAME", False) == True + litellm.get_secret("DISABLE_KEY_NAME", False) is True ): # allow user to disable storing abbreviated key name (shown in UI, to help figure out which key spent how much) pass else: diff --git a/litellm/proxy/management_endpoints/team_endpoints.py b/litellm/proxy/management_endpoints/team_endpoints.py index 4b7af502d8..8f8ad610a8 100644 --- a/litellm/proxy/management_endpoints/team_endpoints.py +++ b/litellm/proxy/management_endpoints/team_endpoints.py @@ -30,7 +30,7 @@ from litellm.proxy._types import ( UpdateTeamRequest, UserAPIKeyAuth, ) -from litellm.proxy.auth.user_api_key_auth import user_api_key_auth +from litellm.proxy.auth.user_api_key_auth import _is_user_proxy_admin, user_api_key_auth from litellm.proxy.management_helpers.utils import ( add_new_member, management_endpoint_wrapper, @@ -39,6 +39,16 @@ from litellm.proxy.management_helpers.utils import ( router = APIRouter() +def _is_user_team_admin( + user_api_key_dict: UserAPIKeyAuth, team_obj: LiteLLM_TeamTable +) -> bool: + for member in team_obj.members_with_roles: + if member.user_id is not None and member.user_id == user_api_key_dict.user_id: + return True + + return False + + #### TEAM MANAGEMENT #### @router.post( "/team/new", @@ -466,6 +476,23 @@ async def team_member_add( complete_team_data = LiteLLM_TeamTable(**existing_team_row.model_dump()) + ## CHECK IF USER IS PROXY ADMIN OR TEAM ADMIN + + if ( + user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN.value + and not _is_user_team_admin( + user_api_key_dict=user_api_key_dict, team_obj=complete_team_data + ) + ): + raise HTTPException( + status_code=403, + detail={ + "error": "Call not allowed. User not proxy admin OR team admin. route={}, team_id={}".format( + "/team/member_add", complete_team_data.team_id + ) + }, + ) + if isinstance(data.member, Member): # add to team db new_member = data.member @@ -570,6 +597,23 @@ async def team_member_delete( ) existing_team_row = LiteLLM_TeamTable(**_existing_team_row.model_dump()) + ## CHECK IF USER IS PROXY ADMIN OR TEAM ADMIN + + if ( + user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN.value + and not _is_user_team_admin( + user_api_key_dict=user_api_key_dict, team_obj=existing_team_row + ) + ): + raise HTTPException( + status_code=403, + detail={ + "error": "Call not allowed. User not proxy admin OR team admin. route={}, team_id={}".format( + "/team/member_delete", existing_team_row.team_id + ) + }, + ) + ## DELETE MEMBER FROM TEAM new_team_members: List[Member] = [] for m in existing_team_row.members_with_roles: diff --git a/litellm/tests/test_proxy_server.py b/litellm/tests/test_proxy_server.py index a1d6d9dee1..79da5bf7c5 100644 --- a/litellm/tests/test_proxy_server.py +++ b/litellm/tests/test_proxy_server.py @@ -989,22 +989,16 @@ async def test_create_team_member_add_team_admin_user_api_key_auth( request._body = json_bytes - try: - await user_api_key_auth(request=request, api_key="Bearer " + user_key) - if team_member_role == "user": - pytest.fail( - "Expected this call to fail. User not allowed to access this route." - ) - except ProxyException: - if team_member_role == "admin": - pytest.fail( - "Expected this call to succeed. Team admin allowed to access /team/member_add" - ) + ## ALLOWED BY USER_API_KEY_AUTH + await user_api_key_auth(request=request, api_key="Bearer " + user_key) @pytest.mark.parametrize("new_member_method", ["user_id", "user_email"]) +@pytest.mark.parametrize("user_role", ["admin", "user"]) @pytest.mark.asyncio -async def test_create_team_member_add_team_admin(prisma_client, new_member_method): +async def test_create_team_member_add_team_admin( + prisma_client, new_member_method, user_role +): """ Relevant issue - https://github.com/BerriAI/litellm/issues/5300 @@ -1018,6 +1012,7 @@ async def test_create_team_member_add_team_admin(prisma_client, new_member_metho from litellm.proxy._types import LiteLLM_TeamTableCachedObj, Member from litellm.proxy.proxy_server import ( + HTTPException, ProxyException, hash_token, user_api_key_auth, @@ -1035,8 +1030,8 @@ async def test_create_team_member_add_team_admin(prisma_client, new_member_metho valid_token = UserAPIKeyAuth( team_id=_team_id, + user_id=user, token=hash_token(user_key), - team_member=Member(role="admin", user_id=user), last_refreshed_at=time.time(), ) user_api_key_cache.set_cache(key=hash_token(user_key), value=valid_token) @@ -1045,6 +1040,7 @@ async def test_create_team_member_add_team_admin(prisma_client, new_member_metho team_id=_team_id, blocked=False, last_refreshed_at=time.time(), + members_with_roles=[Member(role=user_role, user_id=user)], metadata={"guardrails": {"modify_guardrails": False}}, ) @@ -1071,13 +1067,19 @@ async def test_create_team_member_add_team_admin(prisma_client, new_member_metho mock_litellm_usertable.upsert = mock_client mock_litellm_usertable.find_many = AsyncMock(return_value=None) - await team_member_add( - data=team_member_add_request, - user_api_key_dict=valid_token, - http_request=Request( - scope={"type": "http", "path": "/user/new"}, - ), - ) + try: + await team_member_add( + data=team_member_add_request, + user_api_key_dict=valid_token, + http_request=Request( + scope={"type": "http", "path": "/user/new"}, + ), + ) + except HTTPException as e: + if user_role == "user": + assert e.status_code == 403 + else: + raise e mock_client.assert_called()