feat(handle_jwt.py): initial commit adding custom RBAC support on jwt… (#8037)

* feat(handle_jwt.py): initial commit adding custom RBAC support on jwt auth

allows admin to define user role field and allowed roles which map to 'internal_user' on litellm

* fix(auth_checks.py): ensure user allowed to access model, when calling via personal keys

Fixes https://github.com/BerriAI/litellm/issues/8029

* feat(handle_jwt.py): support role based access with model permission control on proxy

Allows admin to just grant users roles on IDP (e.g. Azure AD/Keycloak) and user can immediately start calling models

* docs(rbac): add docs on rbac for model access control

make it clear how admin can use roles to control model access on proxy

* fix: fix linting errors

* test(test_user_api_key_auth.py): add unit testing to ensure rbac role is correctly enforced

* test(test_user_api_key_auth.py): add more testing

* test(test_users.py): add unit testing to ensure user model access is always checked for new keys

Resolves https://github.com/BerriAI/litellm/issues/8029

* test: fix unit test

* fix(dot_notation_indexing.py): fix typing to work with python 3.8
This commit is contained in:
Krish Dholakia 2025-01-28 16:27:06 -08:00 committed by GitHub
parent 9644e197f7
commit 2eaa0079f2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 648 additions and 84 deletions

View file

@ -33,6 +33,7 @@ from litellm.proxy.auth.auth_checks import (
get_end_user_object,
get_key_object,
get_org_object,
get_role_based_models,
get_team_object,
get_user_object,
is_valid_fallback_model,
@ -281,9 +282,34 @@ def get_rbac_role(jwt_handler: JWTHandler, scopes: List[str]) -> str:
return LitellmUserRoles.TEAM
def can_rbac_role_call_model(
rbac_role: RBAC_ROLES,
general_settings: dict,
model: Optional[str],
) -> Literal[True]:
"""
Checks if user is allowed to access the model, based on their role.
"""
role_based_models = get_role_based_models(
rbac_role=rbac_role, general_settings=general_settings
)
if role_based_models is None or model is None:
return True
if model not in role_based_models:
raise HTTPException(
status_code=403,
detail=f"User role={rbac_role} not allowed to call model={model}. Allowed models={role_based_models}",
)
return True
async def _jwt_auth_user_api_key_auth_builder(
api_key: str,
jwt_handler: JWTHandler,
request_data: dict,
general_settings: dict,
route: str,
prisma_client: Optional[PrismaClient],
user_api_key_cache: DualCache,
@ -295,14 +321,20 @@ async def _jwt_auth_user_api_key_auth_builder(
jwt_valid_token: dict = await jwt_handler.auth_jwt(token=api_key)
# check if unmatched token and enforce_rbac is true
if (
jwt_handler.litellm_jwtauth.enforce_rbac is True
and jwt_handler.get_rbac_role(token=jwt_valid_token) is None
):
raise HTTPException(
status_code=403,
detail="Unmatched token passed in. enforce_rbac is set to True. Token must belong to a proxy admin, team, or user. See how to set roles in config here: https://docs.litellm.ai/docs/proxy/token_auth#advanced---spend-tracking-end-users--internal-users--team--org",
)
if jwt_handler.litellm_jwtauth.enforce_rbac is True:
rbac_role = jwt_handler.get_rbac_role(token=jwt_valid_token)
if rbac_role is None:
raise HTTPException(
status_code=403,
detail="Unmatched token passed in. enforce_rbac is set to True. Token must belong to a proxy admin, team, or user. See how to set roles in config here: https://docs.litellm.ai/docs/proxy/token_auth#advanced---spend-tracking-end-users--internal-users--team--org",
)
else:
# run rbac validation checks
can_rbac_role_call_model(
rbac_role=rbac_role,
general_settings=general_settings,
model=request_data.get("model"),
)
# get scopes
scopes = jwt_handler.get_scopes(token=jwt_valid_token)
@ -431,18 +463,18 @@ async def _jwt_auth_user_api_key_auth_builder(
proxy_logging_obj=proxy_logging_obj,
)
return {
"is_proxy_admin": False,
"team_id": team_id,
"team_object": team_object,
"user_id": user_id,
"user_object": user_object,
"org_id": org_id,
"org_object": org_object,
"end_user_id": end_user_id,
"end_user_object": end_user_object,
"token": api_key,
}
return JWTAuthBuilderResult(
is_proxy_admin=False,
team_id=team_id,
team_object=team_object,
user_id=user_id,
user_object=user_object,
org_id=org_id,
org_object=org_object,
end_user_id=end_user_id,
end_user_object=end_user_object,
token=api_key,
)
async def _user_api_key_auth_builder( # noqa: PLR0915
@ -581,6 +613,8 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
verbose_proxy_logger.debug("is_jwt: %s", is_jwt)
if is_jwt:
result = await _jwt_auth_user_api_key_auth_builder(
request_data=request_data,
general_settings=general_settings,
api_key=api_key,
jwt_handler=jwt_handler,
route=route,