mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 11:43:54 +00:00
[BETA] Support OIDC role
based access to proxy (#8260)
* feat(proxy/_types.py): add new jwt field params allows users + services to auth into proxy * feat(handle_jwt.py): allow team role proxy access allows proxy admin to set allowed team roles * fix(proxy/_types.py): add 'routes' to role based permissions allow proxy admin to restrict what routes a team can access easily * feat(handle_jwt.py): support more flexible role based route access v2 on role based 'allowed_routes' * test(test_jwt.py): add unit test for rbac for proxy routes * feat(handle_jwt.py): ensure cost tracking always works for any jwt request with `enforce_rbac=True` * docs(token_auth.md): add documentation on controlling model access via OIDC Roles * test: increase time delay before retrying * test: handle model overloaded for test
This commit is contained in:
parent
1d030ebed7
commit
015b822099
10 changed files with 413 additions and 143 deletions
|
@ -35,11 +35,13 @@ from litellm.proxy._types import (
|
|||
from litellm.proxy.utils import PrismaClient, ProxyLogging
|
||||
|
||||
from .auth_checks import (
|
||||
_allowed_routes_check,
|
||||
allowed_routes_check,
|
||||
get_actual_routes,
|
||||
get_end_user_object,
|
||||
get_org_object,
|
||||
get_role_based_models,
|
||||
get_role_based_routes,
|
||||
get_team_object,
|
||||
get_user_object,
|
||||
)
|
||||
|
@ -78,6 +80,37 @@ class JWTHandler:
|
|||
parts = token.split(".")
|
||||
return len(parts) == 3
|
||||
|
||||
def _rbac_role_from_role_mapping(self, token: dict) -> Optional[RBAC_ROLES]:
|
||||
"""
|
||||
Returns the RBAC role the token 'belongs' to based on role mappings.
|
||||
|
||||
Args:
|
||||
token (dict): The JWT token containing role information
|
||||
|
||||
Returns:
|
||||
Optional[RBAC_ROLES]: The mapped internal RBAC role if a mapping exists,
|
||||
None otherwise
|
||||
|
||||
Note:
|
||||
The function handles both single string roles and lists of roles from the JWT.
|
||||
If multiple mappings match the JWT roles, the first matching mapping is returned.
|
||||
"""
|
||||
if self.litellm_jwtauth.role_mappings is None:
|
||||
return None
|
||||
|
||||
jwt_role = self.get_jwt_role(token=token, default_value=None)
|
||||
if not jwt_role:
|
||||
return None
|
||||
|
||||
jwt_role_set = set(jwt_role)
|
||||
|
||||
for role_mapping in self.litellm_jwtauth.role_mappings:
|
||||
# Check if the mapping role matches any of the JWT roles
|
||||
if role_mapping.role in jwt_role_set:
|
||||
return role_mapping.internal_role
|
||||
|
||||
return None
|
||||
|
||||
def get_rbac_role(self, token: dict) -> Optional[RBAC_ROLES]:
|
||||
"""
|
||||
Returns the RBAC role the token 'belongs' to.
|
||||
|
@ -109,6 +142,8 @@ class JWTHandler:
|
|||
user_roles=user_roles
|
||||
):
|
||||
return LitellmUserRoles.INTERNAL_USER
|
||||
elif rbac_role := self._rbac_role_from_role_mapping(token=token):
|
||||
return rbac_role
|
||||
|
||||
return None
|
||||
|
||||
|
@ -212,6 +247,29 @@ class JWTHandler:
|
|||
user_roles = default_value
|
||||
return user_roles
|
||||
|
||||
def get_jwt_role(
|
||||
self, token: dict, default_value: Optional[List[str]]
|
||||
) -> Optional[List[str]]:
|
||||
"""
|
||||
Generic implementation of `get_user_roles` that can be used for both user and team roles.
|
||||
|
||||
Returns the jwt role from the token.
|
||||
|
||||
Set via 'roles_jwt_field' in the config.
|
||||
"""
|
||||
try:
|
||||
if self.litellm_jwtauth.roles_jwt_field is not None:
|
||||
user_roles = get_nested_value(
|
||||
data=token,
|
||||
key_path=self.litellm_jwtauth.roles_jwt_field,
|
||||
default=default_value,
|
||||
)
|
||||
else:
|
||||
user_roles = default_value
|
||||
except KeyError:
|
||||
user_roles = default_value
|
||||
return user_roles
|
||||
|
||||
def is_allowed_user_role(self, user_roles: Optional[List[str]]) -> bool:
|
||||
"""
|
||||
Returns the user role from the token.
|
||||
|
@ -240,6 +298,16 @@ class JWTHandler:
|
|||
user_email = default_value
|
||||
return user_email
|
||||
|
||||
def get_object_id(self, token: dict, default_value: Optional[str]) -> Optional[str]:
|
||||
try:
|
||||
if self.litellm_jwtauth.object_id_jwt_field is not None:
|
||||
object_id = token[self.litellm_jwtauth.object_id_jwt_field]
|
||||
else:
|
||||
object_id = default_value
|
||||
except KeyError:
|
||||
object_id = default_value
|
||||
return object_id
|
||||
|
||||
def get_org_id(self, token: dict, default_value: Optional[str]) -> Optional[str]:
|
||||
try:
|
||||
if self.litellm_jwtauth.org_id_jwt_field is not None:
|
||||
|
@ -423,6 +491,35 @@ class JWTHandler:
|
|||
class JWTAuthManager:
|
||||
"""Manages JWT authentication and authorization operations"""
|
||||
|
||||
@staticmethod
|
||||
def can_rbac_role_call_route(
|
||||
rbac_role: RBAC_ROLES,
|
||||
general_settings: dict,
|
||||
route: str,
|
||||
) -> Literal[True]:
|
||||
"""
|
||||
Checks if user is allowed to access the route, based on their role.
|
||||
"""
|
||||
role_based_routes = get_role_based_routes(
|
||||
rbac_role=rbac_role, general_settings=general_settings
|
||||
)
|
||||
|
||||
if role_based_routes is None or route is None:
|
||||
return True
|
||||
|
||||
is_allowed = _allowed_routes_check(
|
||||
user_route=route,
|
||||
allowed_routes=role_based_routes,
|
||||
)
|
||||
|
||||
if not is_allowed:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail=f"Role={rbac_role} not allowed to call route={route}. Allowed routes={role_based_routes}",
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def can_rbac_role_call_model(
|
||||
rbac_role: RBAC_ROLES,
|
||||
|
@ -441,7 +538,7 @@ class JWTAuthManager:
|
|||
if model not in role_based_models:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail=f"User role={rbac_role} not allowed to call model={model}. Allowed models={role_based_models}",
|
||||
detail=f"Role={rbac_role} not allowed to call model={model}. Allowed models={role_based_models}",
|
||||
)
|
||||
|
||||
return True
|
||||
|
@ -452,10 +549,11 @@ class JWTAuthManager:
|
|||
jwt_valid_token: dict,
|
||||
general_settings: dict,
|
||||
request_data: dict,
|
||||
route: str,
|
||||
rbac_role: Optional[RBAC_ROLES],
|
||||
) -> None:
|
||||
"""Validate RBAC role and model access permissions"""
|
||||
if jwt_handler.litellm_jwtauth.enforce_rbac is True:
|
||||
rbac_role = jwt_handler.get_rbac_role(token=jwt_valid_token)
|
||||
if rbac_role is None:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
|
@ -466,6 +564,11 @@ class JWTAuthManager:
|
|||
general_settings=general_settings,
|
||||
model=request_data.get("model"),
|
||||
)
|
||||
JWTAuthManager.can_rbac_role_call_route(
|
||||
rbac_role=rbac_role,
|
||||
general_settings=general_settings,
|
||||
route=route,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def check_admin_access(
|
||||
|
@ -685,6 +788,21 @@ class JWTAuthManager:
|
|||
|
||||
return user_object, org_object, end_user_object
|
||||
|
||||
@staticmethod
|
||||
def validate_object_id(
|
||||
user_id: Optional[str],
|
||||
team_id: Optional[str],
|
||||
enforce_rbac: bool,
|
||||
is_proxy_admin: bool,
|
||||
) -> Literal[True]:
|
||||
"""If enforce_rbac is true, validate that a valid rbac id is returned for spend tracking"""
|
||||
if enforce_rbac and not is_proxy_admin and not user_id and not team_id:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="No user or team id found in token. enforce_rbac is set to True. Token must belong to a proxy admin, team, or user.",
|
||||
)
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
async def auth_builder(
|
||||
api_key: str,
|
||||
|
@ -701,10 +819,18 @@ class JWTAuthManager:
|
|||
jwt_valid_token: dict = await jwt_handler.auth_jwt(token=api_key)
|
||||
|
||||
# Check RBAC
|
||||
rbac_role = jwt_handler.get_rbac_role(token=jwt_valid_token)
|
||||
await JWTAuthManager.check_rbac_role(
|
||||
jwt_handler, jwt_valid_token, general_settings, request_data
|
||||
jwt_handler,
|
||||
jwt_valid_token,
|
||||
general_settings,
|
||||
request_data,
|
||||
route,
|
||||
rbac_role,
|
||||
)
|
||||
|
||||
object_id = jwt_handler.get_object_id(token=jwt_valid_token, default_value=None)
|
||||
|
||||
# Get basic user info
|
||||
scopes = jwt_handler.get_scopes(token=jwt_valid_token)
|
||||
user_id, user_email, valid_user_email = await JWTAuthManager.get_user_info(
|
||||
|
@ -716,6 +842,16 @@ class JWTAuthManager:
|
|||
end_user_id = jwt_handler.get_end_user_id(
|
||||
token=jwt_valid_token, default_value=None
|
||||
)
|
||||
team_id: Optional[str] = None
|
||||
team_object: Optional[LiteLLM_TeamTable] = None
|
||||
object_id = jwt_handler.get_object_id(token=jwt_valid_token, default_value=None)
|
||||
|
||||
if rbac_role and object_id:
|
||||
|
||||
if rbac_role == LitellmUserRoles.TEAM:
|
||||
team_id = object_id
|
||||
elif rbac_role == LitellmUserRoles.INTERNAL_USER:
|
||||
user_id = object_id
|
||||
|
||||
# Check admin access
|
||||
admin_result = await JWTAuthManager.check_admin_access(
|
||||
|
@ -726,15 +862,20 @@ class JWTAuthManager:
|
|||
|
||||
# Get team with model access
|
||||
## SPECIFIC TEAM ID
|
||||
team_id, team_object = await JWTAuthManager.find_and_validate_specific_team_id(
|
||||
jwt_handler,
|
||||
jwt_valid_token,
|
||||
prisma_client,
|
||||
user_api_key_cache,
|
||||
parent_otel_span,
|
||||
proxy_logging_obj,
|
||||
)
|
||||
if not team_object:
|
||||
|
||||
if not team_id:
|
||||
team_id, team_object = (
|
||||
await JWTAuthManager.find_and_validate_specific_team_id(
|
||||
jwt_handler,
|
||||
jwt_valid_token,
|
||||
prisma_client,
|
||||
user_api_key_cache,
|
||||
parent_otel_span,
|
||||
proxy_logging_obj,
|
||||
)
|
||||
)
|
||||
|
||||
if not team_object and not team_id:
|
||||
## CHECK USER GROUP ACCESS
|
||||
all_team_ids = JWTAuthManager.get_all_team_ids(jwt_handler, jwt_valid_token)
|
||||
team_id, team_object = await JWTAuthManager.find_team_with_model_access(
|
||||
|
@ -762,6 +903,14 @@ class JWTAuthManager:
|
|||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
|
||||
# Validate that a valid rbac id is returned for spend tracking
|
||||
JWTAuthManager.validate_object_id(
|
||||
user_id=user_id,
|
||||
team_id=team_id,
|
||||
enforce_rbac=general_settings.get("enforce_rbac", False),
|
||||
is_proxy_admin=False,
|
||||
)
|
||||
|
||||
return JWTAuthBuilderResult(
|
||||
is_proxy_admin=False,
|
||||
team_id=team_id,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue