diff --git a/docs/my-website/docs/proxy/token_auth.md b/docs/my-website/docs/proxy/token_auth.md index 9df0462281..753e92c169 100644 --- a/docs/my-website/docs/proxy/token_auth.md +++ b/docs/my-website/docs/proxy/token_auth.md @@ -163,10 +163,12 @@ scope: "litellm-proxy-admin ..." ```yaml general_settings: - master_key: sk-1234 + enable_jwt_auth: True litellm_jwtauth: user_id_jwt_field: "sub" team_ids_jwt_field: "groups" + user_id_upsert: true # add user_id to the db if they don't exist + enforce_team_based_model_access: true # don't allow users to access models unless the team has access ``` This is assuming your token looks like this: diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index d7246f45b9..19cc8cdd07 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -36,7 +36,14 @@ model_list: litellm_settings: cache: true - + +general_settings: + enable_jwt_auth: True + litellm_jwtauth: + user_id_jwt_field: "sub" + team_ids_jwt_field: "groups" + user_id_upsert: true # add user_id to the db if they don't exist + enforce_team_based_model_access: true # don't allow users to access models unless the team has access router_settings: redis_host: os.environ/REDIS_HOST diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 65a6a1794d..7b2435e67c 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -260,6 +260,7 @@ class LiteLLMRoutes(enum.Enum): "/key/health", "/team/info", "/team/list", + "/organization/list", "/team/available", "/user/info", "/model/info", @@ -1100,24 +1101,6 @@ class NewOrganizationRequest(LiteLLM_BudgetTable): budget_id: Optional[str] = None -class LiteLLM_OrganizationTable(LiteLLMPydanticObjectBase): - """Represents user-controllable params for a LiteLLM_OrganizationTable record""" - - organization_id: Optional[str] = None - organization_alias: Optional[str] = None - budget_id: str - metadata: Optional[dict] = None - models: List[str] - created_by: str - updated_by: str - - -class NewOrganizationResponse(LiteLLM_OrganizationTable): - organization_id: str # type: ignore - created_at: datetime - updated_at: datetime - - class OrganizationRequest(LiteLLMPydanticObjectBase): organizations: List[str] @@ -1492,6 +1475,28 @@ class LiteLLM_OrganizationMembershipTable(LiteLLMPydanticObjectBase): model_config = ConfigDict(protected_namespaces=()) +class LiteLLM_OrganizationTable(LiteLLMPydanticObjectBase): + """Represents user-controllable params for a LiteLLM_OrganizationTable record""" + + organization_id: Optional[str] = None + organization_alias: Optional[str] = None + budget_id: str + metadata: Optional[dict] = None + models: List[str] + created_by: str + updated_by: str + + +class LiteLLM_OrganizationTableWithMembers(LiteLLM_OrganizationTable): + members: List[LiteLLM_OrganizationMembershipTable] + + +class NewOrganizationResponse(LiteLLM_OrganizationTable): + organization_id: str # type: ignore + created_at: datetime + updated_at: datetime + + class LiteLLM_UserTable(LiteLLMPydanticObjectBase): user_id: str max_budget: Optional[float] @@ -2375,6 +2380,7 @@ class LiteLLM_JWTAuth(LiteLLMPydanticObjectBase): ) scope_mappings: Optional[List[ScopeMapping]] = None enforce_scope_based_access: bool = False + enforce_team_based_model_access: bool = False def __init__(self, **kwargs: Any) -> None: # get the attribute names for this Pydantic model diff --git a/litellm/proxy/auth/handle_jwt.py b/litellm/proxy/auth/handle_jwt.py index c60d41faee..88a8144b55 100644 --- a/litellm/proxy/auth/handle_jwt.py +++ b/litellm/proxy/auth/handle_jwt.py @@ -154,7 +154,10 @@ class JWTHandler: return False def get_team_ids_from_jwt(self, token: dict) -> List[str]: - if self.litellm_jwtauth.team_ids_jwt_field is not None: + if ( + self.litellm_jwtauth.team_ids_jwt_field is not None + and token.get(self.litellm_jwtauth.team_ids_jwt_field) is not None + ): return token[self.litellm_jwtauth.team_ids_jwt_field] return [] @@ -699,6 +702,11 @@ class JWTAuthManager: """Find first team with access to the requested model""" if not team_ids: + if jwt_handler.litellm_jwtauth.enforce_team_based_model_access: + raise HTTPException( + status_code=403, + detail="No teams found in token. `enforce_team_based_model_access` is set to True. Token must belong to a team.", + ) return None, None for team_id in team_ids: @@ -731,7 +739,7 @@ class JWTAuthManager: if requested_model: raise HTTPException( status_code=403, - detail=f"No team has access to the requested model: {requested_model}. Checked teams={team_ids}", + detail=f"No team has access to the requested model: {requested_model}. Checked teams={team_ids}. Check `/models` to see all available models.", ) return None, None diff --git a/litellm/proxy/management_endpoints/organization_endpoints.py b/litellm/proxy/management_endpoints/organization_endpoints.py index b247c1cbe6..34cee0cc7e 100644 --- a/litellm/proxy/management_endpoints/organization_endpoints.py +++ b/litellm/proxy/management_endpoints/organization_endpoints.py @@ -160,6 +160,7 @@ async def new_organization( "error": f"User not allowed to give access to model={m}. Models you have access to = {user_api_key_dict.models}" }, ) + organization_row = LiteLLM_OrganizationTable( **data.json(exclude_none=True), created_by=user_api_key_dict.user_id or litellm_proxy_admin_name, @@ -201,6 +202,7 @@ async def delete_organization(): "/organization/list", tags=["organization management"], dependencies=[Depends(user_api_key_auth)], + response_model=List[LiteLLM_OrganizationTableWithMembers], ) async def list_organization( user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), @@ -216,24 +218,34 @@ async def list_organization( if prisma_client is None: raise HTTPException(status_code=500, detail={"error": "No db connected"}) - if ( - user_api_key_dict.user_role is None - or user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN - ): - raise HTTPException( - status_code=401, - detail={ - "error": f"Only admins can list orgs. Your role is = {user_api_key_dict.user_role}" - }, - ) if prisma_client is None: raise HTTPException( status_code=400, detail={"error": CommonProxyErrors.db_not_connected_error.value}, ) - response = await prisma_client.db.litellm_organizationtable.find_many( - include={"members": True} - ) + + # if proxy admin - get all orgs + if user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN: + response = await prisma_client.db.litellm_organizationtable.find_many( + include={"members": True} + ) + # if internal user - get orgs they are a member of + else: + org_memberships = ( + await prisma_client.db.litellm_organizationmembership.find_many( + where={"user_id": user_api_key_dict.user_id} + ) + ) + org_objects = await prisma_client.db.litellm_organizationtable.find_many( + where={ + "organization_id": { + "in": [membership.organization_id for membership in org_memberships] + } + }, + include={"members": True}, + ) + + response = org_objects return response diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 51f235522d..da98c2540f 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -1415,7 +1415,8 @@ class PrismaClient: if key_val is None: key_val = {"user_id": user_id} response = await self.db.litellm_usertable.find_unique( # type: ignore - where=key_val # type: ignore + where=key_val, # type: ignore + include={"organization_memberships": True}, ) elif query_type == "find_all" and key_val is not None: response = await self.db.litellm_usertable.find_many( diff --git a/tests/local_testing/test_router.py b/tests/local_testing/test_router.py index 1ef7607c26..62d0a5f52e 100644 --- a/tests/local_testing/test_router.py +++ b/tests/local_testing/test_router.py @@ -1446,7 +1446,7 @@ def test_bedrock_on_router(): # test openai-compatible endpoint @pytest.mark.asyncio async def test_mistral_on_router(): - litellm.set_verbose = True + litellm._turn_on_debug() model_list = [ { "model_name": "gpt-3.5-turbo", diff --git a/ui/litellm-dashboard/src/app/page.tsx b/ui/litellm-dashboard/src/app/page.tsx index 727a9e769c..a08ae1c005 100644 --- a/ui/litellm-dashboard/src/app/page.tsx +++ b/ui/litellm-dashboard/src/app/page.tsx @@ -24,6 +24,7 @@ import Sidebar from "@/components/leftnav"; import Usage from "@/components/usage"; import CacheDashboard from "@/components/cache_dashboard"; import { setGlobalLitellmHeaderName } from "@/components/networking"; +import { Organization } from "@/components/networking"; function getCookie(name: string) { const cookieValue = document.cookie @@ -49,6 +50,8 @@ function formatUserRole(userRole: string) { return "Admin"; case "proxy_admin_viewer": return "Admin Viewer"; + case "org_admin": + return "Org Admin"; case "internal_user": return "Internal User"; case "internal_viewer": @@ -75,6 +78,8 @@ export default function CreateKeyPage() { const [userEmail, setUserEmail] = useState(null); const [teams, setTeams] = useState(null); const [keys, setKeys] = useState(null); + const [currentOrg, setCurrentOrg] = useState(null); + const [organizations, setOrganizations] = useState([]); const [proxySettings, setProxySettings] = useState({ PROXY_BASE_URL: "", PROXY_LOGOUT_URL: "", @@ -166,6 +171,20 @@ export default function CreateKeyPage() { } }, [token]); + const handleOrgChange = (org: Organization) => { + setCurrentOrg(org); + console.log(`org: ${JSON.stringify(org)}`) + if (org.members && userRole != "Admin") { // don't change user role if user is admin + for (const member of org.members) { + console.log(`member: ${JSON.stringify(member)}`) + if (member.user_id == userID) { + console.log(`member.user_role: ${member.user_role}`) + setUserRole(formatUserRole(member.user_role)); + } + } + } + } + return ( Loading...}> @@ -181,6 +200,7 @@ export default function CreateKeyPage() { setUserEmail={setUserEmail} setTeams={setTeams} setKeys={setKeys} + setOrganizations={setOrganizations} /> ) : (
@@ -191,6 +211,9 @@ export default function CreateKeyPage() { premiumUser={premiumUser} setProxySettings={setProxySettings} proxySettings={proxySettings} + currentOrg={currentOrg} + organizations={organizations} + onOrgChange={handleOrgChange} />
@@ -213,6 +236,7 @@ export default function CreateKeyPage() { setUserEmail={setUserEmail} setTeams={setTeams} setKeys={setKeys} + setOrganizations={setOrganizations} /> ) : page == "models" ? ( ) : page == "organizations" ? ( >; proxySettings: any; + currentOrg: Organization | null; + onOrgChange?: (org: Organization) => void; + onNewOrg?: () => void; + organizations?: Organization[]; } + const Navbar: React.FC = ({ userID, userRole, @@ -34,101 +25,163 @@ const Navbar: React.FC = ({ premiumUser, setProxySettings, proxySettings, + currentOrg = null, + onOrgChange = () => {}, + onNewOrg = () => {}, + organizations = [] }) => { - console.log("User ID:", userID); - console.log("userEmail:", userEmail); - console.log("premiumUser:", premiumUser); - - // const userColors = require('./ui_colors.json') || {}; + console.log(`currentOrg: ${JSON.stringify(currentOrg)}`) + console.log(`organizations: ${JSON.stringify(organizations)}`) const isLocal = process.env.NODE_ENV === "development"; - if (isLocal != true) { - console.log = function() {}; - } - const proxyBaseUrl = isLocal ? "http://localhost:4000" : null; const imageUrl = isLocal ? "http://localhost:4000/get_image" : "/get_image"; - let logoutUrl = ""; - - console.log("PROXY_settings=", proxySettings); - - if (proxySettings) { - if (proxySettings.PROXY_LOGOUT_URL && proxySettings.PROXY_LOGOUT_URL !== undefined) { - logoutUrl = proxySettings.PROXY_LOGOUT_URL; - } - } - - console.log("logoutUrl=", logoutUrl); + let logoutUrl = proxySettings?.PROXY_LOGOUT_URL || ""; const handleLogout = () => { - // Clear cookies document.cookie = "token=; expires=Thu, 01 Jan 1970 00:00:00 UTC; path=/;"; window.location.href = logoutUrl; - } - + }; - const items: MenuProps["items"] = [ + const userItems: MenuProps["items"] = [ { key: "1", label: ( - <> -

Role: {userRole}

-

ID: {userID}

-

Premium User: {String(premiumUser)}

- +
+

Role: {userRole}

+

ID: {userID}

+

Premium User: {String(premiumUser)}

+
), }, { key: "2", - label:

Logout

, + label:

Logout

, } ]; - return ( - <> - ); }; -export default Navbar; +export default Navbar; \ No newline at end of file diff --git a/ui/litellm-dashboard/src/components/networking.tsx b/ui/litellm-dashboard/src/components/networking.tsx index e9f543a87b..58bbd9885c 100644 --- a/ui/litellm-dashboard/src/components/networking.tsx +++ b/ui/litellm-dashboard/src/components/networking.tsx @@ -15,6 +15,24 @@ export interface Model { model_info: Object | null; } +export interface Organization { + organization_id: string; + organization_alias: string; + budget_id: string; + metadata: Record; + models: string[]; + spend: number; + model_spend: Record; + created_at: string; + created_by: string; + updated_at: string; + updated_by: string; + litellm_budget_table: any; // Simplified to any since we don't need the detailed structure + teams: any[] | null; + users: any[] | null; + members: any[] | null; +} + const baseUrl = "/"; // Assuming the base URL is the root diff --git a/ui/litellm-dashboard/src/components/teams.tsx b/ui/litellm-dashboard/src/components/teams.tsx index 6b0583bef8..433987eb1a 100644 --- a/ui/litellm-dashboard/src/components/teams.tsx +++ b/ui/litellm-dashboard/src/components/teams.tsx @@ -1,7 +1,7 @@ import React, { useState, useEffect } from "react"; import Link from "next/link"; import { Typography } from "antd"; -import { teamDeleteCall, teamUpdateCall, teamInfoCall } from "./networking"; +import { teamDeleteCall, teamUpdateCall, teamInfoCall, Organization } from "./networking"; import TeamMemberModal from "@/components/team/edit_membership"; import { InformationCircleIcon, @@ -64,6 +64,7 @@ interface TeamProps { setTeams: React.Dispatch>; userID: string | null; userRole: string | null; + currentOrg: Organization | null; } interface EditTeamModalProps { @@ -90,6 +91,7 @@ const Team: React.FC = ({ setTeams, userID, userRole, + currentOrg }) => { const [lastRefreshed, setLastRefreshed] = useState(""); @@ -285,7 +287,7 @@ const Team: React.FC = ({ if (accessToken != null) { const newTeamAlias = formValues?.team_alias; const existingTeamAliases = teams?.map((t) => t.team_alias) ?? []; - let organizationId = formValues?.organization_id; + let organizationId = formValues?.organization_id || currentOrg?.organization_id; if (organizationId === "" || typeof organizationId !== 'string') { formValues.organization_id = null; } else { @@ -618,7 +620,7 @@ const Team: React.FC = ({ )} - {userRole == "Admin"? ( + {userRole == "Admin" || userRole == "Org Admin"? (