mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
Merge pull request #5308 from BerriAI/litellm_team_admin_permissions
feat(user_api_key_auth.py): allow team admin to add new members to team
This commit is contained in:
commit
509ae0ca71
7 changed files with 263 additions and 5 deletions
|
@ -21,6 +21,13 @@ else:
|
||||||
Span = Any
|
Span = Any
|
||||||
|
|
||||||
|
|
||||||
|
class LiteLLMTeamRoles(enum.Enum):
|
||||||
|
# team admin
|
||||||
|
TEAM_ADMIN = "admin"
|
||||||
|
# team member
|
||||||
|
TEAM_MEMBER = "user"
|
||||||
|
|
||||||
|
|
||||||
class LitellmUserRoles(str, enum.Enum):
|
class LitellmUserRoles(str, enum.Enum):
|
||||||
"""
|
"""
|
||||||
Admin Roles:
|
Admin Roles:
|
||||||
|
@ -335,6 +342,11 @@ class LiteLLMRoutes(enum.Enum):
|
||||||
+ sso_only_routes
|
+ sso_only_routes
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self_managed_routes: List = [
|
||||||
|
"/team/member_add",
|
||||||
|
"/team/member_delete",
|
||||||
|
] # routes that manage their own allowed/disallowed logic
|
||||||
|
|
||||||
|
|
||||||
# class LiteLLMAllowedRoutes(LiteLLMBase):
|
# class LiteLLMAllowedRoutes(LiteLLMBase):
|
||||||
# """
|
# """
|
||||||
|
@ -1308,6 +1320,7 @@ class LiteLLM_VerificationTokenView(LiteLLM_VerificationToken):
|
||||||
soft_budget: Optional[float] = None
|
soft_budget: Optional[float] = None
|
||||||
team_model_aliases: Optional[Dict] = None
|
team_model_aliases: Optional[Dict] = None
|
||||||
team_member_spend: Optional[float] = None
|
team_member_spend: Optional[float] = None
|
||||||
|
team_member: Optional[Member] = None
|
||||||
team_metadata: Optional[Dict] = None
|
team_metadata: Optional[Dict] = None
|
||||||
|
|
||||||
# End User Params
|
# End User Params
|
||||||
|
|
|
@ -975,8 +975,6 @@ async def user_api_key_auth(
|
||||||
if not _is_user_proxy_admin(user_obj=user_obj): # if non-admin
|
if not _is_user_proxy_admin(user_obj=user_obj): # if non-admin
|
||||||
if is_llm_api_route(route=route):
|
if is_llm_api_route(route=route):
|
||||||
pass
|
pass
|
||||||
elif is_llm_api_route(route=request["route"].name):
|
|
||||||
pass
|
|
||||||
elif (
|
elif (
|
||||||
route in LiteLLMRoutes.info_routes.value
|
route in LiteLLMRoutes.info_routes.value
|
||||||
): # check if user allowed to call an info route
|
): # check if user allowed to call an info route
|
||||||
|
@ -1046,11 +1044,16 @@ async def user_api_key_auth(
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
detail=f"user not allowed to access this route, role= {_user_role}. Trying to access: {route}",
|
detail=f"user not allowed to access this route, role= {_user_role}. Trying to access: {route}",
|
||||||
)
|
)
|
||||||
|
|
||||||
elif (
|
elif (
|
||||||
_user_role == LitellmUserRoles.INTERNAL_USER.value
|
_user_role == LitellmUserRoles.INTERNAL_USER.value
|
||||||
and route in LiteLLMRoutes.internal_user_routes.value
|
and route in LiteLLMRoutes.internal_user_routes.value
|
||||||
):
|
):
|
||||||
pass
|
pass
|
||||||
|
elif (
|
||||||
|
route in LiteLLMRoutes.self_managed_routes.value
|
||||||
|
): # routes that manage their own allowed/disallowed logic
|
||||||
|
pass
|
||||||
else:
|
else:
|
||||||
user_role = "unknown"
|
user_role = "unknown"
|
||||||
user_id = "unknown"
|
user_id = "unknown"
|
||||||
|
|
|
@ -119,6 +119,7 @@ async def new_user(
|
||||||
http_request=Request(
|
http_request=Request(
|
||||||
scope={"type": "http", "path": "/user/new"},
|
scope={"type": "http", "path": "/user/new"},
|
||||||
),
|
),
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
)
|
)
|
||||||
|
|
||||||
if data.send_invite_email is True:
|
if data.send_invite_email is True:
|
||||||
|
|
|
@ -849,7 +849,7 @@ async def generate_key_helper_fn(
|
||||||
}
|
}
|
||||||
|
|
||||||
if (
|
if (
|
||||||
litellm.get_secret("DISABLE_KEY_NAME", False) == True
|
litellm.get_secret("DISABLE_KEY_NAME", False) is True
|
||||||
): # allow user to disable storing abbreviated key name (shown in UI, to help figure out which key spent how much)
|
): # allow user to disable storing abbreviated key name (shown in UI, to help figure out which key spent how much)
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -30,7 +30,7 @@ from litellm.proxy._types import (
|
||||||
UpdateTeamRequest,
|
UpdateTeamRequest,
|
||||||
UserAPIKeyAuth,
|
UserAPIKeyAuth,
|
||||||
)
|
)
|
||||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
from litellm.proxy.auth.user_api_key_auth import _is_user_proxy_admin, user_api_key_auth
|
||||||
from litellm.proxy.management_helpers.utils import (
|
from litellm.proxy.management_helpers.utils import (
|
||||||
add_new_member,
|
add_new_member,
|
||||||
management_endpoint_wrapper,
|
management_endpoint_wrapper,
|
||||||
|
@ -39,6 +39,16 @@ from litellm.proxy.management_helpers.utils import (
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
def _is_user_team_admin(
|
||||||
|
user_api_key_dict: UserAPIKeyAuth, team_obj: LiteLLM_TeamTable
|
||||||
|
) -> bool:
|
||||||
|
for member in team_obj.members_with_roles:
|
||||||
|
if member.user_id is not None and member.user_id == user_api_key_dict.user_id:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
#### TEAM MANAGEMENT ####
|
#### TEAM MANAGEMENT ####
|
||||||
@router.post(
|
@router.post(
|
||||||
"/team/new",
|
"/team/new",
|
||||||
|
@ -417,6 +427,7 @@ async def team_member_add(
|
||||||
|
|
||||||
If user doesn't exist, new user row will also be added to User Table
|
If user doesn't exist, new user row will also be added to User Table
|
||||||
|
|
||||||
|
Only proxy_admin or admin of team, allowed to access this endpoint.
|
||||||
```
|
```
|
||||||
|
|
||||||
curl -X POST 'http://0.0.0.0:4000/team/member_add' \
|
curl -X POST 'http://0.0.0.0:4000/team/member_add' \
|
||||||
|
@ -465,6 +476,24 @@ async def team_member_add(
|
||||||
|
|
||||||
complete_team_data = LiteLLM_TeamTable(**existing_team_row.model_dump())
|
complete_team_data = LiteLLM_TeamTable(**existing_team_row.model_dump())
|
||||||
|
|
||||||
|
## CHECK IF USER IS PROXY ADMIN OR TEAM ADMIN
|
||||||
|
|
||||||
|
if (
|
||||||
|
user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN.value
|
||||||
|
and not _is_user_team_admin(
|
||||||
|
user_api_key_dict=user_api_key_dict, team_obj=complete_team_data
|
||||||
|
)
|
||||||
|
):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=403,
|
||||||
|
detail={
|
||||||
|
"error": "Call not allowed. User not proxy admin OR team admin. route={}, team_id={}".format(
|
||||||
|
"/team/member_add",
|
||||||
|
complete_team_data.team_id,
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
if isinstance(data.member, Member):
|
if isinstance(data.member, Member):
|
||||||
# add to team db
|
# add to team db
|
||||||
new_member = data.member
|
new_member = data.member
|
||||||
|
@ -569,6 +598,23 @@ async def team_member_delete(
|
||||||
)
|
)
|
||||||
existing_team_row = LiteLLM_TeamTable(**_existing_team_row.model_dump())
|
existing_team_row = LiteLLM_TeamTable(**_existing_team_row.model_dump())
|
||||||
|
|
||||||
|
## CHECK IF USER IS PROXY ADMIN OR TEAM ADMIN
|
||||||
|
|
||||||
|
if (
|
||||||
|
user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN.value
|
||||||
|
and not _is_user_team_admin(
|
||||||
|
user_api_key_dict=user_api_key_dict, team_obj=existing_team_row
|
||||||
|
)
|
||||||
|
):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=403,
|
||||||
|
detail={
|
||||||
|
"error": "Call not allowed. User not proxy admin OR team admin. route={}, team_id={}".format(
|
||||||
|
"/team/member_delete", existing_team_row.team_id
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
## DELETE MEMBER FROM TEAM
|
## DELETE MEMBER FROM TEAM
|
||||||
new_team_members: List[Member] = []
|
new_team_members: List[Member] = []
|
||||||
for m in existing_team_row.members_with_roles:
|
for m in existing_team_row.members_with_roles:
|
||||||
|
|
|
@ -44,6 +44,7 @@ from litellm.proxy._types import (
|
||||||
DynamoDBArgs,
|
DynamoDBArgs,
|
||||||
LiteLLM_VerificationTokenView,
|
LiteLLM_VerificationTokenView,
|
||||||
LitellmUserRoles,
|
LitellmUserRoles,
|
||||||
|
Member,
|
||||||
ResetTeamBudgetRequest,
|
ResetTeamBudgetRequest,
|
||||||
SpendLogsMetadata,
|
SpendLogsMetadata,
|
||||||
SpendLogsPayload,
|
SpendLogsPayload,
|
||||||
|
@ -1395,6 +1396,7 @@ class PrismaClient:
|
||||||
t.blocked AS team_blocked,
|
t.blocked AS team_blocked,
|
||||||
t.team_alias AS team_alias,
|
t.team_alias AS team_alias,
|
||||||
t.metadata AS team_metadata,
|
t.metadata AS team_metadata,
|
||||||
|
t.members_with_roles AS team_members_with_roles,
|
||||||
tm.spend AS team_member_spend,
|
tm.spend AS team_member_spend,
|
||||||
m.aliases as team_model_aliases
|
m.aliases as team_model_aliases
|
||||||
FROM "LiteLLM_VerificationToken" AS v
|
FROM "LiteLLM_VerificationToken" AS v
|
||||||
|
@ -1412,6 +1414,33 @@ class PrismaClient:
|
||||||
response["team_models"] = []
|
response["team_models"] = []
|
||||||
if response["team_blocked"] is None:
|
if response["team_blocked"] is None:
|
||||||
response["team_blocked"] = False
|
response["team_blocked"] = False
|
||||||
|
|
||||||
|
team_member: Optional[Member] = None
|
||||||
|
if (
|
||||||
|
response["team_members_with_roles"] is not None
|
||||||
|
and response["user_id"] is not None
|
||||||
|
):
|
||||||
|
## find the team member corresponding to user id
|
||||||
|
"""
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"role": "admin",
|
||||||
|
"user_id": "default_user_id",
|
||||||
|
"user_email": null
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"user_id": null,
|
||||||
|
"user_email": "test@email.com"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
"""
|
||||||
|
for tm in response["team_members_with_roles"]:
|
||||||
|
if tm.get("user_id") is not None and response[
|
||||||
|
"user_id"
|
||||||
|
] == tm.get("user_id"):
|
||||||
|
team_member = Member(**tm)
|
||||||
|
response["team_member"] = team_member
|
||||||
response = LiteLLM_VerificationTokenView(
|
response = LiteLLM_VerificationTokenView(
|
||||||
**response, last_refreshed_at=time.time()
|
**response, last_refreshed_at=time.time()
|
||||||
)
|
)
|
||||||
|
|
|
@ -909,7 +909,7 @@ async def test_create_team_member_add(prisma_client, new_member_method):
|
||||||
|
|
||||||
await team_member_add(
|
await team_member_add(
|
||||||
data=team_member_add_request,
|
data=team_member_add_request,
|
||||||
user_api_key_dict=UserAPIKeyAuth(),
|
user_api_key_dict=UserAPIKeyAuth(user_role="proxy_admin"),
|
||||||
http_request=Request(
|
http_request=Request(
|
||||||
scope={"type": "http", "path": "/user/new"},
|
scope={"type": "http", "path": "/user/new"},
|
||||||
),
|
),
|
||||||
|
@ -930,6 +930,172 @@ async def test_create_team_member_add(prisma_client, new_member_method):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("team_member_role", ["admin", "user"])
|
||||||
|
@pytest.mark.parametrize("team_route", ["/team/member_add", "/team/member_delete"])
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_team_member_add_team_admin_user_api_key_auth(
|
||||||
|
prisma_client, team_member_role, team_route
|
||||||
|
):
|
||||||
|
import time
|
||||||
|
|
||||||
|
from fastapi import Request
|
||||||
|
|
||||||
|
from litellm.proxy._types import LiteLLM_TeamTableCachedObj, Member
|
||||||
|
from litellm.proxy.proxy_server import (
|
||||||
|
ProxyException,
|
||||||
|
hash_token,
|
||||||
|
user_api_key_auth,
|
||||||
|
user_api_key_cache,
|
||||||
|
)
|
||||||
|
|
||||||
|
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
||||||
|
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
||||||
|
setattr(litellm, "max_internal_user_budget", 10)
|
||||||
|
setattr(litellm, "internal_user_budget_duration", "5m")
|
||||||
|
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||||
|
user = f"ishaan {uuid.uuid4().hex}"
|
||||||
|
_team_id = "litellm-test-client-id-new"
|
||||||
|
user_key = "sk-12345678"
|
||||||
|
|
||||||
|
valid_token = UserAPIKeyAuth(
|
||||||
|
team_id=_team_id,
|
||||||
|
token=hash_token(user_key),
|
||||||
|
team_member=Member(role=team_member_role, user_id=user),
|
||||||
|
last_refreshed_at=time.time(),
|
||||||
|
)
|
||||||
|
user_api_key_cache.set_cache(key=hash_token(user_key), value=valid_token)
|
||||||
|
|
||||||
|
team_obj = LiteLLM_TeamTableCachedObj(
|
||||||
|
team_id=_team_id,
|
||||||
|
blocked=False,
|
||||||
|
last_refreshed_at=time.time(),
|
||||||
|
metadata={"guardrails": {"modify_guardrails": False}},
|
||||||
|
)
|
||||||
|
|
||||||
|
user_api_key_cache.set_cache(key="team_id:{}".format(_team_id), value=team_obj)
|
||||||
|
|
||||||
|
setattr(litellm.proxy.proxy_server, "user_api_key_cache", user_api_key_cache)
|
||||||
|
|
||||||
|
## TEST IF TEAM ADMIN ALLOWED TO CALL /MEMBER_ADD ENDPOINT
|
||||||
|
import json
|
||||||
|
|
||||||
|
from starlette.datastructures import URL
|
||||||
|
|
||||||
|
request = Request(scope={"type": "http"})
|
||||||
|
request._url = URL(url=team_route)
|
||||||
|
|
||||||
|
body = {}
|
||||||
|
json_bytes = json.dumps(body).encode("utf-8")
|
||||||
|
|
||||||
|
request._body = json_bytes
|
||||||
|
|
||||||
|
## ALLOWED BY USER_API_KEY_AUTH
|
||||||
|
await user_api_key_auth(request=request, api_key="Bearer " + user_key)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("new_member_method", ["user_id", "user_email"])
|
||||||
|
@pytest.mark.parametrize("user_role", ["admin", "user"])
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_team_member_add_team_admin(
|
||||||
|
prisma_client, new_member_method, user_role
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Relevant issue - https://github.com/BerriAI/litellm/issues/5300
|
||||||
|
|
||||||
|
Allow team admins to:
|
||||||
|
- Add and remove team members
|
||||||
|
- raise error if team member not an existing 'internal_user'
|
||||||
|
"""
|
||||||
|
import time
|
||||||
|
|
||||||
|
from fastapi import Request
|
||||||
|
|
||||||
|
from litellm.proxy._types import LiteLLM_TeamTableCachedObj, Member
|
||||||
|
from litellm.proxy.proxy_server import (
|
||||||
|
HTTPException,
|
||||||
|
ProxyException,
|
||||||
|
hash_token,
|
||||||
|
user_api_key_auth,
|
||||||
|
user_api_key_cache,
|
||||||
|
)
|
||||||
|
|
||||||
|
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
||||||
|
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
||||||
|
setattr(litellm, "max_internal_user_budget", 10)
|
||||||
|
setattr(litellm, "internal_user_budget_duration", "5m")
|
||||||
|
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||||
|
user = f"ishaan {uuid.uuid4().hex}"
|
||||||
|
_team_id = "litellm-test-client-id-new"
|
||||||
|
user_key = "sk-12345678"
|
||||||
|
|
||||||
|
valid_token = UserAPIKeyAuth(
|
||||||
|
team_id=_team_id,
|
||||||
|
user_id=user,
|
||||||
|
token=hash_token(user_key),
|
||||||
|
last_refreshed_at=time.time(),
|
||||||
|
)
|
||||||
|
user_api_key_cache.set_cache(key=hash_token(user_key), value=valid_token)
|
||||||
|
|
||||||
|
team_obj = LiteLLM_TeamTableCachedObj(
|
||||||
|
team_id=_team_id,
|
||||||
|
blocked=False,
|
||||||
|
last_refreshed_at=time.time(),
|
||||||
|
members_with_roles=[Member(role=user_role, user_id=user)],
|
||||||
|
metadata={"guardrails": {"modify_guardrails": False}},
|
||||||
|
)
|
||||||
|
|
||||||
|
user_api_key_cache.set_cache(key="team_id:{}".format(_team_id), value=team_obj)
|
||||||
|
|
||||||
|
setattr(litellm.proxy.proxy_server, "user_api_key_cache", user_api_key_cache)
|
||||||
|
if new_member_method == "user_id":
|
||||||
|
data = {
|
||||||
|
"team_id": _team_id,
|
||||||
|
"member": [{"role": "user", "user_id": user}],
|
||||||
|
}
|
||||||
|
elif new_member_method == "user_email":
|
||||||
|
data = {
|
||||||
|
"team_id": _team_id,
|
||||||
|
"member": [{"role": "user", "user_email": user}],
|
||||||
|
}
|
||||||
|
team_member_add_request = TeamMemberAddRequest(**data)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"litellm.proxy.proxy_server.prisma_client.db.litellm_usertable",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
) as mock_litellm_usertable:
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_litellm_usertable.upsert = mock_client
|
||||||
|
mock_litellm_usertable.find_many = AsyncMock(return_value=None)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await team_member_add(
|
||||||
|
data=team_member_add_request,
|
||||||
|
user_api_key_dict=valid_token,
|
||||||
|
http_request=Request(
|
||||||
|
scope={"type": "http", "path": "/user/new"},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
except HTTPException as e:
|
||||||
|
if user_role == "user":
|
||||||
|
assert e.status_code == 403
|
||||||
|
else:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
mock_client.assert_called()
|
||||||
|
|
||||||
|
print(f"mock_client.call_args: {mock_client.call_args}")
|
||||||
|
print("mock_client.call_args.kwargs: {}".format(mock_client.call_args.kwargs))
|
||||||
|
|
||||||
|
assert (
|
||||||
|
mock_client.call_args.kwargs["data"]["create"]["max_budget"]
|
||||||
|
== litellm.max_internal_user_budget
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
mock_client.call_args.kwargs["data"]["create"]["budget_duration"]
|
||||||
|
== litellm.internal_user_budget_duration
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_user_info_team_list(prisma_client):
|
async def test_user_info_team_list(prisma_client):
|
||||||
"""Assert user_info for admin calls team_list function"""
|
"""Assert user_info for admin calls team_list function"""
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue