# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. import ssl import time from abc import ABC, abstractmethod from asyncio import Lock from pathlib import Path from typing import Self from urllib.parse import parse_qs import httpx from jose import jwt from pydantic import BaseModel, Field, field_validator, model_validator from llama_stack.distribution.datatypes import AuthenticationConfig, AuthProviderType, User from llama_stack.log import get_logger logger = get_logger(name=__name__, category="auth") class AuthResponse(BaseModel): """The format of the authentication response from the auth endpoint.""" principal: str # further attributes that may be used for access control decisions attributes: dict[str, list[str]] | None = None message: str | None = Field( default=None, description="Optional message providing additional context about the authentication result." ) class AuthRequestContext(BaseModel): path: str = Field(description="The path of the request being authenticated") headers: dict[str, str] = Field(description="HTTP headers from the original request (excluding Authorization)") params: dict[str, list[str]] = Field( description="Query parameters from the original request, parsed as dictionary of lists" ) class AuthRequest(BaseModel): api_key: str = Field(description="The API key extracted from the Authorization header") request: AuthRequestContext = Field(description="Context information about the request being authenticated") class AuthProvider(ABC): """Abstract base class for authentication providers.""" @abstractmethod async def validate_token(self, token: str, scope: dict | None = None) -> User: """Validate a token and return access attributes.""" pass @abstractmethod async def close(self): """Clean up any resources.""" pass def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str]) -> dict[str, list[str]]: attributes: dict[str, list[str]] = {} for claim_key, attribute_key in mapping.items(): if claim_key not in claims: continue claim = claims[claim_key] if isinstance(claim, list): values = claim else: values = claim.split() if attribute_key in attributes: attributes[attribute_key].extend(values) else: attributes[attribute_key] = values return attributes class OAuth2JWKSConfig(BaseModel): # The JWKS URI for collecting public keys uri: str token: str | None = Field(default=None, description="token to authorise access to jwks") key_recheck_period: int = Field(default=3600, description="The period to recheck the JWKS URI for key updates") class OAuth2IntrospectionConfig(BaseModel): url: str client_id: str client_secret: str send_secret_in_body: bool = False class OAuth2TokenAuthProviderConfig(BaseModel): audience: str = "llama-stack" verify_tls: bool = True tls_cafile: Path | None = None issuer: str | None = Field(default=None, description="The OIDC issuer URL.") claims_mapping: dict[str, str] = Field( default_factory=lambda: { "sub": "roles", "username": "roles", "groups": "teams", "team": "teams", "project": "projects", "tenant": "namespaces", "namespace": "namespaces", }, ) jwks: OAuth2JWKSConfig | None introspection: OAuth2IntrospectionConfig | None = None @classmethod @field_validator("claims_mapping") def validate_claims_mapping(cls, v): for key, value in v.items(): if not value: raise ValueError(f"claims_mapping value cannot be empty: {key}") return v @model_validator(mode="after") def validate_mode(self) -> Self: if not self.jwks and not self.introspection: raise ValueError("One of jwks or introspection must be configured") if self.jwks and self.introspection: raise ValueError("At present only one of jwks or introspection should be configured") return self class OAuth2TokenAuthProvider(AuthProvider): """ JWT token authentication provider that validates a JWT token and extracts access attributes. This should be the standard authentication provider for most use cases. """ def __init__(self, config: OAuth2TokenAuthProviderConfig): self.config = config self._jwks_at: float = 0.0 self._jwks: dict[str, str] = {} self._jwks_lock = Lock() async def validate_token(self, token: str, scope: dict | None = None) -> User: if self.config.jwks: return await self.validate_jwt_token(token, scope) if self.config.introspection: return await self.introspect_token(token, scope) raise ValueError("One of jwks or introspection must be configured") async def validate_jwt_token(self, token: str, scope: dict | None = None) -> User: """Validate a token using the JWT token.""" await self._refresh_jwks() try: header = jwt.get_unverified_header(token) kid = header["kid"] if kid not in self._jwks: raise ValueError(f"Unknown key ID: {kid}") key_data = self._jwks[kid] algorithm = header.get("alg", "RS256") claims = jwt.decode( token, key_data, algorithms=[algorithm], audience=self.config.audience, issuer=self.config.issuer, ) except Exception as exc: raise ValueError(f"Invalid JWT token: {token}") from exc # There are other standard claims, the most relevant of which is `scope`. # We should incorporate these into the access attributes. principal = claims["sub"] access_attributes = get_attributes_from_claims(claims, self.config.claims_mapping) return User( principal=principal, attributes=access_attributes, ) async def introspect_token(self, token: str, scope: dict | None = None) -> User: """Validate a token using token introspection as defined by RFC 7662.""" form = { "token": token, } if self.config.introspection is None: raise ValueError("Introspection is not configured") if self.config.introspection.send_secret_in_body: form["client_id"] = self.config.introspection.client_id form["client_secret"] = self.config.introspection.client_secret auth = None else: auth = (self.config.introspection.client_id, self.config.introspection.client_secret) ssl_ctxt = None if self.config.tls_cafile: ssl_ctxt = ssl.create_default_context(cafile=self.config.tls_cafile.as_posix()) try: async with httpx.AsyncClient(verify=ssl_ctxt) as client: response = await client.post( self.config.introspection.url, data=form, auth=auth, timeout=10.0, # Add a reasonable timeout ) if response.status_code != 200: logger.warning(f"Token introspection failed with status code: {response.status_code}") raise ValueError(f"Token introspection failed: {response.status_code}") fields = response.json() if not fields["active"]: raise ValueError("Token not active") principal = fields["sub"] or fields["username"] access_attributes = get_attributes_from_claims(fields, self.config.claims_mapping) return User( principal=principal, attributes=access_attributes, ) except httpx.TimeoutException: logger.exception("Token introspection request timed out") raise except ValueError: # Re-raise ValueError exceptions to preserve their message raise except Exception as e: logger.exception("Error during token introspection") raise ValueError("Token introspection error") from e async def close(self): pass async def _refresh_jwks(self) -> None: """ Refresh the JWKS cache. This is a simple cache that expires after a certain amount of time (defined by `key_recheck_period`). If the cache is expired, we refresh the JWKS from the JWKS URI. Notes: for Kubernetes which doesn't fully implement the OIDC protocol: * It doesn't have user authentication flows * It doesn't have refresh tokens """ async with self._jwks_lock: if self.config.jwks is None: raise ValueError("JWKS is not configured") if time.time() - self._jwks_at > self.config.jwks.key_recheck_period: headers = {} if self.config.jwks.token: headers["Authorization"] = f"Bearer {self.config.jwks.token}" verify = self.config.tls_cafile.as_posix() if self.config.tls_cafile else self.config.verify_tls async with httpx.AsyncClient(verify=verify) as client: res = await client.get(self.config.jwks.uri, timeout=5, headers=headers) res.raise_for_status() jwks_data = res.json()["keys"] updated = {} for k in jwks_data: kid = k["kid"] # Store the entire key object as it may be needed for different algorithms updated[kid] = k self._jwks = updated self._jwks_at = time.time() class CustomAuthProviderConfig(BaseModel): endpoint: str class CustomAuthProvider(AuthProvider): """Custom authentication provider that uses an external endpoint.""" def __init__(self, config: CustomAuthProviderConfig): self.config = config self._client = None async def validate_token(self, token: str, scope: dict | None = None) -> User: """Validate a token using the custom authentication endpoint.""" if scope is None: scope = {} headers = dict(scope.get("headers", [])) path = scope.get("path", "") request_headers = {k.decode(): v.decode() for k, v in headers.items()} # Remove sensitive headers if "authorization" in request_headers: del request_headers["authorization"] query_string = scope.get("query_string", b"").decode() params = parse_qs(query_string) # Build the auth request model auth_request = AuthRequest( api_key=token, request=AuthRequestContext( path=path, headers=request_headers, params=params, ), ) # Validate with authentication endpoint try: async with httpx.AsyncClient() as client: response = await client.post( self.config.endpoint, json=auth_request.model_dump(), timeout=10.0, # Add a reasonable timeout ) if response.status_code != 200: logger.warning(f"Authentication failed with status code: {response.status_code}") raise ValueError(f"Authentication failed: {response.status_code}") # Parse and validate the auth response try: response_data = response.json() auth_response = AuthResponse(**response_data) return User(auth_response.principal, auth_response.attributes) except Exception as e: logger.exception("Error parsing authentication response") raise ValueError("Invalid authentication response format") from e except httpx.TimeoutException: logger.exception("Authentication request timed out") raise except ValueError: # Re-raise ValueError exceptions to preserve their message raise except Exception as e: logger.exception("Error during authentication") raise ValueError("Authentication service error") from e async def close(self): """Close the HTTP client.""" if self._client: await self._client.aclose() self._client = None def create_auth_provider(config: AuthenticationConfig) -> AuthProvider: """Factory function to create the appropriate auth provider.""" provider_type = config.provider_type.lower() if provider_type == "custom": return CustomAuthProvider(CustomAuthProviderConfig.model_validate(config.config)) elif provider_type == "oauth2_token": return OAuth2TokenAuthProvider(OAuth2TokenAuthProviderConfig.model_validate(config.config)) else: supported_providers = ", ".join([t.value for t in AuthProviderType]) raise ValueError(f"Unsupported auth provider type: {provider_type}. Supported types are: {supported_providers}")