[BETA] Support OIDC role based access to proxy (#8260)

* feat(proxy/_types.py): add new jwt field params

allows users + services to auth into proxy

* feat(handle_jwt.py): allow team role proxy access

allows proxy admin to set allowed team roles

* fix(proxy/_types.py): add 'routes' to role based permissions

allow proxy admin to restrict what routes a team can access easily

* feat(handle_jwt.py): support more flexible role based route access

v2 on role based 'allowed_routes'

* test(test_jwt.py): add unit test for rbac for proxy routes

* feat(handle_jwt.py): ensure cost tracking always works for any jwt request with `enforce_rbac=True`

* docs(token_auth.md): add documentation on controlling model access via OIDC Roles

* test: increase time delay before retrying

* test: handle model overloaded for test
This commit is contained in:
Krish Dholakia 2025-02-04 21:59:39 -08:00 committed by GitHub
parent 1d030ebed7
commit 015b822099
10 changed files with 413 additions and 143 deletions

View file

@ -35,11 +35,13 @@ from litellm.proxy._types import (
from litellm.proxy.utils import PrismaClient, ProxyLogging
from .auth_checks import (
_allowed_routes_check,
allowed_routes_check,
get_actual_routes,
get_end_user_object,
get_org_object,
get_role_based_models,
get_role_based_routes,
get_team_object,
get_user_object,
)
@ -78,6 +80,37 @@ class JWTHandler:
parts = token.split(".")
return len(parts) == 3
def _rbac_role_from_role_mapping(self, token: dict) -> Optional[RBAC_ROLES]:
"""
Returns the RBAC role the token 'belongs' to based on role mappings.
Args:
token (dict): The JWT token containing role information
Returns:
Optional[RBAC_ROLES]: The mapped internal RBAC role if a mapping exists,
None otherwise
Note:
The function handles both single string roles and lists of roles from the JWT.
If multiple mappings match the JWT roles, the first matching mapping is returned.
"""
if self.litellm_jwtauth.role_mappings is None:
return None
jwt_role = self.get_jwt_role(token=token, default_value=None)
if not jwt_role:
return None
jwt_role_set = set(jwt_role)
for role_mapping in self.litellm_jwtauth.role_mappings:
# Check if the mapping role matches any of the JWT roles
if role_mapping.role in jwt_role_set:
return role_mapping.internal_role
return None
def get_rbac_role(self, token: dict) -> Optional[RBAC_ROLES]:
"""
Returns the RBAC role the token 'belongs' to.
@ -109,6 +142,8 @@ class JWTHandler:
user_roles=user_roles
):
return LitellmUserRoles.INTERNAL_USER
elif rbac_role := self._rbac_role_from_role_mapping(token=token):
return rbac_role
return None
@ -212,6 +247,29 @@ class JWTHandler:
user_roles = default_value
return user_roles
def get_jwt_role(
self, token: dict, default_value: Optional[List[str]]
) -> Optional[List[str]]:
"""
Generic implementation of `get_user_roles` that can be used for both user and team roles.
Returns the jwt role from the token.
Set via 'roles_jwt_field' in the config.
"""
try:
if self.litellm_jwtauth.roles_jwt_field is not None:
user_roles = get_nested_value(
data=token,
key_path=self.litellm_jwtauth.roles_jwt_field,
default=default_value,
)
else:
user_roles = default_value
except KeyError:
user_roles = default_value
return user_roles
def is_allowed_user_role(self, user_roles: Optional[List[str]]) -> bool:
"""
Returns the user role from the token.
@ -240,6 +298,16 @@ class JWTHandler:
user_email = default_value
return user_email
def get_object_id(self, token: dict, default_value: Optional[str]) -> Optional[str]:
try:
if self.litellm_jwtauth.object_id_jwt_field is not None:
object_id = token[self.litellm_jwtauth.object_id_jwt_field]
else:
object_id = default_value
except KeyError:
object_id = default_value
return object_id
def get_org_id(self, token: dict, default_value: Optional[str]) -> Optional[str]:
try:
if self.litellm_jwtauth.org_id_jwt_field is not None:
@ -423,6 +491,35 @@ class JWTHandler:
class JWTAuthManager:
"""Manages JWT authentication and authorization operations"""
@staticmethod
def can_rbac_role_call_route(
rbac_role: RBAC_ROLES,
general_settings: dict,
route: str,
) -> Literal[True]:
"""
Checks if user is allowed to access the route, based on their role.
"""
role_based_routes = get_role_based_routes(
rbac_role=rbac_role, general_settings=general_settings
)
if role_based_routes is None or route is None:
return True
is_allowed = _allowed_routes_check(
user_route=route,
allowed_routes=role_based_routes,
)
if not is_allowed:
raise HTTPException(
status_code=403,
detail=f"Role={rbac_role} not allowed to call route={route}. Allowed routes={role_based_routes}",
)
return True
@staticmethod
def can_rbac_role_call_model(
rbac_role: RBAC_ROLES,
@ -441,7 +538,7 @@ class JWTAuthManager:
if model not in role_based_models:
raise HTTPException(
status_code=403,
detail=f"User role={rbac_role} not allowed to call model={model}. Allowed models={role_based_models}",
detail=f"Role={rbac_role} not allowed to call model={model}. Allowed models={role_based_models}",
)
return True
@ -452,10 +549,11 @@ class JWTAuthManager:
jwt_valid_token: dict,
general_settings: dict,
request_data: dict,
route: str,
rbac_role: Optional[RBAC_ROLES],
) -> None:
"""Validate RBAC role and model access permissions"""
if jwt_handler.litellm_jwtauth.enforce_rbac is True:
rbac_role = jwt_handler.get_rbac_role(token=jwt_valid_token)
if rbac_role is None:
raise HTTPException(
status_code=403,
@ -466,6 +564,11 @@ class JWTAuthManager:
general_settings=general_settings,
model=request_data.get("model"),
)
JWTAuthManager.can_rbac_role_call_route(
rbac_role=rbac_role,
general_settings=general_settings,
route=route,
)
@staticmethod
async def check_admin_access(
@ -685,6 +788,21 @@ class JWTAuthManager:
return user_object, org_object, end_user_object
@staticmethod
def validate_object_id(
user_id: Optional[str],
team_id: Optional[str],
enforce_rbac: bool,
is_proxy_admin: bool,
) -> Literal[True]:
"""If enforce_rbac is true, validate that a valid rbac id is returned for spend tracking"""
if enforce_rbac and not is_proxy_admin and not user_id and not team_id:
raise HTTPException(
status_code=403,
detail="No user or team id found in token. enforce_rbac is set to True. Token must belong to a proxy admin, team, or user.",
)
return True
@staticmethod
async def auth_builder(
api_key: str,
@ -701,10 +819,18 @@ class JWTAuthManager:
jwt_valid_token: dict = await jwt_handler.auth_jwt(token=api_key)
# Check RBAC
rbac_role = jwt_handler.get_rbac_role(token=jwt_valid_token)
await JWTAuthManager.check_rbac_role(
jwt_handler, jwt_valid_token, general_settings, request_data
jwt_handler,
jwt_valid_token,
general_settings,
request_data,
route,
rbac_role,
)
object_id = jwt_handler.get_object_id(token=jwt_valid_token, default_value=None)
# Get basic user info
scopes = jwt_handler.get_scopes(token=jwt_valid_token)
user_id, user_email, valid_user_email = await JWTAuthManager.get_user_info(
@ -716,6 +842,16 @@ class JWTAuthManager:
end_user_id = jwt_handler.get_end_user_id(
token=jwt_valid_token, default_value=None
)
team_id: Optional[str] = None
team_object: Optional[LiteLLM_TeamTable] = None
object_id = jwt_handler.get_object_id(token=jwt_valid_token, default_value=None)
if rbac_role and object_id:
if rbac_role == LitellmUserRoles.TEAM:
team_id = object_id
elif rbac_role == LitellmUserRoles.INTERNAL_USER:
user_id = object_id
# Check admin access
admin_result = await JWTAuthManager.check_admin_access(
@ -726,15 +862,20 @@ class JWTAuthManager:
# Get team with model access
## SPECIFIC TEAM ID
team_id, team_object = await JWTAuthManager.find_and_validate_specific_team_id(
jwt_handler,
jwt_valid_token,
prisma_client,
user_api_key_cache,
parent_otel_span,
proxy_logging_obj,
)
if not team_object:
if not team_id:
team_id, team_object = (
await JWTAuthManager.find_and_validate_specific_team_id(
jwt_handler,
jwt_valid_token,
prisma_client,
user_api_key_cache,
parent_otel_span,
proxy_logging_obj,
)
)
if not team_object and not team_id:
## CHECK USER GROUP ACCESS
all_team_ids = JWTAuthManager.get_all_team_ids(jwt_handler, jwt_valid_token)
team_id, team_object = await JWTAuthManager.find_team_with_model_access(
@ -762,6 +903,14 @@ class JWTAuthManager:
proxy_logging_obj=proxy_logging_obj,
)
# Validate that a valid rbac id is returned for spend tracking
JWTAuthManager.validate_object_id(
user_id=user_id,
team_id=team_id,
enforce_rbac=general_settings.get("enforce_rbac", False),
is_proxy_admin=False,
)
return JWTAuthBuilderResult(
is_proxy_admin=False,
team_id=team_id,