mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-25 21:57:45 +00:00
Merge cdda96f881
into cbe89d2bdd
This commit is contained in:
commit
1b06226527
4 changed files with 290 additions and 2 deletions
|
@ -342,6 +342,46 @@ server:
|
|||
```
|
||||
|
||||
The provider fetches user information from GitHub and maps it to access attributes based on the `claims_mapping` configuration.
|
||||
#### Kubernetes Authentication Provider
|
||||
|
||||
The server can be configured to use Kubernetes SelfSubjectReview API to validate tokens directly against the Kubernetes API server:
|
||||
|
||||
```yaml
|
||||
server:
|
||||
auth:
|
||||
provider_config:
|
||||
type: "kubernetes"
|
||||
api_server_url: https://kubernetes.default.svc
|
||||
claims_mapping:
|
||||
username: "roles"
|
||||
groups: "roles"
|
||||
uid: "uid_attr"
|
||||
verify_tls: true
|
||||
tls_cafile: "/path/to/ca.crt"
|
||||
```
|
||||
|
||||
Configuration options:
|
||||
- `api_server_url`: The Kubernetes API server URL (e.g., https://kubernetes.default.svc:6443)
|
||||
- `verify_tls`: Whether to verify TLS certificates (default: true)
|
||||
- `tls_cafile`: Path to CA certificate file for TLS verification
|
||||
- `claims_mapping`: Mapping of Kubernetes user claims to access attributes
|
||||
|
||||
The provider validates tokens by sending a SelfSubjectReview request to the Kubernetes API server at `/apis/authentication.k8s.io/v1/selfsubjectreviews`. The provider extracts user information from the response:
|
||||
- Username from the `userInfo.username` field
|
||||
- Groups from the `userInfo.groups` field
|
||||
- UID from the `userInfo.uid` field
|
||||
|
||||
To obtain a token for testing:
|
||||
```bash
|
||||
kubectl create namespace llama-stack
|
||||
kubectl create serviceaccount llama-stack-auth -n llama-stack
|
||||
kubectl create token llama-stack-auth -n llama-stack > llama-stack-auth-token
|
||||
```
|
||||
|
||||
You can validate a request by running:
|
||||
```bash
|
||||
curl -s -L -H "Authorization: Bearer $(cat llama-stack-auth-token)" http://127.0.0.1:8321/v1/providers
|
||||
```
|
||||
|
||||
#### Custom Provider
|
||||
Validates tokens against a custom authentication endpoint:
|
||||
|
|
|
@ -187,6 +187,7 @@ class AuthProviderType(StrEnum):
|
|||
OAUTH2_TOKEN = "oauth2_token"
|
||||
GITHUB_TOKEN = "github_token"
|
||||
CUSTOM = "custom"
|
||||
KUBERNETES = "kubernetes"
|
||||
|
||||
|
||||
class OAuth2TokenAuthConfig(BaseModel):
|
||||
|
@ -257,8 +258,35 @@ class GitHubTokenAuthConfig(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
class KubernetesAuthProviderConfig(BaseModel):
|
||||
"""Configuration for Kubernetes authentication provider."""
|
||||
|
||||
type: Literal[AuthProviderType.KUBERNETES] = AuthProviderType.KUBERNETES
|
||||
api_server_url: str = Field(
|
||||
default="https://kubernetes.default.svc",
|
||||
description="Kubernetes API server URL (e.g., https://api.cluster.domain:6443)",
|
||||
)
|
||||
verify_tls: bool = Field(default=True, description="Whether to verify TLS certificates")
|
||||
tls_cafile: Path | None = Field(default=None, description="Path to CA certificate file for TLS verification")
|
||||
claims_mapping: dict[str, str] = Field(
|
||||
default_factory=lambda: {
|
||||
"username": "roles",
|
||||
"groups": "roles",
|
||||
},
|
||||
description="Mapping of Kubernetes user claims to access attributes",
|
||||
)
|
||||
|
||||
@field_validator("claims_mapping")
|
||||
@classmethod
|
||||
def validate_claims_mapping(cls, v):
|
||||
for key, value in v.items():
|
||||
if not value:
|
||||
raise ValueError(f"claims_mapping value cannot be empty: {key}")
|
||||
return v
|
||||
|
||||
|
||||
AuthProviderConfig = Annotated[
|
||||
OAuth2TokenAuthConfig | GitHubTokenAuthConfig | CustomAuthConfig,
|
||||
OAuth2TokenAuthConfig | GitHubTokenAuthConfig | CustomAuthConfig | KubernetesAuthProviderConfig,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
|
|
@ -18,6 +18,7 @@ from llama_stack.distribution.datatypes import (
|
|||
AuthenticationConfig,
|
||||
CustomAuthConfig,
|
||||
GitHubTokenAuthConfig,
|
||||
KubernetesAuthProviderConfig,
|
||||
OAuth2TokenAuthConfig,
|
||||
User,
|
||||
)
|
||||
|
@ -176,7 +177,7 @@ class OAuth2TokenAuthProvider(AuthProvider):
|
|||
attributes=access_attributes,
|
||||
)
|
||||
except httpx.TimeoutException:
|
||||
logger.exception("Token introspection request timed out")
|
||||
logger.warning("Token introspection request timed out")
|
||||
raise
|
||||
except ValueError:
|
||||
# Re-raise ValueError exceptions to preserve their message
|
||||
|
@ -374,6 +375,90 @@ async def _get_github_user_info(access_token: str, github_api_base_url: str) ->
|
|||
}
|
||||
|
||||
|
||||
class KubernetesAuthProvider(AuthProvider):
|
||||
"""
|
||||
Kubernetes authentication provider that validates tokens using the Kubernetes SelfSubjectReview API.
|
||||
This provider integrates with Kubernetes API server by using the
|
||||
/apis/authentication.k8s.io/v1/selfsubjectreviews endpoint to validate tokens and extract user information.
|
||||
"""
|
||||
|
||||
def __init__(self, config: KubernetesAuthProviderConfig):
|
||||
self.config = config
|
||||
|
||||
async def validate_token(self, token: str, scope: dict | None = None) -> User:
|
||||
"""Validate a token using Kubernetes SelfSubjectReview API endpoint."""
|
||||
|
||||
# Configure SSL context
|
||||
ssl_ctxt = None
|
||||
if self.config.tls_cafile:
|
||||
ssl_ctxt = ssl.create_default_context(cafile=self.config.tls_cafile.as_posix())
|
||||
elif not self.config.verify_tls:
|
||||
ssl_ctxt = ssl.create_default_context()
|
||||
ssl_ctxt.check_hostname = False
|
||||
ssl_ctxt.verify_mode = ssl.CERT_NONE
|
||||
|
||||
# Build the Kubernetes SelfSubjectReview API endpoint URL
|
||||
review_api_url = f"{self.config.api_server_url.rstrip('/')}/apis/authentication.k8s.io/v1/selfsubjectreviews"
|
||||
|
||||
# Create SelfSubjectReview request body
|
||||
review_request = {"apiVersion": "authentication.k8s.io/v1", "kind": "SelfSubjectReview"}
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(verify=ssl_ctxt if ssl_ctxt else self.config.verify_tls) as client:
|
||||
response = await client.post(
|
||||
review_api_url,
|
||||
json=review_request,
|
||||
headers={
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
timeout=10.0,
|
||||
)
|
||||
|
||||
if response.status_code == 401:
|
||||
raise ValueError("Invalid token")
|
||||
elif response.status_code != 201:
|
||||
logger.warning(f"Kubernetes SelfSubjectReview API failed with status code: {response.status_code}")
|
||||
raise ValueError(f"Token validation failed: {response.status_code}")
|
||||
|
||||
review_response = response.json()
|
||||
|
||||
# Extract user information from SelfSubjectReview response
|
||||
status = review_response.get("status", {})
|
||||
if not status:
|
||||
raise ValueError("No status found in SelfSubjectReview response")
|
||||
|
||||
user_info = status.get("userInfo", {})
|
||||
if not user_info:
|
||||
raise ValueError("No userInfo found in SelfSubjectReview response")
|
||||
|
||||
username = user_info.get("username")
|
||||
if not username:
|
||||
raise ValueError("No username found in SelfSubjectReview response")
|
||||
|
||||
# Build user attributes from Kubernetes user info
|
||||
user_attributes = get_attributes_from_claims(user_info, self.config.claims_mapping)
|
||||
|
||||
return User(
|
||||
principal=username,
|
||||
attributes=user_attributes,
|
||||
)
|
||||
|
||||
except httpx.TimeoutException:
|
||||
logger.warning("Kubernetes SelfSubjectReview API request timed out")
|
||||
raise ValueError("Token validation timeout") from None
|
||||
except ValueError:
|
||||
# Re-raise ValueError exceptions to preserve their message
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.warning(f"Error during token validation: {str(e)}")
|
||||
raise ValueError(f"Token validation error: {str(e)}") from e
|
||||
|
||||
async def close(self):
|
||||
"""Close any resources."""
|
||||
pass
|
||||
|
||||
|
||||
def create_auth_provider(config: AuthenticationConfig) -> AuthProvider:
|
||||
"""Factory function to create the appropriate auth provider."""
|
||||
provider_config = config.provider_config
|
||||
|
@ -384,5 +469,7 @@ def create_auth_provider(config: AuthenticationConfig) -> AuthProvider:
|
|||
return OAuth2TokenAuthProvider(provider_config)
|
||||
elif isinstance(provider_config, GitHubTokenAuthConfig):
|
||||
return GitHubTokenAuthProvider(provider_config)
|
||||
elif isinstance(provider_config, KubernetesAuthProviderConfig):
|
||||
return KubernetesAuthProvider(provider_config)
|
||||
else:
|
||||
raise ValueError(f"Unknown authentication provider config type: {type(provider_config)}")
|
||||
|
|
|
@ -581,3 +581,136 @@ def test_valid_introspection_with_custom_mapping_authentication(
|
|||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"message": "Authentication successful"}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_kubernetes_api_server():
|
||||
return "https://api.cluster.example.com:6443"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def kubernetes_auth_app(mock_kubernetes_api_server):
|
||||
app = FastAPI()
|
||||
auth_config = AuthenticationConfig(
|
||||
provider_config={
|
||||
"type": "kubernetes",
|
||||
"api_server_url": mock_kubernetes_api_server,
|
||||
"verify_tls": False,
|
||||
"claims_mapping": {
|
||||
"username": "roles",
|
||||
"groups": "roles",
|
||||
"uid": "uid_attr",
|
||||
},
|
||||
},
|
||||
)
|
||||
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
|
||||
|
||||
@app.get("/test")
|
||||
def test_endpoint():
|
||||
return {"message": "Authentication successful"}
|
||||
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def kubernetes_auth_client(kubernetes_auth_app):
|
||||
return TestClient(kubernetes_auth_app)
|
||||
|
||||
|
||||
def test_missing_auth_header_kubernetes_auth(kubernetes_auth_client):
|
||||
response = kubernetes_auth_client.get("/test")
|
||||
assert response.status_code == 401
|
||||
assert "Authentication required" in response.json()["error"]["message"]
|
||||
|
||||
|
||||
def test_invalid_auth_header_format_kubernetes_auth(kubernetes_auth_client):
|
||||
response = kubernetes_auth_client.get("/test", headers={"Authorization": "InvalidFormat token123"})
|
||||
assert response.status_code == 401
|
||||
assert "Invalid Authorization header format" in response.json()["error"]["message"]
|
||||
|
||||
|
||||
async def mock_kubernetes_selfsubjectreview_success(*args, **kwargs):
|
||||
return MockResponse(
|
||||
201,
|
||||
{
|
||||
"apiVersion": "authentication.k8s.io/v1",
|
||||
"kind": "SelfSubjectReview",
|
||||
"metadata": {"creationTimestamp": "2025-07-15T13:53:56Z"},
|
||||
"status": {
|
||||
"userInfo": {
|
||||
"username": "alice",
|
||||
"uid": "alice-uid-123",
|
||||
"groups": ["system:authenticated", "developers", "admins"],
|
||||
"extra": {"scopes.authorization.openshift.io": ["user:full"]},
|
||||
}
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def mock_kubernetes_selfsubjectreview_failure(*args, **kwargs):
|
||||
return MockResponse(401, {"message": "Unauthorized"})
|
||||
|
||||
|
||||
async def mock_kubernetes_selfsubjectreview_http_error(*args, **kwargs):
|
||||
return MockResponse(500, {"message": "Internal Server Error"})
|
||||
|
||||
|
||||
@patch("httpx.AsyncClient.post", new=mock_kubernetes_selfsubjectreview_success)
|
||||
def test_valid_kubernetes_auth_authentication(kubernetes_auth_client, valid_token):
|
||||
response = kubernetes_auth_client.get("/test", headers={"Authorization": f"Bearer {valid_token}"})
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"message": "Authentication successful"}
|
||||
|
||||
|
||||
@patch("httpx.AsyncClient.post", new=mock_kubernetes_selfsubjectreview_failure)
|
||||
def test_invalid_kubernetes_auth_authentication(kubernetes_auth_client, invalid_token):
|
||||
response = kubernetes_auth_client.get("/test", headers={"Authorization": f"Bearer {invalid_token}"})
|
||||
assert response.status_code == 401
|
||||
assert "Invalid token" in response.json()["error"]["message"]
|
||||
|
||||
|
||||
@patch("httpx.AsyncClient.post", new=mock_kubernetes_selfsubjectreview_http_error)
|
||||
def test_kubernetes_auth_http_error(kubernetes_auth_client, valid_token):
|
||||
response = kubernetes_auth_client.get("/test", headers={"Authorization": f"Bearer {valid_token}"})
|
||||
assert response.status_code == 401
|
||||
assert "Token validation failed" in response.json()["error"]["message"]
|
||||
|
||||
|
||||
def test_kubernetes_auth_request_payload(kubernetes_auth_client, valid_token, mock_kubernetes_api_server):
|
||||
with patch("httpx.AsyncClient.post") as mock_post:
|
||||
mock_response = MockResponse(
|
||||
200,
|
||||
{
|
||||
"apiVersion": "authentication.k8s.io/v1",
|
||||
"kind": "SelfSubjectReview",
|
||||
"metadata": {"creationTimestamp": "2025-07-15T13:53:56Z"},
|
||||
"status": {
|
||||
"userInfo": {
|
||||
"username": "test-user",
|
||||
"uid": "test-uid",
|
||||
"groups": ["test-group"],
|
||||
}
|
||||
},
|
||||
},
|
||||
)
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
kubernetes_auth_client.get("/test", headers={"Authorization": f"Bearer {valid_token}"})
|
||||
|
||||
# Verify the request was made with correct parameters
|
||||
mock_post.assert_called_once()
|
||||
call_args = mock_post.call_args
|
||||
|
||||
# Check URL (passed as positional argument)
|
||||
assert call_args[0][0] == f"{mock_kubernetes_api_server}/apis/authentication.k8s.io/v1/selfsubjectreviews"
|
||||
|
||||
# Check headers (passed as keyword argument)
|
||||
headers = call_args[1]["headers"]
|
||||
assert headers["Authorization"] == f"Bearer {valid_token}"
|
||||
assert headers["Content-Type"] == "application/json"
|
||||
|
||||
# Check request body (passed as keyword argument)
|
||||
request_body = call_args[1]["json"]
|
||||
assert request_body["apiVersion"] == "authentication.k8s.io/v1"
|
||||
assert request_body["kind"] == "SelfSubjectReview"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue