From 1fd437e263ae7e189de8424c16f326f972923d98 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Wed, 9 Oct 2024 15:18:18 +0530 Subject: [PATCH] (feat proxy) [beta] add support for organization role based access controls (#6112) * track LiteLLM_OrganizationMembership * add add_internal_user_to_organization * add org membership to schema * read organization membership when reading user info in auth checks * add check for valid organization_id * add test for test_create_new_user_in_organization * test test_create_new_user_in_organization * add new ADMIN role * add test for org admins creating teams * add test for test_org_admin_create_user_permissions * test_org_admin_create_user_team_wrong_org_permissions * test_org_admin_create_user_team_wrong_org_permissions * fix organization_role_based_access_check * fix getting user members * fix TeamBase * fix types used for use role * fix type checks * sync prisma schema * docs - organization admins * fix use organization_endpoints for /organization management * add types for org member endpoints * fix role name for org admin * add type for member add response * add organization/member_add * add error handling for adding members to an org * add nice doc string for oranization/member_add * fix test_create_new_user_in_organization * linting fix * use simple route changes * fix types * add organization member roles * add org admin auth checks * add auth checks for orgs * test for creating teams as org admin * simplify org id usage * fix typo * test test_org_admin_create_user_team_wrong_org_permissions * fix type check issue * code quality fix * fix schema.prisma --- docs/my-website/docs/proxy/access_control.md | 145 ++++++ docs/my-website/sidebars.js | 7 +- litellm/proxy/_types.py | 189 +++++--- litellm/proxy/auth/auth_checks.py | 31 +- .../proxy/auth/auth_checks_organization.py | 161 +++++++ litellm/proxy/auth/route_checks.py | 10 +- litellm/proxy/auth/user_api_key_auth.py | 4 +- .../internal_user_endpoints.py | 60 ++- .../organization_endpoints.py | 433 +++++++++++++++++ litellm/proxy/proxy_server.py | 206 +------- litellm/proxy/schema.prisma | 24 +- schema.prisma | 24 +- .../test_role_based_access.py | 439 ++++++++++++++++++ tests/test_organizations.py | 2 +- 14 files changed, 1474 insertions(+), 261 deletions(-) create mode 100644 docs/my-website/docs/proxy/access_control.md create mode 100644 litellm/proxy/auth/auth_checks_organization.py create mode 100644 litellm/proxy/management_endpoints/organization_endpoints.py create mode 100644 tests/proxy_admin_ui_tests/test_role_based_access.py diff --git a/docs/my-website/docs/proxy/access_control.md b/docs/my-website/docs/proxy/access_control.md new file mode 100644 index 000000000..5ffcfad5d --- /dev/null +++ b/docs/my-website/docs/proxy/access_control.md @@ -0,0 +1,145 @@ +# Role-based Access Controls (RBAC) + +Role-based access control (RBAC) is based on Organizations, Teams and Internal User Roles + +- `Organizations` are the top-level entities that contain Teams. +- `Team` - A Team is a collection of multiple `Internal Users` +- `Internal Users` - users that can create keys, make LLM API calls, view usage on LiteLLM +- `Roles` define the permissions of an `Internal User` +- `Virtual Keys` - Keys are used for authentication to the LiteLLM API. Keys are tied to a `Internal User` and `Team` + +## Roles + +**Admin Roles** + - `proxy_admin`: admin over the platform + - `proxy_admin_viewer`: can login, view all keys, view all spend. **Cannot** create/delete keys, add new users. + +**Organization Roles** + - `organization_admin`: admin over the organization. Can create teams and users within their organization + +**Internal User Roles** + - `internal_user`: can login, view/create/delete their own keys, view their spend. **Cannot** add new users. + - `internal_user_viewer`: can login, view their own keys, view their own spend. **Cannot** create/delete keys, add new users. + + +## Managing Organizations + +### 1. Creating a new Organization + +Any user with role=`proxy_admin` can create a new organization + +**Usage** + +[**API Reference for /organization/new**](https://litellm-api.up.railway.app/#/organization%20management/new_organization_organization_new_post) + +```shell +curl --location 'http://0.0.0.0:4000/organization/new' \ + --header 'Authorization: Bearer sk-1234' \ + --header 'Content-Type: application/json' \ + --data '{ + "organization_alias": "marketing_department", + "models": ["gpt-4"], + "max_budget": 20 + }' +``` + +Expected Response + +```json +{ + "organization_id": "ad15e8ca-12ae-46f4-8659-d02debef1b23", + "organization_alias": "marketing_department", + "budget_id": "98754244-3a9c-4b31-b2e9-c63edc8fd7eb", + "metadata": {}, + "models": [ + "gpt-4" + ], + "created_by": "109010464461339474872", + "updated_by": "109010464461339474872", + "created_at": "2024-10-08T18:30:24.637000Z", + "updated_at": "2024-10-08T18:30:24.637000Z" +} +``` + + +### 2. Adding an `organization_admin` to an Organization + +Create a user (ishaan@berri.ai) as an `organization_admin` for the `marketing_department` Organization (from [step 1](#1-creating-a-new-organization)) + +Users with the following roles can call `/organization/member_add` +- `proxy_admin` +- `organization_admin` only within their own organization + +```shell +curl -X POST 'http://0.0.0.0:4000/organization/member_add' \ + -H 'Authorization: Bearer sk-1234' \ + -H 'Content-Type: application/json' \ + -d '{"organization_id": "ad15e8ca-12ae-46f4-8659-d02debef1b23", "member": {"role": "organization_admin", "user_id": "ishaan@berri.ai"}}' +``` + +Now a user with user_id = `ishaan@berri.ai` and role = `organization_admin` has been created in the `marketing_department` Organization + +Create a Virtual Key for user_id = `ishaan@berri.ai`. The User can then use the Virtual key for their Organization Admin Operations + +```shell +curl --location 'http://0.0.0.0:4000/key/generate' \ + --header 'Authorization: Bearer sk-1234' \ + --header 'Content-Type: application/json' \ + --data '{ + "user_id": "ishaan@berri.ai" + }' +``` + +Expected Response + +```json +{ + "models": [], + "user_id": "ishaan@berri.ai", + "key": "sk-7shH8TGMAofR4zQpAAo6kQ", + "key_name": "sk-...o6kQ", +} +``` + +### 3. `Organization Admin` - Create a Team + +The organization admin will use the virtual key created in [step 2](#2-adding-an-organization_admin-to-an-organization) to create a `Team` within the `marketing_department` Organization + +```shell +curl --location 'http://0.0.0.0:4000/team/new' \ + --header 'Authorization: Bearer sk-7shH8TGMAofR4zQpAAo6kQ' \ + --header 'Content-Type: application/json' \ + --data '{ + "team_alias": "engineering_team", + "organization_id": "ad15e8ca-12ae-46f4-8659-d02debef1b23", + }' +``` + +This will create the team `engineering_team` within the `marketing_department` Organization + +Expected Response + +```json +{ + "team_alias": "engineering_team", + "team_id": "01044ee8-441b-45f4-be7d-c70e002722d8", + "organization_id": "ad15e8ca-12ae-46f4-8659-d02debef1b23", +} +``` + + +### `Organization Admin` - Add an `Internal User` + +The organization admin will use the virtual key created in [step 2](#2-adding-an-organization_admin-to-an-organization) to add an Internal User to the `engineering_team` Team. + +- We will assign role=`internal_user` so the user can create Virtual Keys for themselves +- `team_id` is from [step 3](#3-organization-admin---create-a-team) + +```shell +curl -X POST 'http://0.0.0.0:4000/team/member_add' \ + -H 'Authorization: Bearer sk-1234' \ + -H 'Content-Type: application/json' \ + -d '{"team_id": "01044ee8-441b-45f4-be7d-c70e002722d8",, "member": {"role": "internal_user", "user_id": "krrish@berri.ai"}}' + +``` + diff --git a/docs/my-website/sidebars.js b/docs/my-website/sidebars.js index 6c58c5002..12967b573 100644 --- a/docs/my-website/sidebars.js +++ b/docs/my-website/sidebars.js @@ -78,7 +78,12 @@ const sidebars = { { type: "category", label: "Admin UI", - items: ["proxy/ui", "proxy/self_serve", "proxy/custom_sso"], + items: [ + "proxy/ui", + "proxy/self_serve", + "proxy/access_control", + "proxy/custom_sso" + ], }, { type: "category", diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 4d9188f9e..efc3542b1 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -34,6 +34,7 @@ class LitellmUserRoles(str, enum.Enum): Admin Roles: PROXY_ADMIN: admin over the platform PROXY_ADMIN_VIEW_ONLY: can login, view all own keys, view all spend + ORG_ADMIN: admin over a specific organization, can create teams, users only within their organization Internal User Roles: INTERNAL_USER: can login, view/create/delete their own keys, view their spend @@ -53,6 +54,9 @@ class LitellmUserRoles(str, enum.Enum): PROXY_ADMIN = "proxy_admin" PROXY_ADMIN_VIEW_ONLY = "proxy_admin_viewer" + # Organization admins + ORG_ADMIN = "org_admin" + # Internal User Roles INTERNAL_USER = "internal_user" INTERNAL_USER_VIEW_ONLY = "internal_user_viewer" @@ -359,6 +363,20 @@ class LiteLLMRoutes(enum.Enum): "/team/member_delete", ] # routes that manage their own allowed/disallowed logic + ## Org Admin Routes ## + + # Routes only an Org Admin Can Access + org_admin_only_routes = [ + "/organization/info", + "/organization/delete", + "/organization/member_add", + ] + + # All routes accesible by an Org Admin + org_admin_allowed_routes = ( + org_admin_only_routes + management_routes + self_managed_routes + ) + # class LiteLLMAllowedRoutes(LiteLLMBase): # """ @@ -695,12 +713,9 @@ class NewUserRequest(_GenerateKeyRequest): LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY, LitellmUserRoles.INTERNAL_USER, LitellmUserRoles.INTERNAL_USER_VIEW_ONLY, - LitellmUserRoles.TEAM, - LitellmUserRoles.CUSTOMER, ] ] = None teams: Optional[list] = None - organization_id: Optional[str] = None auto_create_key: bool = ( True # flag used for returning a key as part of the /user/new response ) @@ -716,12 +731,9 @@ class NewUserResponse(GenerateKeyResponse): LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY, LitellmUserRoles.INTERNAL_USER, LitellmUserRoles.INTERNAL_USER_VIEW_ONLY, - LitellmUserRoles.TEAM, - LitellmUserRoles.CUSTOMER, ] ] = None teams: Optional[list] = None - organization_id: Optional[str] = None user_alias: Optional[str] = None @@ -739,8 +751,6 @@ class UpdateUserRequest(GenerateRequestBase): LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY, LitellmUserRoles.INTERNAL_USER, LitellmUserRoles.INTERNAL_USER_VIEW_ONLY, - LitellmUserRoles.TEAM, - LitellmUserRoles.CUSTOMER, ] ] = None max_budget: Optional[float] = None @@ -811,7 +821,14 @@ class DeleteCustomerRequest(LiteLLMBase): class Member(LiteLLMBase): - role: Literal["admin", "user"] + role: Literal[ + LitellmUserRoles.ORG_ADMIN, + LitellmUserRoles.INTERNAL_USER, + LitellmUserRoles.INTERNAL_USER_VIEW_ONLY, + # older Member roles + "admin", + "user", + ] user_id: Optional[str] = None user_email: Optional[str] = None @@ -857,51 +874,6 @@ class GlobalEndUsersSpend(LiteLLMBase): endTime: Optional[datetime] = None -class TeamMemberAddRequest(LiteLLMBase): - team_id: str - member: Union[List[Member], Member] - max_budget_in_team: Optional[float] = None # Users max budget within the team - - def __init__(self, **data): - member_data = data.get("member") - if isinstance(member_data, list): - # If member is a list of dictionaries, convert each dictionary to a Member object - members = [Member(**item) for item in member_data] - # Replace member_data with the list of Member objects - data["member"] = members - elif isinstance(member_data, dict): - # If member is a dictionary, convert it to a single Member object - member = Member(**member_data) - # Replace member_data with the single Member object - data["member"] = member - # Call the superclass __init__ method to initialize the object - super().__init__(**data) - - -class TeamMemberDeleteRequest(LiteLLMBase): - team_id: str - user_id: Optional[str] = None - user_email: Optional[str] = None - - @model_validator(mode="before") - @classmethod - def check_user_info(cls, values): - if values.get("user_id") is None and values.get("user_email") is None: - raise ValueError("Either user id or user email must be provided") - return values - - -class TeamMemberUpdateRequest(TeamMemberDeleteRequest): - max_budget_in_team: float - - -class TeamMemberUpdateResponse(LiteLLMBase): - team_id: str - user_id: str - user_email: Optional[str] = None - max_budget_in_team: float - - class UpdateTeamRequest(LiteLLMBase): """ UpdateTeamRequest, used by /team/update when you need to update a team @@ -1444,6 +1416,26 @@ class LiteLLM_Config(LiteLLMBase): param_value: Dict +class LiteLLM_OrganizationMembershipTable(LiteLLMBase): + """ + This is the table that track what organizations a user belongs to and users spend within the organization + """ + + user_id: str + organization_id: str + user_role: Optional[str] = None + spend: float = 0.0 + budget_id: Optional[str] = None + created_at: datetime + updated_at: datetime + user: Optional[Any] = ( + None # You might want to replace 'Any' with a more specific type if available + ) + litellm_budget_table: Optional[LiteLLM_BudgetTable] = None + + model_config = ConfigDict(protected_namespaces=()) + + class LiteLLM_UserTable(LiteLLMBase): user_id: str max_budget: Optional[float] @@ -1455,6 +1447,7 @@ class LiteLLM_UserTable(LiteLLMBase): tpm_limit: Optional[int] = None rpm_limit: Optional[int] = None user_role: Optional[str] = None + organization_memberships: Optional[List[LiteLLM_OrganizationMembershipTable]] = None @model_validator(mode="before") @classmethod @@ -1907,11 +1900,99 @@ class LiteLLM_TeamMembership(LiteLLMBase): litellm_budget_table: Optional[LiteLLM_BudgetTable] +#### Organization / Team Member Requests #### + + +class MemberAddRequest(LiteLLMBase): + member: Union[List[Member], Member] + + def __init__(self, **data): + member_data = data.get("member") + if isinstance(member_data, list): + # If member is a list of dictionaries, convert each dictionary to a Member object + members = [Member(**item) for item in member_data] + # Replace member_data with the list of Member objects + data["member"] = members + elif isinstance(member_data, dict): + # If member is a dictionary, convert it to a single Member object + member = Member(**member_data) + # Replace member_data with the single Member object + data["member"] = member + # Call the superclass __init__ method to initialize the object + super().__init__(**data) + + class TeamAddMemberResponse(LiteLLM_TeamTable): updated_users: List[LiteLLM_UserTable] updated_team_memberships: List[LiteLLM_TeamMembership] +class OrganizationAddMemberResponse(LiteLLMBase): + organization_id: str + updated_users: List[LiteLLM_UserTable] + updated_organization_memberships: List[LiteLLM_OrganizationMembershipTable] + + +class MemberDeleteRequest(LiteLLMBase): + user_id: Optional[str] = None + user_email: Optional[str] = None + + @model_validator(mode="before") + @classmethod + def check_user_info(cls, values): + if values.get("user_id") is None and values.get("user_email") is None: + raise ValueError("Either user id or user email must be provided") + return values + + +class MemberUpdateResponse(LiteLLMBase): + user_id: str + user_email: Optional[str] = None + + +# Team Member Requests +class TeamMemberAddRequest(MemberAddRequest): + team_id: str + max_budget_in_team: Optional[float] = None # Users max budget within the team + + +class TeamMemberDeleteRequest(MemberDeleteRequest): + team_id: str + + +class TeamMemberUpdateRequest(TeamMemberDeleteRequest): + max_budget_in_team: float + + +class TeamMemberUpdateResponse(MemberUpdateResponse): + team_id: str + max_budget_in_team: float + + +# Organization Member Requests +class OrganizationMemberAddRequest(MemberAddRequest): + organization_id: str + max_budget_in_organization: Optional[float] = ( + None # Users max budget within the organization + ) + + +class OrganizationMemberDeleteRequest(MemberDeleteRequest): + organization_id: str + + +class OrganizationMemberUpdateRequest(OrganizationMemberDeleteRequest): + max_budget_in_organization: float + + +class OrganizationMemberUpdateResponse(MemberUpdateResponse): + organization_id: str + max_budget_in_organization: float + + +########################################## + + class TeamInfoResponseObject(TypedDict): team_id: str team_info: LiteLLM_TeamTable diff --git a/litellm/proxy/auth/auth_checks.py b/litellm/proxy/auth/auth_checks.py index 7da1caa5d..49f2953c1 100644 --- a/litellm/proxy/auth/auth_checks.py +++ b/litellm/proxy/auth/auth_checks.py @@ -32,6 +32,8 @@ from litellm.proxy.auth.route_checks import is_llm_api_route from litellm.proxy.utils import PrismaClient, ProxyLogging, log_to_opentelemetry from litellm.types.services import ServiceLoggerPayload, ServiceTypes +from .auth_checks_organization import organization_role_based_access_check + if TYPE_CHECKING: from opentelemetry.trace import Span as _Span @@ -63,6 +65,7 @@ def common_checks( 7. [OPTIONAL] If 'litellm.max_budget' is set (>0), is proxy under budget 8. [OPTIONAL] If guardrails modified - is request allowed to change this 9. Check if request body is safe + 10. [OPTIONAL] Organization checks - is user_object.organization_id is set, run these checks """ _model = request_body.get("model", None) if team_object is not None and team_object.blocked is True: @@ -73,6 +76,7 @@ def common_checks( if ( _model is not None and team_object is not None + and team_object.models is not None and len(team_object.models) > 0 and _model not in team_object.models ): @@ -202,6 +206,12 @@ def common_checks( "error": "Your team does not have permission to modify guardrails." }, ) + + # 10 [OPTIONAL] Organization RBAC checks + organization_role_based_access_check( + user_object=user_object, route=route, request_body=request_body + ) + return True @@ -403,17 +413,30 @@ async def get_user_object( try: response = await prisma_client.db.litellm_usertable.find_unique( - where={"user_id": user_id} + where={"user_id": user_id}, include={"organization_memberships": True} ) if response is None: if user_id_upsert: response = await prisma_client.db.litellm_usertable.create( - data={"user_id": user_id} + data={"user_id": user_id}, + include={"organization_memberships": True}, ) else: raise Exception + if ( + response.organization_memberships is not None + and len(response.organization_memberships) > 0 + ): + # dump each organization membership to type LiteLLM_OrganizationMembershipTable + _dumped_memberships = [ + membership.model_dump() + for membership in response.organization_memberships + if membership is not None + ] + response.organization_memberships = _dumped_memberships + _response = LiteLLM_UserTable(**dict(response)) response_dict = _response.model_dump() @@ -421,9 +444,9 @@ async def get_user_object( await user_api_key_cache.async_set_cache(key=user_id, value=response_dict) return _response - except Exception: # if user not in db + except Exception as e: # if user not in db raise ValueError( - f"User doesn't exist in db. 'user_id'={user_id}. Create user via `/user/new` call." + f"User doesn't exist in db. 'user_id'={user_id}. Create user via `/user/new` call. Got error - {e}" ) diff --git a/litellm/proxy/auth/auth_checks_organization.py b/litellm/proxy/auth/auth_checks_organization.py new file mode 100644 index 000000000..3da3d8ddd --- /dev/null +++ b/litellm/proxy/auth/auth_checks_organization.py @@ -0,0 +1,161 @@ +""" +Auth Checks for Organizations +""" + +from typing import Dict, List, Optional, Tuple + +from fastapi import status + +from litellm.proxy._types import * + + +def organization_role_based_access_check( + request_body: dict, + user_object: Optional[LiteLLM_UserTable], + route: str, +): + """ + Role based access control checks only run if a user is part of an Organization + + Organization Checks: + ONLY RUN IF user_object.organization_memberships is not None + + 1. Only Proxy Admins can access /organization/new + 2. IF route is a LiteLLMRoutes.org_admin_only_routes, then check if user is an Org Admin for that organization + + """ + + if user_object is None: + return + + passed_organization_id: Optional[str] = request_body.get("organization_id", None) + + if route == "/organization/new": + if user_object.user_role != LitellmUserRoles.PROXY_ADMIN.value: + raise ProxyException( + message=f"Only proxy admins can create new organizations. You are {user_object.user_role}", + type=ProxyErrorTypes.auth_error.value, + param="user_role", + code=status.HTTP_401_UNAUTHORIZED, + ) + + if user_object.user_role == LitellmUserRoles.PROXY_ADMIN.value: + return + + # Checks if route is an Org Admin Only Route + if route in LiteLLMRoutes.org_admin_only_routes.value: + _user_organizations, _user_organization_role_mapping = ( + get_user_organization_info(user_object) + ) + + if user_object.organization_memberships is None: + raise ProxyException( + message=f"Tried to access route={route} but you are not a member of any organization. Please contact the proxy admin to request access.", + type=ProxyErrorTypes.auth_error.value, + param="organization_id", + code=status.HTTP_401_UNAUTHORIZED, + ) + + if passed_organization_id is None: + raise ProxyException( + message="Passed organization_id is None, please pass an organization_id in your request", + type=ProxyErrorTypes.auth_error.value, + param="organization_id", + code=status.HTTP_401_UNAUTHORIZED, + ) + + user_role: Optional[LitellmUserRoles] = _user_organization_role_mapping.get( + passed_organization_id + ) + if user_role is None: + raise ProxyException( + message=f"You do not have a role within the selected organization. Passed organization_id: {passed_organization_id}. Please contact the organization admin to request access.", + type=ProxyErrorTypes.auth_error.value, + param="organization_id", + code=status.HTTP_401_UNAUTHORIZED, + ) + + if user_role != LitellmUserRoles.ORG_ADMIN.value: + raise ProxyException( + message=f"You do not have the required role to perform {route} in Organization {passed_organization_id}. Your role is {user_role} in Organization {passed_organization_id}", + type=ProxyErrorTypes.auth_error.value, + param="user_role", + code=status.HTTP_401_UNAUTHORIZED, + ) + elif route == "/team/new": + # if user is part of multiple teams, then they need to specify the organization_id + _user_organizations, _user_organization_role_mapping = ( + get_user_organization_info(user_object) + ) + if ( + user_object.organization_memberships is not None + and len(user_object.organization_memberships) > 0 + ): + if passed_organization_id is None: + raise ProxyException( + message=f"Passed organization_id is None, please specify the organization_id in your request. You are part of multiple organizations: {_user_organizations}", + type=ProxyErrorTypes.auth_error.value, + param="organization_id", + code=status.HTTP_401_UNAUTHORIZED, + ) + + _user_role_in_passed_org = _user_organization_role_mapping.get( + passed_organization_id + ) + if _user_role_in_passed_org != LitellmUserRoles.ORG_ADMIN.value: + raise ProxyException( + message=f"You do not have the required role to call {route}. Your role is {_user_role_in_passed_org} in Organization {passed_organization_id}", + type=ProxyErrorTypes.auth_error.value, + param="user_role", + code=status.HTTP_401_UNAUTHORIZED, + ) + + +def get_user_organization_info( + user_object: LiteLLM_UserTable, +) -> Tuple[List[str], Dict[str, Optional[LitellmUserRoles]]]: + """ + Helper function to extract user organization information. + + Args: + user_object (LiteLLM_UserTable): The user object containing organization memberships. + + Returns: + Tuple[List[str], Dict[str, Optional[LitellmUserRoles]]]: A tuple containing: + - List of organization IDs the user is a member of + - Dictionary mapping organization IDs to user roles + """ + _user_organizations: List[str] = [] + _user_organization_role_mapping: Dict[str, Optional[LitellmUserRoles]] = {} + + if user_object.organization_memberships is not None: + for _membership in user_object.organization_memberships: + if _membership.organization_id is not None: + _user_organizations.append(_membership.organization_id) + _user_organization_role_mapping[_membership.organization_id] = _membership.user_role # type: ignore + + return _user_organizations, _user_organization_role_mapping + + +def _user_is_org_admin( + request_data: dict, + user_object: Optional[LiteLLM_UserTable] = None, +) -> bool: + """ + Helper function to check if user is an org admin for the passed organization_id + """ + if request_data.get("organization_id", None) is None: + return False + + if user_object is None: + return False + + if user_object.organization_memberships is None: + return False + + for _membership in user_object.organization_memberships: + if _membership.organization_id == request_data.get("organization_id", None): + if _membership.user_role == LitellmUserRoles.ORG_ADMIN.value: + return True + + return False diff --git a/litellm/proxy/auth/route_checks.py b/litellm/proxy/auth/route_checks.py index 5a370d8c8..cc8fd3113 100644 --- a/litellm/proxy/auth/route_checks.py +++ b/litellm/proxy/auth/route_checks.py @@ -13,10 +13,11 @@ from litellm.proxy._types import ( ) from litellm.proxy.utils import hash_token +from .auth_checks_organization import _user_is_org_admin from .auth_utils import _has_user_setup_sso -def non_admin_allowed_routes_check( +def non_proxy_admin_allowed_routes_check( user_obj: Optional[LiteLLM_UserTable], _user_role: Optional[LitellmUserRoles], route: str, @@ -26,7 +27,7 @@ def non_admin_allowed_routes_check( request_data: dict, ): """ - Checks if Non-Admin User is allowed to access the route + Checks if Non Proxy Admin User is allowed to access the route """ # Check user has defined custom admin routes @@ -106,6 +107,11 @@ def non_admin_allowed_routes_check( and route in LiteLLMRoutes.internal_user_routes.value ): pass + elif ( + _user_is_org_admin(request_data=request_data, user_object=user_obj) + and route in LiteLLMRoutes.org_admin_allowed_routes.value + ): + pass elif ( _user_role == LitellmUserRoles.INTERNAL_USER_VIEW_ONLY.value and route in LiteLLMRoutes.internal_user_view_only_routes.value diff --git a/litellm/proxy/auth/user_api_key_auth.py b/litellm/proxy/auth/user_api_key_auth.py index 7985ed389..6eecb5980 100644 --- a/litellm/proxy/auth/user_api_key_auth.py +++ b/litellm/proxy/auth/user_api_key_auth.py @@ -69,7 +69,7 @@ from litellm.proxy.auth.auth_utils import ( ) from litellm.proxy.auth.oauth2_check import check_oauth2_token from litellm.proxy.auth.oauth2_proxy_hook import handle_oauth2_proxy_request -from litellm.proxy.auth.route_checks import non_admin_allowed_routes_check +from litellm.proxy.auth.route_checks import non_proxy_admin_allowed_routes_check from litellm.proxy.auth.service_account_checks import service_account_checks from litellm.proxy.common_utils.http_parsing_utils import _read_request_body from litellm.proxy.utils import _to_ns @@ -1042,7 +1042,7 @@ async def user_api_key_auth( _user_role = _get_user_role(user_obj=user_obj) if not _is_user_proxy_admin(user_obj=user_obj): # if non-admin - non_admin_allowed_routes_check( + non_proxy_admin_allowed_routes_check( user_obj=user_obj, _user_role=_user_role, route=route, diff --git a/litellm/proxy/management_endpoints/internal_user_endpoints.py b/litellm/proxy/management_endpoints/internal_user_endpoints.py index 85bc7493b..bea27ece1 100644 --- a/litellm/proxy/management_endpoints/internal_user_endpoints.py +++ b/litellm/proxy/management_endpoints/internal_user_endpoints.py @@ -63,7 +63,6 @@ async def new_user( - user_id: Optional[str] - Specify a user id. If not set, a unique id will be generated. - user_alias: Optional[str] - A descriptive name for you to know who this user id refers to. - teams: Optional[list] - specify a list of team id's a user belongs to. - - organization_id: Optional[str] - specify the org a user belongs to. - user_email: Optional[str] - Specify a user email. - send_invite_email: Optional[bool] - Specify if an invite email should be sent. - user_role: Optional[str] - Specify a user role - "proxy_admin", "proxy_admin_viewer", "internal_user", "internal_user_viewer", "team", "customer". Info about each role here: `https://github.com/BerriAI/litellm/litellm/proxy/_types.py#L20` @@ -79,6 +78,18 @@ async def new_user( - expires: (datetime) Datetime object for when key expires. - user_id: (str) Unique user id - used for tracking spend across multiple keys for same user id. - max_budget: (float|None) Max budget for given user. + + Usage Example + + ```shell + curl -X POST "http://localhost:4000/user/new" \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer sk-1234" \ + -d '{ + "username": "new_user", + "email": "new_user@example.com" + }' + ``` """ from litellm.proxy.proxy_server import general_settings, proxy_logging_obj @@ -106,6 +117,7 @@ async def new_user( response = await generate_key_helper_fn(request_type="user", **data_json) # Admin UI Logic + # Add User to Team and Organization # if team_id passed add this user to the team if data_json.get("team_id", None) is not None: from litellm.proxy.management_endpoints.team_endpoints import team_member_add @@ -888,3 +900,49 @@ async def delete_user( ) return deleted_users + + +async def add_internal_user_to_organization( + user_id: str, + organization_id: str, + user_role: LitellmUserRoles, +): + """ + Helper function to add an internal user to an organization + + Adds the user to LiteLLM_OrganizationMembership table + + - Checks if organization_id exists + + Raises: + - Exception if database not connected + - Exception if user_id or organization_id not found + """ + from litellm.proxy.proxy_server import prisma_client + + if prisma_client is None: + raise Exception("Database not connected") + + try: + # Check if organization_id exists + organization_row = await prisma_client.db.litellm_organizationtable.find_unique( + where={"organization_id": organization_id} + ) + if organization_row is None: + raise Exception( + f"Organization not found, passed organization_id={organization_id}" + ) + + # Create a new organization membership entry + new_membership = await prisma_client.db.litellm_organizationmembership.create( + data={ + "user_id": user_id, + "organization_id": organization_id, + "user_role": user_role, + # Note: You can also set budget within an organization if needed + } + ) + + return new_membership + except Exception as e: + raise Exception(f"Failed to add user to organization: {str(e)}") diff --git a/litellm/proxy/management_endpoints/organization_endpoints.py b/litellm/proxy/management_endpoints/organization_endpoints.py new file mode 100644 index 000000000..f448d2fad --- /dev/null +++ b/litellm/proxy/management_endpoints/organization_endpoints.py @@ -0,0 +1,433 @@ +""" +Endpoints for /organization operations + +/organization/new +/organization/update +/organization/delete +/organization/info +""" + +#### ORGANIZATION MANAGEMENT #### + +import asyncio +import copy +import json +import re +import secrets +import traceback +import uuid +from datetime import datetime, timedelta, timezone +from typing import List, Optional, Tuple + +import fastapi +from fastapi import APIRouter, Depends, Header, HTTPException, Query, Request, status + +import litellm +from litellm._logging import verbose_proxy_logger +from litellm.proxy._types import * +from litellm.proxy.auth.user_api_key_auth import user_api_key_auth +from litellm.proxy.management_helpers.utils import ( + get_new_internal_user_defaults, + management_endpoint_wrapper, +) +from litellm.proxy.utils import PrismaClient +from litellm.secret_managers.main import get_secret + +router = APIRouter() + + +@router.post( + "/organization/new", + tags=["organization management"], + dependencies=[Depends(user_api_key_auth)], + response_model=NewOrganizationResponse, +) +async def new_organization( + data: NewOrganizationRequest, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + Allow orgs to own teams + + Set org level budgets + model access. + + Only admins can create orgs. + + # Parameters + + - `organization_alias`: *str* = The name of the organization. + - `models`: *List* = The models the organization has access to. + - `budget_id`: *Optional[str]* = The id for a budget (tpm/rpm/max budget) for the organization. + ### IF NO BUDGET ID - CREATE ONE WITH THESE PARAMS ### + - `max_budget`: *Optional[float]* = Max budget for org + - `tpm_limit`: *Optional[int]* = Max tpm limit for org + - `rpm_limit`: *Optional[int]* = Max rpm limit for org + - `model_max_budget`: *Optional[dict]* = Max budget for a specific model + - `budget_duration`: *Optional[str]* = Frequency of reseting org budget + + Case 1: Create new org **without** a budget_id + + ```bash + curl --location 'http://0.0.0.0:4000/organization/new' \ + + --header 'Authorization: Bearer sk-1234' \ + + --header 'Content-Type: application/json' \ + + --data '{ + "organization_alias": "my-secret-org", + "models": ["model1", "model2"], + "max_budget": 100 + }' + + + ``` + + Case 2: Create new org **with** a budget_id + + ```bash + curl --location 'http://0.0.0.0:4000/organization/new' \ + + --header 'Authorization: Bearer sk-1234' \ + + --header 'Content-Type: application/json' \ + + --data '{ + "organization_alias": "my-secret-org", + "models": ["model1", "model2"], + "budget_id": "428eeaa8-f3ac-4e85-a8fb-7dc8d7aa8689" + }' + ``` + """ + from litellm.proxy.proxy_server import litellm_proxy_admin_name, prisma_client + + if prisma_client is None: + raise HTTPException(status_code=500, detail={"error": "No db connected"}) + + if ( + user_api_key_dict.user_role is None + or user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN + ): + raise HTTPException( + status_code=401, + detail={ + "error": f"Only admins can create orgs. Your role is = {user_api_key_dict.user_role}" + }, + ) + + if data.budget_id is None: + """ + Every organization needs a budget attached. + + If none provided, create one based on provided values + """ + budget_params = LiteLLM_BudgetTable.model_fields.keys() + + # Only include Budget Params when creating an entry in litellm_budgettable + _json_data = data.json(exclude_none=True) + _budget_data = {k: v for k, v in _json_data.items() if k in budget_params} + budget_row = LiteLLM_BudgetTable(**_budget_data) + + new_budget = prisma_client.jsonify_object(budget_row.json(exclude_none=True)) + + _budget = await prisma_client.db.litellm_budgettable.create( + data={ + **new_budget, # type: ignore + "created_by": user_api_key_dict.user_id or litellm_proxy_admin_name, + "updated_by": user_api_key_dict.user_id or litellm_proxy_admin_name, + } + ) # type: ignore + + data.budget_id = _budget.budget_id + + """ + Ensure only models that user has access to, are given to org + """ + if len(user_api_key_dict.models) == 0: # user has access to all models + pass + else: + if len(data.models) == 0: + raise HTTPException( + status_code=400, + detail={ + "error": "User not allowed to give access to all models. Select models you want org to have access to." + }, + ) + for m in data.models: + if m not in user_api_key_dict.models: + raise HTTPException( + status_code=400, + detail={ + "error": f"User not allowed to give access to model={m}. Models you have access to = {user_api_key_dict.models}" + }, + ) + organization_row = LiteLLM_OrganizationTable( + **data.json(exclude_none=True), + created_by=user_api_key_dict.user_id or litellm_proxy_admin_name, + updated_by=user_api_key_dict.user_id or litellm_proxy_admin_name, + ) + new_organization_row = prisma_client.jsonify_object( + organization_row.json(exclude_none=True) + ) + response = await prisma_client.db.litellm_organizationtable.create( + data={ + **new_organization_row, # type: ignore + } + ) + + return response + + +@router.post( + "/organization/update", + tags=["organization management"], + dependencies=[Depends(user_api_key_auth)], +) +async def update_organization(): + """[TODO] Not Implemented yet. Let us know if you need this - https://github.com/BerriAI/litellm/issues""" + pass + + +@router.post( + "/organization/delete", + tags=["organization management"], + dependencies=[Depends(user_api_key_auth)], +) +async def delete_organization(): + """[TODO] Not Implemented yet. Let us know if you need this - https://github.com/BerriAI/litellm/issues""" + pass + + +@router.post( + "/organization/info", + tags=["organization management"], + dependencies=[Depends(user_api_key_auth)], +) +async def info_organization(data: OrganizationRequest): + """ + Get the org specific information + """ + from litellm.proxy.proxy_server import prisma_client + + if prisma_client is None: + raise HTTPException(status_code=500, detail={"error": "No db connected"}) + + if len(data.organizations) == 0: + raise HTTPException( + status_code=400, + detail={ + "error": f"Specify list of organization id's to query. Passed in={data.organizations}" + }, + ) + response = await prisma_client.db.litellm_organizationtable.find_many( + where={"organization_id": {"in": data.organizations}}, + include={"litellm_budget_table": True}, + ) + + return response + + +@router.post( + "/organization/member_add", + tags=["organization management"], + dependencies=[Depends(user_api_key_auth)], + response_model=OrganizationAddMemberResponse, +) +@management_endpoint_wrapper +async def organization_member_add( + data: OrganizationMemberAddRequest, + http_request: Request, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +) -> OrganizationAddMemberResponse: + """ + [BETA] + + Add new members (either via user_email or user_id) to an organization + + If user doesn't exist, new user row will also be added to User Table + + Only proxy_admin or org_admin of organization, allowed to access this endpoint. + + # Parameters: + + - organization_id: str (required) + - member: Union[List[Member], Member] (required) + - role: Literal[LitellmUserRoles] (required) + - user_id: Optional[str] + - user_email: Optional[str] + + Note: Either user_id or user_email must be provided for each member. + + Example: + ``` + curl -X POST 'http://0.0.0.0:4000/organization/member_add' \ + -H 'Authorization: Bearer sk-1234' \ + -H 'Content-Type: application/json' \ + -d '{ + "organization_id": "45e3e396-ee08-4a61-a88e-16b3ce7e0849", + "member": { + "role": "internal_user", + "user_id": "krrish247652@berri.ai" + }, + "max_budget_in_organization": 100.0 + }' + ``` + + The following is executed in this function: + + 1. Check if organization exists + 2. Creates a new Internal User if the user_id or user_email is not found in LiteLLM_UserTable + 3. Add Internal User to the `LiteLLM_OrganizationMembership` table + """ + try: + from litellm.proxy.proxy_server import ( + litellm_proxy_admin_name, + prisma_client, + proxy_logging_obj, + user_api_key_cache, + ) + + if prisma_client is None: + raise HTTPException(status_code=500, detail={"error": "No db connected"}) + + # Check if organization exists + existing_organization_row = ( + await prisma_client.db.litellm_organizationtable.find_unique( + where={"organization_id": data.organization_id} + ) + ) + if existing_organization_row is None: + raise HTTPException( + status_code=404, + detail={ + "error": f"Organization not found for organization_id={getattr(data, 'organization_id', None)}" + }, + ) + + members: List[Member] + if isinstance(data.member, List): + members = data.member + else: + members = [data.member] + + updated_users: List[LiteLLM_UserTable] = [] + updated_organization_memberships: List[LiteLLM_OrganizationMembershipTable] = [] + + for member in members: + updated_user, updated_organization_membership = ( + await add_member_to_organization( + member=member, + organization_id=data.organization_id, + prisma_client=prisma_client, + ) + ) + + updated_users.append(updated_user) + updated_organization_memberships.append(updated_organization_membership) + + return OrganizationAddMemberResponse( + organization_id=data.organization_id, + updated_users=updated_users, + updated_organization_memberships=updated_organization_memberships, + ) + except Exception as e: + if isinstance(e, HTTPException): + raise ProxyException( + message=getattr(e, "detail", f"Authentication Error({str(e)})"), + type=ProxyErrorTypes.auth_error, + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", status.HTTP_500_INTERNAL_SERVER_ERROR), + ) + elif isinstance(e, ProxyException): + raise e + raise ProxyException( + message="Authentication Error, " + str(e), + type=ProxyErrorTypes.auth_error, + param=getattr(e, "param", "None"), + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + + +async def add_member_to_organization( + member: Member, + organization_id: str, + prisma_client: PrismaClient, +) -> Tuple[LiteLLM_UserTable, LiteLLM_OrganizationMembershipTable]: + """ + Add a member to an organization + + - Checks if member.user_id or member.user_email is in LiteLLM_UserTable + - If not found, create a new user in LiteLLM_UserTable + - Add user to organization in LiteLLM_OrganizationMembership + """ + + try: + user_object: Optional[LiteLLM_UserTable] = None + existing_user_id_row = None + existing_user_email_row = None + ## Check if user exists in LiteLLM_UserTable - user exists - either the user_id or user_email is in LiteLLM_UserTable + if member.user_id is not None: + existing_user_id_row = await prisma_client.db.litellm_usertable.find_unique( + where={"user_id": member.user_id} + ) + + if member.user_email is not None: + existing_user_email_row = ( + await prisma_client.db.litellm_usertable.find_unique( + where={"user_email": member.user_email} + ) + ) + + ## If user does not exist, create a new user + if existing_user_id_row is None and existing_user_email_row is None: + # Create a new user - since user does not exist + user_id: str = member.user_id or str(uuid.uuid4()) + new_user_defaults = get_new_internal_user_defaults( + user_id=user_id, + user_email=member.user_email, + ) + + _returned_user = await prisma_client.insert_data(data=new_user_defaults, table_name="user") # type: ignore + if _returned_user is not None: + user_object = LiteLLM_UserTable(**_returned_user.model_dump()) + elif existing_user_email_row is not None and len(existing_user_email_row) > 1: + raise HTTPException( + status_code=400, + detail={ + "error": "Multiple users with this email found in db. Please use 'user_id' instead." + }, + ) + elif existing_user_email_row is not None: + user_object = LiteLLM_UserTable(**existing_user_email_row.model_dump()) + elif existing_user_id_row is not None: + user_object = LiteLLM_UserTable(**existing_user_id_row.model_dump()) + else: + raise HTTPException( + status_code=404, + detail={ + "error": f"User not found for user_id={member.user_id} and user_email={member.user_email}" + }, + ) + + if user_object is None: + raise ValueError( + f"User does not exist in LiteLLM_UserTable. user_id={member.user_id} and user_email={member.user_email}" + ) + + # Add user to organization + _organization_membership = ( + await prisma_client.db.litellm_organizationmembership.create( + data={ + "organization_id": organization_id, + "user_id": user_object.user_id, + "user_role": member.role, + } + ) + ) + organization_membership = LiteLLM_OrganizationMembershipTable( + **_organization_membership.model_dump() + ) + return user_object, organization_membership + + except Exception as e: + raise ValueError(f"Error adding member to organization: {e}") diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 7d71de385..1383e6794 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -185,6 +185,9 @@ from litellm.proxy.management_endpoints.key_management_endpoints import ( from litellm.proxy.management_endpoints.key_management_endpoints import ( router as key_management_router, ) +from litellm.proxy.management_endpoints.organization_endpoints import ( + router as organization_router, +) from litellm.proxy.management_endpoints.team_callback_endpoints import ( router as team_callback_router, ) @@ -6313,200 +6316,6 @@ async def create_audit_log_for_update(request_data: LiteLLM_AuditLogs): return -#### ORGANIZATION MANAGEMENT #### - - -@router.post( - "/organization/new", - tags=["organization management"], - dependencies=[Depends(user_api_key_auth)], - response_model=NewOrganizationResponse, -) -async def new_organization( - data: NewOrganizationRequest, - user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), -): - """ - Allow orgs to own teams - - Set org level budgets + model access. - - Only admins can create orgs. - - # Parameters - - - `organization_alias`: *str* = The name of the organization. - - `models`: *List* = The models the organization has access to. - - `budget_id`: *Optional[str]* = The id for a budget (tpm/rpm/max budget) for the organization. - ### IF NO BUDGET ID - CREATE ONE WITH THESE PARAMS ### - - `max_budget`: *Optional[float]* = Max budget for org - - `tpm_limit`: *Optional[int]* = Max tpm limit for org - - `rpm_limit`: *Optional[int]* = Max rpm limit for org - - `model_max_budget`: *Optional[dict]* = Max budget for a specific model - - `budget_duration`: *Optional[str]* = Frequency of reseting org budget - - Case 1: Create new org **without** a budget_id - - ```bash - curl --location 'http://0.0.0.0:4000/organization/new' \ - - --header 'Authorization: Bearer sk-1234' \ - - --header 'Content-Type: application/json' \ - - --data '{ - "organization_alias": "my-secret-org", - "models": ["model1", "model2"], - "max_budget": 100 - }' - - - ``` - - Case 2: Create new org **with** a budget_id - - ```bash - curl --location 'http://0.0.0.0:4000/organization/new' \ - - --header 'Authorization: Bearer sk-1234' \ - - --header 'Content-Type: application/json' \ - - --data '{ - "organization_alias": "my-secret-org", - "models": ["model1", "model2"], - "budget_id": "428eeaa8-f3ac-4e85-a8fb-7dc8d7aa8689" - }' - ``` - """ - global prisma_client - - if prisma_client is None: - raise HTTPException(status_code=500, detail={"error": "No db connected"}) - - if ( - user_api_key_dict.user_role is None - or user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN - ): - raise HTTPException( - status_code=401, - detail={ - "error": f"Only admins can create orgs. Your role is = {user_api_key_dict.user_role}" - }, - ) - - if data.budget_id is None: - """ - Every organization needs a budget attached. - - If none provided, create one based on provided values - """ - budget_params = LiteLLM_BudgetTable.model_fields.keys() - - # Only include Budget Params when creating an entry in litellm_budgettable - _json_data = data.json(exclude_none=True) - _budget_data = {k: v for k, v in _json_data.items() if k in budget_params} - budget_row = LiteLLM_BudgetTable(**_budget_data) - - new_budget = prisma_client.jsonify_object(budget_row.json(exclude_none=True)) - - _budget = await prisma_client.db.litellm_budgettable.create( - data={ - **new_budget, # type: ignore - "created_by": user_api_key_dict.user_id or litellm_proxy_admin_name, - "updated_by": user_api_key_dict.user_id or litellm_proxy_admin_name, - } - ) # type: ignore - - data.budget_id = _budget.budget_id - - """ - Ensure only models that user has access to, are given to org - """ - if len(user_api_key_dict.models) == 0: # user has access to all models - pass - else: - if len(data.models) == 0: - raise HTTPException( - status_code=400, - detail={ - "error": "User not allowed to give access to all models. Select models you want org to have access to." - }, - ) - for m in data.models: - if m not in user_api_key_dict.models: - raise HTTPException( - status_code=400, - detail={ - "error": f"User not allowed to give access to model={m}. Models you have access to = {user_api_key_dict.models}" - }, - ) - organization_row = LiteLLM_OrganizationTable( - **data.json(exclude_none=True), - created_by=user_api_key_dict.user_id or litellm_proxy_admin_name, - updated_by=user_api_key_dict.user_id or litellm_proxy_admin_name, - ) - new_organization_row = prisma_client.jsonify_object( - organization_row.json(exclude_none=True) - ) - response = await prisma_client.db.litellm_organizationtable.create( - data={ - **new_organization_row, # type: ignore - } - ) - - return response - - -@router.post( - "/organization/update", - tags=["organization management"], - dependencies=[Depends(user_api_key_auth)], -) -async def update_organization(): - """[TODO] Not Implemented yet. Let us know if you need this - https://github.com/BerriAI/litellm/issues""" - pass - - -@router.post( - "/organization/delete", - tags=["organization management"], - dependencies=[Depends(user_api_key_auth)], -) -async def delete_organization(): - """[TODO] Not Implemented yet. Let us know if you need this - https://github.com/BerriAI/litellm/issues""" - pass - - -@router.post( - "/organization/info", - tags=["organization management"], - dependencies=[Depends(user_api_key_auth)], -) -async def info_organization(data: OrganizationRequest): - """ - Get the org specific information - """ - global prisma_client - - if prisma_client is None: - raise HTTPException(status_code=500, detail={"error": "No db connected"}) - - if len(data.organizations) == 0: - raise HTTPException( - status_code=400, - detail={ - "error": f"Specify list of organization id's to query. Passed in={data.organizations}" - }, - ) - response = await prisma_client.db.litellm_organizationtable.find_many( - where={"organization_id": {"in": data.organizations}}, - include={"litellm_budget_table": True}, - ) - - return response - - #### BUDGET TABLE MANAGEMENT #### @@ -8181,6 +7990,14 @@ async def login(request: Request): # check if we can find the `username` in the db. on the ui, users can enter username=their email _user_row = None + user_role: Optional[ + Literal[ + LitellmUserRoles.PROXY_ADMIN, + LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY, + LitellmUserRoles.INTERNAL_USER, + LitellmUserRoles.INTERNAL_USER_VIEW_ONLY, + ] + ] = None if prisma_client is not None: _user_row = await prisma_client.db.litellm_usertable.find_first( where={"user_email": {"equals": username}} @@ -9654,6 +9471,7 @@ app.include_router(key_management_router) app.include_router(internal_user_router) app.include_router(team_router) app.include_router(ui_sso_router) +app.include_router(organization_router) app.include_router(spend_management_router) app.include_router(caching_router) app.include_router(analytics_router) diff --git a/litellm/proxy/schema.prisma b/litellm/proxy/schema.prisma index ff2fc68c4..e5a9a0ab8 100644 --- a/litellm/proxy/schema.prisma +++ b/litellm/proxy/schema.prisma @@ -26,6 +26,7 @@ model LiteLLM_BudgetTable { keys LiteLLM_VerificationToken[] // multiple keys can have the same budget end_users LiteLLM_EndUserTable[] // multiple end-users can have the same budget team_membership LiteLLM_TeamMembership[] // budgets of Users within a Team + organization_membership LiteLLM_OrganizationMembership[] // budgets of Users within a Organization } // Models on proxy @@ -118,7 +119,10 @@ model LiteLLM_UserTable { allowed_cache_controls String[] @default([]) model_spend Json @default("{}") model_max_budget Json @default("{}") - litellm_organization_table LiteLLM_OrganizationTable? @relation(fields: [organization_id], references: [organization_id]) + + // relations + litellm_organization_table LiteLLM_OrganizationTable? @relation(fields: [organization_id], references: [organization_id]) + organization_memberships LiteLLM_OrganizationMembership[] invitations_created LiteLLM_InvitationLink[] @relation("CreatedBy") invitations_updated LiteLLM_InvitationLink[] @relation("UpdatedBy") invitations_user LiteLLM_InvitationLink[] @relation("UserId") @@ -232,6 +236,24 @@ model LiteLLM_TeamMembership { @@id([user_id, team_id]) } +model LiteLLM_OrganizationMembership { + // Use this table to track Internal User and Organization membership. Helps tracking a users role within an Organization + user_id String? + organization_id String? + user_role String? + spend Float? @default(0.0) + budget_id String? + created_at DateTime? @default(now()) @map("created_at") + updated_at DateTime? @default(now()) @updatedAt @map("updated_at") + + // relations + user LiteLLM_UserTable @relation(fields: [user_id], references: [user_id]) + litellm_budget_table LiteLLM_BudgetTable? @relation(fields: [budget_id], references: [budget_id]) + + @@id([user_id, organization_id]) + @@unique([user_id, organization_id]) +} + model LiteLLM_InvitationLink { // use this table to track invite links sent by admin for people to join the proxy id String @id @default(uuid()) diff --git a/schema.prisma b/schema.prisma index ff2fc68c4..e5a9a0ab8 100644 --- a/schema.prisma +++ b/schema.prisma @@ -26,6 +26,7 @@ model LiteLLM_BudgetTable { keys LiteLLM_VerificationToken[] // multiple keys can have the same budget end_users LiteLLM_EndUserTable[] // multiple end-users can have the same budget team_membership LiteLLM_TeamMembership[] // budgets of Users within a Team + organization_membership LiteLLM_OrganizationMembership[] // budgets of Users within a Organization } // Models on proxy @@ -118,7 +119,10 @@ model LiteLLM_UserTable { allowed_cache_controls String[] @default([]) model_spend Json @default("{}") model_max_budget Json @default("{}") - litellm_organization_table LiteLLM_OrganizationTable? @relation(fields: [organization_id], references: [organization_id]) + + // relations + litellm_organization_table LiteLLM_OrganizationTable? @relation(fields: [organization_id], references: [organization_id]) + organization_memberships LiteLLM_OrganizationMembership[] invitations_created LiteLLM_InvitationLink[] @relation("CreatedBy") invitations_updated LiteLLM_InvitationLink[] @relation("UpdatedBy") invitations_user LiteLLM_InvitationLink[] @relation("UserId") @@ -232,6 +236,24 @@ model LiteLLM_TeamMembership { @@id([user_id, team_id]) } +model LiteLLM_OrganizationMembership { + // Use this table to track Internal User and Organization membership. Helps tracking a users role within an Organization + user_id String? + organization_id String? + user_role String? + spend Float? @default(0.0) + budget_id String? + created_at DateTime? @default(now()) @map("created_at") + updated_at DateTime? @default(now()) @updatedAt @map("updated_at") + + // relations + user LiteLLM_UserTable @relation(fields: [user_id], references: [user_id]) + litellm_budget_table LiteLLM_BudgetTable? @relation(fields: [budget_id], references: [budget_id]) + + @@id([user_id, organization_id]) + @@unique([user_id, organization_id]) +} + model LiteLLM_InvitationLink { // use this table to track invite links sent by admin for people to join the proxy id String @id @default(uuid()) diff --git a/tests/proxy_admin_ui_tests/test_role_based_access.py b/tests/proxy_admin_ui_tests/test_role_based_access.py new file mode 100644 index 000000000..d851ca568 --- /dev/null +++ b/tests/proxy_admin_ui_tests/test_role_based_access.py @@ -0,0 +1,439 @@ +""" +RBAC tests +""" + +import os +import sys +import traceback +import uuid +from datetime import datetime + +from dotenv import load_dotenv +from fastapi import Request +from fastapi.routing import APIRoute + +load_dotenv() +import io +import os +import time + +# this file is to test litellm/proxy + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +import asyncio +import logging + +import pytest + +import litellm +from litellm._logging import verbose_proxy_logger +from litellm.proxy.auth.auth_checks import get_user_object +from litellm.proxy.management_endpoints.key_management_endpoints import ( + delete_key_fn, + generate_key_fn, + generate_key_helper_fn, + info_key_fn, + regenerate_key_fn, + update_key_fn, +) +from litellm.proxy.management_endpoints.internal_user_endpoints import new_user +from litellm.proxy.management_endpoints.organization_endpoints import ( + new_organization, + organization_member_add, +) + +from litellm.proxy.management_endpoints.team_endpoints import ( + new_team, + team_info, + update_team, +) +from litellm.proxy.proxy_server import ( + LitellmUserRoles, + audio_transcriptions, + chat_completion, + completion, + embeddings, + image_generation, + model_list, + moderations, + new_end_user, + user_api_key_auth, +) +from litellm.proxy.spend_tracking.spend_management_endpoints import ( + global_spend, + global_spend_logs, + global_spend_models, + global_spend_keys, + spend_key_fn, + spend_user_fn, + view_spend_logs, +) +from starlette.datastructures import URL + +from litellm.proxy.utils import PrismaClient, ProxyLogging, hash_token, update_spend + +verbose_proxy_logger.setLevel(level=logging.DEBUG) + +from starlette.datastructures import URL + +from litellm.caching import DualCache +from litellm.proxy._types import * + +proxy_logging_obj = ProxyLogging(user_api_key_cache=DualCache()) + + +@pytest.fixture +def prisma_client(): + from litellm.proxy.proxy_cli import append_query_params + + ### add connection pool + pool timeout args + params = {"connection_limit": 100, "pool_timeout": 60} + database_url = os.getenv("DATABASE_URL") + modified_url = append_query_params(database_url, params) + os.environ["DATABASE_URL"] = modified_url + + # Assuming PrismaClient is a class that needs to be instantiated + prisma_client = PrismaClient( + database_url=os.environ["DATABASE_URL"], proxy_logging_obj=proxy_logging_obj + ) + + # Reset litellm.proxy.proxy_server.prisma_client to None + litellm.proxy.proxy_server.litellm_proxy_budget_name = ( + f"litellm-proxy-budget-{time.time()}" + ) + litellm.proxy.proxy_server.user_custom_key_generate = None + + return prisma_client + + +""" +RBAC Tests + +1. Add a user to an organization + - test 1 - if organization_id does exist expect to create a new user and user, organization relation + +2. org admin creates team in his org → success + +3. org admin adds new internal user to his org → success + +4. org admin creates team and internal user not in his org → fail both +""" + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "user_role", + [ + LitellmUserRoles.ORG_ADMIN, + LitellmUserRoles.INTERNAL_USER, + LitellmUserRoles.INTERNAL_USER_VIEW_ONLY, + ], +) +async def test_create_new_user_in_organization(prisma_client, user_role): + """ + + Add a member to an organization and assert the user object is created with the correct organization memberships / roles + """ + master_key = "sk-1234" + setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) + setattr(litellm.proxy.proxy_server, "master_key", master_key) + + await litellm.proxy.proxy_server.prisma_client.connect() + + created_user_id = f"new-user-{uuid.uuid4()}" + + response = await new_organization( + data=NewOrganizationRequest( + organization_alias=f"new-org-{uuid.uuid4()}", + ), + user_api_key_dict=UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, + ), + ) + + org_id = response.organization_id + + response = await organization_member_add( + data=OrganizationMemberAddRequest( + organization_id=org_id, + member=Member(role=user_role, user_id=created_user_id), + ), + http_request=None, + ) + + print("new user response", response) + + # call get_user_object + + user_object = await get_user_object( + user_id=created_user_id, + prisma_client=prisma_client, + user_api_key_cache=DualCache(), + user_id_upsert=False, + ) + + print("user object", user_object) + + assert user_object.organization_memberships is not None + + _membership = user_object.organization_memberships[0] + + assert _membership.user_id == created_user_id + assert _membership.organization_id == org_id + + if user_role != None: + assert _membership.user_role == user_role + else: + assert _membership.user_role == LitellmUserRoles.INTERNAL_USER_VIEW_ONLY + + +@pytest.mark.asyncio +async def test_org_admin_create_team_permissions(prisma_client): + """ + Create a new org admin + + org admin creates a new team in their org -> success + """ + import json + + master_key = "sk-1234" + setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) + setattr(litellm.proxy.proxy_server, "master_key", master_key) + + await litellm.proxy.proxy_server.prisma_client.connect() + + response = await new_organization( + data=NewOrganizationRequest( + organization_alias=f"new-org-{uuid.uuid4()}", + ), + user_api_key_dict=UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, + ), + ) + + org_id = response.organization_id + created_user_id = f"new-user-{uuid.uuid4()}" + response = await organization_member_add( + data=OrganizationMemberAddRequest( + organization_id=org_id, + member=Member(role=LitellmUserRoles.ORG_ADMIN, user_id=created_user_id), + ), + http_request=None, + ) + + # create key with the response["user_id"] + # proxy admin will generate key for org admin + _new_key = await generate_key_fn( + data=GenerateKeyRequest(user_id=created_user_id), + user_api_key_dict=UserAPIKeyAuth(user_id=created_user_id), + ) + + new_key = _new_key.key + + print("user api key auth response", response) + + # Create /team/new request -> expect auth to pass + request = Request(scope={"type": "http"}) + request._url = URL(url="/team/new") + + async def return_body(): + body = {"organization_id": org_id} + return bytes(json.dumps(body), "utf-8") + + request.body = return_body + response = await user_api_key_auth(request=request, api_key="Bearer " + new_key) + + # after auth - actually create team now + response = await new_team( + data=NewTeamRequest( + organization_id=org_id, + ), + http_request=request, + user_api_key_dict=UserAPIKeyAuth( + user_id=response.user_id, + ), + ) + + print("response from new team") + + +@pytest.mark.asyncio +async def test_org_admin_create_user_permissions(prisma_client): + """ + 1. Create a new org admin + + 2. org admin adds a new member to their org -> success (using using /organization/member_add) + + """ + import json + + master_key = "sk-1234" + setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) + setattr(litellm.proxy.proxy_server, "master_key", master_key) + + await litellm.proxy.proxy_server.prisma_client.connect() + + # create new org + response = await new_organization( + data=NewOrganizationRequest( + organization_alias=f"new-org-{uuid.uuid4()}", + ), + user_api_key_dict=UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, + ), + ) + # Create Org Admin + org_id = response.organization_id + created_user_id = f"new-user-{uuid.uuid4()}" + response = await organization_member_add( + data=OrganizationMemberAddRequest( + organization_id=org_id, + member=Member(role=LitellmUserRoles.ORG_ADMIN, user_id=created_user_id), + ), + http_request=None, + ) + + # create key with for Org Admin + _new_key = await generate_key_fn( + data=GenerateKeyRequest(user_id=created_user_id), + user_api_key_dict=UserAPIKeyAuth(user_id=created_user_id), + ) + + new_key = _new_key.key + + print("user api key auth response", response) + + # Create /organization/member_add request -> expect auth to pass + request = Request(scope={"type": "http"}) + request._url = URL(url="/organization/member_add") + + async def return_body(): + body = {"organization_id": org_id} + return bytes(json.dumps(body), "utf-8") + + request.body = return_body + response = await user_api_key_auth(request=request, api_key="Bearer " + new_key) + + # after auth - actually actually add new user to organization + new_internal_user_for_org = f"new-org-user-{uuid.uuid4()}" + response = await organization_member_add( + data=OrganizationMemberAddRequest( + organization_id=org_id, + member=Member( + role=LitellmUserRoles.INTERNAL_USER, user_id=new_internal_user_for_org + ), + ), + http_request=request, + ) + + print("response from new team") + + +@pytest.mark.asyncio +async def test_org_admin_create_user_team_wrong_org_permissions(prisma_client): + """ + Create a new org admin + + org admin creates a new user and new team in orgs they are not part of -> expect error + """ + import json + + master_key = "sk-1234" + setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) + setattr(litellm.proxy.proxy_server, "master_key", master_key) + + await litellm.proxy.proxy_server.prisma_client.connect() + created_user_id = f"new-user-{uuid.uuid4()}" + response = await new_organization( + data=NewOrganizationRequest( + organization_alias=f"new-org-{uuid.uuid4()}", + ), + user_api_key_dict=UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, + ), + ) + + response2 = await new_organization( + data=NewOrganizationRequest( + organization_alias=f"new-org-{uuid.uuid4()}", + ), + user_api_key_dict=UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, + ), + ) + + org1_id = response.organization_id # has an admin + + org2_id = response2.organization_id # does not have an org admin + + # Create Org Admin for Org1 + created_user_id = f"new-user-{uuid.uuid4()}" + response = await organization_member_add( + data=OrganizationMemberAddRequest( + organization_id=org1_id, + member=Member(role=LitellmUserRoles.ORG_ADMIN, user_id=created_user_id), + ), + http_request=None, + ) + + _new_key = await generate_key_fn( + data=GenerateKeyRequest( + user_id=created_user_id, + ), + user_api_key_dict=UserAPIKeyAuth( + user_role=LitellmUserRoles.ORG_ADMIN, + user_id=created_user_id, + ), + ) + + new_key = _new_key.key + + print("user api key auth response", response) + + # Add a new request in organization=org_without_admins -> expect fail (organization/member_add) + request = Request(scope={"type": "http"}) + request._url = URL(url="/organization/member_add") + + async def return_body(): + body = {"organization_id": org2_id} + return bytes(json.dumps(body), "utf-8") + + request.body = return_body + + try: + response = await user_api_key_auth(request=request, api_key="Bearer " + new_key) + pytest.fail( + f"This should have failed!. creating a user in an org without admins" + ) + except Exception as e: + print("got exception", e) + print("exception.message", e.message) + assert ( + "You do not have a role within the selected organization. Passed organization_id" + in e.message + ) + + # Create /team/new request in organization=org_without_admins -> expect fail + request = Request(scope={"type": "http"}) + request._url = URL(url="/team/new") + + async def return_body(): + body = {"organization_id": org2_id} + return bytes(json.dumps(body), "utf-8") + + request.body = return_body + + try: + response = await user_api_key_auth(request=request, api_key="Bearer " + new_key) + pytest.fail( + f"This should have failed!. Org Admin creating a team in an org where they are not an admin" + ) + except Exception as e: + print("got exception", e) + print("exception.message", e.message) + assert ( + "You do not have the required role to call" in e.message + and org2_id in e.message + ) diff --git a/tests/test_organizations.py b/tests/test_organizations.py index 00e99cb66..5d9eb2e27 100644 --- a/tests/test_organizations.py +++ b/tests/test_organizations.py @@ -33,7 +33,7 @@ async def new_organization(session, i, organization_alias, max_budget=None): @pytest.mark.asyncio async def test_organization_new(): """ - Make 20 parallel calls to /user/new. Assert all worked. + Make 20 parallel calls to /organization/new. Assert all worked. """ organization_alias = f"Organization: {uuid.uuid4()}" async with aiohttp.ClientSession() as session: