mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 12:07:34 +00:00
Add Kubernetes authentication provider support
- Add KubernetesAuthProvider class for token validation using Kubernetes SelfSubjectReview API - Add KubernetesAuthProviderConfig with configurable API server URL, TLS settings, and claims mapping - Implement authentication via POST requests to /apis/authentication.k8s.io/v1/selfsubjectreviews endpoint - Add support for parsing Kubernetes SelfSubjectReview response format to extract user information - Add KUBERNETES provider type to AuthProviderType enum - Update create_auth_provider factory function to handle 'kubernetes' provider type - Add comprehensive unit tests for KubernetesAuthProvider functionality - Add documentation with configuration examples and usage instructions The provider validates tokens by sending SelfSubjectReview requests to the Kubernetes API server and extracts user information from the userInfo structure in the response. Signed-off-by: Akram Ben Aissi <akram.benaissi@gmail.com>
This commit is contained in:
parent
44e1a40595
commit
e0a47f86b3
5 changed files with 311 additions and 4 deletions
|
@ -354,6 +354,47 @@ You can easily validate a request by running:
|
||||||
curl -s -L -H "Authorization: Bearer $(cat llama-stack-auth-token)" http://127.0.0.1:8321/v1/providers
|
curl -s -L -H "Authorization: Bearer $(cat llama-stack-auth-token)" http://127.0.0.1:8321/v1/providers
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### Kubernetes Authentication Provider
|
||||||
|
|
||||||
|
The server can be configured to use Kubernetes SelfSubjectReview API to validate tokens directly against the Kubernetes API server:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
server:
|
||||||
|
auth:
|
||||||
|
provider_config:
|
||||||
|
type: "kubernetes"
|
||||||
|
api_server_url: "https://kubernetes.default.svc"
|
||||||
|
claims_mapping:
|
||||||
|
username: "roles"
|
||||||
|
groups: "roles"
|
||||||
|
uid: "uid_attr"
|
||||||
|
verify_tls: true
|
||||||
|
tls_cafile: "/path/to/ca.crt"
|
||||||
|
```
|
||||||
|
|
||||||
|
Configuration options:
|
||||||
|
- `api_server_url`: The Kubernetes API server URL (e.g., https://kubernetes.default.svc:6443)
|
||||||
|
- `verify_tls`: Whether to verify TLS certificates (default: true)
|
||||||
|
- `tls_cafile`: Path to CA certificate file for TLS verification
|
||||||
|
- `claims_mapping`: Mapping of Kubernetes user claims to access attributes
|
||||||
|
|
||||||
|
The provider validates tokens by sending a SelfSubjectReview request to the Kubernetes API server at `/apis/authentication.k8s.io/v1/selfsubjectreviews`. The provider extracts user information from the response:
|
||||||
|
- Username from the `userInfo.username` field
|
||||||
|
- Groups from the `userInfo.groups` field
|
||||||
|
- UID from the `userInfo.uid` field
|
||||||
|
|
||||||
|
To obtain a token for testing:
|
||||||
|
```bash
|
||||||
|
kubectl create namespace llama-stack
|
||||||
|
kubectl create serviceaccount llama-stack-auth -n llama-stack
|
||||||
|
kubectl create token llama-stack-auth -n llama-stack > llama-stack-auth-token
|
||||||
|
```
|
||||||
|
|
||||||
|
You can validate a request by running:
|
||||||
|
```bash
|
||||||
|
curl -s -L -H "Authorization: Bearer $(cat llama-stack-auth-token)" http://127.0.0.1:8321/v1/providers
|
||||||
|
```
|
||||||
|
|
||||||
#### GitHub Token Provider
|
#### GitHub Token Provider
|
||||||
Validates GitHub personal access tokens or OAuth tokens directly:
|
Validates GitHub personal access tokens or OAuth tokens directly:
|
||||||
```yaml
|
```yaml
|
||||||
|
|
|
@ -79,3 +79,10 @@ class ConflictError(ValueError):
|
||||||
|
|
||||||
def __init__(self, message: str) -> None:
|
def __init__(self, message: str) -> None:
|
||||||
super().__init__(message)
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class TokenValidationError(ValueError):
|
||||||
|
"""raised when token validation fails during authentication"""
|
||||||
|
|
||||||
|
def __init__(self, message: str) -> None:
|
||||||
|
super().__init__(message)
|
||||||
|
|
|
@ -7,6 +7,7 @@
|
||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Annotated, Any, Literal, Self
|
from typing import Annotated, Any, Literal, Self
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||||
|
|
||||||
|
@ -212,6 +213,7 @@ class AuthProviderType(StrEnum):
|
||||||
OAUTH2_TOKEN = "oauth2_token"
|
OAUTH2_TOKEN = "oauth2_token"
|
||||||
GITHUB_TOKEN = "github_token"
|
GITHUB_TOKEN = "github_token"
|
||||||
CUSTOM = "custom"
|
CUSTOM = "custom"
|
||||||
|
KUBERNETES = "kubernetes"
|
||||||
|
|
||||||
|
|
||||||
class OAuth2TokenAuthConfig(BaseModel):
|
class OAuth2TokenAuthConfig(BaseModel):
|
||||||
|
@ -282,8 +284,45 @@ class GitHubTokenAuthConfig(BaseModel):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class KubernetesAuthProviderConfig(BaseModel):
|
||||||
|
"""Configuration for Kubernetes authentication provider."""
|
||||||
|
|
||||||
|
type: Literal[AuthProviderType.KUBERNETES] = AuthProviderType.KUBERNETES
|
||||||
|
api_server_url: str = Field(
|
||||||
|
default="https://kubernetes.default.svc",
|
||||||
|
description="Kubernetes API server URL (e.g., https://api.cluster.domain:6443)",
|
||||||
|
)
|
||||||
|
verify_tls: bool = Field(default=True, description="Whether to verify TLS certificates")
|
||||||
|
tls_cafile: Path | None = Field(default=None, description="Path to CA certificate file for TLS verification")
|
||||||
|
claims_mapping: dict[str, str] = Field(
|
||||||
|
default_factory=lambda: {
|
||||||
|
"username": "roles",
|
||||||
|
"groups": "roles",
|
||||||
|
},
|
||||||
|
description="Mapping of Kubernetes user claims to access attributes",
|
||||||
|
)
|
||||||
|
|
||||||
|
@field_validator("api_server_url")
|
||||||
|
@classmethod
|
||||||
|
def validate_api_server_url(cls, v):
|
||||||
|
parsed = urlparse(v)
|
||||||
|
if not parsed.scheme or not parsed.netloc:
|
||||||
|
raise ValueError(f"api_server_url must be a valid URL with scheme and host: {v}")
|
||||||
|
if parsed.scheme not in ["http", "https"]:
|
||||||
|
raise ValueError(f"api_server_url scheme must be http or https: {v}")
|
||||||
|
return v
|
||||||
|
|
||||||
|
@field_validator("claims_mapping")
|
||||||
|
@classmethod
|
||||||
|
def validate_claims_mapping(cls, v):
|
||||||
|
for key, value in v.items():
|
||||||
|
if not value:
|
||||||
|
raise ValueError(f"claims_mapping value cannot be empty: {key}")
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
AuthProviderConfig = Annotated[
|
AuthProviderConfig = Annotated[
|
||||||
OAuth2TokenAuthConfig | GitHubTokenAuthConfig | CustomAuthConfig,
|
OAuth2TokenAuthConfig | GitHubTokenAuthConfig | CustomAuthConfig | KubernetesAuthProviderConfig,
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
@ -8,16 +8,18 @@ import ssl
|
||||||
import time
|
import time
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from asyncio import Lock
|
from asyncio import Lock
|
||||||
from urllib.parse import parse_qs, urlparse
|
from urllib.parse import parse_qs, urljoin, urlparse
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from jose import jwt
|
from jose import jwt
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from llama_stack.apis.common.errors import TokenValidationError
|
||||||
from llama_stack.core.datatypes import (
|
from llama_stack.core.datatypes import (
|
||||||
AuthenticationConfig,
|
AuthenticationConfig,
|
||||||
CustomAuthConfig,
|
CustomAuthConfig,
|
||||||
GitHubTokenAuthConfig,
|
GitHubTokenAuthConfig,
|
||||||
|
KubernetesAuthProviderConfig,
|
||||||
OAuth2TokenAuthConfig,
|
OAuth2TokenAuthConfig,
|
||||||
User,
|
User,
|
||||||
)
|
)
|
||||||
|
@ -162,7 +164,7 @@ class OAuth2TokenAuthProvider(AuthProvider):
|
||||||
auth=auth,
|
auth=auth,
|
||||||
timeout=10.0, # Add a reasonable timeout
|
timeout=10.0, # Add a reasonable timeout
|
||||||
)
|
)
|
||||||
if response.status_code != 200:
|
if response.status_code != httpx.codes.OK:
|
||||||
logger.warning(f"Token introspection failed with status code: {response.status_code}")
|
logger.warning(f"Token introspection failed with status code: {response.status_code}")
|
||||||
raise ValueError(f"Token introspection failed: {response.status_code}")
|
raise ValueError(f"Token introspection failed: {response.status_code}")
|
||||||
|
|
||||||
|
@ -272,7 +274,7 @@ class CustomAuthProvider(AuthProvider):
|
||||||
json=auth_request.model_dump(),
|
json=auth_request.model_dump(),
|
||||||
timeout=10.0, # Add a reasonable timeout
|
timeout=10.0, # Add a reasonable timeout
|
||||||
)
|
)
|
||||||
if response.status_code != 200:
|
if response.status_code != httpx.codes.OK:
|
||||||
logger.warning(f"Authentication failed with status code: {response.status_code}")
|
logger.warning(f"Authentication failed with status code: {response.status_code}")
|
||||||
raise ValueError(f"Authentication failed: {response.status_code}")
|
raise ValueError(f"Authentication failed: {response.status_code}")
|
||||||
|
|
||||||
|
@ -374,6 +376,89 @@ async def _get_github_user_info(access_token: str, github_api_base_url: str) ->
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class KubernetesAuthProvider(AuthProvider):
|
||||||
|
"""
|
||||||
|
Kubernetes authentication provider that validates tokens using the Kubernetes SelfSubjectReview API.
|
||||||
|
This provider integrates with Kubernetes API server by using the
|
||||||
|
/apis/authentication.k8s.io/v1/selfsubjectreviews endpoint to validate tokens and extract user information.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: KubernetesAuthProviderConfig):
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
def _httpx_verify_value(self) -> bool | str:
|
||||||
|
"""
|
||||||
|
Build the value for httpx's `verify` parameter.
|
||||||
|
- False disables verification.
|
||||||
|
- Path string points to a CA bundle.
|
||||||
|
- True uses system defaults.
|
||||||
|
"""
|
||||||
|
if not self.config.verify_tls:
|
||||||
|
return False
|
||||||
|
if self.config.tls_cafile:
|
||||||
|
return self.config.tls_cafile.as_posix()
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def validate_token(self, token: str, scope: dict | None = None) -> User:
|
||||||
|
"""Validate a token using Kubernetes SelfSubjectReview API endpoint."""
|
||||||
|
# Build the Kubernetes SelfSubjectReview API endpoint URL
|
||||||
|
review_api_url = urljoin(self.config.api_server_url, "/apis/authentication.k8s.io/v1/selfsubjectreviews")
|
||||||
|
|
||||||
|
# Create SelfSubjectReview request body
|
||||||
|
review_request = {"apiVersion": "authentication.k8s.io/v1", "kind": "SelfSubjectReview"}
|
||||||
|
verify = self._httpx_verify_value()
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient(verify=verify, timeout=10.0) as client:
|
||||||
|
response = await client.post(
|
||||||
|
review_api_url,
|
||||||
|
json=review_request,
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {token}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code == httpx.codes.UNAUTHORIZED:
|
||||||
|
raise TokenValidationError("Invalid token")
|
||||||
|
if response.status_code != httpx.codes.CREATED:
|
||||||
|
logger.warning(f"Kubernetes SelfSubjectReview API failed with status code: {response.status_code}")
|
||||||
|
raise TokenValidationError(f"Token validation failed: {response.status_code}")
|
||||||
|
|
||||||
|
review_response = response.json()
|
||||||
|
# Extract user information from SelfSubjectReview response
|
||||||
|
status = review_response.get("status", {})
|
||||||
|
if not status:
|
||||||
|
raise ValueError("No status found in SelfSubjectReview response")
|
||||||
|
|
||||||
|
user_info = status.get("userInfo", {})
|
||||||
|
if not user_info:
|
||||||
|
raise ValueError("No userInfo found in SelfSubjectReview response")
|
||||||
|
|
||||||
|
username = user_info.get("username")
|
||||||
|
if not username:
|
||||||
|
raise ValueError("No username found in SelfSubjectReview response")
|
||||||
|
|
||||||
|
# Build user attributes from Kubernetes user info
|
||||||
|
user_attributes = get_attributes_from_claims(user_info, self.config.claims_mapping)
|
||||||
|
|
||||||
|
return User(
|
||||||
|
principal=username,
|
||||||
|
attributes=user_attributes,
|
||||||
|
)
|
||||||
|
|
||||||
|
except httpx.TimeoutException:
|
||||||
|
logger.warning("Kubernetes SelfSubjectReview API request timed out")
|
||||||
|
raise ValueError("Token validation timeout") from None
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error during token validation: {str(e)}")
|
||||||
|
raise ValueError(f"Token validation error: {str(e)}") from e
|
||||||
|
|
||||||
|
async def close(self):
|
||||||
|
"""Close any resources."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def create_auth_provider(config: AuthenticationConfig) -> AuthProvider:
|
def create_auth_provider(config: AuthenticationConfig) -> AuthProvider:
|
||||||
"""Factory function to create the appropriate auth provider."""
|
"""Factory function to create the appropriate auth provider."""
|
||||||
provider_config = config.provider_config
|
provider_config = config.provider_config
|
||||||
|
@ -384,5 +469,7 @@ def create_auth_provider(config: AuthenticationConfig) -> AuthProvider:
|
||||||
return OAuth2TokenAuthProvider(provider_config)
|
return OAuth2TokenAuthProvider(provider_config)
|
||||||
elif isinstance(provider_config, GitHubTokenAuthConfig):
|
elif isinstance(provider_config, GitHubTokenAuthConfig):
|
||||||
return GitHubTokenAuthProvider(provider_config)
|
return GitHubTokenAuthProvider(provider_config)
|
||||||
|
elif isinstance(provider_config, KubernetesAuthProviderConfig):
|
||||||
|
return KubernetesAuthProvider(provider_config)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown authentication provider config type: {type(provider_config)}")
|
raise ValueError(f"Unknown authentication provider config type: {type(provider_config)}")
|
||||||
|
|
|
@ -774,3 +774,136 @@ def test_has_required_scope_function():
|
||||||
|
|
||||||
# Test no user (auth disabled)
|
# Test no user (auth disabled)
|
||||||
assert _has_required_scope("test.read", None)
|
assert _has_required_scope("test.read", None)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_kubernetes_api_server():
|
||||||
|
return "https://api.cluster.example.com:6443"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def kubernetes_auth_app(mock_kubernetes_api_server):
|
||||||
|
app = FastAPI()
|
||||||
|
auth_config = AuthenticationConfig(
|
||||||
|
provider_config={
|
||||||
|
"type": "kubernetes",
|
||||||
|
"api_server_url": mock_kubernetes_api_server,
|
||||||
|
"verify_tls": False,
|
||||||
|
"claims_mapping": {
|
||||||
|
"username": "roles",
|
||||||
|
"groups": "roles",
|
||||||
|
"uid": "uid_attr",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config, impls={})
|
||||||
|
|
||||||
|
@app.get("/test")
|
||||||
|
def test_endpoint():
|
||||||
|
return {"message": "Authentication successful"}
|
||||||
|
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def kubernetes_auth_client(kubernetes_auth_app):
|
||||||
|
return TestClient(kubernetes_auth_app)
|
||||||
|
|
||||||
|
|
||||||
|
def test_missing_auth_header_kubernetes_auth(kubernetes_auth_client):
|
||||||
|
response = kubernetes_auth_client.get("/test")
|
||||||
|
assert response.status_code == 401
|
||||||
|
assert "Authentication required" in response.json()["error"]["message"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_invalid_auth_header_format_kubernetes_auth(kubernetes_auth_client):
|
||||||
|
response = kubernetes_auth_client.get("/test", headers={"Authorization": "InvalidFormat token123"})
|
||||||
|
assert response.status_code == 401
|
||||||
|
assert "Invalid Authorization header format" in response.json()["error"]["message"]
|
||||||
|
|
||||||
|
|
||||||
|
async def mock_kubernetes_selfsubjectreview_success(*args, **kwargs):
|
||||||
|
return MockResponse(
|
||||||
|
201,
|
||||||
|
{
|
||||||
|
"apiVersion": "authentication.k8s.io/v1",
|
||||||
|
"kind": "SelfSubjectReview",
|
||||||
|
"metadata": {"creationTimestamp": "2025-07-15T13:53:56Z"},
|
||||||
|
"status": {
|
||||||
|
"userInfo": {
|
||||||
|
"username": "alice",
|
||||||
|
"uid": "alice-uid-123",
|
||||||
|
"groups": ["system:authenticated", "developers", "admins"],
|
||||||
|
"extra": {"scopes.authorization.openshift.io": ["user:full"]},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def mock_kubernetes_selfsubjectreview_failure(*args, **kwargs):
|
||||||
|
return MockResponse(401, {"message": "Unauthorized"})
|
||||||
|
|
||||||
|
|
||||||
|
async def mock_kubernetes_selfsubjectreview_http_error(*args, **kwargs):
|
||||||
|
return MockResponse(500, {"message": "Internal Server Error"})
|
||||||
|
|
||||||
|
|
||||||
|
@patch("httpx.AsyncClient.post", new=mock_kubernetes_selfsubjectreview_success)
|
||||||
|
def test_valid_kubernetes_auth_authentication(kubernetes_auth_client, valid_token):
|
||||||
|
response = kubernetes_auth_client.get("/test", headers={"Authorization": f"Bearer {valid_token}"})
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == {"message": "Authentication successful"}
|
||||||
|
|
||||||
|
|
||||||
|
@patch("httpx.AsyncClient.post", new=mock_kubernetes_selfsubjectreview_failure)
|
||||||
|
def test_invalid_kubernetes_auth_authentication(kubernetes_auth_client, invalid_token):
|
||||||
|
response = kubernetes_auth_client.get("/test", headers={"Authorization": f"Bearer {invalid_token}"})
|
||||||
|
assert response.status_code == 401
|
||||||
|
assert "Invalid token" in response.json()["error"]["message"]
|
||||||
|
|
||||||
|
|
||||||
|
@patch("httpx.AsyncClient.post", new=mock_kubernetes_selfsubjectreview_http_error)
|
||||||
|
def test_kubernetes_auth_http_error(kubernetes_auth_client, valid_token):
|
||||||
|
response = kubernetes_auth_client.get("/test", headers={"Authorization": f"Bearer {valid_token}"})
|
||||||
|
assert response.status_code == 401
|
||||||
|
assert "Token validation failed" in response.json()["error"]["message"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_kubernetes_auth_request_payload(kubernetes_auth_client, valid_token, mock_kubernetes_api_server):
|
||||||
|
with patch("httpx.AsyncClient.post") as mock_post:
|
||||||
|
mock_response = MockResponse(
|
||||||
|
200,
|
||||||
|
{
|
||||||
|
"apiVersion": "authentication.k8s.io/v1",
|
||||||
|
"kind": "SelfSubjectReview",
|
||||||
|
"metadata": {"creationTimestamp": "2025-07-15T13:53:56Z"},
|
||||||
|
"status": {
|
||||||
|
"userInfo": {
|
||||||
|
"username": "test-user",
|
||||||
|
"uid": "test-uid",
|
||||||
|
"groups": ["test-group"],
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
mock_post.return_value = mock_response
|
||||||
|
|
||||||
|
kubernetes_auth_client.get("/test", headers={"Authorization": f"Bearer {valid_token}"})
|
||||||
|
|
||||||
|
# Verify the request was made with correct parameters
|
||||||
|
mock_post.assert_called_once()
|
||||||
|
call_args = mock_post.call_args
|
||||||
|
|
||||||
|
# Check URL (passed as positional argument)
|
||||||
|
assert call_args[0][0] == f"{mock_kubernetes_api_server}/apis/authentication.k8s.io/v1/selfsubjectreviews"
|
||||||
|
|
||||||
|
# Check headers (passed as keyword argument)
|
||||||
|
headers = call_args[1]["headers"]
|
||||||
|
assert headers["Authorization"] == f"Bearer {valid_token}"
|
||||||
|
assert headers["Content-Type"] == "application/json"
|
||||||
|
|
||||||
|
# Check request body (passed as keyword argument)
|
||||||
|
request_body = call_args[1]["json"]
|
||||||
|
assert request_body["apiVersion"] == "authentication.k8s.io/v1"
|
||||||
|
assert request_body["kind"] == "SelfSubjectReview"
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue