chore: remove k8s auth in favor of k8s jwks endpoint (#2216)

# What does this PR do?

Kubernetes since 1.20 exposes a JWKS endpoint that we can use with our
recent oauth2 recent implementation.
The CI test has been kept intact for validation.

Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
Sébastien Han 2025-05-21 16:23:54 +02:00 committed by GitHub
parent 2890243107
commit c25acedbcd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 147 additions and 359 deletions

View file

@ -220,14 +220,14 @@ class LoggingConfig(BaseModel):
class AuthProviderType(str, Enum):
"""Supported authentication provider types."""
KUBERNETES = "kubernetes"
OAUTH2_TOKEN = "oauth2_token"
CUSTOM = "custom"
class AuthenticationConfig(BaseModel):
provider_type: AuthProviderType = Field(
...,
description="Type of authentication provider (e.g., 'kubernetes', 'custom')",
description="Type of authentication provider",
)
config: dict[str, Any] = Field(
...,

View file

@ -8,7 +8,8 @@ import json
import httpx
from llama_stack.distribution.server.auth_providers import AuthProviderConfig, create_auth_provider
from llama_stack.distribution.datatypes import AuthenticationConfig
from llama_stack.distribution.server.auth_providers import create_auth_provider
from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="auth")
@ -77,7 +78,7 @@ class AuthenticationMiddleware:
access resources that don't have access_attributes defined.
"""
def __init__(self, app, auth_config: AuthProviderConfig):
def __init__(self, app, auth_config: AuthenticationConfig):
self.app = app
self.auth_provider = create_auth_provider(auth_config)

View file

@ -4,13 +4,11 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import ssl
import time
from abc import ABC, abstractmethod
from asyncio import Lock
from enum import Enum
from typing import Any
from pathlib import Path
from urllib.parse import parse_qs
import httpx
@ -18,7 +16,7 @@ from jose import jwt
from pydantic import BaseModel, Field, field_validator, model_validator
from typing_extensions import Self
from llama_stack.distribution.datatypes import AccessAttributes
from llama_stack.distribution.datatypes import AccessAttributes, AuthenticationConfig, AuthProviderType
from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="auth")
@ -76,21 +74,6 @@ class AuthRequest(BaseModel):
request: AuthRequestContext = Field(description="Context information about the request being authenticated")
class AuthProviderType(str, Enum):
"""Supported authentication provider types."""
KUBERNETES = "kubernetes"
CUSTOM = "custom"
OAUTH2_TOKEN = "oauth2_token"
class AuthProviderConfig(BaseModel):
"""Base configuration for authentication providers."""
provider_type: AuthProviderType = Field(..., description="Type of authentication provider")
config: dict[str, Any] = Field(..., description="Provider-specific configuration")
class AuthProvider(ABC):
"""Abstract base class for authentication providers."""
@ -105,83 +88,6 @@ class AuthProvider(ABC):
pass
class KubernetesAuthProviderConfig(BaseModel):
api_server_url: str
ca_cert_path: str | None = None
class KubernetesAuthProvider(AuthProvider):
"""Kubernetes authentication provider that validates tokens against the Kubernetes API server."""
def __init__(self, config: KubernetesAuthProviderConfig):
self.config = config
self._client = None
async def _get_client(self):
"""Get or create a Kubernetes client."""
if self._client is None:
# kubernetes-client has not async support, see:
# https://github.com/kubernetes-client/python/issues/323
from kubernetes import client
from kubernetes.client import ApiClient
# Configure the client
configuration = client.Configuration()
configuration.host = self.config.api_server_url
if self.config.ca_cert_path:
configuration.ssl_ca_cert = self.config.ca_cert_path
configuration.verify_ssl = bool(self.config.ca_cert_path)
# Create API client
self._client = ApiClient(configuration)
return self._client
async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
"""Validate a Kubernetes token and return access attributes."""
try:
client = await self._get_client()
# Set the token in the client
client.set_default_header("Authorization", f"Bearer {token}")
# Make a request to validate the token
# We use the /api endpoint which requires authentication
from kubernetes.client import CoreV1Api
api = CoreV1Api(client)
api.get_api_resources(_request_timeout=3.0) # Set timeout for this specific request
# If we get here, the token is valid
# Extract user info from the token claims
import base64
# Decode the token (without verification since we've already validated it)
token_parts = token.split(".")
payload = json.loads(base64.b64decode(token_parts[1] + "=" * (-len(token_parts[1]) % 4)))
# Extract user information from the token
username = payload.get("sub", "")
groups = payload.get("groups", [])
return TokenValidationResult(
principal=username,
access_attributes=AccessAttributes(
roles=[username], # Use username as a role
teams=groups, # Use Kubernetes groups as teams
),
)
except Exception as e:
logger.exception("Failed to validate Kubernetes token")
raise ValueError("Invalid or expired token") from e
async def close(self):
"""Close the HTTP client."""
if self._client:
self._client.close()
self._client = None
def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str]) -> AccessAttributes:
attributes = AccessAttributes()
for claim_key, attribute_key in mapping.items():
@ -212,11 +118,13 @@ class OAuth2IntrospectionConfig(BaseModel):
client_id: str
client_secret: str
send_secret_in_body: bool = False
tls_cafile: str | None = None
class OAuth2TokenAuthProviderConfig(BaseModel):
audience: str = "llama-stack"
verify_tls: bool = True
tls_cafile: Path | None = None
issuer: str | None = Field(default=None, description="The OIDC issuer URL.")
claims_mapping: dict[str, str] = Field(
default_factory=lambda: {
"sub": "roles",
@ -265,16 +173,14 @@ class OAuth2TokenAuthProvider(AuthProvider):
async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
if self.config.jwks:
return await self.validate_jwt_token(token, self.config.jwks, scope)
return await self.validate_jwt_token(token, scope)
if self.config.introspection:
return await self.introspect_token(token, self.config.introspection, scope)
return await self.introspect_token(token, scope)
raise ValueError("One of jwks or introspection must be configured")
async def validate_jwt_token(
self, token: str, config: OAuth2JWKSConfig, scope: dict | None = None
) -> TokenValidationResult:
async def validate_jwt_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
"""Validate a token using the JWT token."""
await self._refresh_jwks(config)
await self._refresh_jwks()
try:
header = jwt.get_unverified_header(token)
@ -288,7 +194,7 @@ class OAuth2TokenAuthProvider(AuthProvider):
key_data,
algorithms=[algorithm],
audience=self.config.audience,
options={"verify_exp": True},
issuer=self.config.issuer,
)
except Exception as exc:
raise ValueError(f"Invalid JWT token: {token}") from exc
@ -302,26 +208,27 @@ class OAuth2TokenAuthProvider(AuthProvider):
access_attributes=access_attributes,
)
async def introspect_token(
self, token: str, config: OAuth2IntrospectionConfig, scope: dict | None = None
) -> TokenValidationResult:
async def introspect_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
"""Validate a token using token introspection as defined by RFC 7662."""
form = {
"token": token,
}
if config.send_secret_in_body:
form["client_id"] = config.client_id
form["client_secret"] = config.client_secret
if self.config.introspection is None:
raise ValueError("Introspection is not configured")
if self.config.introspection.send_secret_in_body:
form["client_id"] = self.config.introspection.client_id
form["client_secret"] = self.config.introspection.client_secret
auth = None
else:
auth = (config.client_id, config.client_secret)
auth = (self.config.introspection.client_id, self.config.introspection.client_secret)
ssl_ctxt = None
if config.tls_cafile:
ssl_ctxt = ssl.create_default_context(cafile=config.tls_cafile)
if self.config.tls_cafile:
ssl_ctxt = ssl.create_default_context(cafile=self.config.tls_cafile.as_posix())
try:
async with httpx.AsyncClient(verify=ssl_ctxt) as client:
response = await client.post(
config.url,
self.config.introspection.url,
data=form,
auth=auth,
timeout=10.0, # Add a reasonable timeout
@ -352,11 +259,24 @@ class OAuth2TokenAuthProvider(AuthProvider):
async def close(self):
pass
async def _refresh_jwks(self, config: OAuth2JWKSConfig) -> None:
async def _refresh_jwks(self) -> None:
"""
Refresh the JWKS cache.
This is a simple cache that expires after a certain amount of time (defined by `cache_ttl`).
If the cache is expired, we refresh the JWKS from the JWKS URI.
Notes: for Kubernetes which doesn't fully implement the OIDC protocol:
* It doesn't have user authentication flows
* It doesn't have refresh tokens
"""
async with self._jwks_lock:
if time.time() - self._jwks_at > config.cache_ttl:
async with httpx.AsyncClient() as client:
res = await client.get(config.uri, timeout=5)
if self.config.jwks is None:
raise ValueError("JWKS is not configured")
if time.time() - self._jwks_at > self.config.jwks.cache_ttl:
verify = self.config.tls_cafile.as_posix() if self.config.tls_cafile else self.config.verify_tls
async with httpx.AsyncClient(verify=verify) as client:
res = await client.get(self.config.jwks.uri, timeout=5)
res.raise_for_status()
jwks_data = res.json()["keys"]
updated = {}
@ -443,13 +363,11 @@ class CustomAuthProvider(AuthProvider):
self._client = None
def create_auth_provider(config: AuthProviderConfig) -> AuthProvider:
def create_auth_provider(config: AuthenticationConfig) -> AuthProvider:
"""Factory function to create the appropriate auth provider."""
provider_type = config.provider_type.lower()
if provider_type == "kubernetes":
return KubernetesAuthProvider(KubernetesAuthProviderConfig.model_validate(config.config))
elif provider_type == "custom":
if provider_type == "custom":
return CustomAuthProvider(CustomAuthProviderConfig.model_validate(config.config))
elif provider_type == "oauth2_token":
return OAuth2TokenAuthProvider(OAuth2TokenAuthProviderConfig.model_validate(config.config))