mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
[BETA] Support OIDC role
based access to proxy (#8260)
* feat(proxy/_types.py): add new jwt field params allows users + services to auth into proxy * feat(handle_jwt.py): allow team role proxy access allows proxy admin to set allowed team roles * fix(proxy/_types.py): add 'routes' to role based permissions allow proxy admin to restrict what routes a team can access easily * feat(handle_jwt.py): support more flexible role based route access v2 on role based 'allowed_routes' * test(test_jwt.py): add unit test for rbac for proxy routes * feat(handle_jwt.py): ensure cost tracking always works for any jwt request with `enforce_rbac=True` * docs(token_auth.md): add documentation on controlling model access via OIDC Roles * test: increase time delay before retrying * test: handle model overloaded for test
This commit is contained in:
parent
1d030ebed7
commit
015b822099
10 changed files with 413 additions and 143 deletions
|
@ -3,7 +3,7 @@ import TabItem from '@theme/TabItem';
|
||||||
|
|
||||||
# OIDC - JWT-based Auth
|
# OIDC - JWT-based Auth
|
||||||
|
|
||||||
Use JWT's to auth admins / projects into the proxy.
|
Use JWT's to auth admins / users / projects into the proxy.
|
||||||
|
|
||||||
:::info
|
:::info
|
||||||
|
|
||||||
|
@ -156,27 +156,6 @@ scope: ["litellm-proxy-admin",...]
|
||||||
scope: "litellm-proxy-admin ..."
|
scope: "litellm-proxy-admin ..."
|
||||||
```
|
```
|
||||||
|
|
||||||
## Control Model Access with Roles
|
|
||||||
|
|
||||||
Reject a JWT token if it's valid but doesn't have the required scopes / fields.
|
|
||||||
|
|
||||||
Only tokens which with valid Admin (`admin_jwt_scope`), User (`user_id_jwt_field`), Team (`team_id_jwt_field`) are allowed.
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
general_settings:
|
|
||||||
enable_jwt_auth: True
|
|
||||||
litellm_jwtauth:
|
|
||||||
user_roles_jwt_field: "resource_access.litellm-test-client-id.roles"
|
|
||||||
user_allowed_roles: ["basic_user"] # roles that map to an 'internal_user' role on LiteLLM
|
|
||||||
enforce_rbac: true # if true, will check if the user has the correct role to access the model + endpoint
|
|
||||||
|
|
||||||
role_permissions: # control what models + endpointsare allowed for each role
|
|
||||||
- role: internal_user
|
|
||||||
models: ["anthropic-claude"]
|
|
||||||
```
|
|
||||||
|
|
||||||
**[Architecture Diagram (Control Model Access)](./jwt_auth_arch)**
|
|
||||||
|
|
||||||
## Control model access with Teams
|
## Control model access with Teams
|
||||||
|
|
||||||
|
|
||||||
|
@ -331,3 +310,64 @@ general_settings:
|
||||||
user_allowed_email_domain: "my-co.com" # allows user@my-co.com to call proxy
|
user_allowed_email_domain: "my-co.com" # allows user@my-co.com to call proxy
|
||||||
user_id_upsert: true # 👈 upserts the user to db, if valid email but not in db
|
user_id_upsert: true # 👈 upserts the user to db, if valid email but not in db
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## [BETA] Control Access with OIDC Roles
|
||||||
|
|
||||||
|
Allow JWT tokens with supported roles to access the proxy.
|
||||||
|
|
||||||
|
Let users and teams access the proxy, without needing to add them to the DB.
|
||||||
|
|
||||||
|
|
||||||
|
Very important, set `enforce_rbac: true` to ensure that the RBAC system is enabled.
|
||||||
|
|
||||||
|
**Note:** This is in beta and might change unexpectedly.
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
general_settings:
|
||||||
|
enable_jwt_auth: True
|
||||||
|
litellm_jwtauth:
|
||||||
|
object_id_jwt_field: "oid" # can be either user / team, inferred from the role mapping
|
||||||
|
roles_jwt_field: "roles"
|
||||||
|
role_mappings:
|
||||||
|
- role: litellm.api.consumer
|
||||||
|
internal_role: "team"
|
||||||
|
enforce_rbac: true # 👈 VERY IMPORTANT
|
||||||
|
|
||||||
|
role_permissions: # default model + endpoint permissions for a role.
|
||||||
|
- role: team
|
||||||
|
models: ["anthropic-claude"]
|
||||||
|
routes: ["/v1/chat/completions"]
|
||||||
|
|
||||||
|
environment_variables:
|
||||||
|
JWT_AUDIENCE: "api://LiteLLM_Proxy" # ensures audience is validated
|
||||||
|
```
|
||||||
|
|
||||||
|
- `object_id_jwt_field`: The field in the JWT token that contains the object id. This id can be either a user id or a team id. Use this instead of `user_id_jwt_field` and `team_id_jwt_field`. If the same field could be both.
|
||||||
|
|
||||||
|
- `roles_jwt_field`: The field in the JWT token that contains the roles. This field is a list of roles that the user has. To index into a nested field, use dot notation - eg. `resource_access.litellm-test-client-id.roles`.
|
||||||
|
|
||||||
|
- `role_mappings`: A list of role mappings. Map the received role in the JWT token to an internal role on LiteLLM.
|
||||||
|
|
||||||
|
- `JWT_AUDIENCE`: The audience of the JWT token. This is used to validate the audience of the JWT token. Set via an environment variable.
|
||||||
|
|
||||||
|
### Example Token
|
||||||
|
|
||||||
|
```
|
||||||
|
{
|
||||||
|
"aud": "api://LiteLLM_Proxy",
|
||||||
|
"oid": "eec236bd-0135-4b28-9354-8fc4032d543e",
|
||||||
|
"roles": ["litellm.api.consumer"]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Role Mapping Spec
|
||||||
|
|
||||||
|
- `role`: The expected role in the JWT token.
|
||||||
|
- `internal_role`: The internal role on LiteLLM that will be used to control access.
|
||||||
|
|
||||||
|
Supported internal roles:
|
||||||
|
- `team`: Team object will be used for RBAC spend tracking. Use this for tracking spend for a 'use case'.
|
||||||
|
- `internal_user`: User object will be used for RBAC spend tracking. Use this for tracking spend for an 'individual user'.
|
||||||
|
- `proxy_admin`: Proxy admin will be used for RBAC spend tracking. Use this for granting admin access to a token.
|
||||||
|
|
||||||
|
### [Architecture Diagram (Control Model Access)](./jwt_auth_arch)
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
|
@ -35,6 +35,15 @@ litellm_settings:
|
||||||
general_settings:
|
general_settings:
|
||||||
enable_jwt_auth: True
|
enable_jwt_auth: True
|
||||||
litellm_jwtauth:
|
litellm_jwtauth:
|
||||||
user_id_jwt_field: "sub"
|
object_id_jwt_field: "client_id" # can be either user / team, inferred from the role mapping
|
||||||
user_email_jwt_field: "email"
|
roles_jwt_field: "resource_access.litellm-test-client-id.roles"
|
||||||
team_ids_jwt_field: "groups" # 👈 CAN BE ANY FIELD
|
role_mappings:
|
||||||
|
- role: litellm.api.consumer
|
||||||
|
internal_role: "team"
|
||||||
|
enforce_rbac: true
|
||||||
|
role_permissions: # default model + endpoint permissions for a role.
|
||||||
|
- role: team
|
||||||
|
models: ["anthropic-claude"]
|
||||||
|
routes: ["openai_routes"]
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -397,92 +397,6 @@ class LiteLLMRoutes(enum.Enum):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# class LiteLLMAllowedRoutes(LiteLLMPydanticObjectBase):
|
|
||||||
# """
|
|
||||||
# Defines allowed routes based on key type.
|
|
||||||
|
|
||||||
# Types = ["admin", "team", "user", "unmapped"]
|
|
||||||
# """
|
|
||||||
|
|
||||||
# admin_allowed_routes: List[
|
|
||||||
# Literal["openai_routes", "info_routes", "management_routes", "spend_tracking_routes", "global_spend_tracking_routes"]
|
|
||||||
# ] = ["management_routes"]
|
|
||||||
|
|
||||||
|
|
||||||
class LiteLLM_JWTAuth(LiteLLMPydanticObjectBase):
|
|
||||||
"""
|
|
||||||
A class to define the roles and permissions for a LiteLLM Proxy w/ JWT Auth.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
- admin_jwt_scope: The JWT scope required for proxy admin roles.
|
|
||||||
- admin_allowed_routes: list of allowed routes for proxy admin roles.
|
|
||||||
- team_jwt_scope: The JWT scope required for proxy team roles.
|
|
||||||
- team_id_jwt_field: The field in the JWT token that stores the team ID. Default - `client_id`.
|
|
||||||
- team_allowed_routes: list of allowed routes for proxy team roles.
|
|
||||||
- user_id_jwt_field: The field in the JWT token that stores the user id (maps to `LiteLLMUserTable`). Use this for internal employees.
|
|
||||||
- user_email_jwt_field: The field in the JWT token that stores the user email (maps to `LiteLLMUserTable`). Use this for internal employees.
|
|
||||||
- user_allowed_email_subdomain: If specified, only emails from specified subdomain will be allowed to access proxy.
|
|
||||||
- end_user_id_jwt_field: The field in the JWT token that stores the end-user ID (maps to `LiteLLMEndUserTable`). Turn this off by setting to `None`. Enables end-user cost tracking. Use this for external customers.
|
|
||||||
- public_key_ttl: Default - 600s. TTL for caching public JWT keys.
|
|
||||||
- public_allowed_routes: list of allowed routes for authenticated but unknown litellm role jwt tokens.
|
|
||||||
- enforce_rbac: If true, enforce RBAC for all routes.
|
|
||||||
|
|
||||||
See `auth_checks.py` for the specific routes
|
|
||||||
"""
|
|
||||||
|
|
||||||
admin_jwt_scope: str = "litellm_proxy_admin"
|
|
||||||
admin_allowed_routes: List[str] = [
|
|
||||||
"management_routes",
|
|
||||||
"spend_tracking_routes",
|
|
||||||
"global_spend_tracking_routes",
|
|
||||||
"info_routes",
|
|
||||||
]
|
|
||||||
team_id_jwt_field: Optional[str] = None
|
|
||||||
team_ids_jwt_field: Optional[str] = None
|
|
||||||
upsert_sso_user_to_team: bool = False
|
|
||||||
team_allowed_routes: List[
|
|
||||||
Literal["openai_routes", "info_routes", "management_routes"]
|
|
||||||
] = ["openai_routes", "info_routes"]
|
|
||||||
team_id_default: Optional[str] = Field(
|
|
||||||
default=None,
|
|
||||||
description="If no team_id given, default permissions/spend-tracking to this team.s",
|
|
||||||
)
|
|
||||||
org_id_jwt_field: Optional[str] = None
|
|
||||||
user_id_jwt_field: Optional[str] = None
|
|
||||||
user_email_jwt_field: Optional[str] = None
|
|
||||||
user_allowed_email_domain: Optional[str] = None
|
|
||||||
user_roles_jwt_field: Optional[str] = None
|
|
||||||
user_allowed_roles: Optional[List[str]] = None
|
|
||||||
user_id_upsert: bool = Field(
|
|
||||||
default=False, description="If user doesn't exist, upsert them into the db."
|
|
||||||
)
|
|
||||||
end_user_id_jwt_field: Optional[str] = None
|
|
||||||
public_key_ttl: float = 600
|
|
||||||
public_allowed_routes: List[str] = ["public_routes"]
|
|
||||||
enforce_rbac: bool = False
|
|
||||||
|
|
||||||
def __init__(self, **kwargs: Any) -> None:
|
|
||||||
# get the attribute names for this Pydantic model
|
|
||||||
allowed_keys = self.__annotations__.keys()
|
|
||||||
|
|
||||||
invalid_keys = set(kwargs.keys()) - allowed_keys
|
|
||||||
user_roles_jwt_field = kwargs.get("user_roles_jwt_field")
|
|
||||||
user_allowed_roles = kwargs.get("user_allowed_roles")
|
|
||||||
|
|
||||||
if invalid_keys:
|
|
||||||
raise ValueError(
|
|
||||||
f"Invalid arguments provided: {', '.join(invalid_keys)}. Allowed arguments are: {', '.join(allowed_keys)}."
|
|
||||||
)
|
|
||||||
if (user_roles_jwt_field is not None and user_allowed_roles is None) or (
|
|
||||||
user_roles_jwt_field is None and user_allowed_roles is not None
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
"user_allowed_roles must be provided if user_roles_jwt_field is set."
|
|
||||||
)
|
|
||||||
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
class LiteLLMPromptInjectionParams(LiteLLMPydanticObjectBase):
|
class LiteLLMPromptInjectionParams(LiteLLMPydanticObjectBase):
|
||||||
heuristics_check: bool = False
|
heuristics_check: bool = False
|
||||||
vector_db_check: bool = False
|
vector_db_check: bool = False
|
||||||
|
@ -2364,6 +2278,103 @@ RBAC_ROLES = Literal[
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class RoleBasedPermissions(TypedDict):
|
class RoleBasedPermissions(LiteLLMPydanticObjectBase):
|
||||||
role: Required[RBAC_ROLES]
|
role: RBAC_ROLES
|
||||||
models: Required[List[str]]
|
models: Optional[List[str]] = None
|
||||||
|
routes: Optional[List[str]] = None
|
||||||
|
|
||||||
|
model_config = {
|
||||||
|
"extra": "forbid",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class RoleMapping(BaseModel):
|
||||||
|
role: str
|
||||||
|
internal_role: RBAC_ROLES
|
||||||
|
|
||||||
|
|
||||||
|
class LiteLLM_JWTAuth(LiteLLMPydanticObjectBase):
|
||||||
|
"""
|
||||||
|
A class to define the roles and permissions for a LiteLLM Proxy w/ JWT Auth.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
- admin_jwt_scope: The JWT scope required for proxy admin roles.
|
||||||
|
- admin_allowed_routes: list of allowed routes for proxy admin roles.
|
||||||
|
- team_jwt_scope: The JWT scope required for proxy team roles.
|
||||||
|
- team_id_jwt_field: The field in the JWT token that stores the team ID. Default - `client_id`.
|
||||||
|
- team_allowed_routes: list of allowed routes for proxy team roles.
|
||||||
|
- user_id_jwt_field: The field in the JWT token that stores the user id (maps to `LiteLLMUserTable`). Use this for internal employees.
|
||||||
|
- user_email_jwt_field: The field in the JWT token that stores the user email (maps to `LiteLLMUserTable`). Use this for internal employees.
|
||||||
|
- user_allowed_email_subdomain: If specified, only emails from specified subdomain will be allowed to access proxy.
|
||||||
|
- end_user_id_jwt_field: The field in the JWT token that stores the end-user ID (maps to `LiteLLMEndUserTable`). Turn this off by setting to `None`. Enables end-user cost tracking. Use this for external customers.
|
||||||
|
- public_key_ttl: Default - 600s. TTL for caching public JWT keys.
|
||||||
|
- public_allowed_routes: list of allowed routes for authenticated but unknown litellm role jwt tokens.
|
||||||
|
- enforce_rbac: If true, enforce RBAC for all routes.
|
||||||
|
|
||||||
|
See `auth_checks.py` for the specific routes
|
||||||
|
"""
|
||||||
|
|
||||||
|
admin_jwt_scope: str = "litellm_proxy_admin"
|
||||||
|
admin_allowed_routes: List[str] = [
|
||||||
|
"management_routes",
|
||||||
|
"spend_tracking_routes",
|
||||||
|
"global_spend_tracking_routes",
|
||||||
|
"info_routes",
|
||||||
|
]
|
||||||
|
team_id_jwt_field: Optional[str] = None
|
||||||
|
team_ids_jwt_field: Optional[str] = None
|
||||||
|
upsert_sso_user_to_team: bool = False
|
||||||
|
team_allowed_routes: List[
|
||||||
|
Literal["openai_routes", "info_routes", "management_routes"]
|
||||||
|
] = ["openai_routes", "info_routes"]
|
||||||
|
team_id_default: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="If no team_id given, default permissions/spend-tracking to this team.s",
|
||||||
|
)
|
||||||
|
|
||||||
|
org_id_jwt_field: Optional[str] = None
|
||||||
|
user_id_jwt_field: Optional[str] = None
|
||||||
|
user_email_jwt_field: Optional[str] = None
|
||||||
|
user_allowed_email_domain: Optional[str] = None
|
||||||
|
user_roles_jwt_field: Optional[str] = None
|
||||||
|
user_allowed_roles: Optional[List[str]] = None
|
||||||
|
user_id_upsert: bool = Field(
|
||||||
|
default=False, description="If user doesn't exist, upsert them into the db."
|
||||||
|
)
|
||||||
|
end_user_id_jwt_field: Optional[str] = None
|
||||||
|
public_key_ttl: float = 600
|
||||||
|
public_allowed_routes: List[str] = ["public_routes"]
|
||||||
|
enforce_rbac: bool = False
|
||||||
|
roles_jwt_field: Optional[str] = None # v2 on role mappings
|
||||||
|
role_mappings: Optional[List[RoleMapping]] = None
|
||||||
|
object_id_jwt_field: Optional[str] = (
|
||||||
|
None # can be either user / team, inferred from the role mapping
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self, **kwargs: Any) -> None:
|
||||||
|
# get the attribute names for this Pydantic model
|
||||||
|
allowed_keys = self.__annotations__.keys()
|
||||||
|
|
||||||
|
invalid_keys = set(kwargs.keys()) - allowed_keys
|
||||||
|
user_roles_jwt_field = kwargs.get("user_roles_jwt_field")
|
||||||
|
user_allowed_roles = kwargs.get("user_allowed_roles")
|
||||||
|
object_id_jwt_field = kwargs.get("object_id_jwt_field")
|
||||||
|
role_mappings = kwargs.get("role_mappings")
|
||||||
|
|
||||||
|
if invalid_keys:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid arguments provided: {', '.join(invalid_keys)}. Allowed arguments are: {', '.join(allowed_keys)}."
|
||||||
|
)
|
||||||
|
if (user_roles_jwt_field is not None and user_allowed_roles is None) or (
|
||||||
|
user_roles_jwt_field is None and user_allowed_roles is not None
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"user_allowed_roles must be provided if user_roles_jwt_field is set."
|
||||||
|
)
|
||||||
|
|
||||||
|
if object_id_jwt_field is not None and role_mappings is None:
|
||||||
|
raise ValueError(
|
||||||
|
"if object_id_jwt_field is set, role_mappings must also be set. Needed to infer if the caller is a user or team."
|
||||||
|
)
|
||||||
|
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
|
@ -200,6 +200,7 @@ def _allowed_routes_check(user_route: str, allowed_routes: list) -> bool:
|
||||||
- user_route: str - the route the user is trying to call
|
- user_route: str - the route the user is trying to call
|
||||||
- allowed_routes: List[str|LiteLLMRoutes] - the list of allowed routes for the user.
|
- allowed_routes: List[str|LiteLLMRoutes] - the list of allowed routes for the user.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
for allowed_route in allowed_routes:
|
for allowed_route in allowed_routes:
|
||||||
if (
|
if (
|
||||||
allowed_route in LiteLLMRoutes.__members__
|
allowed_route in LiteLLMRoutes.__members__
|
||||||
|
@ -402,6 +403,29 @@ def _update_last_db_access_time(
|
||||||
last_db_access_time[key] = (value, time.time())
|
last_db_access_time[key] = (value, time.time())
|
||||||
|
|
||||||
|
|
||||||
|
def _get_role_based_permissions(
|
||||||
|
rbac_role: RBAC_ROLES,
|
||||||
|
general_settings: dict,
|
||||||
|
key: Literal["models", "routes"],
|
||||||
|
) -> Optional[List[str]]:
|
||||||
|
"""
|
||||||
|
Get the role based permissions from the general settings.
|
||||||
|
"""
|
||||||
|
role_based_permissions = cast(
|
||||||
|
Optional[List[RoleBasedPermissions]],
|
||||||
|
general_settings.get("role_permissions", []),
|
||||||
|
)
|
||||||
|
if role_based_permissions is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
for role_based_permission in role_based_permissions:
|
||||||
|
|
||||||
|
if role_based_permission.role == rbac_role:
|
||||||
|
return getattr(role_based_permission, key)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def get_role_based_models(
|
def get_role_based_models(
|
||||||
rbac_role: RBAC_ROLES,
|
rbac_role: RBAC_ROLES,
|
||||||
general_settings: dict,
|
general_settings: dict,
|
||||||
|
@ -412,18 +436,26 @@ def get_role_based_models(
|
||||||
Used by JWT Auth.
|
Used by JWT Auth.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
role_based_permissions = cast(
|
return _get_role_based_permissions(
|
||||||
Optional[List[RoleBasedPermissions]],
|
rbac_role=rbac_role,
|
||||||
general_settings.get("role_permissions", []),
|
general_settings=general_settings,
|
||||||
|
key="models",
|
||||||
)
|
)
|
||||||
if role_based_permissions is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
for role_based_permission in role_based_permissions:
|
|
||||||
if role_based_permission["role"] == rbac_role:
|
|
||||||
return role_based_permission["models"]
|
|
||||||
|
|
||||||
return None
|
def get_role_based_routes(
|
||||||
|
rbac_role: RBAC_ROLES,
|
||||||
|
general_settings: dict,
|
||||||
|
) -> Optional[List[str]]:
|
||||||
|
"""
|
||||||
|
Get the routes allowed for a user role.
|
||||||
|
"""
|
||||||
|
|
||||||
|
return _get_role_based_permissions(
|
||||||
|
rbac_role=rbac_role,
|
||||||
|
general_settings=general_settings,
|
||||||
|
key="routes",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def _get_fuzzy_user_object(
|
async def _get_fuzzy_user_object(
|
||||||
|
|
|
@ -35,11 +35,13 @@ from litellm.proxy._types import (
|
||||||
from litellm.proxy.utils import PrismaClient, ProxyLogging
|
from litellm.proxy.utils import PrismaClient, ProxyLogging
|
||||||
|
|
||||||
from .auth_checks import (
|
from .auth_checks import (
|
||||||
|
_allowed_routes_check,
|
||||||
allowed_routes_check,
|
allowed_routes_check,
|
||||||
get_actual_routes,
|
get_actual_routes,
|
||||||
get_end_user_object,
|
get_end_user_object,
|
||||||
get_org_object,
|
get_org_object,
|
||||||
get_role_based_models,
|
get_role_based_models,
|
||||||
|
get_role_based_routes,
|
||||||
get_team_object,
|
get_team_object,
|
||||||
get_user_object,
|
get_user_object,
|
||||||
)
|
)
|
||||||
|
@ -78,6 +80,37 @@ class JWTHandler:
|
||||||
parts = token.split(".")
|
parts = token.split(".")
|
||||||
return len(parts) == 3
|
return len(parts) == 3
|
||||||
|
|
||||||
|
def _rbac_role_from_role_mapping(self, token: dict) -> Optional[RBAC_ROLES]:
|
||||||
|
"""
|
||||||
|
Returns the RBAC role the token 'belongs' to based on role mappings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token (dict): The JWT token containing role information
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[RBAC_ROLES]: The mapped internal RBAC role if a mapping exists,
|
||||||
|
None otherwise
|
||||||
|
|
||||||
|
Note:
|
||||||
|
The function handles both single string roles and lists of roles from the JWT.
|
||||||
|
If multiple mappings match the JWT roles, the first matching mapping is returned.
|
||||||
|
"""
|
||||||
|
if self.litellm_jwtauth.role_mappings is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
jwt_role = self.get_jwt_role(token=token, default_value=None)
|
||||||
|
if not jwt_role:
|
||||||
|
return None
|
||||||
|
|
||||||
|
jwt_role_set = set(jwt_role)
|
||||||
|
|
||||||
|
for role_mapping in self.litellm_jwtauth.role_mappings:
|
||||||
|
# Check if the mapping role matches any of the JWT roles
|
||||||
|
if role_mapping.role in jwt_role_set:
|
||||||
|
return role_mapping.internal_role
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
def get_rbac_role(self, token: dict) -> Optional[RBAC_ROLES]:
|
def get_rbac_role(self, token: dict) -> Optional[RBAC_ROLES]:
|
||||||
"""
|
"""
|
||||||
Returns the RBAC role the token 'belongs' to.
|
Returns the RBAC role the token 'belongs' to.
|
||||||
|
@ -109,6 +142,8 @@ class JWTHandler:
|
||||||
user_roles=user_roles
|
user_roles=user_roles
|
||||||
):
|
):
|
||||||
return LitellmUserRoles.INTERNAL_USER
|
return LitellmUserRoles.INTERNAL_USER
|
||||||
|
elif rbac_role := self._rbac_role_from_role_mapping(token=token):
|
||||||
|
return rbac_role
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@ -212,6 +247,29 @@ class JWTHandler:
|
||||||
user_roles = default_value
|
user_roles = default_value
|
||||||
return user_roles
|
return user_roles
|
||||||
|
|
||||||
|
def get_jwt_role(
|
||||||
|
self, token: dict, default_value: Optional[List[str]]
|
||||||
|
) -> Optional[List[str]]:
|
||||||
|
"""
|
||||||
|
Generic implementation of `get_user_roles` that can be used for both user and team roles.
|
||||||
|
|
||||||
|
Returns the jwt role from the token.
|
||||||
|
|
||||||
|
Set via 'roles_jwt_field' in the config.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if self.litellm_jwtauth.roles_jwt_field is not None:
|
||||||
|
user_roles = get_nested_value(
|
||||||
|
data=token,
|
||||||
|
key_path=self.litellm_jwtauth.roles_jwt_field,
|
||||||
|
default=default_value,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
user_roles = default_value
|
||||||
|
except KeyError:
|
||||||
|
user_roles = default_value
|
||||||
|
return user_roles
|
||||||
|
|
||||||
def is_allowed_user_role(self, user_roles: Optional[List[str]]) -> bool:
|
def is_allowed_user_role(self, user_roles: Optional[List[str]]) -> bool:
|
||||||
"""
|
"""
|
||||||
Returns the user role from the token.
|
Returns the user role from the token.
|
||||||
|
@ -240,6 +298,16 @@ class JWTHandler:
|
||||||
user_email = default_value
|
user_email = default_value
|
||||||
return user_email
|
return user_email
|
||||||
|
|
||||||
|
def get_object_id(self, token: dict, default_value: Optional[str]) -> Optional[str]:
|
||||||
|
try:
|
||||||
|
if self.litellm_jwtauth.object_id_jwt_field is not None:
|
||||||
|
object_id = token[self.litellm_jwtauth.object_id_jwt_field]
|
||||||
|
else:
|
||||||
|
object_id = default_value
|
||||||
|
except KeyError:
|
||||||
|
object_id = default_value
|
||||||
|
return object_id
|
||||||
|
|
||||||
def get_org_id(self, token: dict, default_value: Optional[str]) -> Optional[str]:
|
def get_org_id(self, token: dict, default_value: Optional[str]) -> Optional[str]:
|
||||||
try:
|
try:
|
||||||
if self.litellm_jwtauth.org_id_jwt_field is not None:
|
if self.litellm_jwtauth.org_id_jwt_field is not None:
|
||||||
|
@ -423,6 +491,35 @@ class JWTHandler:
|
||||||
class JWTAuthManager:
|
class JWTAuthManager:
|
||||||
"""Manages JWT authentication and authorization operations"""
|
"""Manages JWT authentication and authorization operations"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def can_rbac_role_call_route(
|
||||||
|
rbac_role: RBAC_ROLES,
|
||||||
|
general_settings: dict,
|
||||||
|
route: str,
|
||||||
|
) -> Literal[True]:
|
||||||
|
"""
|
||||||
|
Checks if user is allowed to access the route, based on their role.
|
||||||
|
"""
|
||||||
|
role_based_routes = get_role_based_routes(
|
||||||
|
rbac_role=rbac_role, general_settings=general_settings
|
||||||
|
)
|
||||||
|
|
||||||
|
if role_based_routes is None or route is None:
|
||||||
|
return True
|
||||||
|
|
||||||
|
is_allowed = _allowed_routes_check(
|
||||||
|
user_route=route,
|
||||||
|
allowed_routes=role_based_routes,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not is_allowed:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=403,
|
||||||
|
detail=f"Role={rbac_role} not allowed to call route={route}. Allowed routes={role_based_routes}",
|
||||||
|
)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def can_rbac_role_call_model(
|
def can_rbac_role_call_model(
|
||||||
rbac_role: RBAC_ROLES,
|
rbac_role: RBAC_ROLES,
|
||||||
|
@ -441,7 +538,7 @@ class JWTAuthManager:
|
||||||
if model not in role_based_models:
|
if model not in role_based_models:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=403,
|
status_code=403,
|
||||||
detail=f"User role={rbac_role} not allowed to call model={model}. Allowed models={role_based_models}",
|
detail=f"Role={rbac_role} not allowed to call model={model}. Allowed models={role_based_models}",
|
||||||
)
|
)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
@ -452,10 +549,11 @@ class JWTAuthManager:
|
||||||
jwt_valid_token: dict,
|
jwt_valid_token: dict,
|
||||||
general_settings: dict,
|
general_settings: dict,
|
||||||
request_data: dict,
|
request_data: dict,
|
||||||
|
route: str,
|
||||||
|
rbac_role: Optional[RBAC_ROLES],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Validate RBAC role and model access permissions"""
|
"""Validate RBAC role and model access permissions"""
|
||||||
if jwt_handler.litellm_jwtauth.enforce_rbac is True:
|
if jwt_handler.litellm_jwtauth.enforce_rbac is True:
|
||||||
rbac_role = jwt_handler.get_rbac_role(token=jwt_valid_token)
|
|
||||||
if rbac_role is None:
|
if rbac_role is None:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=403,
|
status_code=403,
|
||||||
|
@ -466,6 +564,11 @@ class JWTAuthManager:
|
||||||
general_settings=general_settings,
|
general_settings=general_settings,
|
||||||
model=request_data.get("model"),
|
model=request_data.get("model"),
|
||||||
)
|
)
|
||||||
|
JWTAuthManager.can_rbac_role_call_route(
|
||||||
|
rbac_role=rbac_role,
|
||||||
|
general_settings=general_settings,
|
||||||
|
route=route,
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def check_admin_access(
|
async def check_admin_access(
|
||||||
|
@ -685,6 +788,21 @@ class JWTAuthManager:
|
||||||
|
|
||||||
return user_object, org_object, end_user_object
|
return user_object, org_object, end_user_object
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def validate_object_id(
|
||||||
|
user_id: Optional[str],
|
||||||
|
team_id: Optional[str],
|
||||||
|
enforce_rbac: bool,
|
||||||
|
is_proxy_admin: bool,
|
||||||
|
) -> Literal[True]:
|
||||||
|
"""If enforce_rbac is true, validate that a valid rbac id is returned for spend tracking"""
|
||||||
|
if enforce_rbac and not is_proxy_admin and not user_id and not team_id:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=403,
|
||||||
|
detail="No user or team id found in token. enforce_rbac is set to True. Token must belong to a proxy admin, team, or user.",
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def auth_builder(
|
async def auth_builder(
|
||||||
api_key: str,
|
api_key: str,
|
||||||
|
@ -701,10 +819,18 @@ class JWTAuthManager:
|
||||||
jwt_valid_token: dict = await jwt_handler.auth_jwt(token=api_key)
|
jwt_valid_token: dict = await jwt_handler.auth_jwt(token=api_key)
|
||||||
|
|
||||||
# Check RBAC
|
# Check RBAC
|
||||||
|
rbac_role = jwt_handler.get_rbac_role(token=jwt_valid_token)
|
||||||
await JWTAuthManager.check_rbac_role(
|
await JWTAuthManager.check_rbac_role(
|
||||||
jwt_handler, jwt_valid_token, general_settings, request_data
|
jwt_handler,
|
||||||
|
jwt_valid_token,
|
||||||
|
general_settings,
|
||||||
|
request_data,
|
||||||
|
route,
|
||||||
|
rbac_role,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
object_id = jwt_handler.get_object_id(token=jwt_valid_token, default_value=None)
|
||||||
|
|
||||||
# Get basic user info
|
# Get basic user info
|
||||||
scopes = jwt_handler.get_scopes(token=jwt_valid_token)
|
scopes = jwt_handler.get_scopes(token=jwt_valid_token)
|
||||||
user_id, user_email, valid_user_email = await JWTAuthManager.get_user_info(
|
user_id, user_email, valid_user_email = await JWTAuthManager.get_user_info(
|
||||||
|
@ -716,6 +842,16 @@ class JWTAuthManager:
|
||||||
end_user_id = jwt_handler.get_end_user_id(
|
end_user_id = jwt_handler.get_end_user_id(
|
||||||
token=jwt_valid_token, default_value=None
|
token=jwt_valid_token, default_value=None
|
||||||
)
|
)
|
||||||
|
team_id: Optional[str] = None
|
||||||
|
team_object: Optional[LiteLLM_TeamTable] = None
|
||||||
|
object_id = jwt_handler.get_object_id(token=jwt_valid_token, default_value=None)
|
||||||
|
|
||||||
|
if rbac_role and object_id:
|
||||||
|
|
||||||
|
if rbac_role == LitellmUserRoles.TEAM:
|
||||||
|
team_id = object_id
|
||||||
|
elif rbac_role == LitellmUserRoles.INTERNAL_USER:
|
||||||
|
user_id = object_id
|
||||||
|
|
||||||
# Check admin access
|
# Check admin access
|
||||||
admin_result = await JWTAuthManager.check_admin_access(
|
admin_result = await JWTAuthManager.check_admin_access(
|
||||||
|
@ -726,15 +862,20 @@ class JWTAuthManager:
|
||||||
|
|
||||||
# Get team with model access
|
# Get team with model access
|
||||||
## SPECIFIC TEAM ID
|
## SPECIFIC TEAM ID
|
||||||
team_id, team_object = await JWTAuthManager.find_and_validate_specific_team_id(
|
|
||||||
jwt_handler,
|
if not team_id:
|
||||||
jwt_valid_token,
|
team_id, team_object = (
|
||||||
prisma_client,
|
await JWTAuthManager.find_and_validate_specific_team_id(
|
||||||
user_api_key_cache,
|
jwt_handler,
|
||||||
parent_otel_span,
|
jwt_valid_token,
|
||||||
proxy_logging_obj,
|
prisma_client,
|
||||||
)
|
user_api_key_cache,
|
||||||
if not team_object:
|
parent_otel_span,
|
||||||
|
proxy_logging_obj,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if not team_object and not team_id:
|
||||||
## CHECK USER GROUP ACCESS
|
## CHECK USER GROUP ACCESS
|
||||||
all_team_ids = JWTAuthManager.get_all_team_ids(jwt_handler, jwt_valid_token)
|
all_team_ids = JWTAuthManager.get_all_team_ids(jwt_handler, jwt_valid_token)
|
||||||
team_id, team_object = await JWTAuthManager.find_team_with_model_access(
|
team_id, team_object = await JWTAuthManager.find_team_with_model_access(
|
||||||
|
@ -762,6 +903,14 @@ class JWTAuthManager:
|
||||||
proxy_logging_obj=proxy_logging_obj,
|
proxy_logging_obj=proxy_logging_obj,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Validate that a valid rbac id is returned for spend tracking
|
||||||
|
JWTAuthManager.validate_object_id(
|
||||||
|
user_id=user_id,
|
||||||
|
team_id=team_id,
|
||||||
|
enforce_rbac=general_settings.get("enforce_rbac", False),
|
||||||
|
is_proxy_admin=False,
|
||||||
|
)
|
||||||
|
|
||||||
return JWTAuthBuilderResult(
|
return JWTAuthBuilderResult(
|
||||||
is_proxy_admin=False,
|
is_proxy_admin=False,
|
||||||
team_id=team_id,
|
team_id=team_id,
|
||||||
|
|
|
@ -2093,6 +2093,14 @@ class ProxyConfig:
|
||||||
health_check_interval = general_settings.get("health_check_interval", 300)
|
health_check_interval = general_settings.get("health_check_interval", 300)
|
||||||
health_check_details = general_settings.get("health_check_details", True)
|
health_check_details = general_settings.get("health_check_details", True)
|
||||||
|
|
||||||
|
### RBAC ###
|
||||||
|
rbac_role_permissions = general_settings.get("role_permissions", None)
|
||||||
|
if rbac_role_permissions is not None:
|
||||||
|
general_settings["role_permissions"] = [ # validate role permissions
|
||||||
|
RoleBasedPermissions(**role_permission)
|
||||||
|
for role_permission in rbac_role_permissions
|
||||||
|
]
|
||||||
|
|
||||||
## check if user has set a premium feature in general_settings
|
## check if user has set a premium feature in general_settings
|
||||||
if (
|
if (
|
||||||
general_settings.get("enforced_params") is not None
|
general_settings.get("enforced_params") is not None
|
||||||
|
|
|
@ -468,7 +468,7 @@ class BaseLLMChatTest(ABC):
|
||||||
"https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg",
|
"https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@pytest.mark.flaky(retries=4, delay=1)
|
@pytest.mark.flaky(retries=4, delay=2)
|
||||||
def test_image_url(self, detail, image_url):
|
def test_image_url(self, detail, image_url):
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
from litellm.utils import supports_vision
|
from litellm.utils import supports_vision
|
||||||
|
@ -515,9 +515,13 @@ class BaseLLMChatTest(ABC):
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
response = self.completion_function(
|
try:
|
||||||
**base_completion_call_args, messages=messages
|
response = self.completion_function(
|
||||||
)
|
**base_completion_call_args, messages=messages
|
||||||
|
)
|
||||||
|
except litellm.InternalServerError:
|
||||||
|
pytest.skip("Model is overloaded")
|
||||||
|
|
||||||
assert response is not None
|
assert response is not None
|
||||||
|
|
||||||
@pytest.mark.flaky(retries=4, delay=1)
|
@pytest.mark.flaky(retries=4, delay=1)
|
||||||
|
|
|
@ -21,7 +21,7 @@ from datetime import datetime, timedelta
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from fastapi import Request
|
from fastapi import Request, HTTPException
|
||||||
from fastapi.routing import APIRoute
|
from fastapi.routing import APIRoute
|
||||||
from fastapi.responses import Response
|
from fastapi.responses import Response
|
||||||
import litellm
|
import litellm
|
||||||
|
@ -1164,3 +1164,22 @@ async def test_end_user_jwt_auth(monkeypatch):
|
||||||
mock_client.call_args.kwargs[
|
mock_client.call_args.kwargs[
|
||||||
"end_user_id"
|
"end_user_id"
|
||||||
] == "81b3e52a-67a6-4efb-9645-70527e101479"
|
] == "81b3e52a-67a6-4efb-9645-70527e101479"
|
||||||
|
|
||||||
|
|
||||||
|
def test_can_rbac_role_call_route():
|
||||||
|
from litellm.proxy.auth.handle_jwt import JWTAuthManager
|
||||||
|
from litellm.proxy._types import RoleBasedPermissions
|
||||||
|
from litellm.proxy._types import LitellmUserRoles
|
||||||
|
|
||||||
|
with pytest.raises(HTTPException):
|
||||||
|
JWTAuthManager.can_rbac_role_call_route(
|
||||||
|
rbac_role=LitellmUserRoles.TEAM,
|
||||||
|
general_settings={
|
||||||
|
"role_permissions": [
|
||||||
|
RoleBasedPermissions(
|
||||||
|
role=LitellmUserRoles.TEAM, routes=["/v1/chat/completions"]
|
||||||
|
)
|
||||||
|
]
|
||||||
|
},
|
||||||
|
route="/v1/embeddings",
|
||||||
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue