feat(user_api_key_auth.py): allow team admin to add new members to team

This commit is contained in:
Krrish Dholakia 2024-08-20 14:01:12 -07:00
parent 16d09b1dd3
commit fa6c9bf42e
5 changed files with 220 additions and 1 deletions

View file

@ -21,6 +21,13 @@ else:
Span = Any
class LiteLLMTeamRoles(enum.Enum):
# team admin
TEAM_ADMIN = "admin"
# team member
TEAM_MEMBER = "user"
class LitellmUserRoles(str, enum.Enum):
"""
Admin Roles:
@ -335,6 +342,8 @@ class LiteLLMRoutes(enum.Enum):
+ sso_only_routes
)
team_admin_routes: List = ["/team/member_add"] + internal_user_routes
# class LiteLLMAllowedRoutes(LiteLLMBase):
# """
@ -1308,6 +1317,7 @@ class LiteLLM_VerificationTokenView(LiteLLM_VerificationToken):
soft_budget: Optional[float] = None
team_model_aliases: Optional[Dict] = None
team_member_spend: Optional[float] = None
team_member: Optional[Member] = None
team_metadata: Optional[Dict] = None
# End User Params

View file

@ -975,7 +975,7 @@ async def user_api_key_auth(
if not _is_user_proxy_admin(user_obj=user_obj): # if non-admin
if is_llm_api_route(route=route):
pass
elif is_llm_api_route(route=request["route"].name):
elif is_llm_api_route(route=route):
pass
elif (
route in LiteLLMRoutes.info_routes.value
@ -1046,11 +1046,17 @@ async def user_api_key_auth(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"user not allowed to access this route, role= {_user_role}. Trying to access: {route}",
)
elif (
_user_role == LitellmUserRoles.INTERNAL_USER.value
and route in LiteLLMRoutes.internal_user_routes.value
):
pass
elif (
_is_user_team_admin(user_api_key_dict=valid_token)
and route in LiteLLMRoutes.team_admin_routes.value
):
pass
else:
user_role = "unknown"
user_id = "unknown"
@ -1326,3 +1332,13 @@ def get_api_key_from_custom_header(
f"No LiteLLM Virtual Key pass. Please set header={custom_litellm_key_header_name}: Bearer <api_key>"
)
return api_key
def _is_user_team_admin(user_api_key_dict: UserAPIKeyAuth) -> bool:
if user_api_key_dict.team_member is None:
return False
if user_api_key_dict.team_member.role == LiteLLMTeamRoles.TEAM_ADMIN.value:
return True
return False

View file

@ -417,6 +417,7 @@ async def team_member_add(
If user doesn't exist, new user row will also be added to User Table
Only proxy_admin or admin of team, allowed to access this endpoint.
```
curl -X POST 'http://0.0.0.0:4000/team/member_add' \

View file

@ -44,6 +44,7 @@ from litellm.proxy._types import (
DynamoDBArgs,
LiteLLM_VerificationTokenView,
LitellmUserRoles,
Member,
ResetTeamBudgetRequest,
SpendLogsMetadata,
SpendLogsPayload,
@ -1395,6 +1396,7 @@ class PrismaClient:
t.blocked AS team_blocked,
t.team_alias AS team_alias,
t.metadata AS team_metadata,
t.members_with_roles AS team_members_with_roles,
tm.spend AS team_member_spend,
m.aliases as team_model_aliases
FROM "LiteLLM_VerificationToken" AS v
@ -1412,6 +1414,33 @@ class PrismaClient:
response["team_models"] = []
if response["team_blocked"] is None:
response["team_blocked"] = False
team_member: Optional[Member] = None
if (
response["team_members_with_roles"] is not None
and response["user_id"] is not None
):
## find the team member corresponding to user id
"""
[
{
"role": "admin",
"user_id": "default_user_id",
"user_email": null
},
{
"role": "user",
"user_id": null,
"user_email": "test@email.com"
}
]
"""
for tm in response["team_members_with_roles"]:
if tm.get("user_id") is not None and response[
"user_id"
] == tm.get("user_id"):
team_member = Member(**tm)
response["team_member"] = team_member
response = LiteLLM_VerificationTokenView(
**response, last_refreshed_at=time.time()
)

View file

@ -930,6 +930,169 @@ async def test_create_team_member_add(prisma_client, new_member_method):
)
@pytest.mark.parametrize("team_member_role", ["admin", "user"])
@pytest.mark.asyncio
async def test_create_team_member_add_team_admin_user_api_key_auth(
prisma_client, team_member_role
):
import time
from fastapi import Request
from litellm.proxy._types import LiteLLM_TeamTableCachedObj, Member
from litellm.proxy.proxy_server import (
ProxyException,
hash_token,
user_api_key_auth,
user_api_key_cache,
)
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
setattr(litellm, "max_internal_user_budget", 10)
setattr(litellm, "internal_user_budget_duration", "5m")
await litellm.proxy.proxy_server.prisma_client.connect()
user = f"ishaan {uuid.uuid4().hex}"
_team_id = "litellm-test-client-id-new"
user_key = "sk-12345678"
valid_token = UserAPIKeyAuth(
team_id=_team_id,
token=hash_token(user_key),
team_member=Member(role=team_member_role, user_id=user),
last_refreshed_at=time.time(),
)
user_api_key_cache.set_cache(key=hash_token(user_key), value=valid_token)
team_obj = LiteLLM_TeamTableCachedObj(
team_id=_team_id,
blocked=False,
last_refreshed_at=time.time(),
metadata={"guardrails": {"modify_guardrails": False}},
)
user_api_key_cache.set_cache(key="team_id:{}".format(_team_id), value=team_obj)
setattr(litellm.proxy.proxy_server, "user_api_key_cache", user_api_key_cache)
## TEST IF TEAM ADMIN ALLOWED TO CALL /MEMBER_ADD ENDPOINT
import json
from starlette.datastructures import URL
request = Request(scope={"type": "http"})
request._url = URL(url="/team/member_add")
body = {}
json_bytes = json.dumps(body).encode("utf-8")
request._body = json_bytes
try:
await user_api_key_auth(request=request, api_key="Bearer " + user_key)
if team_member_role == "user":
pytest.fail(
"Expected this call to fail. User not allowed to access this route."
)
except ProxyException:
if team_member_role == "admin":
pytest.fail(
"Expected this call to succeed. Team admin allowed to access /team/member_add"
)
@pytest.mark.parametrize("new_member_method", ["user_id", "user_email"])
@pytest.mark.asyncio
async def test_create_team_member_add_team_admin(prisma_client, new_member_method):
"""
Relevant issue - https://github.com/BerriAI/litellm/issues/5300
Allow team admins to:
- Add and remove team members
- raise error if team member not an existing 'internal_user'
"""
import time
from fastapi import Request
from litellm.proxy._types import LiteLLM_TeamTableCachedObj, Member
from litellm.proxy.proxy_server import (
ProxyException,
hash_token,
user_api_key_auth,
user_api_key_cache,
)
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
setattr(litellm, "max_internal_user_budget", 10)
setattr(litellm, "internal_user_budget_duration", "5m")
await litellm.proxy.proxy_server.prisma_client.connect()
user = f"ishaan {uuid.uuid4().hex}"
_team_id = "litellm-test-client-id-new"
user_key = "sk-12345678"
valid_token = UserAPIKeyAuth(
team_id=_team_id,
token=hash_token(user_key),
team_member=Member(role="admin", user_id=user),
last_refreshed_at=time.time(),
)
user_api_key_cache.set_cache(key=hash_token(user_key), value=valid_token)
team_obj = LiteLLM_TeamTableCachedObj(
team_id=_team_id,
blocked=False,
last_refreshed_at=time.time(),
metadata={"guardrails": {"modify_guardrails": False}},
)
user_api_key_cache.set_cache(key="team_id:{}".format(_team_id), value=team_obj)
setattr(litellm.proxy.proxy_server, "user_api_key_cache", user_api_key_cache)
if new_member_method == "user_id":
data = {
"team_id": _team_id,
"member": [{"role": "user", "user_id": user}],
}
elif new_member_method == "user_email":
data = {
"team_id": _team_id,
"member": [{"role": "user", "user_email": user}],
}
team_member_add_request = TeamMemberAddRequest(**data)
with patch(
"litellm.proxy.proxy_server.prisma_client.db.litellm_usertable",
new_callable=AsyncMock,
) as mock_litellm_usertable:
mock_client = AsyncMock()
mock_litellm_usertable.upsert = mock_client
mock_litellm_usertable.find_many = AsyncMock(return_value=None)
await team_member_add(
data=team_member_add_request,
user_api_key_dict=valid_token,
http_request=Request(
scope={"type": "http", "path": "/user/new"},
),
)
mock_client.assert_called()
print(f"mock_client.call_args: {mock_client.call_args}")
print("mock_client.call_args.kwargs: {}".format(mock_client.call_args.kwargs))
assert (
mock_client.call_args.kwargs["data"]["create"]["max_budget"]
== litellm.max_internal_user_budget
)
assert (
mock_client.call_args.kwargs["data"]["create"]["budget_duration"]
== litellm.internal_user_budget_duration
)
@pytest.mark.asyncio
async def test_user_info_team_list(prisma_client):
"""Assert user_info for admin calls team_list function"""