llama-stack-mirror/llama_stack/distribution/server/auth_providers.py
grs 7c1998db25
feat: fine grained access control policy (#2264)
This allows a set of rules to be defined for determining access to
resources. The rules are (loosely) based on the cedar policy format.

A rule defines a list of action either to permit or to forbid. It may
specify a principal or a resource that must match for the rule to take
effect. It may also specify a condition, either a 'when' or an 'unless',
with additional constraints as to where the rule applies.

A list of rules is held for each type to be protected and tried in order
to find a match. If a match is found, the request is permitted or
forbidden depening on the type of rule. If no match is found, the
request is denied. If no rules are specified for a given type, a rule
that allows any action as long as the resource attributes match the user
attributes is added (i.e. the previous behaviour is the default.

Some examples in yaml:

```
    model:
    - permit:
      principal: user-1
      actions: [create, read, delete]
      comment: user-1 has full access to all models
    - permit:
      principal: user-2
      actions: [read]
      resource: model-1
      comment: user-2 has read access to model-1 only
    - permit:
      actions: [read]
      when:
        user_in: resource.namespaces
      comment: any user has read access to models with matching attributes
    vector_db:
    - forbid:
      actions: [create, read, delete]
      unless:
        user_in: role::admin
      comment: only user with admin role can use vector_db resources
```

---------

Signed-off-by: Gordon Sim <gsim@redhat.com>
2025-06-03 14:51:12 -07:00

349 lines
13 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import ssl
import time
from abc import ABC, abstractmethod
from asyncio import Lock
from pathlib import Path
from urllib.parse import parse_qs
import httpx
from jose import jwt
from pydantic import BaseModel, Field, field_validator, model_validator
from typing_extensions import Self
from llama_stack.distribution.datatypes import AuthenticationConfig, AuthProviderType, User
from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="auth")
class AuthResponse(BaseModel):
"""The format of the authentication response from the auth endpoint."""
principal: str
# further attributes that may be used for access control decisions
attributes: dict[str, list[str]] | None = None
message: str | None = Field(
default=None, description="Optional message providing additional context about the authentication result."
)
class AuthRequestContext(BaseModel):
path: str = Field(description="The path of the request being authenticated")
headers: dict[str, str] = Field(description="HTTP headers from the original request (excluding Authorization)")
params: dict[str, list[str]] = Field(
description="Query parameters from the original request, parsed as dictionary of lists"
)
class AuthRequest(BaseModel):
api_key: str = Field(description="The API key extracted from the Authorization header")
request: AuthRequestContext = Field(description="Context information about the request being authenticated")
class AuthProvider(ABC):
"""Abstract base class for authentication providers."""
@abstractmethod
async def validate_token(self, token: str, scope: dict | None = None) -> User:
"""Validate a token and return access attributes."""
pass
@abstractmethod
async def close(self):
"""Clean up any resources."""
pass
def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str]) -> dict[str, list[str]]:
attributes: dict[str, list[str]] = {}
for claim_key, attribute_key in mapping.items():
if claim_key not in claims:
continue
claim = claims[claim_key]
if isinstance(claim, list):
values = claim
else:
values = claim.split()
if attribute_key in attributes:
attributes[attribute_key].extend(values)
else:
attributes[attribute_key] = values
return attributes
class OAuth2JWKSConfig(BaseModel):
# The JWKS URI for collecting public keys
uri: str
key_recheck_period: int = Field(default=3600, description="The period to recheck the JWKS URI for key updates")
class OAuth2IntrospectionConfig(BaseModel):
url: str
client_id: str
client_secret: str
send_secret_in_body: bool = False
class OAuth2TokenAuthProviderConfig(BaseModel):
audience: str = "llama-stack"
verify_tls: bool = True
tls_cafile: Path | None = None
issuer: str | None = Field(default=None, description="The OIDC issuer URL.")
claims_mapping: dict[str, str] = Field(
default_factory=lambda: {
"sub": "roles",
"username": "roles",
"groups": "teams",
"team": "teams",
"project": "projects",
"tenant": "namespaces",
"namespace": "namespaces",
},
)
jwks: OAuth2JWKSConfig | None
introspection: OAuth2IntrospectionConfig | None = None
@classmethod
@field_validator("claims_mapping")
def validate_claims_mapping(cls, v):
for key, value in v.items():
if not value:
raise ValueError(f"claims_mapping value cannot be empty: {key}")
return v
@model_validator(mode="after")
def validate_mode(self) -> Self:
if not self.jwks and not self.introspection:
raise ValueError("One of jwks or introspection must be configured")
if self.jwks and self.introspection:
raise ValueError("At present only one of jwks or introspection should be configured")
return self
class OAuth2TokenAuthProvider(AuthProvider):
"""
JWT token authentication provider that validates a JWT token and extracts access attributes.
This should be the standard authentication provider for most use cases.
"""
def __init__(self, config: OAuth2TokenAuthProviderConfig):
self.config = config
self._jwks_at: float = 0.0
self._jwks: dict[str, str] = {}
self._jwks_lock = Lock()
async def validate_token(self, token: str, scope: dict | None = None) -> User:
if self.config.jwks:
return await self.validate_jwt_token(token, scope)
if self.config.introspection:
return await self.introspect_token(token, scope)
raise ValueError("One of jwks or introspection must be configured")
async def validate_jwt_token(self, token: str, scope: dict | None = None) -> User:
"""Validate a token using the JWT token."""
await self._refresh_jwks()
try:
header = jwt.get_unverified_header(token)
kid = header["kid"]
if kid not in self._jwks:
raise ValueError(f"Unknown key ID: {kid}")
key_data = self._jwks[kid]
algorithm = header.get("alg", "RS256")
claims = jwt.decode(
token,
key_data,
algorithms=[algorithm],
audience=self.config.audience,
issuer=self.config.issuer,
)
except Exception as exc:
raise ValueError(f"Invalid JWT token: {token}") from exc
# There are other standard claims, the most relevant of which is `scope`.
# We should incorporate these into the access attributes.
principal = claims["sub"]
access_attributes = get_attributes_from_claims(claims, self.config.claims_mapping)
return User(
principal=principal,
attributes=access_attributes,
)
async def introspect_token(self, token: str, scope: dict | None = None) -> User:
"""Validate a token using token introspection as defined by RFC 7662."""
form = {
"token": token,
}
if self.config.introspection is None:
raise ValueError("Introspection is not configured")
if self.config.introspection.send_secret_in_body:
form["client_id"] = self.config.introspection.client_id
form["client_secret"] = self.config.introspection.client_secret
auth = None
else:
auth = (self.config.introspection.client_id, self.config.introspection.client_secret)
ssl_ctxt = None
if self.config.tls_cafile:
ssl_ctxt = ssl.create_default_context(cafile=self.config.tls_cafile.as_posix())
try:
async with httpx.AsyncClient(verify=ssl_ctxt) as client:
response = await client.post(
self.config.introspection.url,
data=form,
auth=auth,
timeout=10.0, # Add a reasonable timeout
)
if response.status_code != 200:
logger.warning(f"Token introspection failed with status code: {response.status_code}")
raise ValueError(f"Token introspection failed: {response.status_code}")
fields = response.json()
if not fields["active"]:
raise ValueError("Token not active")
principal = fields["sub"] or fields["username"]
access_attributes = get_attributes_from_claims(fields, self.config.claims_mapping)
return User(
principal=principal,
attributes=access_attributes,
)
except httpx.TimeoutException:
logger.exception("Token introspection request timed out")
raise
except ValueError:
# Re-raise ValueError exceptions to preserve their message
raise
except Exception as e:
logger.exception("Error during token introspection")
raise ValueError("Token introspection error") from e
async def close(self):
pass
async def _refresh_jwks(self) -> None:
"""
Refresh the JWKS cache.
This is a simple cache that expires after a certain amount of time (defined by `key_recheck_period`).
If the cache is expired, we refresh the JWKS from the JWKS URI.
Notes: for Kubernetes which doesn't fully implement the OIDC protocol:
* It doesn't have user authentication flows
* It doesn't have refresh tokens
"""
async with self._jwks_lock:
if self.config.jwks is None:
raise ValueError("JWKS is not configured")
if time.time() - self._jwks_at > self.config.jwks.key_recheck_period:
verify = self.config.tls_cafile.as_posix() if self.config.tls_cafile else self.config.verify_tls
async with httpx.AsyncClient(verify=verify) as client:
res = await client.get(self.config.jwks.uri, timeout=5)
res.raise_for_status()
jwks_data = res.json()["keys"]
updated = {}
for k in jwks_data:
kid = k["kid"]
# Store the entire key object as it may be needed for different algorithms
updated[kid] = k
self._jwks = updated
self._jwks_at = time.time()
class CustomAuthProviderConfig(BaseModel):
endpoint: str
class CustomAuthProvider(AuthProvider):
"""Custom authentication provider that uses an external endpoint."""
def __init__(self, config: CustomAuthProviderConfig):
self.config = config
self._client = None
async def validate_token(self, token: str, scope: dict | None = None) -> User:
"""Validate a token using the custom authentication endpoint."""
if scope is None:
scope = {}
headers = dict(scope.get("headers", []))
path = scope.get("path", "")
request_headers = {k.decode(): v.decode() for k, v in headers.items()}
# Remove sensitive headers
if "authorization" in request_headers:
del request_headers["authorization"]
query_string = scope.get("query_string", b"").decode()
params = parse_qs(query_string)
# Build the auth request model
auth_request = AuthRequest(
api_key=token,
request=AuthRequestContext(
path=path,
headers=request_headers,
params=params,
),
)
# Validate with authentication endpoint
try:
async with httpx.AsyncClient() as client:
response = await client.post(
self.config.endpoint,
json=auth_request.model_dump(),
timeout=10.0, # Add a reasonable timeout
)
print("MADE CALL")
if response.status_code != 200:
logger.warning(f"Authentication failed with status code: {response.status_code}")
raise ValueError(f"Authentication failed: {response.status_code}")
# Parse and validate the auth response
try:
response_data = response.json()
auth_response = AuthResponse(**response_data)
return User(auth_response.principal, auth_response.attributes)
except Exception as e:
logger.exception("Error parsing authentication response")
raise ValueError("Invalid authentication response format") from e
except httpx.TimeoutException:
logger.exception("Authentication request timed out")
raise
except ValueError:
# Re-raise ValueError exceptions to preserve their message
raise
except Exception as e:
logger.exception("Error during authentication")
raise ValueError("Authentication service error") from e
async def close(self):
"""Close the HTTP client."""
if self._client:
await self._client.aclose()
self._client = None
def create_auth_provider(config: AuthenticationConfig) -> AuthProvider:
"""Factory function to create the appropriate auth provider."""
provider_type = config.provider_type.lower()
if provider_type == "custom":
return CustomAuthProvider(CustomAuthProviderConfig.model_validate(config.config))
elif provider_type == "oauth2_token":
return OAuth2TokenAuthProvider(OAuth2TokenAuthProviderConfig.model_validate(config.config))
else:
supported_providers = ", ".join([t.value for t in AuthProviderType])
raise ValueError(f"Unsupported auth provider type: {provider_type}. Supported types are: {supported_providers}")