feat: introduce JWKSAuthProvider

This commit is contained in:
Ashwin Bharambe 2025-05-15 17:02:23 -07:00
parent 1341916caf
commit fd86961c88
2 changed files with 125 additions and 42 deletions

View file

@ -93,7 +93,7 @@ class AuthenticationMiddleware:
# Validate token and get access attributes # Validate token and get access attributes
try: try:
access_attributes = await self.auth_provider.validate_token(token, scope) validation_result = await self.auth_provider.validate_token(token, scope)
except httpx.TimeoutException: except httpx.TimeoutException:
logger.exception("Authentication request timed out") logger.exception("Authentication request timed out")
return await self._send_auth_error(send, "Authentication service timeout") return await self._send_auth_error(send, "Authentication service timeout")
@ -105,17 +105,20 @@ class AuthenticationMiddleware:
return await self._send_auth_error(send, "Authentication service error") return await self._send_auth_error(send, "Authentication service error")
# Store attributes in request scope for access control # Store attributes in request scope for access control
if access_attributes: if validation_result.access_attributes:
user_attributes = access_attributes.model_dump(exclude_none=True) user_attributes = validation_result.access_attributes.model_dump(exclude_none=True)
else: else:
logger.warning("No access attributes, setting namespace to token by default") logger.warning("No access attributes, setting namespace to token by default")
user_attributes = { user_attributes = {
"namespaces": [token], "roles": [token],
} }
# Store attributes in request scope # Store attributes in request scope
scope["user_attributes"] = user_attributes scope["user_attributes"] = user_attributes
logger.debug(f"Authentication successful: {len(scope['user_attributes'])} attributes") scope["principal"] = validation_result.principal
logger.debug(
f"Authentication successful: {validation_result.principal} with {len(scope['user_attributes'])} attributes"
)
return await self.app(scope, receive, send) return await self.app(scope, receive, send)

View file

@ -5,11 +5,13 @@
# the root directory of this source tree. # the root directory of this source tree.
import json import json
import time
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from enum import Enum from enum import Enum
from urllib.parse import parse_qs from urllib.parse import parse_qs
import httpx import httpx
from jose import jwt
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from llama_stack.distribution.datatypes import AccessAttributes from llama_stack.distribution.datatypes import AccessAttributes
@ -18,9 +20,11 @@ from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="auth") logger = get_logger(name=__name__, category="auth")
class AuthResponse(BaseModel): class TokenValidationResult(BaseModel):
"""The format of the authentication response from the auth endpoint.""" principal: str | None = Field(
default=None,
description="The principal (username or persistent identifier) of the authenticated user",
)
access_attributes: AccessAttributes | None = Field( access_attributes: AccessAttributes | None = Field(
default=None, default=None,
description=""" description="""
@ -43,6 +47,10 @@ class AuthResponse(BaseModel):
""", """,
) )
class AuthResponse(TokenValidationResult):
"""The format of the authentication response from the auth endpoint."""
message: str | None = Field( message: str | None = Field(
default=None, description="Optional message providing additional context about the authentication result." default=None, description="Optional message providing additional context about the authentication result."
) )
@ -82,7 +90,7 @@ class AuthProvider(ABC):
"""Abstract base class for authentication providers.""" """Abstract base class for authentication providers."""
@abstractmethod @abstractmethod
async def validate_token(self, token: str, scope: dict | None = None) -> AccessAttributes | None: async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
"""Validate a token and return access attributes.""" """Validate a token and return access attributes."""
pass pass
@ -92,12 +100,16 @@ class AuthProvider(ABC):
pass pass
class KubernetesAuthProviderConfig(BaseModel):
api_server_url: str
ca_cert_path: str | None = None
class KubernetesAuthProvider(AuthProvider): class KubernetesAuthProvider(AuthProvider):
"""Kubernetes authentication provider that validates tokens against the Kubernetes API server.""" """Kubernetes authentication provider that validates tokens against the Kubernetes API server."""
def __init__(self, config: dict[str, str]): def __init__(self, config: KubernetesAuthProviderConfig):
self.api_server_url = config["api_server_url"] self.config = config
self.ca_cert_path = config.get("ca_cert_path")
self._client = None self._client = None
async def _get_client(self): async def _get_client(self):
@ -110,16 +122,16 @@ class KubernetesAuthProvider(AuthProvider):
# Configure the client # Configure the client
configuration = client.Configuration() configuration = client.Configuration()
configuration.host = self.api_server_url configuration.host = self.config.api_server_url
if self.ca_cert_path: if self.config.ca_cert_path:
configuration.ssl_ca_cert = self.ca_cert_path configuration.ssl_ca_cert = self.config.ca_cert_path
configuration.verify_ssl = bool(self.ca_cert_path) configuration.verify_ssl = bool(self.config.ca_cert_path)
# Create API client # Create API client
self._client = ApiClient(configuration) self._client = ApiClient(configuration)
return self._client return self._client
async def validate_token(self, token: str, scope: dict | None = None) -> AccessAttributes | None: async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
"""Validate a Kubernetes token and return access attributes.""" """Validate a Kubernetes token and return access attributes."""
try: try:
client = await self._get_client() client = await self._get_client()
@ -146,9 +158,12 @@ class KubernetesAuthProvider(AuthProvider):
username = payload.get("sub", "") username = payload.get("sub", "")
groups = payload.get("groups", []) groups = payload.get("groups", [])
return AccessAttributes( return TokenValidationResult(
roles=[username], # Use username as a role principal=username,
teams=groups, # Use Kubernetes groups as teams access_attributes=AccessAttributes(
roles=[username], # Use username as a role
teams=groups, # Use Kubernetes groups as teams
),
) )
except Exception as e: except Exception as e:
@ -162,18 +177,93 @@ class KubernetesAuthProvider(AuthProvider):
self._client = None self._client = None
JWT_AUDIENCE = "llama-stack"
class JWKSAuthProviderConfig(BaseModel):
"""Configuration for JWT token authentication provider."""
# The JWKS URI for collecting public keys
jwks_uri: str
algorithm: str = "RS256"
cache_ttl: int = 3600
class JWKSAuthProvider(AuthProvider):
"""JWT token authentication provider that validates tokens against the JWT token."""
def __init__(self, config: JWKSAuthProviderConfig):
self.config = config
self._jwks_at: float = 0.0
self._jwks: dict[str, str] = {}
async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
"""Validate a token using the JWT token."""
await self._refresh_jwks()
try:
kid = jwt.get_unverified_header(token)["kid"]
key = self._jwks[kid] # raises if unknown
claims = jwt.decode(
token,
key,
algorithms=[self.config.algorithm],
audience=JWT_AUDIENCE,
options={"verify_exp": True},
)
except Exception as exc:
raise ValueError(f"invalid token: {token}") from exc
principal = f"{claims['iss']}:{claims['sub']}"
teams = claims.get("teams", [])
if not teams:
if team := claims.get("team", claims.get("team_id")):
teams = [team]
projects = claims.get("projects", [])
if not projects:
if project := claims.get("project", claims.get("project_id")):
projects = [project]
namespaces = claims.get("namespaces", [])
if not namespaces:
if namespace := claims.get("namespace", claims.get("tenant")):
namespaces = [namespace]
return TokenValidationResult(
principal=principal,
access_attributes=AccessAttributes(
roles=claims.get("groups", claims.get("roles", [])), # Okta / Auth0
teams=teams,
projects=projects,
namespaces=namespaces,
),
)
async def close(self):
"""Close the HTTP client."""
async def _refresh_jwks(self) -> None:
if time.time() - self._jwks_at > self.config.cache_ttl:
with httpx.AsyncClient() as client:
res = await client.get(self.config.jwks_uri, timeout=5)
res.raise_for_status()
self._jwks = {k["kid"]: k for k in res.json()["keys"]}
self._jwks_at = time.time()
class CustomAuthProviderConfig(BaseModel):
endpoint: str
class CustomAuthProvider(AuthProvider): class CustomAuthProvider(AuthProvider):
"""Custom authentication provider that uses an external endpoint.""" """Custom authentication provider that uses an external endpoint."""
def __init__(self, config: dict[str, str]): def __init__(self, config: CustomAuthProviderConfig):
self.endpoint = config["endpoint"] self.config = config
self._client = None self._client = None
async def validate_token(self, token: str, scope: dict | None = None) -> AccessAttributes | None: async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
"""Validate a token using the custom authentication endpoint.""" """Validate a token using the custom authentication endpoint."""
if not self.endpoint:
raise ValueError("Authentication endpoint not configured")
if scope is None: if scope is None:
scope = {} scope = {}
@ -202,7 +292,7 @@ class CustomAuthProvider(AuthProvider):
try: try:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.post( response = await client.post(
self.endpoint, self.config.endpoint,
json=auth_request.model_dump(), json=auth_request.model_dump(),
timeout=10.0, # Add a reasonable timeout timeout=10.0, # Add a reasonable timeout
) )
@ -214,19 +304,7 @@ class CustomAuthProvider(AuthProvider):
try: try:
response_data = response.json() response_data = response.json()
auth_response = AuthResponse(**response_data) auth_response = AuthResponse(**response_data)
return auth_response
# Store attributes in request scope for access control
if auth_response.access_attributes:
return auth_response.access_attributes
else:
logger.warning("No access attributes, setting namespace to api_key by default")
user_attributes = {
"namespaces": [token],
}
scope["user_attributes"] = user_attributes
logger.debug(f"Authentication successful: {len(user_attributes)} attributes")
return auth_response.access_attributes
except Exception as e: except Exception as e:
logger.exception("Error parsing authentication response") logger.exception("Error parsing authentication response")
raise ValueError("Invalid authentication response format") from e raise ValueError("Invalid authentication response format") from e
@ -253,9 +331,11 @@ def create_auth_provider(config: AuthProviderConfig) -> AuthProvider:
provider_type = config.provider_type.lower() provider_type = config.provider_type.lower()
if provider_type == "kubernetes": if provider_type == "kubernetes":
return KubernetesAuthProvider(config.config) return KubernetesAuthProvider(KubernetesAuthProviderConfig.model_validate(config.config))
elif provider_type == "custom": elif provider_type == "custom":
return CustomAuthProvider(config.config) return CustomAuthProvider(CustomAuthProviderConfig.model_validate(config.config))
elif provider_type == "jwks":
return JWKSAuthProvider(JWKSAuthProviderConfig.model_validate(config.config))
else: else:
supported_providers = ", ".join([t.value for t in AuthProviderType]) supported_providers = ", ".join([t.value for t in AuthProviderType])
raise ValueError(f"Unsupported auth provider type: {provider_type}. Supported types are: {supported_providers}") raise ValueError(f"Unsupported auth provider type: {provider_type}. Supported types are: {supported_providers}")