diff --git a/docs/source/distributions/configuration.md b/docs/source/distributions/configuration.md index 9548780c6..60da9d5ae 100644 --- a/docs/source/distributions/configuration.md +++ b/docs/source/distributions/configuration.md @@ -342,6 +342,46 @@ server: ``` The provider fetches user information from GitHub and maps it to access attributes based on the `claims_mapping` configuration. +#### Kubernetes Authentication Provider + +The server can be configured to use Kubernetes SelfSubjectReview API to validate tokens directly against the Kubernetes API server: + +```yaml +server: + auth: + provider_config: + type: "kubernetes" + api_server_url: https://kubernetes.default.svc + claims_mapping: + username: "roles" + groups: "roles" + uid: "uid_attr" + verify_tls: true + tls_cafile: "/path/to/ca.crt" +``` + +Configuration options: +- `api_server_url`: The Kubernetes API server URL (e.g., https://kubernetes.default.svc:6443) +- `verify_tls`: Whether to verify TLS certificates (default: true) +- `tls_cafile`: Path to CA certificate file for TLS verification +- `claims_mapping`: Mapping of Kubernetes user claims to access attributes + +The provider validates tokens by sending a SelfSubjectReview request to the Kubernetes API server at `/apis/authentication.k8s.io/v1/selfsubjectreviews`. The provider extracts user information from the response: +- Username from the `userInfo.username` field +- Groups from the `userInfo.groups` field +- UID from the `userInfo.uid` field + +To obtain a token for testing: +```bash +kubectl create namespace llama-stack +kubectl create serviceaccount llama-stack-auth -n llama-stack +kubectl create token llama-stack-auth -n llama-stack > llama-stack-auth-token +``` + +You can validate a request by running: +```bash +curl -s -L -H "Authorization: Bearer $(cat llama-stack-auth-token)" http://127.0.0.1:8321/v1/providers +``` #### Custom Provider Validates tokens against a custom authentication endpoint: diff --git a/llama_stack/distribution/datatypes.py b/llama_stack/distribution/datatypes.py index ead1331f3..6d09d41da 100644 --- a/llama_stack/distribution/datatypes.py +++ b/llama_stack/distribution/datatypes.py @@ -181,6 +181,7 @@ class AuthProviderType(StrEnum): OAUTH2_TOKEN = "oauth2_token" GITHUB_TOKEN = "github_token" CUSTOM = "custom" + KUBERNETES = "kubernetes" class OAuth2TokenAuthConfig(BaseModel): @@ -251,8 +252,35 @@ class GitHubTokenAuthConfig(BaseModel): ) +class KubernetesAuthProviderConfig(BaseModel): + """Configuration for Kubernetes authentication provider.""" + + type: Literal[AuthProviderType.KUBERNETES] = AuthProviderType.KUBERNETES + api_server_url: str = Field( + default="https://kubernetes.default.svc", + description="Kubernetes API server URL (e.g., https://api.cluster.domain:6443)", + ) + verify_tls: bool = Field(default=True, description="Whether to verify TLS certificates") + tls_cafile: Path | None = Field(default=None, description="Path to CA certificate file for TLS verification") + claims_mapping: dict[str, str] = Field( + default_factory=lambda: { + "username": "roles", + "groups": "roles", + }, + description="Mapping of Kubernetes user claims to access attributes", + ) + + @field_validator("claims_mapping") + @classmethod + def validate_claims_mapping(cls, v): + for key, value in v.items(): + if not value: + raise ValueError(f"claims_mapping value cannot be empty: {key}") + return v + + AuthProviderConfig = Annotated[ - OAuth2TokenAuthConfig | GitHubTokenAuthConfig | CustomAuthConfig, + OAuth2TokenAuthConfig | GitHubTokenAuthConfig | CustomAuthConfig | KubernetesAuthProviderConfig, Field(discriminator="type"), ] diff --git a/llama_stack/distribution/server/auth_providers.py b/llama_stack/distribution/server/auth_providers.py index 9b0e182f5..015078da7 100644 --- a/llama_stack/distribution/server/auth_providers.py +++ b/llama_stack/distribution/server/auth_providers.py @@ -18,6 +18,7 @@ from llama_stack.distribution.datatypes import ( AuthenticationConfig, CustomAuthConfig, GitHubTokenAuthConfig, + KubernetesAuthProviderConfig, OAuth2TokenAuthConfig, User, ) @@ -176,7 +177,7 @@ class OAuth2TokenAuthProvider(AuthProvider): attributes=access_attributes, ) except httpx.TimeoutException: - logger.exception("Token introspection request timed out") + logger.warning("Token introspection request timed out") raise except ValueError: # Re-raise ValueError exceptions to preserve their message @@ -374,6 +375,90 @@ async def _get_github_user_info(access_token: str, github_api_base_url: str) -> } +class KubernetesAuthProvider(AuthProvider): + """ + Kubernetes authentication provider that validates tokens using the Kubernetes SelfSubjectReview API. + This provider integrates with Kubernetes API server by using the + /apis/authentication.k8s.io/v1/selfsubjectreviews endpoint to validate tokens and extract user information. + """ + + def __init__(self, config: KubernetesAuthProviderConfig): + self.config = config + + async def validate_token(self, token: str, scope: dict | None = None) -> User: + """Validate a token using Kubernetes SelfSubjectReview API endpoint.""" + + # Configure SSL context + ssl_ctxt = None + if self.config.tls_cafile: + ssl_ctxt = ssl.create_default_context(cafile=self.config.tls_cafile.as_posix()) + elif not self.config.verify_tls: + ssl_ctxt = ssl.create_default_context() + ssl_ctxt.check_hostname = False + ssl_ctxt.verify_mode = ssl.CERT_NONE + + # Build the Kubernetes SelfSubjectReview API endpoint URL + review_api_url = f"{self.config.api_server_url.rstrip('/')}/apis/authentication.k8s.io/v1/selfsubjectreviews" + + # Create SelfSubjectReview request body + review_request = {"apiVersion": "authentication.k8s.io/v1", "kind": "SelfSubjectReview"} + + try: + async with httpx.AsyncClient(verify=ssl_ctxt if ssl_ctxt else self.config.verify_tls) as client: + response = await client.post( + review_api_url, + json=review_request, + headers={ + "Authorization": f"Bearer {token}", + "Content-Type": "application/json", + }, + timeout=10.0, + ) + + if response.status_code == 401: + raise ValueError("Invalid token") + elif response.status_code != 201: + logger.warning(f"Kubernetes SelfSubjectReview API failed with status code: {response.status_code}") + raise ValueError(f"Token validation failed: {response.status_code}") + + review_response = response.json() + + # Extract user information from SelfSubjectReview response + status = review_response.get("status", {}) + if not status: + raise ValueError("No status found in SelfSubjectReview response") + + user_info = status.get("userInfo", {}) + if not user_info: + raise ValueError("No userInfo found in SelfSubjectReview response") + + username = user_info.get("username") + if not username: + raise ValueError("No username found in SelfSubjectReview response") + + # Build user attributes from Kubernetes user info + user_attributes = get_attributes_from_claims(user_info, self.config.claims_mapping) + + return User( + principal=username, + attributes=user_attributes, + ) + + except httpx.TimeoutException: + logger.warning("Kubernetes SelfSubjectReview API request timed out") + raise ValueError("Token validation timeout") from None + except ValueError: + # Re-raise ValueError exceptions to preserve their message + raise + except Exception as e: + logger.warning(f"Error during token validation: {str(e)}") + raise ValueError(f"Token validation error: {str(e)}") from e + + async def close(self): + """Close any resources.""" + pass + + def create_auth_provider(config: AuthenticationConfig) -> AuthProvider: """Factory function to create the appropriate auth provider.""" provider_config = config.provider_config @@ -384,5 +469,7 @@ def create_auth_provider(config: AuthenticationConfig) -> AuthProvider: return OAuth2TokenAuthProvider(provider_config) elif isinstance(provider_config, GitHubTokenAuthConfig): return GitHubTokenAuthProvider(provider_config) + elif isinstance(provider_config, KubernetesAuthProviderConfig): + return KubernetesAuthProvider(provider_config) else: raise ValueError(f"Unknown authentication provider config type: {type(provider_config)}") diff --git a/tests/unit/server/test_auth.py b/tests/unit/server/test_auth.py index 7012a7f17..5b599a3d9 100644 --- a/tests/unit/server/test_auth.py +++ b/tests/unit/server/test_auth.py @@ -581,3 +581,136 @@ def test_valid_introspection_with_custom_mapping_authentication( ) assert response.status_code == 200 assert response.json() == {"message": "Authentication successful"} + + +@pytest.fixture +def mock_kubernetes_api_server(): + return "https://api.cluster.example.com:6443" + + +@pytest.fixture +def kubernetes_auth_app(mock_kubernetes_api_server): + app = FastAPI() + auth_config = AuthenticationConfig( + provider_config={ + "type": "kubernetes", + "api_server_url": mock_kubernetes_api_server, + "verify_tls": False, + "claims_mapping": { + "username": "roles", + "groups": "roles", + "uid": "uid_attr", + }, + }, + ) + app.add_middleware(AuthenticationMiddleware, auth_config=auth_config) + + @app.get("/test") + def test_endpoint(): + return {"message": "Authentication successful"} + + return app + + +@pytest.fixture +def kubernetes_auth_client(kubernetes_auth_app): + return TestClient(kubernetes_auth_app) + + +def test_missing_auth_header_kubernetes_auth(kubernetes_auth_client): + response = kubernetes_auth_client.get("/test") + assert response.status_code == 401 + assert "Authentication required" in response.json()["error"]["message"] + + +def test_invalid_auth_header_format_kubernetes_auth(kubernetes_auth_client): + response = kubernetes_auth_client.get("/test", headers={"Authorization": "InvalidFormat token123"}) + assert response.status_code == 401 + assert "Invalid Authorization header format" in response.json()["error"]["message"] + + +async def mock_kubernetes_selfsubjectreview_success(*args, **kwargs): + return MockResponse( + 201, + { + "apiVersion": "authentication.k8s.io/v1", + "kind": "SelfSubjectReview", + "metadata": {"creationTimestamp": "2025-07-15T13:53:56Z"}, + "status": { + "userInfo": { + "username": "alice", + "uid": "alice-uid-123", + "groups": ["system:authenticated", "developers", "admins"], + "extra": {"scopes.authorization.openshift.io": ["user:full"]}, + } + }, + }, + ) + + +async def mock_kubernetes_selfsubjectreview_failure(*args, **kwargs): + return MockResponse(401, {"message": "Unauthorized"}) + + +async def mock_kubernetes_selfsubjectreview_http_error(*args, **kwargs): + return MockResponse(500, {"message": "Internal Server Error"}) + + +@patch("httpx.AsyncClient.post", new=mock_kubernetes_selfsubjectreview_success) +def test_valid_kubernetes_auth_authentication(kubernetes_auth_client, valid_token): + response = kubernetes_auth_client.get("/test", headers={"Authorization": f"Bearer {valid_token}"}) + assert response.status_code == 200 + assert response.json() == {"message": "Authentication successful"} + + +@patch("httpx.AsyncClient.post", new=mock_kubernetes_selfsubjectreview_failure) +def test_invalid_kubernetes_auth_authentication(kubernetes_auth_client, invalid_token): + response = kubernetes_auth_client.get("/test", headers={"Authorization": f"Bearer {invalid_token}"}) + assert response.status_code == 401 + assert "Invalid token" in response.json()["error"]["message"] + + +@patch("httpx.AsyncClient.post", new=mock_kubernetes_selfsubjectreview_http_error) +def test_kubernetes_auth_http_error(kubernetes_auth_client, valid_token): + response = kubernetes_auth_client.get("/test", headers={"Authorization": f"Bearer {valid_token}"}) + assert response.status_code == 401 + assert "Token validation failed" in response.json()["error"]["message"] + + +def test_kubernetes_auth_request_payload(kubernetes_auth_client, valid_token, mock_kubernetes_api_server): + with patch("httpx.AsyncClient.post") as mock_post: + mock_response = MockResponse( + 200, + { + "apiVersion": "authentication.k8s.io/v1", + "kind": "SelfSubjectReview", + "metadata": {"creationTimestamp": "2025-07-15T13:53:56Z"}, + "status": { + "userInfo": { + "username": "test-user", + "uid": "test-uid", + "groups": ["test-group"], + } + }, + }, + ) + mock_post.return_value = mock_response + + kubernetes_auth_client.get("/test", headers={"Authorization": f"Bearer {valid_token}"}) + + # Verify the request was made with correct parameters + mock_post.assert_called_once() + call_args = mock_post.call_args + + # Check URL (passed as positional argument) + assert call_args[0][0] == f"{mock_kubernetes_api_server}/apis/authentication.k8s.io/v1/selfsubjectreviews" + + # Check headers (passed as keyword argument) + headers = call_args[1]["headers"] + assert headers["Authorization"] == f"Bearer {valid_token}" + assert headers["Content-Type"] == "application/json" + + # Check request body (passed as keyword argument) + request_body = call_args[1]["json"] + assert request_body["apiVersion"] == "authentication.k8s.io/v1" + assert request_body["kind"] == "SelfSubjectReview"