forked from phoenix/litellm-mirror
(feat proxy) [beta] add support for organization role based access controls (#6112)
* track LiteLLM_OrganizationMembership * add add_internal_user_to_organization * add org membership to schema * read organization membership when reading user info in auth checks * add check for valid organization_id * add test for test_create_new_user_in_organization * test test_create_new_user_in_organization * add new ADMIN role * add test for org admins creating teams * add test for test_org_admin_create_user_permissions * test_org_admin_create_user_team_wrong_org_permissions * test_org_admin_create_user_team_wrong_org_permissions * fix organization_role_based_access_check * fix getting user members * fix TeamBase * fix types used for use role * fix type checks * sync prisma schema * docs - organization admins * fix use organization_endpoints for /organization management * add types for org member endpoints * fix role name for org admin * add type for member add response * add organization/member_add * add error handling for adding members to an org * add nice doc string for oranization/member_add * fix test_create_new_user_in_organization * linting fix * use simple route changes * fix types * add organization member roles * add org admin auth checks * add auth checks for orgs * test for creating teams as org admin * simplify org id usage * fix typo * test test_org_admin_create_user_team_wrong_org_permissions * fix type check issue * code quality fix * fix schema.prisma
This commit is contained in:
parent
945267a511
commit
1fd437e263
14 changed files with 1474 additions and 261 deletions
439
tests/proxy_admin_ui_tests/test_role_based_access.py
Normal file
439
tests/proxy_admin_ui_tests/test_role_based_access.py
Normal file
|
@ -0,0 +1,439 @@
|
|||
"""
|
||||
RBAC tests
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import Request
|
||||
from fastapi.routing import APIRoute
|
||||
|
||||
load_dotenv()
|
||||
import io
|
||||
import os
|
||||
import time
|
||||
|
||||
# this file is to test litellm/proxy
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
import pytest
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy.auth.auth_checks import get_user_object
|
||||
from litellm.proxy.management_endpoints.key_management_endpoints import (
|
||||
delete_key_fn,
|
||||
generate_key_fn,
|
||||
generate_key_helper_fn,
|
||||
info_key_fn,
|
||||
regenerate_key_fn,
|
||||
update_key_fn,
|
||||
)
|
||||
from litellm.proxy.management_endpoints.internal_user_endpoints import new_user
|
||||
from litellm.proxy.management_endpoints.organization_endpoints import (
|
||||
new_organization,
|
||||
organization_member_add,
|
||||
)
|
||||
|
||||
from litellm.proxy.management_endpoints.team_endpoints import (
|
||||
new_team,
|
||||
team_info,
|
||||
update_team,
|
||||
)
|
||||
from litellm.proxy.proxy_server import (
|
||||
LitellmUserRoles,
|
||||
audio_transcriptions,
|
||||
chat_completion,
|
||||
completion,
|
||||
embeddings,
|
||||
image_generation,
|
||||
model_list,
|
||||
moderations,
|
||||
new_end_user,
|
||||
user_api_key_auth,
|
||||
)
|
||||
from litellm.proxy.spend_tracking.spend_management_endpoints import (
|
||||
global_spend,
|
||||
global_spend_logs,
|
||||
global_spend_models,
|
||||
global_spend_keys,
|
||||
spend_key_fn,
|
||||
spend_user_fn,
|
||||
view_spend_logs,
|
||||
)
|
||||
from starlette.datastructures import URL
|
||||
|
||||
from litellm.proxy.utils import PrismaClient, ProxyLogging, hash_token, update_spend
|
||||
|
||||
verbose_proxy_logger.setLevel(level=logging.DEBUG)
|
||||
|
||||
from starlette.datastructures import URL
|
||||
|
||||
from litellm.caching import DualCache
|
||||
from litellm.proxy._types import *
|
||||
|
||||
proxy_logging_obj = ProxyLogging(user_api_key_cache=DualCache())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def prisma_client():
|
||||
from litellm.proxy.proxy_cli import append_query_params
|
||||
|
||||
### add connection pool + pool timeout args
|
||||
params = {"connection_limit": 100, "pool_timeout": 60}
|
||||
database_url = os.getenv("DATABASE_URL")
|
||||
modified_url = append_query_params(database_url, params)
|
||||
os.environ["DATABASE_URL"] = modified_url
|
||||
|
||||
# Assuming PrismaClient is a class that needs to be instantiated
|
||||
prisma_client = PrismaClient(
|
||||
database_url=os.environ["DATABASE_URL"], proxy_logging_obj=proxy_logging_obj
|
||||
)
|
||||
|
||||
# Reset litellm.proxy.proxy_server.prisma_client to None
|
||||
litellm.proxy.proxy_server.litellm_proxy_budget_name = (
|
||||
f"litellm-proxy-budget-{time.time()}"
|
||||
)
|
||||
litellm.proxy.proxy_server.user_custom_key_generate = None
|
||||
|
||||
return prisma_client
|
||||
|
||||
|
||||
"""
|
||||
RBAC Tests
|
||||
|
||||
1. Add a user to an organization
|
||||
- test 1 - if organization_id does exist expect to create a new user and user, organization relation
|
||||
|
||||
2. org admin creates team in his org → success
|
||||
|
||||
3. org admin adds new internal user to his org → success
|
||||
|
||||
4. org admin creates team and internal user not in his org → fail both
|
||||
"""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"user_role",
|
||||
[
|
||||
LitellmUserRoles.ORG_ADMIN,
|
||||
LitellmUserRoles.INTERNAL_USER,
|
||||
LitellmUserRoles.INTERNAL_USER_VIEW_ONLY,
|
||||
],
|
||||
)
|
||||
async def test_create_new_user_in_organization(prisma_client, user_role):
|
||||
"""
|
||||
|
||||
Add a member to an organization and assert the user object is created with the correct organization memberships / roles
|
||||
"""
|
||||
master_key = "sk-1234"
|
||||
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
||||
setattr(litellm.proxy.proxy_server, "master_key", master_key)
|
||||
|
||||
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||
|
||||
created_user_id = f"new-user-{uuid.uuid4()}"
|
||||
|
||||
response = await new_organization(
|
||||
data=NewOrganizationRequest(
|
||||
organization_alias=f"new-org-{uuid.uuid4()}",
|
||||
),
|
||||
user_api_key_dict=UserAPIKeyAuth(
|
||||
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||
),
|
||||
)
|
||||
|
||||
org_id = response.organization_id
|
||||
|
||||
response = await organization_member_add(
|
||||
data=OrganizationMemberAddRequest(
|
||||
organization_id=org_id,
|
||||
member=Member(role=user_role, user_id=created_user_id),
|
||||
),
|
||||
http_request=None,
|
||||
)
|
||||
|
||||
print("new user response", response)
|
||||
|
||||
# call get_user_object
|
||||
|
||||
user_object = await get_user_object(
|
||||
user_id=created_user_id,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=DualCache(),
|
||||
user_id_upsert=False,
|
||||
)
|
||||
|
||||
print("user object", user_object)
|
||||
|
||||
assert user_object.organization_memberships is not None
|
||||
|
||||
_membership = user_object.organization_memberships[0]
|
||||
|
||||
assert _membership.user_id == created_user_id
|
||||
assert _membership.organization_id == org_id
|
||||
|
||||
if user_role != None:
|
||||
assert _membership.user_role == user_role
|
||||
else:
|
||||
assert _membership.user_role == LitellmUserRoles.INTERNAL_USER_VIEW_ONLY
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_org_admin_create_team_permissions(prisma_client):
|
||||
"""
|
||||
Create a new org admin
|
||||
|
||||
org admin creates a new team in their org -> success
|
||||
"""
|
||||
import json
|
||||
|
||||
master_key = "sk-1234"
|
||||
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
||||
setattr(litellm.proxy.proxy_server, "master_key", master_key)
|
||||
|
||||
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||
|
||||
response = await new_organization(
|
||||
data=NewOrganizationRequest(
|
||||
organization_alias=f"new-org-{uuid.uuid4()}",
|
||||
),
|
||||
user_api_key_dict=UserAPIKeyAuth(
|
||||
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||
),
|
||||
)
|
||||
|
||||
org_id = response.organization_id
|
||||
created_user_id = f"new-user-{uuid.uuid4()}"
|
||||
response = await organization_member_add(
|
||||
data=OrganizationMemberAddRequest(
|
||||
organization_id=org_id,
|
||||
member=Member(role=LitellmUserRoles.ORG_ADMIN, user_id=created_user_id),
|
||||
),
|
||||
http_request=None,
|
||||
)
|
||||
|
||||
# create key with the response["user_id"]
|
||||
# proxy admin will generate key for org admin
|
||||
_new_key = await generate_key_fn(
|
||||
data=GenerateKeyRequest(user_id=created_user_id),
|
||||
user_api_key_dict=UserAPIKeyAuth(user_id=created_user_id),
|
||||
)
|
||||
|
||||
new_key = _new_key.key
|
||||
|
||||
print("user api key auth response", response)
|
||||
|
||||
# Create /team/new request -> expect auth to pass
|
||||
request = Request(scope={"type": "http"})
|
||||
request._url = URL(url="/team/new")
|
||||
|
||||
async def return_body():
|
||||
body = {"organization_id": org_id}
|
||||
return bytes(json.dumps(body), "utf-8")
|
||||
|
||||
request.body = return_body
|
||||
response = await user_api_key_auth(request=request, api_key="Bearer " + new_key)
|
||||
|
||||
# after auth - actually create team now
|
||||
response = await new_team(
|
||||
data=NewTeamRequest(
|
||||
organization_id=org_id,
|
||||
),
|
||||
http_request=request,
|
||||
user_api_key_dict=UserAPIKeyAuth(
|
||||
user_id=response.user_id,
|
||||
),
|
||||
)
|
||||
|
||||
print("response from new team")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_org_admin_create_user_permissions(prisma_client):
|
||||
"""
|
||||
1. Create a new org admin
|
||||
|
||||
2. org admin adds a new member to their org -> success (using using /organization/member_add)
|
||||
|
||||
"""
|
||||
import json
|
||||
|
||||
master_key = "sk-1234"
|
||||
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
||||
setattr(litellm.proxy.proxy_server, "master_key", master_key)
|
||||
|
||||
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||
|
||||
# create new org
|
||||
response = await new_organization(
|
||||
data=NewOrganizationRequest(
|
||||
organization_alias=f"new-org-{uuid.uuid4()}",
|
||||
),
|
||||
user_api_key_dict=UserAPIKeyAuth(
|
||||
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||
),
|
||||
)
|
||||
# Create Org Admin
|
||||
org_id = response.organization_id
|
||||
created_user_id = f"new-user-{uuid.uuid4()}"
|
||||
response = await organization_member_add(
|
||||
data=OrganizationMemberAddRequest(
|
||||
organization_id=org_id,
|
||||
member=Member(role=LitellmUserRoles.ORG_ADMIN, user_id=created_user_id),
|
||||
),
|
||||
http_request=None,
|
||||
)
|
||||
|
||||
# create key with for Org Admin
|
||||
_new_key = await generate_key_fn(
|
||||
data=GenerateKeyRequest(user_id=created_user_id),
|
||||
user_api_key_dict=UserAPIKeyAuth(user_id=created_user_id),
|
||||
)
|
||||
|
||||
new_key = _new_key.key
|
||||
|
||||
print("user api key auth response", response)
|
||||
|
||||
# Create /organization/member_add request -> expect auth to pass
|
||||
request = Request(scope={"type": "http"})
|
||||
request._url = URL(url="/organization/member_add")
|
||||
|
||||
async def return_body():
|
||||
body = {"organization_id": org_id}
|
||||
return bytes(json.dumps(body), "utf-8")
|
||||
|
||||
request.body = return_body
|
||||
response = await user_api_key_auth(request=request, api_key="Bearer " + new_key)
|
||||
|
||||
# after auth - actually actually add new user to organization
|
||||
new_internal_user_for_org = f"new-org-user-{uuid.uuid4()}"
|
||||
response = await organization_member_add(
|
||||
data=OrganizationMemberAddRequest(
|
||||
organization_id=org_id,
|
||||
member=Member(
|
||||
role=LitellmUserRoles.INTERNAL_USER, user_id=new_internal_user_for_org
|
||||
),
|
||||
),
|
||||
http_request=request,
|
||||
)
|
||||
|
||||
print("response from new team")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_org_admin_create_user_team_wrong_org_permissions(prisma_client):
|
||||
"""
|
||||
Create a new org admin
|
||||
|
||||
org admin creates a new user and new team in orgs they are not part of -> expect error
|
||||
"""
|
||||
import json
|
||||
|
||||
master_key = "sk-1234"
|
||||
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
||||
setattr(litellm.proxy.proxy_server, "master_key", master_key)
|
||||
|
||||
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||
created_user_id = f"new-user-{uuid.uuid4()}"
|
||||
response = await new_organization(
|
||||
data=NewOrganizationRequest(
|
||||
organization_alias=f"new-org-{uuid.uuid4()}",
|
||||
),
|
||||
user_api_key_dict=UserAPIKeyAuth(
|
||||
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||
),
|
||||
)
|
||||
|
||||
response2 = await new_organization(
|
||||
data=NewOrganizationRequest(
|
||||
organization_alias=f"new-org-{uuid.uuid4()}",
|
||||
),
|
||||
user_api_key_dict=UserAPIKeyAuth(
|
||||
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||
),
|
||||
)
|
||||
|
||||
org1_id = response.organization_id # has an admin
|
||||
|
||||
org2_id = response2.organization_id # does not have an org admin
|
||||
|
||||
# Create Org Admin for Org1
|
||||
created_user_id = f"new-user-{uuid.uuid4()}"
|
||||
response = await organization_member_add(
|
||||
data=OrganizationMemberAddRequest(
|
||||
organization_id=org1_id,
|
||||
member=Member(role=LitellmUserRoles.ORG_ADMIN, user_id=created_user_id),
|
||||
),
|
||||
http_request=None,
|
||||
)
|
||||
|
||||
_new_key = await generate_key_fn(
|
||||
data=GenerateKeyRequest(
|
||||
user_id=created_user_id,
|
||||
),
|
||||
user_api_key_dict=UserAPIKeyAuth(
|
||||
user_role=LitellmUserRoles.ORG_ADMIN,
|
||||
user_id=created_user_id,
|
||||
),
|
||||
)
|
||||
|
||||
new_key = _new_key.key
|
||||
|
||||
print("user api key auth response", response)
|
||||
|
||||
# Add a new request in organization=org_without_admins -> expect fail (organization/member_add)
|
||||
request = Request(scope={"type": "http"})
|
||||
request._url = URL(url="/organization/member_add")
|
||||
|
||||
async def return_body():
|
||||
body = {"organization_id": org2_id}
|
||||
return bytes(json.dumps(body), "utf-8")
|
||||
|
||||
request.body = return_body
|
||||
|
||||
try:
|
||||
response = await user_api_key_auth(request=request, api_key="Bearer " + new_key)
|
||||
pytest.fail(
|
||||
f"This should have failed!. creating a user in an org without admins"
|
||||
)
|
||||
except Exception as e:
|
||||
print("got exception", e)
|
||||
print("exception.message", e.message)
|
||||
assert (
|
||||
"You do not have a role within the selected organization. Passed organization_id"
|
||||
in e.message
|
||||
)
|
||||
|
||||
# Create /team/new request in organization=org_without_admins -> expect fail
|
||||
request = Request(scope={"type": "http"})
|
||||
request._url = URL(url="/team/new")
|
||||
|
||||
async def return_body():
|
||||
body = {"organization_id": org2_id}
|
||||
return bytes(json.dumps(body), "utf-8")
|
||||
|
||||
request.body = return_body
|
||||
|
||||
try:
|
||||
response = await user_api_key_auth(request=request, api_key="Bearer " + new_key)
|
||||
pytest.fail(
|
||||
f"This should have failed!. Org Admin creating a team in an org where they are not an admin"
|
||||
)
|
||||
except Exception as e:
|
||||
print("got exception", e)
|
||||
print("exception.message", e.message)
|
||||
assert (
|
||||
"You do not have the required role to call" in e.message
|
||||
and org2_id in e.message
|
||||
)
|
|
@ -33,7 +33,7 @@ async def new_organization(session, i, organization_alias, max_budget=None):
|
|||
@pytest.mark.asyncio
|
||||
async def test_organization_new():
|
||||
"""
|
||||
Make 20 parallel calls to /user/new. Assert all worked.
|
||||
Make 20 parallel calls to /organization/new. Assert all worked.
|
||||
"""
|
||||
organization_alias = f"Organization: {uuid.uuid4()}"
|
||||
async with aiohttp.ClientSession() as session:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue