mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 11:43:54 +00:00
Litellm dev 01 31 2025 p2 (#8164)
* docs(token_auth.md): clarify title * refactor(handle_jwt.py): add jwt auth manager + refactor to handle groups allows user to call model if user belongs to group with model access * refactor(handle_jwt.py): refactor to first check if service call then check user call * feat(handle_jwt.py): new `enforce_team_access` param only allows user to call model if a team they belong to has model access allows controlling user model access by team * fix(handle_jwt.py): fix error string, remove unecessary param * docs(token_auth.md): add controlling model access for jwt tokens via teams to docs * test: fix tests post refactor * fix: fix linting errors * fix: fix linting error * test: fix import error
This commit is contained in:
parent
2674fa4dc3
commit
a008a2d4f4
6 changed files with 447 additions and 242 deletions
|
@ -8,11 +8,12 @@ JWT token must have 'litellm_proxy_admin' in scope.
|
|||
|
||||
import json
|
||||
import os
|
||||
from typing import List, Optional, cast
|
||||
from typing import Any, List, Literal, Optional, Set, Tuple, cast
|
||||
|
||||
from cryptography import x509
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from fastapi import HTTPException
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.caching.caching import DualCache
|
||||
|
@ -21,11 +22,27 @@ from litellm.llms.custom_httpx.httpx_handler import HTTPHandler
|
|||
from litellm.proxy._types import (
|
||||
RBAC_ROLES,
|
||||
JWKKeyValue,
|
||||
JWTAuthBuilderResult,
|
||||
JWTKeyItem,
|
||||
LiteLLM_EndUserTable,
|
||||
LiteLLM_JWTAuth,
|
||||
LiteLLM_OrganizationTable,
|
||||
LiteLLM_TeamTable,
|
||||
LiteLLM_UserTable,
|
||||
LitellmUserRoles,
|
||||
Span,
|
||||
)
|
||||
from litellm.proxy.utils import PrismaClient, ProxyLogging
|
||||
|
||||
from .auth_checks import (
|
||||
allowed_routes_check,
|
||||
get_actual_routes,
|
||||
get_end_user_object,
|
||||
get_org_object,
|
||||
get_role_based_models,
|
||||
get_team_object,
|
||||
get_user_object,
|
||||
)
|
||||
from litellm.proxy.utils import PrismaClient
|
||||
|
||||
|
||||
class JWTHandler:
|
||||
|
@ -401,3 +418,355 @@ class JWTHandler:
|
|||
|
||||
async def close(self):
|
||||
await self.http_handler.close()
|
||||
|
||||
|
||||
class JWTAuthManager:
|
||||
"""Manages JWT authentication and authorization operations"""
|
||||
|
||||
@staticmethod
|
||||
def can_rbac_role_call_model(
|
||||
rbac_role: RBAC_ROLES,
|
||||
general_settings: dict,
|
||||
model: Optional[str],
|
||||
) -> Literal[True]:
|
||||
"""
|
||||
Checks if user is allowed to access the model, based on their role.
|
||||
"""
|
||||
role_based_models = get_role_based_models(
|
||||
rbac_role=rbac_role, general_settings=general_settings
|
||||
)
|
||||
if role_based_models is None or model is None:
|
||||
return True
|
||||
|
||||
if model not in role_based_models:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail=f"User role={rbac_role} not allowed to call model={model}. Allowed models={role_based_models}",
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
async def check_rbac_role(
|
||||
jwt_handler: JWTHandler,
|
||||
jwt_valid_token: dict,
|
||||
general_settings: dict,
|
||||
request_data: dict,
|
||||
) -> None:
|
||||
"""Validate RBAC role and model access permissions"""
|
||||
if jwt_handler.litellm_jwtauth.enforce_rbac is True:
|
||||
rbac_role = jwt_handler.get_rbac_role(token=jwt_valid_token)
|
||||
if rbac_role is None:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Unmatched token passed in. enforce_rbac is set to True. Token must belong to a proxy admin, team, or user.",
|
||||
)
|
||||
JWTAuthManager.can_rbac_role_call_model(
|
||||
rbac_role=rbac_role,
|
||||
general_settings=general_settings,
|
||||
model=request_data.get("model"),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def check_admin_access(
|
||||
jwt_handler: JWTHandler,
|
||||
scopes: list,
|
||||
route: str,
|
||||
user_id: Optional[str],
|
||||
org_id: Optional[str],
|
||||
api_key: str,
|
||||
) -> Optional[JWTAuthBuilderResult]:
|
||||
"""Check admin status and route access permissions"""
|
||||
if not jwt_handler.is_admin(scopes=scopes):
|
||||
return None
|
||||
|
||||
is_allowed = allowed_routes_check(
|
||||
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||
user_route=route,
|
||||
litellm_proxy_roles=jwt_handler.litellm_jwtauth,
|
||||
)
|
||||
if not is_allowed:
|
||||
allowed_routes: List[Any] = jwt_handler.litellm_jwtauth.admin_allowed_routes
|
||||
actual_routes = get_actual_routes(allowed_routes=allowed_routes)
|
||||
raise Exception(
|
||||
f"Admin not allowed to access this route. Route={route}, Allowed Routes={actual_routes}"
|
||||
)
|
||||
|
||||
return JWTAuthBuilderResult(
|
||||
is_proxy_admin=True,
|
||||
team_object=None,
|
||||
user_object=None,
|
||||
end_user_object=None,
|
||||
org_object=None,
|
||||
token=api_key,
|
||||
team_id=None,
|
||||
user_id=user_id,
|
||||
end_user_id=None,
|
||||
org_id=org_id,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def find_and_validate_specific_team_id(
|
||||
jwt_handler: JWTHandler,
|
||||
jwt_valid_token: dict,
|
||||
prisma_client: Optional[PrismaClient],
|
||||
user_api_key_cache: DualCache,
|
||||
parent_otel_span: Optional[Span],
|
||||
proxy_logging_obj: ProxyLogging,
|
||||
) -> Tuple[Optional[str], Optional[LiteLLM_TeamTable]]:
|
||||
"""Find and validate specific team ID"""
|
||||
individual_team_id = jwt_handler.get_team_id(
|
||||
token=jwt_valid_token, default_value=None
|
||||
)
|
||||
|
||||
if not individual_team_id and jwt_handler.is_required_team_id() is True:
|
||||
raise Exception(
|
||||
f"No team id found in token. Checked team_id field '{jwt_handler.litellm_jwtauth.team_id_jwt_field}'"
|
||||
)
|
||||
|
||||
## VALIDATE TEAM OBJECT ###
|
||||
team_object: Optional[LiteLLM_TeamTable] = None
|
||||
if individual_team_id:
|
||||
team_object = await get_team_object(
|
||||
team_id=individual_team_id,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
parent_otel_span=parent_otel_span,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
|
||||
return individual_team_id, team_object
|
||||
|
||||
@staticmethod
|
||||
def get_all_team_ids(jwt_handler: JWTHandler, jwt_valid_token: dict) -> Set[str]:
|
||||
"""Get combined team IDs from groups and individual team_id"""
|
||||
team_ids_from_groups = jwt_handler.get_team_ids_from_jwt(token=jwt_valid_token)
|
||||
|
||||
all_team_ids = set(team_ids_from_groups)
|
||||
|
||||
return all_team_ids
|
||||
|
||||
@staticmethod
|
||||
async def find_team_with_model_access(
|
||||
team_ids: Set[str],
|
||||
requested_model: Optional[str],
|
||||
route: str,
|
||||
jwt_handler: JWTHandler,
|
||||
prisma_client: Optional[PrismaClient],
|
||||
user_api_key_cache: DualCache,
|
||||
parent_otel_span: Optional[Span],
|
||||
proxy_logging_obj: ProxyLogging,
|
||||
) -> Tuple[Optional[str], Optional[LiteLLM_TeamTable]]:
|
||||
"""Find first team with access to the requested model"""
|
||||
|
||||
if not team_ids:
|
||||
return None, None
|
||||
|
||||
for team_id in team_ids:
|
||||
try:
|
||||
team_object = await get_team_object(
|
||||
team_id=team_id,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
parent_otel_span=parent_otel_span,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
|
||||
if team_object and team_object.models is not None:
|
||||
team_models = team_object.models
|
||||
if isinstance(team_models, list) and (
|
||||
not requested_model
|
||||
or requested_model in team_models
|
||||
or "*" in team_models
|
||||
):
|
||||
is_allowed = allowed_routes_check(
|
||||
user_role=LitellmUserRoles.TEAM,
|
||||
user_route=route,
|
||||
litellm_proxy_roles=jwt_handler.litellm_jwtauth,
|
||||
)
|
||||
if is_allowed:
|
||||
return team_id, team_object
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if requested_model:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail=f"No team has access to the requested model: {requested_model}. Checked teams={team_ids}",
|
||||
)
|
||||
|
||||
return None, None
|
||||
|
||||
@staticmethod
|
||||
async def get_user_info(
|
||||
jwt_handler: JWTHandler,
|
||||
jwt_valid_token: dict,
|
||||
) -> Tuple[Optional[str], Optional[str], Optional[bool]]:
|
||||
"""Get user email and validation status"""
|
||||
user_email = jwt_handler.get_user_email(
|
||||
token=jwt_valid_token, default_value=None
|
||||
)
|
||||
valid_user_email = None
|
||||
if jwt_handler.is_enforced_email_domain():
|
||||
valid_user_email = (
|
||||
False
|
||||
if user_email is None
|
||||
else jwt_handler.is_allowed_domain(user_email=user_email)
|
||||
)
|
||||
user_id = jwt_handler.get_user_id(
|
||||
token=jwt_valid_token, default_value=user_email
|
||||
)
|
||||
return user_id, user_email, valid_user_email
|
||||
|
||||
@staticmethod
|
||||
async def get_objects(
|
||||
user_id: Optional[str],
|
||||
org_id: Optional[str],
|
||||
end_user_id: Optional[str],
|
||||
valid_user_email: Optional[bool],
|
||||
jwt_handler: JWTHandler,
|
||||
prisma_client: Optional[PrismaClient],
|
||||
user_api_key_cache: DualCache,
|
||||
parent_otel_span: Optional[Span],
|
||||
proxy_logging_obj: ProxyLogging,
|
||||
) -> Tuple[
|
||||
Optional[LiteLLM_UserTable],
|
||||
Optional[LiteLLM_OrganizationTable],
|
||||
Optional[LiteLLM_EndUserTable],
|
||||
]:
|
||||
"""Get user, org, and end user objects"""
|
||||
org_object: Optional[LiteLLM_OrganizationTable] = None
|
||||
if org_id:
|
||||
org_object = (
|
||||
await get_org_object(
|
||||
org_id=org_id,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
parent_otel_span=parent_otel_span,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
if org_id
|
||||
else None
|
||||
)
|
||||
|
||||
user_object: Optional[LiteLLM_UserTable] = None
|
||||
if user_id:
|
||||
user_object = (
|
||||
await get_user_object(
|
||||
user_id=user_id,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
user_id_upsert=jwt_handler.is_upsert_user_id(
|
||||
valid_user_email=valid_user_email
|
||||
),
|
||||
parent_otel_span=parent_otel_span,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
if user_id
|
||||
else None
|
||||
)
|
||||
|
||||
end_user_object: Optional[LiteLLM_EndUserTable] = None
|
||||
if end_user_id:
|
||||
end_user_object = (
|
||||
await get_end_user_object(
|
||||
end_user_id=end_user_id,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
parent_otel_span=parent_otel_span,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
if end_user_id
|
||||
else None
|
||||
)
|
||||
|
||||
return user_object, org_object, end_user_object
|
||||
|
||||
@staticmethod
|
||||
async def auth_builder(
|
||||
api_key: str,
|
||||
jwt_handler: JWTHandler,
|
||||
request_data: dict,
|
||||
general_settings: dict,
|
||||
route: str,
|
||||
prisma_client: Optional[PrismaClient],
|
||||
user_api_key_cache: DualCache,
|
||||
parent_otel_span: Optional[Span],
|
||||
proxy_logging_obj: ProxyLogging,
|
||||
) -> JWTAuthBuilderResult:
|
||||
"""Main authentication and authorization builder"""
|
||||
jwt_valid_token: dict = await jwt_handler.auth_jwt(token=api_key)
|
||||
|
||||
# Check RBAC
|
||||
await JWTAuthManager.check_rbac_role(
|
||||
jwt_handler, jwt_valid_token, general_settings, request_data
|
||||
)
|
||||
|
||||
# Get basic user info
|
||||
scopes = jwt_handler.get_scopes(token=jwt_valid_token)
|
||||
user_id, _, valid_user_email = await JWTAuthManager.get_user_info(
|
||||
jwt_handler, jwt_valid_token
|
||||
)
|
||||
|
||||
# Get IDs
|
||||
org_id = jwt_handler.get_org_id(token=jwt_valid_token, default_value=None)
|
||||
end_user_id = jwt_handler.get_end_user_id(
|
||||
token=jwt_valid_token, default_value=None
|
||||
)
|
||||
|
||||
# Check admin access
|
||||
admin_result = await JWTAuthManager.check_admin_access(
|
||||
jwt_handler, scopes, route, user_id, org_id, api_key
|
||||
)
|
||||
if admin_result:
|
||||
return admin_result
|
||||
|
||||
# Get team with model access
|
||||
## SPECIFIC TEAM ID
|
||||
team_id, team_object = await JWTAuthManager.find_and_validate_specific_team_id(
|
||||
jwt_handler,
|
||||
jwt_valid_token,
|
||||
prisma_client,
|
||||
user_api_key_cache,
|
||||
parent_otel_span,
|
||||
proxy_logging_obj,
|
||||
)
|
||||
if not team_object:
|
||||
## CHECK USER GROUP ACCESS
|
||||
all_team_ids = JWTAuthManager.get_all_team_ids(jwt_handler, jwt_valid_token)
|
||||
team_id, team_object = await JWTAuthManager.find_team_with_model_access(
|
||||
team_ids=all_team_ids,
|
||||
requested_model=request_data.get("model"),
|
||||
route=route,
|
||||
jwt_handler=jwt_handler,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
parent_otel_span=parent_otel_span,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
|
||||
# Get other objects
|
||||
user_object, org_object, end_user_object = await JWTAuthManager.get_objects(
|
||||
user_id=user_id,
|
||||
org_id=org_id,
|
||||
end_user_id=end_user_id,
|
||||
valid_user_email=valid_user_email,
|
||||
jwt_handler=jwt_handler,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
parent_otel_span=parent_otel_span,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
|
||||
return JWTAuthBuilderResult(
|
||||
is_proxy_admin=False,
|
||||
team_id=team_id,
|
||||
team_object=team_object,
|
||||
user_id=user_id,
|
||||
user_object=user_object,
|
||||
org_id=org_id,
|
||||
org_object=org_object,
|
||||
end_user_id=end_user_id,
|
||||
end_user_object=end_user_object,
|
||||
token=api_key,
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue