From fa6c9bf42ef8aba7b7d31d0e2384b0fcf29407a3 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Tue, 20 Aug 2024 14:01:12 -0700 Subject: [PATCH] feat(user_api_key_auth.py): allow team admin to add new members to team --- litellm/proxy/_types.py | 10 ++ litellm/proxy/auth/user_api_key_auth.py | 18 +- .../management_endpoints/team_endpoints.py | 1 + litellm/proxy/utils.py | 29 ++++ litellm/tests/test_proxy_server.py | 163 ++++++++++++++++++ 5 files changed, 220 insertions(+), 1 deletion(-) diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 75934ee1f..fa7780936 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -21,6 +21,13 @@ else: Span = Any +class LiteLLMTeamRoles(enum.Enum): + # team admin + TEAM_ADMIN = "admin" + # team member + TEAM_MEMBER = "user" + + class LitellmUserRoles(str, enum.Enum): """ Admin Roles: @@ -335,6 +342,8 @@ class LiteLLMRoutes(enum.Enum): + sso_only_routes ) + team_admin_routes: List = ["/team/member_add"] + internal_user_routes + # class LiteLLMAllowedRoutes(LiteLLMBase): # """ @@ -1308,6 +1317,7 @@ class LiteLLM_VerificationTokenView(LiteLLM_VerificationToken): soft_budget: Optional[float] = None team_model_aliases: Optional[Dict] = None team_member_spend: Optional[float] = None + team_member: Optional[Member] = None team_metadata: Optional[Dict] = None # End User Params diff --git a/litellm/proxy/auth/user_api_key_auth.py b/litellm/proxy/auth/user_api_key_auth.py index c980f47b5..d20ab54bd 100644 --- a/litellm/proxy/auth/user_api_key_auth.py +++ b/litellm/proxy/auth/user_api_key_auth.py @@ -975,7 +975,7 @@ async def user_api_key_auth( if not _is_user_proxy_admin(user_obj=user_obj): # if non-admin if is_llm_api_route(route=route): pass - elif is_llm_api_route(route=request["route"].name): + elif is_llm_api_route(route=route): pass elif ( route in LiteLLMRoutes.info_routes.value @@ -1046,11 +1046,17 @@ async def user_api_key_auth( status_code=status.HTTP_403_FORBIDDEN, detail=f"user not allowed to access this route, role= {_user_role}. Trying to access: {route}", ) + elif ( _user_role == LitellmUserRoles.INTERNAL_USER.value and route in LiteLLMRoutes.internal_user_routes.value ): pass + elif ( + _is_user_team_admin(user_api_key_dict=valid_token) + and route in LiteLLMRoutes.team_admin_routes.value + ): + pass else: user_role = "unknown" user_id = "unknown" @@ -1326,3 +1332,13 @@ def get_api_key_from_custom_header( f"No LiteLLM Virtual Key pass. Please set header={custom_litellm_key_header_name}: Bearer " ) return api_key + + +def _is_user_team_admin(user_api_key_dict: UserAPIKeyAuth) -> bool: + if user_api_key_dict.team_member is None: + return False + + if user_api_key_dict.team_member.role == LiteLLMTeamRoles.TEAM_ADMIN.value: + return True + + return False diff --git a/litellm/proxy/management_endpoints/team_endpoints.py b/litellm/proxy/management_endpoints/team_endpoints.py index 815ab308c..4b7af502d 100644 --- a/litellm/proxy/management_endpoints/team_endpoints.py +++ b/litellm/proxy/management_endpoints/team_endpoints.py @@ -417,6 +417,7 @@ async def team_member_add( If user doesn't exist, new user row will also be added to User Table + Only proxy_admin or admin of team, allowed to access this endpoint. ``` curl -X POST 'http://0.0.0.0:4000/team/member_add' \ diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index a2b09b4e6..a77017717 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -44,6 +44,7 @@ from litellm.proxy._types import ( DynamoDBArgs, LiteLLM_VerificationTokenView, LitellmUserRoles, + Member, ResetTeamBudgetRequest, SpendLogsMetadata, SpendLogsPayload, @@ -1395,6 +1396,7 @@ class PrismaClient: t.blocked AS team_blocked, t.team_alias AS team_alias, t.metadata AS team_metadata, + t.members_with_roles AS team_members_with_roles, tm.spend AS team_member_spend, m.aliases as team_model_aliases FROM "LiteLLM_VerificationToken" AS v @@ -1412,6 +1414,33 @@ class PrismaClient: response["team_models"] = [] if response["team_blocked"] is None: response["team_blocked"] = False + + team_member: Optional[Member] = None + if ( + response["team_members_with_roles"] is not None + and response["user_id"] is not None + ): + ## find the team member corresponding to user id + """ + [ + { + "role": "admin", + "user_id": "default_user_id", + "user_email": null + }, + { + "role": "user", + "user_id": null, + "user_email": "test@email.com" + } + ] + """ + for tm in response["team_members_with_roles"]: + if tm.get("user_id") is not None and response[ + "user_id" + ] == tm.get("user_id"): + team_member = Member(**tm) + response["team_member"] = team_member response = LiteLLM_VerificationTokenView( **response, last_refreshed_at=time.time() ) diff --git a/litellm/tests/test_proxy_server.py b/litellm/tests/test_proxy_server.py index 28f3aad63..7caf1ecbc 100644 --- a/litellm/tests/test_proxy_server.py +++ b/litellm/tests/test_proxy_server.py @@ -930,6 +930,169 @@ async def test_create_team_member_add(prisma_client, new_member_method): ) +@pytest.mark.parametrize("team_member_role", ["admin", "user"]) +@pytest.mark.asyncio +async def test_create_team_member_add_team_admin_user_api_key_auth( + prisma_client, team_member_role +): + import time + + from fastapi import Request + + from litellm.proxy._types import LiteLLM_TeamTableCachedObj, Member + from litellm.proxy.proxy_server import ( + ProxyException, + hash_token, + user_api_key_auth, + user_api_key_cache, + ) + + setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) + setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") + setattr(litellm, "max_internal_user_budget", 10) + setattr(litellm, "internal_user_budget_duration", "5m") + await litellm.proxy.proxy_server.prisma_client.connect() + user = f"ishaan {uuid.uuid4().hex}" + _team_id = "litellm-test-client-id-new" + user_key = "sk-12345678" + + valid_token = UserAPIKeyAuth( + team_id=_team_id, + token=hash_token(user_key), + team_member=Member(role=team_member_role, user_id=user), + last_refreshed_at=time.time(), + ) + user_api_key_cache.set_cache(key=hash_token(user_key), value=valid_token) + + team_obj = LiteLLM_TeamTableCachedObj( + team_id=_team_id, + blocked=False, + last_refreshed_at=time.time(), + metadata={"guardrails": {"modify_guardrails": False}}, + ) + + user_api_key_cache.set_cache(key="team_id:{}".format(_team_id), value=team_obj) + + setattr(litellm.proxy.proxy_server, "user_api_key_cache", user_api_key_cache) + + ## TEST IF TEAM ADMIN ALLOWED TO CALL /MEMBER_ADD ENDPOINT + import json + + from starlette.datastructures import URL + + request = Request(scope={"type": "http"}) + request._url = URL(url="/team/member_add") + + body = {} + json_bytes = json.dumps(body).encode("utf-8") + + request._body = json_bytes + + try: + await user_api_key_auth(request=request, api_key="Bearer " + user_key) + if team_member_role == "user": + pytest.fail( + "Expected this call to fail. User not allowed to access this route." + ) + except ProxyException: + if team_member_role == "admin": + pytest.fail( + "Expected this call to succeed. Team admin allowed to access /team/member_add" + ) + + +@pytest.mark.parametrize("new_member_method", ["user_id", "user_email"]) +@pytest.mark.asyncio +async def test_create_team_member_add_team_admin(prisma_client, new_member_method): + """ + Relevant issue - https://github.com/BerriAI/litellm/issues/5300 + + Allow team admins to: + - Add and remove team members + - raise error if team member not an existing 'internal_user' + """ + import time + + from fastapi import Request + + from litellm.proxy._types import LiteLLM_TeamTableCachedObj, Member + from litellm.proxy.proxy_server import ( + ProxyException, + hash_token, + user_api_key_auth, + user_api_key_cache, + ) + + setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) + setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") + setattr(litellm, "max_internal_user_budget", 10) + setattr(litellm, "internal_user_budget_duration", "5m") + await litellm.proxy.proxy_server.prisma_client.connect() + user = f"ishaan {uuid.uuid4().hex}" + _team_id = "litellm-test-client-id-new" + user_key = "sk-12345678" + + valid_token = UserAPIKeyAuth( + team_id=_team_id, + token=hash_token(user_key), + team_member=Member(role="admin", user_id=user), + last_refreshed_at=time.time(), + ) + user_api_key_cache.set_cache(key=hash_token(user_key), value=valid_token) + + team_obj = LiteLLM_TeamTableCachedObj( + team_id=_team_id, + blocked=False, + last_refreshed_at=time.time(), + metadata={"guardrails": {"modify_guardrails": False}}, + ) + + user_api_key_cache.set_cache(key="team_id:{}".format(_team_id), value=team_obj) + + setattr(litellm.proxy.proxy_server, "user_api_key_cache", user_api_key_cache) + if new_member_method == "user_id": + data = { + "team_id": _team_id, + "member": [{"role": "user", "user_id": user}], + } + elif new_member_method == "user_email": + data = { + "team_id": _team_id, + "member": [{"role": "user", "user_email": user}], + } + team_member_add_request = TeamMemberAddRequest(**data) + + with patch( + "litellm.proxy.proxy_server.prisma_client.db.litellm_usertable", + new_callable=AsyncMock, + ) as mock_litellm_usertable: + mock_client = AsyncMock() + mock_litellm_usertable.upsert = mock_client + mock_litellm_usertable.find_many = AsyncMock(return_value=None) + + await team_member_add( + data=team_member_add_request, + user_api_key_dict=valid_token, + http_request=Request( + scope={"type": "http", "path": "/user/new"}, + ), + ) + + mock_client.assert_called() + + print(f"mock_client.call_args: {mock_client.call_args}") + print("mock_client.call_args.kwargs: {}".format(mock_client.call_args.kwargs)) + + assert ( + mock_client.call_args.kwargs["data"]["create"]["max_budget"] + == litellm.max_internal_user_budget + ) + assert ( + mock_client.call_args.kwargs["data"]["create"]["budget_duration"] + == litellm.internal_user_budget_duration + ) + + @pytest.mark.asyncio async def test_user_info_team_list(prisma_client): """Assert user_info for admin calls team_list function"""