From fd86961c883df8c7dc91efca40243a053b17367d Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Thu, 15 May 2025 17:02:23 -0700 Subject: [PATCH] feat: introduce JWKSAuthProvider --- llama_stack/distribution/server/auth.py | 13 +- .../distribution/server/auth_providers.py | 154 +++++++++++++----- 2 files changed, 125 insertions(+), 42 deletions(-) diff --git a/llama_stack/distribution/server/auth.py b/llama_stack/distribution/server/auth.py index 429232ece..83436c51f 100644 --- a/llama_stack/distribution/server/auth.py +++ b/llama_stack/distribution/server/auth.py @@ -93,7 +93,7 @@ class AuthenticationMiddleware: # Validate token and get access attributes try: - access_attributes = await self.auth_provider.validate_token(token, scope) + validation_result = await self.auth_provider.validate_token(token, scope) except httpx.TimeoutException: logger.exception("Authentication request timed out") return await self._send_auth_error(send, "Authentication service timeout") @@ -105,17 +105,20 @@ class AuthenticationMiddleware: return await self._send_auth_error(send, "Authentication service error") # Store attributes in request scope for access control - if access_attributes: - user_attributes = access_attributes.model_dump(exclude_none=True) + if validation_result.access_attributes: + user_attributes = validation_result.access_attributes.model_dump(exclude_none=True) else: logger.warning("No access attributes, setting namespace to token by default") user_attributes = { - "namespaces": [token], + "roles": [token], } # Store attributes in request scope scope["user_attributes"] = user_attributes - logger.debug(f"Authentication successful: {len(scope['user_attributes'])} attributes") + scope["principal"] = validation_result.principal + logger.debug( + f"Authentication successful: {validation_result.principal} with {len(scope['user_attributes'])} attributes" + ) return await self.app(scope, receive, send) diff --git a/llama_stack/distribution/server/auth_providers.py b/llama_stack/distribution/server/auth_providers.py index 1b19f8923..fc78138ff 100644 --- a/llama_stack/distribution/server/auth_providers.py +++ b/llama_stack/distribution/server/auth_providers.py @@ -5,11 +5,13 @@ # the root directory of this source tree. import json +import time from abc import ABC, abstractmethod from enum import Enum from urllib.parse import parse_qs import httpx +from jose import jwt from pydantic import BaseModel, Field from llama_stack.distribution.datatypes import AccessAttributes @@ -18,9 +20,11 @@ from llama_stack.log import get_logger logger = get_logger(name=__name__, category="auth") -class AuthResponse(BaseModel): - """The format of the authentication response from the auth endpoint.""" - +class TokenValidationResult(BaseModel): + principal: str | None = Field( + default=None, + description="The principal (username or persistent identifier) of the authenticated user", + ) access_attributes: AccessAttributes | None = Field( default=None, description=""" @@ -43,6 +47,10 @@ class AuthResponse(BaseModel): """, ) + +class AuthResponse(TokenValidationResult): + """The format of the authentication response from the auth endpoint.""" + message: str | None = Field( default=None, description="Optional message providing additional context about the authentication result." ) @@ -82,7 +90,7 @@ class AuthProvider(ABC): """Abstract base class for authentication providers.""" @abstractmethod - async def validate_token(self, token: str, scope: dict | None = None) -> AccessAttributes | None: + async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult: """Validate a token and return access attributes.""" pass @@ -92,12 +100,16 @@ class AuthProvider(ABC): pass +class KubernetesAuthProviderConfig(BaseModel): + api_server_url: str + ca_cert_path: str | None = None + + class KubernetesAuthProvider(AuthProvider): """Kubernetes authentication provider that validates tokens against the Kubernetes API server.""" - def __init__(self, config: dict[str, str]): - self.api_server_url = config["api_server_url"] - self.ca_cert_path = config.get("ca_cert_path") + def __init__(self, config: KubernetesAuthProviderConfig): + self.config = config self._client = None async def _get_client(self): @@ -110,16 +122,16 @@ class KubernetesAuthProvider(AuthProvider): # Configure the client configuration = client.Configuration() - configuration.host = self.api_server_url - if self.ca_cert_path: - configuration.ssl_ca_cert = self.ca_cert_path - configuration.verify_ssl = bool(self.ca_cert_path) + configuration.host = self.config.api_server_url + if self.config.ca_cert_path: + configuration.ssl_ca_cert = self.config.ca_cert_path + configuration.verify_ssl = bool(self.config.ca_cert_path) # Create API client self._client = ApiClient(configuration) return self._client - async def validate_token(self, token: str, scope: dict | None = None) -> AccessAttributes | None: + async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult: """Validate a Kubernetes token and return access attributes.""" try: client = await self._get_client() @@ -146,9 +158,12 @@ class KubernetesAuthProvider(AuthProvider): username = payload.get("sub", "") groups = payload.get("groups", []) - return AccessAttributes( - roles=[username], # Use username as a role - teams=groups, # Use Kubernetes groups as teams + return TokenValidationResult( + principal=username, + access_attributes=AccessAttributes( + roles=[username], # Use username as a role + teams=groups, # Use Kubernetes groups as teams + ), ) except Exception as e: @@ -162,18 +177,93 @@ class KubernetesAuthProvider(AuthProvider): self._client = None +JWT_AUDIENCE = "llama-stack" + + +class JWKSAuthProviderConfig(BaseModel): + """Configuration for JWT token authentication provider.""" + + # The JWKS URI for collecting public keys + jwks_uri: str + algorithm: str = "RS256" + cache_ttl: int = 3600 + + +class JWKSAuthProvider(AuthProvider): + """JWT token authentication provider that validates tokens against the JWT token.""" + + def __init__(self, config: JWKSAuthProviderConfig): + self.config = config + self._jwks_at: float = 0.0 + self._jwks: dict[str, str] = {} + + async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult: + """Validate a token using the JWT token.""" + await self._refresh_jwks() + + try: + kid = jwt.get_unverified_header(token)["kid"] + key = self._jwks[kid] # raises if unknown + claims = jwt.decode( + token, + key, + algorithms=[self.config.algorithm], + audience=JWT_AUDIENCE, + options={"verify_exp": True}, + ) + except Exception as exc: + raise ValueError(f"invalid token: {token}") from exc + + principal = f"{claims['iss']}:{claims['sub']}" + + teams = claims.get("teams", []) + if not teams: + if team := claims.get("team", claims.get("team_id")): + teams = [team] + projects = claims.get("projects", []) + if not projects: + if project := claims.get("project", claims.get("project_id")): + projects = [project] + namespaces = claims.get("namespaces", []) + if not namespaces: + if namespace := claims.get("namespace", claims.get("tenant")): + namespaces = [namespace] + + return TokenValidationResult( + principal=principal, + access_attributes=AccessAttributes( + roles=claims.get("groups", claims.get("roles", [])), # Okta / Auth0 + teams=teams, + projects=projects, + namespaces=namespaces, + ), + ) + + async def close(self): + """Close the HTTP client.""" + + async def _refresh_jwks(self) -> None: + if time.time() - self._jwks_at > self.config.cache_ttl: + with httpx.AsyncClient() as client: + res = await client.get(self.config.jwks_uri, timeout=5) + res.raise_for_status() + self._jwks = {k["kid"]: k for k in res.json()["keys"]} + self._jwks_at = time.time() + + +class CustomAuthProviderConfig(BaseModel): + endpoint: str + + class CustomAuthProvider(AuthProvider): """Custom authentication provider that uses an external endpoint.""" - def __init__(self, config: dict[str, str]): - self.endpoint = config["endpoint"] + def __init__(self, config: CustomAuthProviderConfig): + self.config = config self._client = None - async def validate_token(self, token: str, scope: dict | None = None) -> AccessAttributes | None: + async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult: """Validate a token using the custom authentication endpoint.""" - if not self.endpoint: - raise ValueError("Authentication endpoint not configured") - if scope is None: scope = {} @@ -202,7 +292,7 @@ class CustomAuthProvider(AuthProvider): try: async with httpx.AsyncClient() as client: response = await client.post( - self.endpoint, + self.config.endpoint, json=auth_request.model_dump(), timeout=10.0, # Add a reasonable timeout ) @@ -214,19 +304,7 @@ class CustomAuthProvider(AuthProvider): try: response_data = response.json() auth_response = AuthResponse(**response_data) - - # Store attributes in request scope for access control - if auth_response.access_attributes: - return auth_response.access_attributes - else: - logger.warning("No access attributes, setting namespace to api_key by default") - user_attributes = { - "namespaces": [token], - } - - scope["user_attributes"] = user_attributes - logger.debug(f"Authentication successful: {len(user_attributes)} attributes") - return auth_response.access_attributes + return auth_response except Exception as e: logger.exception("Error parsing authentication response") raise ValueError("Invalid authentication response format") from e @@ -253,9 +331,11 @@ def create_auth_provider(config: AuthProviderConfig) -> AuthProvider: provider_type = config.provider_type.lower() if provider_type == "kubernetes": - return KubernetesAuthProvider(config.config) + return KubernetesAuthProvider(KubernetesAuthProviderConfig.model_validate(config.config)) elif provider_type == "custom": - return CustomAuthProvider(config.config) + return CustomAuthProvider(CustomAuthProviderConfig.model_validate(config.config)) + elif provider_type == "jwks": + return JWKSAuthProvider(JWKSAuthProviderConfig.model_validate(config.config)) else: supported_providers = ", ".join([t.value for t in AuthProviderType]) raise ValueError(f"Unsupported auth provider type: {provider_type}. Supported types are: {supported_providers}")