diff --git a/docs/source/distributions/configuration.md b/docs/source/distributions/configuration.md index c9677b3b6..452c3d95f 100644 --- a/docs/source/distributions/configuration.md +++ b/docs/source/distributions/configuration.md @@ -354,6 +354,47 @@ You can easily validate a request by running: curl -s -L -H "Authorization: Bearer $(cat llama-stack-auth-token)" http://127.0.0.1:8321/v1/providers ``` +#### Kubernetes Authentication Provider + +The server can be configured to use Kubernetes SelfSubjectReview API to validate tokens directly against the Kubernetes API server: + +```yaml +server: + auth: + provider_config: + type: "kubernetes" + api_server_url: "https://kubernetes.default.svc" + claims_mapping: + username: "roles" + groups: "roles" + uid: "uid_attr" + verify_tls: true + tls_cafile: "/path/to/ca.crt" +``` + +Configuration options: +- `api_server_url`: The Kubernetes API server URL (e.g., https://kubernetes.default.svc:6443) +- `verify_tls`: Whether to verify TLS certificates (default: true) +- `tls_cafile`: Path to CA certificate file for TLS verification +- `claims_mapping`: Mapping of Kubernetes user claims to access attributes + +The provider validates tokens by sending a SelfSubjectReview request to the Kubernetes API server at `/apis/authentication.k8s.io/v1/selfsubjectreviews`. The provider extracts user information from the response: +- Username from the `userInfo.username` field +- Groups from the `userInfo.groups` field +- UID from the `userInfo.uid` field + +To obtain a token for testing: +```bash +kubectl create namespace llama-stack +kubectl create serviceaccount llama-stack-auth -n llama-stack +kubectl create token llama-stack-auth -n llama-stack > llama-stack-auth-token +``` + +You can validate a request by running: +```bash +curl -s -L -H "Authorization: Bearer $(cat llama-stack-auth-token)" http://127.0.0.1:8321/v1/providers +``` + #### GitHub Token Provider Validates GitHub personal access tokens or OAuth tokens directly: ```yaml diff --git a/llama_stack/apis/common/errors.py b/llama_stack/apis/common/errors.py index ec3d2b1ce..4c9c0a818 100644 --- a/llama_stack/apis/common/errors.py +++ b/llama_stack/apis/common/errors.py @@ -79,3 +79,10 @@ class ConflictError(ValueError): def __init__(self, message: str) -> None: super().__init__(message) + + +class TokenValidationError(ValueError): + """raised when token validation fails during authentication""" + + def __init__(self, message: str) -> None: + super().__init__(message) diff --git a/llama_stack/core/datatypes.py b/llama_stack/core/datatypes.py index c3940fcbd..0f348b067 100644 --- a/llama_stack/core/datatypes.py +++ b/llama_stack/core/datatypes.py @@ -7,6 +7,7 @@ from enum import StrEnum from pathlib import Path from typing import Annotated, Any, Literal, Self +from urllib.parse import urlparse from pydantic import BaseModel, Field, field_validator, model_validator @@ -212,6 +213,7 @@ class AuthProviderType(StrEnum): OAUTH2_TOKEN = "oauth2_token" GITHUB_TOKEN = "github_token" CUSTOM = "custom" + KUBERNETES = "kubernetes" class OAuth2TokenAuthConfig(BaseModel): @@ -282,8 +284,45 @@ class GitHubTokenAuthConfig(BaseModel): ) +class KubernetesAuthProviderConfig(BaseModel): + """Configuration for Kubernetes authentication provider.""" + + type: Literal[AuthProviderType.KUBERNETES] = AuthProviderType.KUBERNETES + api_server_url: str = Field( + default="https://kubernetes.default.svc", + description="Kubernetes API server URL (e.g., https://api.cluster.domain:6443)", + ) + verify_tls: bool = Field(default=True, description="Whether to verify TLS certificates") + tls_cafile: Path | None = Field(default=None, description="Path to CA certificate file for TLS verification") + claims_mapping: dict[str, str] = Field( + default_factory=lambda: { + "username": "roles", + "groups": "roles", + }, + description="Mapping of Kubernetes user claims to access attributes", + ) + + @field_validator("api_server_url") + @classmethod + def validate_api_server_url(cls, v): + parsed = urlparse(v) + if not parsed.scheme or not parsed.netloc: + raise ValueError(f"api_server_url must be a valid URL with scheme and host: {v}") + if parsed.scheme not in ["http", "https"]: + raise ValueError(f"api_server_url scheme must be http or https: {v}") + return v + + @field_validator("claims_mapping") + @classmethod + def validate_claims_mapping(cls, v): + for key, value in v.items(): + if not value: + raise ValueError(f"claims_mapping value cannot be empty: {key}") + return v + + AuthProviderConfig = Annotated[ - OAuth2TokenAuthConfig | GitHubTokenAuthConfig | CustomAuthConfig, + OAuth2TokenAuthConfig | GitHubTokenAuthConfig | CustomAuthConfig | KubernetesAuthProviderConfig, Field(discriminator="type"), ] diff --git a/llama_stack/core/server/auth_providers.py b/llama_stack/core/server/auth_providers.py index a8af6f75a..38188c49a 100644 --- a/llama_stack/core/server/auth_providers.py +++ b/llama_stack/core/server/auth_providers.py @@ -8,16 +8,18 @@ import ssl import time from abc import ABC, abstractmethod from asyncio import Lock -from urllib.parse import parse_qs, urlparse +from urllib.parse import parse_qs, urljoin, urlparse import httpx from jose import jwt from pydantic import BaseModel, Field +from llama_stack.apis.common.errors import TokenValidationError from llama_stack.core.datatypes import ( AuthenticationConfig, CustomAuthConfig, GitHubTokenAuthConfig, + KubernetesAuthProviderConfig, OAuth2TokenAuthConfig, User, ) @@ -162,7 +164,7 @@ class OAuth2TokenAuthProvider(AuthProvider): auth=auth, timeout=10.0, # Add a reasonable timeout ) - if response.status_code != 200: + if response.status_code != httpx.codes.OK: logger.warning(f"Token introspection failed with status code: {response.status_code}") raise ValueError(f"Token introspection failed: {response.status_code}") @@ -272,7 +274,7 @@ class CustomAuthProvider(AuthProvider): json=auth_request.model_dump(), timeout=10.0, # Add a reasonable timeout ) - if response.status_code != 200: + if response.status_code != httpx.codes.OK: logger.warning(f"Authentication failed with status code: {response.status_code}") raise ValueError(f"Authentication failed: {response.status_code}") @@ -374,6 +376,89 @@ async def _get_github_user_info(access_token: str, github_api_base_url: str) -> } +class KubernetesAuthProvider(AuthProvider): + """ + Kubernetes authentication provider that validates tokens using the Kubernetes SelfSubjectReview API. + This provider integrates with Kubernetes API server by using the + /apis/authentication.k8s.io/v1/selfsubjectreviews endpoint to validate tokens and extract user information. + """ + + def __init__(self, config: KubernetesAuthProviderConfig): + self.config = config + + def _httpx_verify_value(self) -> bool | str: + """ + Build the value for httpx's `verify` parameter. + - False disables verification. + - Path string points to a CA bundle. + - True uses system defaults. + """ + if not self.config.verify_tls: + return False + if self.config.tls_cafile: + return self.config.tls_cafile.as_posix() + return True + + async def validate_token(self, token: str, scope: dict | None = None) -> User: + """Validate a token using Kubernetes SelfSubjectReview API endpoint.""" + # Build the Kubernetes SelfSubjectReview API endpoint URL + review_api_url = urljoin(self.config.api_server_url, "/apis/authentication.k8s.io/v1/selfsubjectreviews") + + # Create SelfSubjectReview request body + review_request = {"apiVersion": "authentication.k8s.io/v1", "kind": "SelfSubjectReview"} + verify = self._httpx_verify_value() + + try: + async with httpx.AsyncClient(verify=verify, timeout=10.0) as client: + response = await client.post( + review_api_url, + json=review_request, + headers={ + "Authorization": f"Bearer {token}", + "Content-Type": "application/json", + }, + ) + + if response.status_code == httpx.codes.UNAUTHORIZED: + raise TokenValidationError("Invalid token") + if response.status_code != httpx.codes.CREATED: + logger.warning(f"Kubernetes SelfSubjectReview API failed with status code: {response.status_code}") + raise TokenValidationError(f"Token validation failed: {response.status_code}") + + review_response = response.json() + # Extract user information from SelfSubjectReview response + status = review_response.get("status", {}) + if not status: + raise ValueError("No status found in SelfSubjectReview response") + + user_info = status.get("userInfo", {}) + if not user_info: + raise ValueError("No userInfo found in SelfSubjectReview response") + + username = user_info.get("username") + if not username: + raise ValueError("No username found in SelfSubjectReview response") + + # Build user attributes from Kubernetes user info + user_attributes = get_attributes_from_claims(user_info, self.config.claims_mapping) + + return User( + principal=username, + attributes=user_attributes, + ) + + except httpx.TimeoutException: + logger.warning("Kubernetes SelfSubjectReview API request timed out") + raise ValueError("Token validation timeout") from None + except Exception as e: + logger.warning(f"Error during token validation: {str(e)}") + raise ValueError(f"Token validation error: {str(e)}") from e + + async def close(self): + """Close any resources.""" + pass + + def create_auth_provider(config: AuthenticationConfig) -> AuthProvider: """Factory function to create the appropriate auth provider.""" provider_config = config.provider_config @@ -384,5 +469,7 @@ def create_auth_provider(config: AuthenticationConfig) -> AuthProvider: return OAuth2TokenAuthProvider(provider_config) elif isinstance(provider_config, GitHubTokenAuthConfig): return GitHubTokenAuthProvider(provider_config) + elif isinstance(provider_config, KubernetesAuthProviderConfig): + return KubernetesAuthProvider(provider_config) else: raise ValueError(f"Unknown authentication provider config type: {type(provider_config)}") diff --git a/tests/unit/server/test_auth.py b/tests/unit/server/test_auth.py index 37b543976..205e0ce65 100644 --- a/tests/unit/server/test_auth.py +++ b/tests/unit/server/test_auth.py @@ -774,3 +774,136 @@ def test_has_required_scope_function(): # Test no user (auth disabled) assert _has_required_scope("test.read", None) + + +@pytest.fixture +def mock_kubernetes_api_server(): + return "https://api.cluster.example.com:6443" + + +@pytest.fixture +def kubernetes_auth_app(mock_kubernetes_api_server): + app = FastAPI() + auth_config = AuthenticationConfig( + provider_config={ + "type": "kubernetes", + "api_server_url": mock_kubernetes_api_server, + "verify_tls": False, + "claims_mapping": { + "username": "roles", + "groups": "roles", + "uid": "uid_attr", + }, + }, + ) + app.add_middleware(AuthenticationMiddleware, auth_config=auth_config, impls={}) + + @app.get("/test") + def test_endpoint(): + return {"message": "Authentication successful"} + + return app + + +@pytest.fixture +def kubernetes_auth_client(kubernetes_auth_app): + return TestClient(kubernetes_auth_app) + + +def test_missing_auth_header_kubernetes_auth(kubernetes_auth_client): + response = kubernetes_auth_client.get("/test") + assert response.status_code == 401 + assert "Authentication required" in response.json()["error"]["message"] + + +def test_invalid_auth_header_format_kubernetes_auth(kubernetes_auth_client): + response = kubernetes_auth_client.get("/test", headers={"Authorization": "InvalidFormat token123"}) + assert response.status_code == 401 + assert "Invalid Authorization header format" in response.json()["error"]["message"] + + +async def mock_kubernetes_selfsubjectreview_success(*args, **kwargs): + return MockResponse( + 201, + { + "apiVersion": "authentication.k8s.io/v1", + "kind": "SelfSubjectReview", + "metadata": {"creationTimestamp": "2025-07-15T13:53:56Z"}, + "status": { + "userInfo": { + "username": "alice", + "uid": "alice-uid-123", + "groups": ["system:authenticated", "developers", "admins"], + "extra": {"scopes.authorization.openshift.io": ["user:full"]}, + } + }, + }, + ) + + +async def mock_kubernetes_selfsubjectreview_failure(*args, **kwargs): + return MockResponse(401, {"message": "Unauthorized"}) + + +async def mock_kubernetes_selfsubjectreview_http_error(*args, **kwargs): + return MockResponse(500, {"message": "Internal Server Error"}) + + +@patch("httpx.AsyncClient.post", new=mock_kubernetes_selfsubjectreview_success) +def test_valid_kubernetes_auth_authentication(kubernetes_auth_client, valid_token): + response = kubernetes_auth_client.get("/test", headers={"Authorization": f"Bearer {valid_token}"}) + assert response.status_code == 200 + assert response.json() == {"message": "Authentication successful"} + + +@patch("httpx.AsyncClient.post", new=mock_kubernetes_selfsubjectreview_failure) +def test_invalid_kubernetes_auth_authentication(kubernetes_auth_client, invalid_token): + response = kubernetes_auth_client.get("/test", headers={"Authorization": f"Bearer {invalid_token}"}) + assert response.status_code == 401 + assert "Invalid token" in response.json()["error"]["message"] + + +@patch("httpx.AsyncClient.post", new=mock_kubernetes_selfsubjectreview_http_error) +def test_kubernetes_auth_http_error(kubernetes_auth_client, valid_token): + response = kubernetes_auth_client.get("/test", headers={"Authorization": f"Bearer {valid_token}"}) + assert response.status_code == 401 + assert "Token validation failed" in response.json()["error"]["message"] + + +def test_kubernetes_auth_request_payload(kubernetes_auth_client, valid_token, mock_kubernetes_api_server): + with patch("httpx.AsyncClient.post") as mock_post: + mock_response = MockResponse( + 200, + { + "apiVersion": "authentication.k8s.io/v1", + "kind": "SelfSubjectReview", + "metadata": {"creationTimestamp": "2025-07-15T13:53:56Z"}, + "status": { + "userInfo": { + "username": "test-user", + "uid": "test-uid", + "groups": ["test-group"], + } + }, + }, + ) + mock_post.return_value = mock_response + + kubernetes_auth_client.get("/test", headers={"Authorization": f"Bearer {valid_token}"}) + + # Verify the request was made with correct parameters + mock_post.assert_called_once() + call_args = mock_post.call_args + + # Check URL (passed as positional argument) + assert call_args[0][0] == f"{mock_kubernetes_api_server}/apis/authentication.k8s.io/v1/selfsubjectreviews" + + # Check headers (passed as keyword argument) + headers = call_args[1]["headers"] + assert headers["Authorization"] == f"Bearer {valid_token}" + assert headers["Content-Type"] == "application/json" + + # Check request body (passed as keyword argument) + request_body = call_args[1]["json"] + assert request_body["apiVersion"] == "authentication.k8s.io/v1" + assert request_body["kind"] == "SelfSubjectReview"