mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
chore: remove k8s auth in favor of k8s jwks endpoint (#2216)
# What does this PR do? Kubernetes since 1.20 exposes a JWKS endpoint that we can use with our recent oauth2 recent implementation. The CI test has been kept intact for validation. Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
parent
2890243107
commit
c25acedbcd
9 changed files with 147 additions and 359 deletions
|
@ -220,14 +220,14 @@ class LoggingConfig(BaseModel):
|
|||
class AuthProviderType(str, Enum):
|
||||
"""Supported authentication provider types."""
|
||||
|
||||
KUBERNETES = "kubernetes"
|
||||
OAUTH2_TOKEN = "oauth2_token"
|
||||
CUSTOM = "custom"
|
||||
|
||||
|
||||
class AuthenticationConfig(BaseModel):
|
||||
provider_type: AuthProviderType = Field(
|
||||
...,
|
||||
description="Type of authentication provider (e.g., 'kubernetes', 'custom')",
|
||||
description="Type of authentication provider",
|
||||
)
|
||||
config: dict[str, Any] = Field(
|
||||
...,
|
||||
|
|
|
@ -8,7 +8,8 @@ import json
|
|||
|
||||
import httpx
|
||||
|
||||
from llama_stack.distribution.server.auth_providers import AuthProviderConfig, create_auth_provider
|
||||
from llama_stack.distribution.datatypes import AuthenticationConfig
|
||||
from llama_stack.distribution.server.auth_providers import create_auth_provider
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
logger = get_logger(name=__name__, category="auth")
|
||||
|
@ -77,7 +78,7 @@ class AuthenticationMiddleware:
|
|||
access resources that don't have access_attributes defined.
|
||||
"""
|
||||
|
||||
def __init__(self, app, auth_config: AuthProviderConfig):
|
||||
def __init__(self, app, auth_config: AuthenticationConfig):
|
||||
self.app = app
|
||||
self.auth_provider = create_auth_provider(auth_config)
|
||||
|
||||
|
|
|
@ -4,13 +4,11 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import ssl
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from asyncio import Lock
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
from pathlib import Path
|
||||
from urllib.parse import parse_qs
|
||||
|
||||
import httpx
|
||||
|
@ -18,7 +16,7 @@ from jose import jwt
|
|||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from llama_stack.distribution.datatypes import AccessAttributes
|
||||
from llama_stack.distribution.datatypes import AccessAttributes, AuthenticationConfig, AuthProviderType
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
logger = get_logger(name=__name__, category="auth")
|
||||
|
@ -76,21 +74,6 @@ class AuthRequest(BaseModel):
|
|||
request: AuthRequestContext = Field(description="Context information about the request being authenticated")
|
||||
|
||||
|
||||
class AuthProviderType(str, Enum):
|
||||
"""Supported authentication provider types."""
|
||||
|
||||
KUBERNETES = "kubernetes"
|
||||
CUSTOM = "custom"
|
||||
OAUTH2_TOKEN = "oauth2_token"
|
||||
|
||||
|
||||
class AuthProviderConfig(BaseModel):
|
||||
"""Base configuration for authentication providers."""
|
||||
|
||||
provider_type: AuthProviderType = Field(..., description="Type of authentication provider")
|
||||
config: dict[str, Any] = Field(..., description="Provider-specific configuration")
|
||||
|
||||
|
||||
class AuthProvider(ABC):
|
||||
"""Abstract base class for authentication providers."""
|
||||
|
||||
|
@ -105,83 +88,6 @@ class AuthProvider(ABC):
|
|||
pass
|
||||
|
||||
|
||||
class KubernetesAuthProviderConfig(BaseModel):
|
||||
api_server_url: str
|
||||
ca_cert_path: str | None = None
|
||||
|
||||
|
||||
class KubernetesAuthProvider(AuthProvider):
|
||||
"""Kubernetes authentication provider that validates tokens against the Kubernetes API server."""
|
||||
|
||||
def __init__(self, config: KubernetesAuthProviderConfig):
|
||||
self.config = config
|
||||
self._client = None
|
||||
|
||||
async def _get_client(self):
|
||||
"""Get or create a Kubernetes client."""
|
||||
if self._client is None:
|
||||
# kubernetes-client has not async support, see:
|
||||
# https://github.com/kubernetes-client/python/issues/323
|
||||
from kubernetes import client
|
||||
from kubernetes.client import ApiClient
|
||||
|
||||
# Configure the client
|
||||
configuration = client.Configuration()
|
||||
configuration.host = self.config.api_server_url
|
||||
if self.config.ca_cert_path:
|
||||
configuration.ssl_ca_cert = self.config.ca_cert_path
|
||||
configuration.verify_ssl = bool(self.config.ca_cert_path)
|
||||
|
||||
# Create API client
|
||||
self._client = ApiClient(configuration)
|
||||
return self._client
|
||||
|
||||
async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
|
||||
"""Validate a Kubernetes token and return access attributes."""
|
||||
try:
|
||||
client = await self._get_client()
|
||||
|
||||
# Set the token in the client
|
||||
client.set_default_header("Authorization", f"Bearer {token}")
|
||||
|
||||
# Make a request to validate the token
|
||||
# We use the /api endpoint which requires authentication
|
||||
from kubernetes.client import CoreV1Api
|
||||
|
||||
api = CoreV1Api(client)
|
||||
api.get_api_resources(_request_timeout=3.0) # Set timeout for this specific request
|
||||
|
||||
# If we get here, the token is valid
|
||||
# Extract user info from the token claims
|
||||
import base64
|
||||
|
||||
# Decode the token (without verification since we've already validated it)
|
||||
token_parts = token.split(".")
|
||||
payload = json.loads(base64.b64decode(token_parts[1] + "=" * (-len(token_parts[1]) % 4)))
|
||||
|
||||
# Extract user information from the token
|
||||
username = payload.get("sub", "")
|
||||
groups = payload.get("groups", [])
|
||||
|
||||
return TokenValidationResult(
|
||||
principal=username,
|
||||
access_attributes=AccessAttributes(
|
||||
roles=[username], # Use username as a role
|
||||
teams=groups, # Use Kubernetes groups as teams
|
||||
),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to validate Kubernetes token")
|
||||
raise ValueError("Invalid or expired token") from e
|
||||
|
||||
async def close(self):
|
||||
"""Close the HTTP client."""
|
||||
if self._client:
|
||||
self._client.close()
|
||||
self._client = None
|
||||
|
||||
|
||||
def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str]) -> AccessAttributes:
|
||||
attributes = AccessAttributes()
|
||||
for claim_key, attribute_key in mapping.items():
|
||||
|
@ -212,11 +118,13 @@ class OAuth2IntrospectionConfig(BaseModel):
|
|||
client_id: str
|
||||
client_secret: str
|
||||
send_secret_in_body: bool = False
|
||||
tls_cafile: str | None = None
|
||||
|
||||
|
||||
class OAuth2TokenAuthProviderConfig(BaseModel):
|
||||
audience: str = "llama-stack"
|
||||
verify_tls: bool = True
|
||||
tls_cafile: Path | None = None
|
||||
issuer: str | None = Field(default=None, description="The OIDC issuer URL.")
|
||||
claims_mapping: dict[str, str] = Field(
|
||||
default_factory=lambda: {
|
||||
"sub": "roles",
|
||||
|
@ -265,16 +173,14 @@ class OAuth2TokenAuthProvider(AuthProvider):
|
|||
|
||||
async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
|
||||
if self.config.jwks:
|
||||
return await self.validate_jwt_token(token, self.config.jwks, scope)
|
||||
return await self.validate_jwt_token(token, scope)
|
||||
if self.config.introspection:
|
||||
return await self.introspect_token(token, self.config.introspection, scope)
|
||||
return await self.introspect_token(token, scope)
|
||||
raise ValueError("One of jwks or introspection must be configured")
|
||||
|
||||
async def validate_jwt_token(
|
||||
self, token: str, config: OAuth2JWKSConfig, scope: dict | None = None
|
||||
) -> TokenValidationResult:
|
||||
async def validate_jwt_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
|
||||
"""Validate a token using the JWT token."""
|
||||
await self._refresh_jwks(config)
|
||||
await self._refresh_jwks()
|
||||
|
||||
try:
|
||||
header = jwt.get_unverified_header(token)
|
||||
|
@ -288,7 +194,7 @@ class OAuth2TokenAuthProvider(AuthProvider):
|
|||
key_data,
|
||||
algorithms=[algorithm],
|
||||
audience=self.config.audience,
|
||||
options={"verify_exp": True},
|
||||
issuer=self.config.issuer,
|
||||
)
|
||||
except Exception as exc:
|
||||
raise ValueError(f"Invalid JWT token: {token}") from exc
|
||||
|
@ -302,26 +208,27 @@ class OAuth2TokenAuthProvider(AuthProvider):
|
|||
access_attributes=access_attributes,
|
||||
)
|
||||
|
||||
async def introspect_token(
|
||||
self, token: str, config: OAuth2IntrospectionConfig, scope: dict | None = None
|
||||
) -> TokenValidationResult:
|
||||
async def introspect_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
|
||||
"""Validate a token using token introspection as defined by RFC 7662."""
|
||||
form = {
|
||||
"token": token,
|
||||
}
|
||||
if config.send_secret_in_body:
|
||||
form["client_id"] = config.client_id
|
||||
form["client_secret"] = config.client_secret
|
||||
if self.config.introspection is None:
|
||||
raise ValueError("Introspection is not configured")
|
||||
|
||||
if self.config.introspection.send_secret_in_body:
|
||||
form["client_id"] = self.config.introspection.client_id
|
||||
form["client_secret"] = self.config.introspection.client_secret
|
||||
auth = None
|
||||
else:
|
||||
auth = (config.client_id, config.client_secret)
|
||||
auth = (self.config.introspection.client_id, self.config.introspection.client_secret)
|
||||
ssl_ctxt = None
|
||||
if config.tls_cafile:
|
||||
ssl_ctxt = ssl.create_default_context(cafile=config.tls_cafile)
|
||||
if self.config.tls_cafile:
|
||||
ssl_ctxt = ssl.create_default_context(cafile=self.config.tls_cafile.as_posix())
|
||||
try:
|
||||
async with httpx.AsyncClient(verify=ssl_ctxt) as client:
|
||||
response = await client.post(
|
||||
config.url,
|
||||
self.config.introspection.url,
|
||||
data=form,
|
||||
auth=auth,
|
||||
timeout=10.0, # Add a reasonable timeout
|
||||
|
@ -352,11 +259,24 @@ class OAuth2TokenAuthProvider(AuthProvider):
|
|||
async def close(self):
|
||||
pass
|
||||
|
||||
async def _refresh_jwks(self, config: OAuth2JWKSConfig) -> None:
|
||||
async def _refresh_jwks(self) -> None:
|
||||
"""
|
||||
Refresh the JWKS cache.
|
||||
|
||||
This is a simple cache that expires after a certain amount of time (defined by `cache_ttl`).
|
||||
If the cache is expired, we refresh the JWKS from the JWKS URI.
|
||||
|
||||
Notes: for Kubernetes which doesn't fully implement the OIDC protocol:
|
||||
* It doesn't have user authentication flows
|
||||
* It doesn't have refresh tokens
|
||||
"""
|
||||
async with self._jwks_lock:
|
||||
if time.time() - self._jwks_at > config.cache_ttl:
|
||||
async with httpx.AsyncClient() as client:
|
||||
res = await client.get(config.uri, timeout=5)
|
||||
if self.config.jwks is None:
|
||||
raise ValueError("JWKS is not configured")
|
||||
if time.time() - self._jwks_at > self.config.jwks.cache_ttl:
|
||||
verify = self.config.tls_cafile.as_posix() if self.config.tls_cafile else self.config.verify_tls
|
||||
async with httpx.AsyncClient(verify=verify) as client:
|
||||
res = await client.get(self.config.jwks.uri, timeout=5)
|
||||
res.raise_for_status()
|
||||
jwks_data = res.json()["keys"]
|
||||
updated = {}
|
||||
|
@ -443,13 +363,11 @@ class CustomAuthProvider(AuthProvider):
|
|||
self._client = None
|
||||
|
||||
|
||||
def create_auth_provider(config: AuthProviderConfig) -> AuthProvider:
|
||||
def create_auth_provider(config: AuthenticationConfig) -> AuthProvider:
|
||||
"""Factory function to create the appropriate auth provider."""
|
||||
provider_type = config.provider_type.lower()
|
||||
|
||||
if provider_type == "kubernetes":
|
||||
return KubernetesAuthProvider(KubernetesAuthProviderConfig.model_validate(config.config))
|
||||
elif provider_type == "custom":
|
||||
if provider_type == "custom":
|
||||
return CustomAuthProvider(CustomAuthProviderConfig.model_validate(config.config))
|
||||
elif provider_type == "oauth2_token":
|
||||
return OAuth2TokenAuthProvider(OAuth2TokenAuthProviderConfig.model_validate(config.config))
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue