mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-27 06:28:50 +00:00
Merge cdda96f881
into cbe89d2bdd
This commit is contained in:
commit
1b06226527
4 changed files with 290 additions and 2 deletions
|
@ -342,6 +342,46 @@ server:
|
||||||
```
|
```
|
||||||
|
|
||||||
The provider fetches user information from GitHub and maps it to access attributes based on the `claims_mapping` configuration.
|
The provider fetches user information from GitHub and maps it to access attributes based on the `claims_mapping` configuration.
|
||||||
|
#### Kubernetes Authentication Provider
|
||||||
|
|
||||||
|
The server can be configured to use Kubernetes SelfSubjectReview API to validate tokens directly against the Kubernetes API server:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
server:
|
||||||
|
auth:
|
||||||
|
provider_config:
|
||||||
|
type: "kubernetes"
|
||||||
|
api_server_url: https://kubernetes.default.svc
|
||||||
|
claims_mapping:
|
||||||
|
username: "roles"
|
||||||
|
groups: "roles"
|
||||||
|
uid: "uid_attr"
|
||||||
|
verify_tls: true
|
||||||
|
tls_cafile: "/path/to/ca.crt"
|
||||||
|
```
|
||||||
|
|
||||||
|
Configuration options:
|
||||||
|
- `api_server_url`: The Kubernetes API server URL (e.g., https://kubernetes.default.svc:6443)
|
||||||
|
- `verify_tls`: Whether to verify TLS certificates (default: true)
|
||||||
|
- `tls_cafile`: Path to CA certificate file for TLS verification
|
||||||
|
- `claims_mapping`: Mapping of Kubernetes user claims to access attributes
|
||||||
|
|
||||||
|
The provider validates tokens by sending a SelfSubjectReview request to the Kubernetes API server at `/apis/authentication.k8s.io/v1/selfsubjectreviews`. The provider extracts user information from the response:
|
||||||
|
- Username from the `userInfo.username` field
|
||||||
|
- Groups from the `userInfo.groups` field
|
||||||
|
- UID from the `userInfo.uid` field
|
||||||
|
|
||||||
|
To obtain a token for testing:
|
||||||
|
```bash
|
||||||
|
kubectl create namespace llama-stack
|
||||||
|
kubectl create serviceaccount llama-stack-auth -n llama-stack
|
||||||
|
kubectl create token llama-stack-auth -n llama-stack > llama-stack-auth-token
|
||||||
|
```
|
||||||
|
|
||||||
|
You can validate a request by running:
|
||||||
|
```bash
|
||||||
|
curl -s -L -H "Authorization: Bearer $(cat llama-stack-auth-token)" http://127.0.0.1:8321/v1/providers
|
||||||
|
```
|
||||||
|
|
||||||
#### Custom Provider
|
#### Custom Provider
|
||||||
Validates tokens against a custom authentication endpoint:
|
Validates tokens against a custom authentication endpoint:
|
||||||
|
|
|
@ -187,6 +187,7 @@ class AuthProviderType(StrEnum):
|
||||||
OAUTH2_TOKEN = "oauth2_token"
|
OAUTH2_TOKEN = "oauth2_token"
|
||||||
GITHUB_TOKEN = "github_token"
|
GITHUB_TOKEN = "github_token"
|
||||||
CUSTOM = "custom"
|
CUSTOM = "custom"
|
||||||
|
KUBERNETES = "kubernetes"
|
||||||
|
|
||||||
|
|
||||||
class OAuth2TokenAuthConfig(BaseModel):
|
class OAuth2TokenAuthConfig(BaseModel):
|
||||||
|
@ -257,8 +258,35 @@ class GitHubTokenAuthConfig(BaseModel):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class KubernetesAuthProviderConfig(BaseModel):
|
||||||
|
"""Configuration for Kubernetes authentication provider."""
|
||||||
|
|
||||||
|
type: Literal[AuthProviderType.KUBERNETES] = AuthProviderType.KUBERNETES
|
||||||
|
api_server_url: str = Field(
|
||||||
|
default="https://kubernetes.default.svc",
|
||||||
|
description="Kubernetes API server URL (e.g., https://api.cluster.domain:6443)",
|
||||||
|
)
|
||||||
|
verify_tls: bool = Field(default=True, description="Whether to verify TLS certificates")
|
||||||
|
tls_cafile: Path | None = Field(default=None, description="Path to CA certificate file for TLS verification")
|
||||||
|
claims_mapping: dict[str, str] = Field(
|
||||||
|
default_factory=lambda: {
|
||||||
|
"username": "roles",
|
||||||
|
"groups": "roles",
|
||||||
|
},
|
||||||
|
description="Mapping of Kubernetes user claims to access attributes",
|
||||||
|
)
|
||||||
|
|
||||||
|
@field_validator("claims_mapping")
|
||||||
|
@classmethod
|
||||||
|
def validate_claims_mapping(cls, v):
|
||||||
|
for key, value in v.items():
|
||||||
|
if not value:
|
||||||
|
raise ValueError(f"claims_mapping value cannot be empty: {key}")
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
AuthProviderConfig = Annotated[
|
AuthProviderConfig = Annotated[
|
||||||
OAuth2TokenAuthConfig | GitHubTokenAuthConfig | CustomAuthConfig,
|
OAuth2TokenAuthConfig | GitHubTokenAuthConfig | CustomAuthConfig | KubernetesAuthProviderConfig,
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
@ -18,6 +18,7 @@ from llama_stack.distribution.datatypes import (
|
||||||
AuthenticationConfig,
|
AuthenticationConfig,
|
||||||
CustomAuthConfig,
|
CustomAuthConfig,
|
||||||
GitHubTokenAuthConfig,
|
GitHubTokenAuthConfig,
|
||||||
|
KubernetesAuthProviderConfig,
|
||||||
OAuth2TokenAuthConfig,
|
OAuth2TokenAuthConfig,
|
||||||
User,
|
User,
|
||||||
)
|
)
|
||||||
|
@ -176,7 +177,7 @@ class OAuth2TokenAuthProvider(AuthProvider):
|
||||||
attributes=access_attributes,
|
attributes=access_attributes,
|
||||||
)
|
)
|
||||||
except httpx.TimeoutException:
|
except httpx.TimeoutException:
|
||||||
logger.exception("Token introspection request timed out")
|
logger.warning("Token introspection request timed out")
|
||||||
raise
|
raise
|
||||||
except ValueError:
|
except ValueError:
|
||||||
# Re-raise ValueError exceptions to preserve their message
|
# Re-raise ValueError exceptions to preserve their message
|
||||||
|
@ -374,6 +375,90 @@ async def _get_github_user_info(access_token: str, github_api_base_url: str) ->
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class KubernetesAuthProvider(AuthProvider):
|
||||||
|
"""
|
||||||
|
Kubernetes authentication provider that validates tokens using the Kubernetes SelfSubjectReview API.
|
||||||
|
This provider integrates with Kubernetes API server by using the
|
||||||
|
/apis/authentication.k8s.io/v1/selfsubjectreviews endpoint to validate tokens and extract user information.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: KubernetesAuthProviderConfig):
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
async def validate_token(self, token: str, scope: dict | None = None) -> User:
|
||||||
|
"""Validate a token using Kubernetes SelfSubjectReview API endpoint."""
|
||||||
|
|
||||||
|
# Configure SSL context
|
||||||
|
ssl_ctxt = None
|
||||||
|
if self.config.tls_cafile:
|
||||||
|
ssl_ctxt = ssl.create_default_context(cafile=self.config.tls_cafile.as_posix())
|
||||||
|
elif not self.config.verify_tls:
|
||||||
|
ssl_ctxt = ssl.create_default_context()
|
||||||
|
ssl_ctxt.check_hostname = False
|
||||||
|
ssl_ctxt.verify_mode = ssl.CERT_NONE
|
||||||
|
|
||||||
|
# Build the Kubernetes SelfSubjectReview API endpoint URL
|
||||||
|
review_api_url = f"{self.config.api_server_url.rstrip('/')}/apis/authentication.k8s.io/v1/selfsubjectreviews"
|
||||||
|
|
||||||
|
# Create SelfSubjectReview request body
|
||||||
|
review_request = {"apiVersion": "authentication.k8s.io/v1", "kind": "SelfSubjectReview"}
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient(verify=ssl_ctxt if ssl_ctxt else self.config.verify_tls) as client:
|
||||||
|
response = await client.post(
|
||||||
|
review_api_url,
|
||||||
|
json=review_request,
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {token}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
timeout=10.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code == 401:
|
||||||
|
raise ValueError("Invalid token")
|
||||||
|
elif response.status_code != 201:
|
||||||
|
logger.warning(f"Kubernetes SelfSubjectReview API failed with status code: {response.status_code}")
|
||||||
|
raise ValueError(f"Token validation failed: {response.status_code}")
|
||||||
|
|
||||||
|
review_response = response.json()
|
||||||
|
|
||||||
|
# Extract user information from SelfSubjectReview response
|
||||||
|
status = review_response.get("status", {})
|
||||||
|
if not status:
|
||||||
|
raise ValueError("No status found in SelfSubjectReview response")
|
||||||
|
|
||||||
|
user_info = status.get("userInfo", {})
|
||||||
|
if not user_info:
|
||||||
|
raise ValueError("No userInfo found in SelfSubjectReview response")
|
||||||
|
|
||||||
|
username = user_info.get("username")
|
||||||
|
if not username:
|
||||||
|
raise ValueError("No username found in SelfSubjectReview response")
|
||||||
|
|
||||||
|
# Build user attributes from Kubernetes user info
|
||||||
|
user_attributes = get_attributes_from_claims(user_info, self.config.claims_mapping)
|
||||||
|
|
||||||
|
return User(
|
||||||
|
principal=username,
|
||||||
|
attributes=user_attributes,
|
||||||
|
)
|
||||||
|
|
||||||
|
except httpx.TimeoutException:
|
||||||
|
logger.warning("Kubernetes SelfSubjectReview API request timed out")
|
||||||
|
raise ValueError("Token validation timeout") from None
|
||||||
|
except ValueError:
|
||||||
|
# Re-raise ValueError exceptions to preserve their message
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error during token validation: {str(e)}")
|
||||||
|
raise ValueError(f"Token validation error: {str(e)}") from e
|
||||||
|
|
||||||
|
async def close(self):
|
||||||
|
"""Close any resources."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def create_auth_provider(config: AuthenticationConfig) -> AuthProvider:
|
def create_auth_provider(config: AuthenticationConfig) -> AuthProvider:
|
||||||
"""Factory function to create the appropriate auth provider."""
|
"""Factory function to create the appropriate auth provider."""
|
||||||
provider_config = config.provider_config
|
provider_config = config.provider_config
|
||||||
|
@ -384,5 +469,7 @@ def create_auth_provider(config: AuthenticationConfig) -> AuthProvider:
|
||||||
return OAuth2TokenAuthProvider(provider_config)
|
return OAuth2TokenAuthProvider(provider_config)
|
||||||
elif isinstance(provider_config, GitHubTokenAuthConfig):
|
elif isinstance(provider_config, GitHubTokenAuthConfig):
|
||||||
return GitHubTokenAuthProvider(provider_config)
|
return GitHubTokenAuthProvider(provider_config)
|
||||||
|
elif isinstance(provider_config, KubernetesAuthProviderConfig):
|
||||||
|
return KubernetesAuthProvider(provider_config)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown authentication provider config type: {type(provider_config)}")
|
raise ValueError(f"Unknown authentication provider config type: {type(provider_config)}")
|
||||||
|
|
|
@ -581,3 +581,136 @@ def test_valid_introspection_with_custom_mapping_authentication(
|
||||||
)
|
)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json() == {"message": "Authentication successful"}
|
assert response.json() == {"message": "Authentication successful"}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_kubernetes_api_server():
|
||||||
|
return "https://api.cluster.example.com:6443"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def kubernetes_auth_app(mock_kubernetes_api_server):
|
||||||
|
app = FastAPI()
|
||||||
|
auth_config = AuthenticationConfig(
|
||||||
|
provider_config={
|
||||||
|
"type": "kubernetes",
|
||||||
|
"api_server_url": mock_kubernetes_api_server,
|
||||||
|
"verify_tls": False,
|
||||||
|
"claims_mapping": {
|
||||||
|
"username": "roles",
|
||||||
|
"groups": "roles",
|
||||||
|
"uid": "uid_attr",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
|
||||||
|
|
||||||
|
@app.get("/test")
|
||||||
|
def test_endpoint():
|
||||||
|
return {"message": "Authentication successful"}
|
||||||
|
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def kubernetes_auth_client(kubernetes_auth_app):
|
||||||
|
return TestClient(kubernetes_auth_app)
|
||||||
|
|
||||||
|
|
||||||
|
def test_missing_auth_header_kubernetes_auth(kubernetes_auth_client):
|
||||||
|
response = kubernetes_auth_client.get("/test")
|
||||||
|
assert response.status_code == 401
|
||||||
|
assert "Authentication required" in response.json()["error"]["message"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_invalid_auth_header_format_kubernetes_auth(kubernetes_auth_client):
|
||||||
|
response = kubernetes_auth_client.get("/test", headers={"Authorization": "InvalidFormat token123"})
|
||||||
|
assert response.status_code == 401
|
||||||
|
assert "Invalid Authorization header format" in response.json()["error"]["message"]
|
||||||
|
|
||||||
|
|
||||||
|
async def mock_kubernetes_selfsubjectreview_success(*args, **kwargs):
|
||||||
|
return MockResponse(
|
||||||
|
201,
|
||||||
|
{
|
||||||
|
"apiVersion": "authentication.k8s.io/v1",
|
||||||
|
"kind": "SelfSubjectReview",
|
||||||
|
"metadata": {"creationTimestamp": "2025-07-15T13:53:56Z"},
|
||||||
|
"status": {
|
||||||
|
"userInfo": {
|
||||||
|
"username": "alice",
|
||||||
|
"uid": "alice-uid-123",
|
||||||
|
"groups": ["system:authenticated", "developers", "admins"],
|
||||||
|
"extra": {"scopes.authorization.openshift.io": ["user:full"]},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def mock_kubernetes_selfsubjectreview_failure(*args, **kwargs):
|
||||||
|
return MockResponse(401, {"message": "Unauthorized"})
|
||||||
|
|
||||||
|
|
||||||
|
async def mock_kubernetes_selfsubjectreview_http_error(*args, **kwargs):
|
||||||
|
return MockResponse(500, {"message": "Internal Server Error"})
|
||||||
|
|
||||||
|
|
||||||
|
@patch("httpx.AsyncClient.post", new=mock_kubernetes_selfsubjectreview_success)
|
||||||
|
def test_valid_kubernetes_auth_authentication(kubernetes_auth_client, valid_token):
|
||||||
|
response = kubernetes_auth_client.get("/test", headers={"Authorization": f"Bearer {valid_token}"})
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == {"message": "Authentication successful"}
|
||||||
|
|
||||||
|
|
||||||
|
@patch("httpx.AsyncClient.post", new=mock_kubernetes_selfsubjectreview_failure)
|
||||||
|
def test_invalid_kubernetes_auth_authentication(kubernetes_auth_client, invalid_token):
|
||||||
|
response = kubernetes_auth_client.get("/test", headers={"Authorization": f"Bearer {invalid_token}"})
|
||||||
|
assert response.status_code == 401
|
||||||
|
assert "Invalid token" in response.json()["error"]["message"]
|
||||||
|
|
||||||
|
|
||||||
|
@patch("httpx.AsyncClient.post", new=mock_kubernetes_selfsubjectreview_http_error)
|
||||||
|
def test_kubernetes_auth_http_error(kubernetes_auth_client, valid_token):
|
||||||
|
response = kubernetes_auth_client.get("/test", headers={"Authorization": f"Bearer {valid_token}"})
|
||||||
|
assert response.status_code == 401
|
||||||
|
assert "Token validation failed" in response.json()["error"]["message"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_kubernetes_auth_request_payload(kubernetes_auth_client, valid_token, mock_kubernetes_api_server):
|
||||||
|
with patch("httpx.AsyncClient.post") as mock_post:
|
||||||
|
mock_response = MockResponse(
|
||||||
|
200,
|
||||||
|
{
|
||||||
|
"apiVersion": "authentication.k8s.io/v1",
|
||||||
|
"kind": "SelfSubjectReview",
|
||||||
|
"metadata": {"creationTimestamp": "2025-07-15T13:53:56Z"},
|
||||||
|
"status": {
|
||||||
|
"userInfo": {
|
||||||
|
"username": "test-user",
|
||||||
|
"uid": "test-uid",
|
||||||
|
"groups": ["test-group"],
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
mock_post.return_value = mock_response
|
||||||
|
|
||||||
|
kubernetes_auth_client.get("/test", headers={"Authorization": f"Bearer {valid_token}"})
|
||||||
|
|
||||||
|
# Verify the request was made with correct parameters
|
||||||
|
mock_post.assert_called_once()
|
||||||
|
call_args = mock_post.call_args
|
||||||
|
|
||||||
|
# Check URL (passed as positional argument)
|
||||||
|
assert call_args[0][0] == f"{mock_kubernetes_api_server}/apis/authentication.k8s.io/v1/selfsubjectreviews"
|
||||||
|
|
||||||
|
# Check headers (passed as keyword argument)
|
||||||
|
headers = call_args[1]["headers"]
|
||||||
|
assert headers["Authorization"] == f"Bearer {valid_token}"
|
||||||
|
assert headers["Content-Type"] == "application/json"
|
||||||
|
|
||||||
|
# Check request body (passed as keyword argument)
|
||||||
|
request_body = call_args[1]["json"]
|
||||||
|
assert request_body["apiVersion"] == "authentication.k8s.io/v1"
|
||||||
|
assert request_body["kind"] == "SelfSubjectReview"
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue