mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-28 01:01:59 +00:00
Merge branch 'main' into patch-metadata
This commit is contained in:
commit
5a807da6af
6 changed files with 2551 additions and 2257 deletions
|
|
@ -93,7 +93,7 @@ class AuthenticationMiddleware:
|
|||
|
||||
# Validate token and get access attributes
|
||||
try:
|
||||
access_attributes = await self.auth_provider.validate_token(token, scope)
|
||||
validation_result = await self.auth_provider.validate_token(token, scope)
|
||||
except httpx.TimeoutException:
|
||||
logger.exception("Authentication request timed out")
|
||||
return await self._send_auth_error(send, "Authentication service timeout")
|
||||
|
|
@ -105,17 +105,20 @@ class AuthenticationMiddleware:
|
|||
return await self._send_auth_error(send, "Authentication service error")
|
||||
|
||||
# Store attributes in request scope for access control
|
||||
if access_attributes:
|
||||
user_attributes = access_attributes.model_dump(exclude_none=True)
|
||||
if validation_result.access_attributes:
|
||||
user_attributes = validation_result.access_attributes.model_dump(exclude_none=True)
|
||||
else:
|
||||
logger.warning("No access attributes, setting namespace to token by default")
|
||||
user_attributes = {
|
||||
"namespaces": [token],
|
||||
"roles": [token],
|
||||
}
|
||||
|
||||
# Store attributes in request scope
|
||||
scope["user_attributes"] = user_attributes
|
||||
logger.debug(f"Authentication successful: {len(scope['user_attributes'])} attributes")
|
||||
scope["principal"] = validation_result.principal
|
||||
logger.debug(
|
||||
f"Authentication successful: {validation_result.principal} with {len(scope['user_attributes'])} attributes"
|
||||
)
|
||||
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
|
|
|
|||
|
|
@ -5,12 +5,14 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from urllib.parse import parse_qs
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel, Field
|
||||
from jose import jwt
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from llama_stack.distribution.datatypes import AccessAttributes
|
||||
from llama_stack.log import get_logger
|
||||
|
|
@ -18,9 +20,11 @@ from llama_stack.log import get_logger
|
|||
logger = get_logger(name=__name__, category="auth")
|
||||
|
||||
|
||||
class AuthResponse(BaseModel):
|
||||
"""The format of the authentication response from the auth endpoint."""
|
||||
|
||||
class TokenValidationResult(BaseModel):
|
||||
principal: str | None = Field(
|
||||
default=None,
|
||||
description="The principal (username or persistent identifier) of the authenticated user",
|
||||
)
|
||||
access_attributes: AccessAttributes | None = Field(
|
||||
default=None,
|
||||
description="""
|
||||
|
|
@ -43,6 +47,10 @@ class AuthResponse(BaseModel):
|
|||
""",
|
||||
)
|
||||
|
||||
|
||||
class AuthResponse(TokenValidationResult):
|
||||
"""The format of the authentication response from the auth endpoint."""
|
||||
|
||||
message: str | None = Field(
|
||||
default=None, description="Optional message providing additional context about the authentication result."
|
||||
)
|
||||
|
|
@ -69,6 +77,7 @@ class AuthProviderType(str, Enum):
|
|||
|
||||
KUBERNETES = "kubernetes"
|
||||
CUSTOM = "custom"
|
||||
OAUTH2_TOKEN = "oauth2_token"
|
||||
|
||||
|
||||
class AuthProviderConfig(BaseModel):
|
||||
|
|
@ -82,7 +91,7 @@ class AuthProvider(ABC):
|
|||
"""Abstract base class for authentication providers."""
|
||||
|
||||
@abstractmethod
|
||||
async def validate_token(self, token: str, scope: dict | None = None) -> AccessAttributes | None:
|
||||
async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
|
||||
"""Validate a token and return access attributes."""
|
||||
pass
|
||||
|
||||
|
|
@ -92,12 +101,16 @@ class AuthProvider(ABC):
|
|||
pass
|
||||
|
||||
|
||||
class KubernetesAuthProviderConfig(BaseModel):
|
||||
api_server_url: str
|
||||
ca_cert_path: str | None = None
|
||||
|
||||
|
||||
class KubernetesAuthProvider(AuthProvider):
|
||||
"""Kubernetes authentication provider that validates tokens against the Kubernetes API server."""
|
||||
|
||||
def __init__(self, config: dict[str, str]):
|
||||
self.api_server_url = config["api_server_url"]
|
||||
self.ca_cert_path = config.get("ca_cert_path")
|
||||
def __init__(self, config: KubernetesAuthProviderConfig):
|
||||
self.config = config
|
||||
self._client = None
|
||||
|
||||
async def _get_client(self):
|
||||
|
|
@ -110,16 +123,16 @@ class KubernetesAuthProvider(AuthProvider):
|
|||
|
||||
# Configure the client
|
||||
configuration = client.Configuration()
|
||||
configuration.host = self.api_server_url
|
||||
if self.ca_cert_path:
|
||||
configuration.ssl_ca_cert = self.ca_cert_path
|
||||
configuration.verify_ssl = bool(self.ca_cert_path)
|
||||
configuration.host = self.config.api_server_url
|
||||
if self.config.ca_cert_path:
|
||||
configuration.ssl_ca_cert = self.config.ca_cert_path
|
||||
configuration.verify_ssl = bool(self.config.ca_cert_path)
|
||||
|
||||
# Create API client
|
||||
self._client = ApiClient(configuration)
|
||||
return self._client
|
||||
|
||||
async def validate_token(self, token: str, scope: dict | None = None) -> AccessAttributes | None:
|
||||
async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
|
||||
"""Validate a Kubernetes token and return access attributes."""
|
||||
try:
|
||||
client = await self._get_client()
|
||||
|
|
@ -146,9 +159,12 @@ class KubernetesAuthProvider(AuthProvider):
|
|||
username = payload.get("sub", "")
|
||||
groups = payload.get("groups", [])
|
||||
|
||||
return AccessAttributes(
|
||||
roles=[username], # Use username as a role
|
||||
teams=groups, # Use Kubernetes groups as teams
|
||||
return TokenValidationResult(
|
||||
principal=username,
|
||||
access_attributes=AccessAttributes(
|
||||
roles=[username], # Use username as a role
|
||||
teams=groups, # Use Kubernetes groups as teams
|
||||
),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
|
|
@ -162,18 +178,125 @@ class KubernetesAuthProvider(AuthProvider):
|
|||
self._client = None
|
||||
|
||||
|
||||
def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str]) -> AccessAttributes:
|
||||
attributes = AccessAttributes()
|
||||
for claim_key, attribute_key in mapping.items():
|
||||
if claim_key not in claims or not hasattr(attributes, attribute_key):
|
||||
continue
|
||||
claim = claims[claim_key]
|
||||
if isinstance(claim, list):
|
||||
values = claim
|
||||
else:
|
||||
values = claim.split()
|
||||
|
||||
current = getattr(attributes, attribute_key)
|
||||
if current:
|
||||
current.extend(values)
|
||||
else:
|
||||
setattr(attributes, attribute_key, values)
|
||||
return attributes
|
||||
|
||||
|
||||
class OAuth2TokenAuthProviderConfig(BaseModel):
|
||||
# The JWKS URI for collecting public keys
|
||||
jwks_uri: str
|
||||
cache_ttl: int = 3600
|
||||
audience: str = "llama-stack"
|
||||
claims_mapping: dict[str, str] = Field(
|
||||
default_factory=lambda: {
|
||||
"sub": "roles",
|
||||
"username": "roles",
|
||||
"groups": "teams",
|
||||
"team": "teams",
|
||||
"project": "projects",
|
||||
"tenant": "namespaces",
|
||||
"namespace": "namespaces",
|
||||
},
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@field_validator("claims_mapping")
|
||||
def validate_claims_mapping(cls, v):
|
||||
for key, value in v.items():
|
||||
if not value:
|
||||
raise ValueError(f"claims_mapping value cannot be empty: {key}")
|
||||
if value not in AccessAttributes.model_fields:
|
||||
raise ValueError(f"claims_mapping value is not a valid attribute: {value}")
|
||||
return v
|
||||
|
||||
|
||||
class OAuth2TokenAuthProvider(AuthProvider):
|
||||
"""
|
||||
JWT token authentication provider that validates a JWT token and extracts access attributes.
|
||||
|
||||
This should be the standard authentication provider for most use cases.
|
||||
"""
|
||||
|
||||
def __init__(self, config: OAuth2TokenAuthProviderConfig):
|
||||
self.config = config
|
||||
self._jwks_at: float = 0.0
|
||||
self._jwks: dict[str, str] = {}
|
||||
|
||||
async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
|
||||
"""Validate a token using the JWT token."""
|
||||
await self._refresh_jwks()
|
||||
|
||||
try:
|
||||
header = jwt.get_unverified_header(token)
|
||||
kid = header["kid"]
|
||||
if kid not in self._jwks:
|
||||
raise ValueError(f"Unknown key ID: {kid}")
|
||||
key_data = self._jwks[kid]
|
||||
algorithm = header.get("alg", "RS256")
|
||||
claims = jwt.decode(
|
||||
token,
|
||||
key_data,
|
||||
algorithms=[algorithm],
|
||||
audience=self.config.audience,
|
||||
options={"verify_exp": True},
|
||||
)
|
||||
except Exception as exc:
|
||||
raise ValueError(f"Invalid JWT token: {token}") from exc
|
||||
|
||||
# There are other standard claims, the most relevant of which is `scope`.
|
||||
# We should incorporate these into the access attributes.
|
||||
principal = claims["sub"]
|
||||
access_attributes = get_attributes_from_claims(claims, self.config.claims_mapping)
|
||||
return TokenValidationResult(
|
||||
principal=principal,
|
||||
access_attributes=access_attributes,
|
||||
)
|
||||
|
||||
async def close(self):
|
||||
"""Close the HTTP client."""
|
||||
|
||||
async def _refresh_jwks(self) -> None:
|
||||
if time.time() - self._jwks_at > self.config.cache_ttl:
|
||||
async with httpx.AsyncClient() as client:
|
||||
res = await client.get(self.config.jwks_uri, timeout=5)
|
||||
res.raise_for_status()
|
||||
jwks_data = res.json()["keys"]
|
||||
self._jwks = {}
|
||||
for k in jwks_data:
|
||||
kid = k["kid"]
|
||||
# Store the entire key object as it may be needed for different algorithms
|
||||
self._jwks[kid] = k
|
||||
self._jwks_at = time.time()
|
||||
|
||||
|
||||
class CustomAuthProviderConfig(BaseModel):
|
||||
endpoint: str
|
||||
|
||||
|
||||
class CustomAuthProvider(AuthProvider):
|
||||
"""Custom authentication provider that uses an external endpoint."""
|
||||
|
||||
def __init__(self, config: dict[str, str]):
|
||||
self.endpoint = config["endpoint"]
|
||||
def __init__(self, config: CustomAuthProviderConfig):
|
||||
self.config = config
|
||||
self._client = None
|
||||
|
||||
async def validate_token(self, token: str, scope: dict | None = None) -> AccessAttributes | None:
|
||||
async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
|
||||
"""Validate a token using the custom authentication endpoint."""
|
||||
if not self.endpoint:
|
||||
raise ValueError("Authentication endpoint not configured")
|
||||
|
||||
if scope is None:
|
||||
scope = {}
|
||||
|
||||
|
|
@ -202,7 +325,7 @@ class CustomAuthProvider(AuthProvider):
|
|||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
self.endpoint,
|
||||
self.config.endpoint,
|
||||
json=auth_request.model_dump(),
|
||||
timeout=10.0, # Add a reasonable timeout
|
||||
)
|
||||
|
|
@ -214,19 +337,7 @@ class CustomAuthProvider(AuthProvider):
|
|||
try:
|
||||
response_data = response.json()
|
||||
auth_response = AuthResponse(**response_data)
|
||||
|
||||
# Store attributes in request scope for access control
|
||||
if auth_response.access_attributes:
|
||||
return auth_response.access_attributes
|
||||
else:
|
||||
logger.warning("No access attributes, setting namespace to api_key by default")
|
||||
user_attributes = {
|
||||
"namespaces": [token],
|
||||
}
|
||||
|
||||
scope["user_attributes"] = user_attributes
|
||||
logger.debug(f"Authentication successful: {len(user_attributes)} attributes")
|
||||
return auth_response.access_attributes
|
||||
return auth_response
|
||||
except Exception as e:
|
||||
logger.exception("Error parsing authentication response")
|
||||
raise ValueError("Invalid authentication response format") from e
|
||||
|
|
@ -253,9 +364,11 @@ def create_auth_provider(config: AuthProviderConfig) -> AuthProvider:
|
|||
provider_type = config.provider_type.lower()
|
||||
|
||||
if provider_type == "kubernetes":
|
||||
return KubernetesAuthProvider(config.config)
|
||||
return KubernetesAuthProvider(KubernetesAuthProviderConfig.model_validate(config.config))
|
||||
elif provider_type == "custom":
|
||||
return CustomAuthProvider(config.config)
|
||||
return CustomAuthProvider(CustomAuthProviderConfig.model_validate(config.config))
|
||||
elif provider_type == "oauth2_token":
|
||||
return OAuth2TokenAuthProvider(OAuth2TokenAuthProviderConfig.model_validate(config.config))
|
||||
else:
|
||||
supported_providers = ", ".join([t.value for t in AuthProviderType])
|
||||
raise ValueError(f"Unsupported auth provider type: {provider_type}. Supported types are: {supported_providers}")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue