chore: remove k8s auth in favor of k8s jwks endpoint (#2216)

# What does this PR do?

Kubernetes since 1.20 exposes a JWKS endpoint that we can use with our
recent oauth2 recent implementation.
The CI test has been kept intact for validation.

Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
Sébastien Han 2025-05-21 16:23:54 +02:00 committed by GitHub
parent 2890243107
commit c25acedbcd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 147 additions and 359 deletions

View file

@ -11,12 +11,10 @@ import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from llama_stack.distribution.datatypes import AccessAttributes
from llama_stack.distribution.datatypes import AuthenticationConfig
from llama_stack.distribution.server.auth import AuthenticationMiddleware
from llama_stack.distribution.server.auth_providers import (
AuthProviderConfig,
AuthProviderType,
TokenValidationResult,
get_attributes_from_claims,
)
@ -62,7 +60,7 @@ def invalid_token():
@pytest.fixture
def http_app(mock_auth_endpoint):
app = FastAPI()
auth_config = AuthProviderConfig(
auth_config = AuthenticationConfig(
provider_type=AuthProviderType.CUSTOM,
config={"endpoint": mock_auth_endpoint},
)
@ -78,7 +76,7 @@ def http_app(mock_auth_endpoint):
@pytest.fixture
def k8s_app():
app = FastAPI()
auth_config = AuthProviderConfig(
auth_config = AuthenticationConfig(
provider_type=AuthProviderType.KUBERNETES,
config={"api_server_url": "https://kubernetes.default.svc"},
)
@ -118,7 +116,7 @@ def mock_scope():
@pytest.fixture
def mock_http_middleware(mock_auth_endpoint):
mock_app = AsyncMock()
auth_config = AuthProviderConfig(
auth_config = AuthenticationConfig(
provider_type=AuthProviderType.CUSTOM,
config={"endpoint": mock_auth_endpoint},
)
@ -128,7 +126,7 @@ def mock_http_middleware(mock_auth_endpoint):
@pytest.fixture
def mock_k8s_middleware():
mock_app = AsyncMock()
auth_config = AuthProviderConfig(
auth_config = AuthenticationConfig(
provider_type=AuthProviderType.KUBERNETES,
config={"api_server_url": "https://kubernetes.default.svc"},
)
@ -284,116 +282,13 @@ async def test_http_middleware_no_attributes(mock_http_middleware, mock_scope):
assert attributes["roles"] == ["test.jwt.token"]
# Kubernetes Tests
def test_missing_auth_header_k8s(k8s_client):
response = k8s_client.get("/test")
assert response.status_code == 401
assert "Missing or invalid Authorization header" in response.json()["error"]["message"]
def test_invalid_auth_header_format_k8s(k8s_client):
response = k8s_client.get("/test", headers={"Authorization": "InvalidFormat token123"})
assert response.status_code == 401
assert "Missing or invalid Authorization header" in response.json()["error"]["message"]
@patch("kubernetes.client.ApiClient")
def test_valid_k8s_authentication(mock_api_client, k8s_client, valid_token):
# Mock the Kubernetes client
mock_client = AsyncMock()
mock_api_client.return_value = mock_client
# Mock successful token validation
mock_client.set_default_header = AsyncMock()
# Mock the token validation to return valid access attributes
with patch("llama_stack.distribution.server.auth_providers.KubernetesAuthProvider.validate_token") as mock_validate:
mock_validate.return_value = TokenValidationResult(
principal="test-principal",
access_attributes=AccessAttributes(
roles=["admin"], teams=["ml-team"], projects=["llama-3"], namespaces=["research"]
),
)
response = k8s_client.get("/test", headers={"Authorization": f"Bearer {valid_token}"})
assert response.status_code == 200
assert response.json() == {"message": "Authentication successful"}
@patch("kubernetes.client.ApiClient")
def test_invalid_k8s_authentication(mock_api_client, k8s_client, invalid_token):
# Mock the Kubernetes client
mock_client = AsyncMock()
mock_api_client.return_value = mock_client
# Mock failed token validation by raising an exception
with patch("llama_stack.distribution.server.auth_providers.KubernetesAuthProvider.validate_token") as mock_validate:
mock_validate.side_effect = ValueError("Invalid or expired token")
response = k8s_client.get("/test", headers={"Authorization": f"Bearer {invalid_token}"})
assert response.status_code == 401
assert "Invalid or expired token" in response.json()["error"]["message"]
@pytest.mark.asyncio
async def test_k8s_middleware_with_access_attributes(mock_k8s_middleware, mock_scope):
middleware, mock_app = mock_k8s_middleware
mock_receive = AsyncMock()
mock_send = AsyncMock()
with patch("kubernetes.client.ApiClient") as mock_api_client:
mock_client = AsyncMock()
mock_api_client.return_value = mock_client
# Mock successful token validation
mock_client.set_default_header = AsyncMock()
# Mock token payload with access attributes
mock_token_parts = ["header", "eyJzdWIiOiJhZG1pbiIsImdyb3VwcyI6WyJtbC10ZWFtIl19", "signature"]
mock_scope["headers"][1] = (b"authorization", f"Bearer {'.'.join(mock_token_parts)}".encode())
await middleware(mock_scope, mock_receive, mock_send)
assert "user_attributes" in mock_scope
assert mock_scope["user_attributes"]["roles"] == ["admin"]
assert mock_scope["user_attributes"]["teams"] == ["ml-team"]
mock_app.assert_called_once_with(mock_scope, mock_receive, mock_send)
@pytest.mark.asyncio
async def test_k8s_middleware_no_attributes(mock_k8s_middleware, mock_scope):
"""Test middleware behavior with no access attributes"""
middleware, mock_app = mock_k8s_middleware
mock_receive = AsyncMock()
mock_send = AsyncMock()
with patch("kubernetes.client.ApiClient") as mock_api_client:
mock_client = AsyncMock()
mock_api_client.return_value = mock_client
# Mock successful token validation
mock_client.set_default_header = AsyncMock()
# Mock token payload without access attributes
mock_token_parts = ["header", "eyJzdWIiOiJhZG1pbiJ9", "signature"]
mock_scope["headers"][1] = (b"authorization", f"Bearer {'.'.join(mock_token_parts)}".encode())
await middleware(mock_scope, mock_receive, mock_send)
assert "user_attributes" in mock_scope
attributes = mock_scope["user_attributes"]
assert "roles" in attributes
assert attributes["roles"] == ["admin"]
mock_app.assert_called_once_with(mock_scope, mock_receive, mock_send)
# oauth2 token provider tests
@pytest.fixture
def oauth2_app():
app = FastAPI()
auth_config = AuthProviderConfig(
auth_config = AuthenticationConfig(
provider_type=AuthProviderType.OAUTH2_TOKEN,
config={
"jwks": {
@ -530,7 +425,7 @@ def mock_introspection_endpoint():
@pytest.fixture
def introspection_app(mock_introspection_endpoint):
app = FastAPI()
auth_config = AuthProviderConfig(
auth_config = AuthenticationConfig(
provider_type=AuthProviderType.OAUTH2_TOKEN,
config={
"jwks": None,
@ -549,7 +444,7 @@ def introspection_app(mock_introspection_endpoint):
@pytest.fixture
def introspection_app_with_custom_mapping(mock_introspection_endpoint):
app = FastAPI()
auth_config = AuthProviderConfig(
auth_config = AuthenticationConfig(
provider_type=AuthProviderType.OAUTH2_TOKEN,
config={
"jwks": None,