feat: fine grained access control policy (#2264)

This allows a set of rules to be defined for determining access to
resources. The rules are (loosely) based on the cedar policy format.

A rule defines a list of action either to permit or to forbid. It may
specify a principal or a resource that must match for the rule to take
effect. It may also specify a condition, either a 'when' or an 'unless',
with additional constraints as to where the rule applies.

A list of rules is held for each type to be protected and tried in order
to find a match. If a match is found, the request is permitted or
forbidden depening on the type of rule. If no match is found, the
request is denied. If no rules are specified for a given type, a rule
that allows any action as long as the resource attributes match the user
attributes is added (i.e. the previous behaviour is the default.

Some examples in yaml:

```
    model:
    - permit:
      principal: user-1
      actions: [create, read, delete]
      comment: user-1 has full access to all models
    - permit:
      principal: user-2
      actions: [read]
      resource: model-1
      comment: user-2 has read access to model-1 only
    - permit:
      actions: [read]
      when:
        user_in: resource.namespaces
      comment: any user has read access to models with matching attributes
    vector_db:
    - forbid:
      actions: [create, read, delete]
      unless:
        user_in: role::admin
      comment: only user with admin role can use vector_db resources
```

---------

Signed-off-by: Gordon Sim <gsim@redhat.com>
This commit is contained in:
grs 2025-06-03 17:51:12 -04:00 committed by GitHub
parent 8bee2954be
commit 7c1998db25
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
32 changed files with 956 additions and 450 deletions

View file

@ -16,43 +16,18 @@ from jose import jwt
from pydantic import BaseModel, Field, field_validator, model_validator
from typing_extensions import Self
from llama_stack.distribution.datatypes import AccessAttributes, AuthenticationConfig, AuthProviderType
from llama_stack.distribution.datatypes import AuthenticationConfig, AuthProviderType, User
from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="auth")
class TokenValidationResult(BaseModel):
principal: str | None = Field(
default=None,
description="The principal (username or persistent identifier) of the authenticated user",
)
access_attributes: AccessAttributes | None = Field(
default=None,
description="""
Structured user attributes for attribute-based access control.
These attributes determine which resources the user can access.
The model provides standard categories like "roles", "teams", "projects", and "namespaces".
Each attribute category contains a list of values that the user has for that category.
During access control checks, these values are compared against resource requirements.
Example with standard categories:
```json
{
"roles": ["admin", "data-scientist"],
"teams": ["ml-team"],
"projects": ["llama-3"],
"namespaces": ["research"]
}
```
""",
)
class AuthResponse(TokenValidationResult):
class AuthResponse(BaseModel):
"""The format of the authentication response from the auth endpoint."""
principal: str
# further attributes that may be used for access control decisions
attributes: dict[str, list[str]] | None = None
message: str | None = Field(
default=None, description="Optional message providing additional context about the authentication result."
)
@ -78,7 +53,7 @@ class AuthProvider(ABC):
"""Abstract base class for authentication providers."""
@abstractmethod
async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
async def validate_token(self, token: str, scope: dict | None = None) -> User:
"""Validate a token and return access attributes."""
pass
@ -88,10 +63,10 @@ class AuthProvider(ABC):
pass
def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str]) -> AccessAttributes:
attributes = AccessAttributes()
def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str]) -> dict[str, list[str]]:
attributes: dict[str, list[str]] = {}
for claim_key, attribute_key in mapping.items():
if claim_key not in claims or not hasattr(attributes, attribute_key):
if claim_key not in claims:
continue
claim = claims[claim_key]
if isinstance(claim, list):
@ -99,11 +74,10 @@ def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str])
else:
values = claim.split()
current = getattr(attributes, attribute_key)
if current:
current.extend(values)
if attribute_key in attributes:
attributes[attribute_key].extend(values)
else:
setattr(attributes, attribute_key, values)
attributes[attribute_key] = values
return attributes
@ -145,8 +119,6 @@ class OAuth2TokenAuthProviderConfig(BaseModel):
for key, value in v.items():
if not value:
raise ValueError(f"claims_mapping value cannot be empty: {key}")
if value not in AccessAttributes.model_fields:
raise ValueError(f"claims_mapping value is not a valid attribute: {value}")
return v
@model_validator(mode="after")
@ -171,14 +143,14 @@ class OAuth2TokenAuthProvider(AuthProvider):
self._jwks: dict[str, str] = {}
self._jwks_lock = Lock()
async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
async def validate_token(self, token: str, scope: dict | None = None) -> User:
if self.config.jwks:
return await self.validate_jwt_token(token, scope)
if self.config.introspection:
return await self.introspect_token(token, scope)
raise ValueError("One of jwks or introspection must be configured")
async def validate_jwt_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
async def validate_jwt_token(self, token: str, scope: dict | None = None) -> User:
"""Validate a token using the JWT token."""
await self._refresh_jwks()
@ -203,12 +175,12 @@ class OAuth2TokenAuthProvider(AuthProvider):
# We should incorporate these into the access attributes.
principal = claims["sub"]
access_attributes = get_attributes_from_claims(claims, self.config.claims_mapping)
return TokenValidationResult(
return User(
principal=principal,
access_attributes=access_attributes,
attributes=access_attributes,
)
async def introspect_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
async def introspect_token(self, token: str, scope: dict | None = None) -> User:
"""Validate a token using token introspection as defined by RFC 7662."""
form = {
"token": token,
@ -242,9 +214,9 @@ class OAuth2TokenAuthProvider(AuthProvider):
raise ValueError("Token not active")
principal = fields["sub"] or fields["username"]
access_attributes = get_attributes_from_claims(fields, self.config.claims_mapping)
return TokenValidationResult(
return User(
principal=principal,
access_attributes=access_attributes,
attributes=access_attributes,
)
except httpx.TimeoutException:
logger.exception("Token introspection request timed out")
@ -299,7 +271,7 @@ class CustomAuthProvider(AuthProvider):
self.config = config
self._client = None
async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
async def validate_token(self, token: str, scope: dict | None = None) -> User:
"""Validate a token using the custom authentication endpoint."""
if scope is None:
scope = {}
@ -333,6 +305,7 @@ class CustomAuthProvider(AuthProvider):
json=auth_request.model_dump(),
timeout=10.0, # Add a reasonable timeout
)
print("MADE CALL")
if response.status_code != 200:
logger.warning(f"Authentication failed with status code: {response.status_code}")
raise ValueError(f"Authentication failed: {response.status_code}")
@ -341,7 +314,7 @@ class CustomAuthProvider(AuthProvider):
try:
response_data = response.json()
auth_response = AuthResponse(**response_data)
return auth_response
return User(auth_response.principal, auth_response.attributes)
except Exception as e:
logger.exception("Error parsing authentication response")
raise ValueError("Invalid authentication response format") from e