mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-07 20:50:52 +00:00
Merge branch 'main' into prompt-api
This commit is contained in:
commit
931f7da61b
14 changed files with 728 additions and 191 deletions
|
@ -7,6 +7,7 @@
|
|||
from enum import StrEnum
|
||||
from pathlib import Path
|
||||
from typing import Annotated, Any, Literal, Self
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
|
||||
|
@ -212,6 +213,7 @@ class AuthProviderType(StrEnum):
|
|||
OAUTH2_TOKEN = "oauth2_token"
|
||||
GITHUB_TOKEN = "github_token"
|
||||
CUSTOM = "custom"
|
||||
KUBERNETES = "kubernetes"
|
||||
|
||||
|
||||
class OAuth2TokenAuthConfig(BaseModel):
|
||||
|
@ -282,8 +284,45 @@ class GitHubTokenAuthConfig(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
class KubernetesAuthProviderConfig(BaseModel):
|
||||
"""Configuration for Kubernetes authentication provider."""
|
||||
|
||||
type: Literal[AuthProviderType.KUBERNETES] = AuthProviderType.KUBERNETES
|
||||
api_server_url: str = Field(
|
||||
default="https://kubernetes.default.svc",
|
||||
description="Kubernetes API server URL (e.g., https://api.cluster.domain:6443)",
|
||||
)
|
||||
verify_tls: bool = Field(default=True, description="Whether to verify TLS certificates")
|
||||
tls_cafile: Path | None = Field(default=None, description="Path to CA certificate file for TLS verification")
|
||||
claims_mapping: dict[str, str] = Field(
|
||||
default_factory=lambda: {
|
||||
"username": "roles",
|
||||
"groups": "roles",
|
||||
},
|
||||
description="Mapping of Kubernetes user claims to access attributes",
|
||||
)
|
||||
|
||||
@field_validator("api_server_url")
|
||||
@classmethod
|
||||
def validate_api_server_url(cls, v):
|
||||
parsed = urlparse(v)
|
||||
if not parsed.scheme or not parsed.netloc:
|
||||
raise ValueError(f"api_server_url must be a valid URL with scheme and host: {v}")
|
||||
if parsed.scheme not in ["http", "https"]:
|
||||
raise ValueError(f"api_server_url scheme must be http or https: {v}")
|
||||
return v
|
||||
|
||||
@field_validator("claims_mapping")
|
||||
@classmethod
|
||||
def validate_claims_mapping(cls, v):
|
||||
for key, value in v.items():
|
||||
if not value:
|
||||
raise ValueError(f"claims_mapping value cannot be empty: {key}")
|
||||
return v
|
||||
|
||||
|
||||
AuthProviderConfig = Annotated[
|
||||
OAuth2TokenAuthConfig | GitHubTokenAuthConfig | CustomAuthConfig,
|
||||
OAuth2TokenAuthConfig | GitHubTokenAuthConfig | CustomAuthConfig | KubernetesAuthProviderConfig,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
|
|
@ -8,16 +8,18 @@ import ssl
|
|||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from asyncio import Lock
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
from urllib.parse import parse_qs, urljoin, urlparse
|
||||
|
||||
import httpx
|
||||
from jose import jwt
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.apis.common.errors import TokenValidationError
|
||||
from llama_stack.core.datatypes import (
|
||||
AuthenticationConfig,
|
||||
CustomAuthConfig,
|
||||
GitHubTokenAuthConfig,
|
||||
KubernetesAuthProviderConfig,
|
||||
OAuth2TokenAuthConfig,
|
||||
User,
|
||||
)
|
||||
|
@ -162,7 +164,7 @@ class OAuth2TokenAuthProvider(AuthProvider):
|
|||
auth=auth,
|
||||
timeout=10.0, # Add a reasonable timeout
|
||||
)
|
||||
if response.status_code != 200:
|
||||
if response.status_code != httpx.codes.OK:
|
||||
logger.warning(f"Token introspection failed with status code: {response.status_code}")
|
||||
raise ValueError(f"Token introspection failed: {response.status_code}")
|
||||
|
||||
|
@ -272,7 +274,7 @@ class CustomAuthProvider(AuthProvider):
|
|||
json=auth_request.model_dump(),
|
||||
timeout=10.0, # Add a reasonable timeout
|
||||
)
|
||||
if response.status_code != 200:
|
||||
if response.status_code != httpx.codes.OK:
|
||||
logger.warning(f"Authentication failed with status code: {response.status_code}")
|
||||
raise ValueError(f"Authentication failed: {response.status_code}")
|
||||
|
||||
|
@ -374,6 +376,89 @@ async def _get_github_user_info(access_token: str, github_api_base_url: str) ->
|
|||
}
|
||||
|
||||
|
||||
class KubernetesAuthProvider(AuthProvider):
|
||||
"""
|
||||
Kubernetes authentication provider that validates tokens using the Kubernetes SelfSubjectReview API.
|
||||
This provider integrates with Kubernetes API server by using the
|
||||
/apis/authentication.k8s.io/v1/selfsubjectreviews endpoint to validate tokens and extract user information.
|
||||
"""
|
||||
|
||||
def __init__(self, config: KubernetesAuthProviderConfig):
|
||||
self.config = config
|
||||
|
||||
def _httpx_verify_value(self) -> bool | str:
|
||||
"""
|
||||
Build the value for httpx's `verify` parameter.
|
||||
- False disables verification.
|
||||
- Path string points to a CA bundle.
|
||||
- True uses system defaults.
|
||||
"""
|
||||
if not self.config.verify_tls:
|
||||
return False
|
||||
if self.config.tls_cafile:
|
||||
return self.config.tls_cafile.as_posix()
|
||||
return True
|
||||
|
||||
async def validate_token(self, token: str, scope: dict | None = None) -> User:
|
||||
"""Validate a token using Kubernetes SelfSubjectReview API endpoint."""
|
||||
# Build the Kubernetes SelfSubjectReview API endpoint URL
|
||||
review_api_url = urljoin(self.config.api_server_url, "/apis/authentication.k8s.io/v1/selfsubjectreviews")
|
||||
|
||||
# Create SelfSubjectReview request body
|
||||
review_request = {"apiVersion": "authentication.k8s.io/v1", "kind": "SelfSubjectReview"}
|
||||
verify = self._httpx_verify_value()
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(verify=verify, timeout=10.0) as client:
|
||||
response = await client.post(
|
||||
review_api_url,
|
||||
json=review_request,
|
||||
headers={
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
if response.status_code == httpx.codes.UNAUTHORIZED:
|
||||
raise TokenValidationError("Invalid token")
|
||||
if response.status_code != httpx.codes.CREATED:
|
||||
logger.warning(f"Kubernetes SelfSubjectReview API failed with status code: {response.status_code}")
|
||||
raise TokenValidationError(f"Token validation failed: {response.status_code}")
|
||||
|
||||
review_response = response.json()
|
||||
# Extract user information from SelfSubjectReview response
|
||||
status = review_response.get("status", {})
|
||||
if not status:
|
||||
raise ValueError("No status found in SelfSubjectReview response")
|
||||
|
||||
user_info = status.get("userInfo", {})
|
||||
if not user_info:
|
||||
raise ValueError("No userInfo found in SelfSubjectReview response")
|
||||
|
||||
username = user_info.get("username")
|
||||
if not username:
|
||||
raise ValueError("No username found in SelfSubjectReview response")
|
||||
|
||||
# Build user attributes from Kubernetes user info
|
||||
user_attributes = get_attributes_from_claims(user_info, self.config.claims_mapping)
|
||||
|
||||
return User(
|
||||
principal=username,
|
||||
attributes=user_attributes,
|
||||
)
|
||||
|
||||
except httpx.TimeoutException:
|
||||
logger.warning("Kubernetes SelfSubjectReview API request timed out")
|
||||
raise ValueError("Token validation timeout") from None
|
||||
except Exception as e:
|
||||
logger.warning(f"Error during token validation: {str(e)}")
|
||||
raise ValueError(f"Token validation error: {str(e)}") from e
|
||||
|
||||
async def close(self):
|
||||
"""Close any resources."""
|
||||
pass
|
||||
|
||||
|
||||
def create_auth_provider(config: AuthenticationConfig) -> AuthProvider:
|
||||
"""Factory function to create the appropriate auth provider."""
|
||||
provider_config = config.provider_config
|
||||
|
@ -384,5 +469,7 @@ def create_auth_provider(config: AuthenticationConfig) -> AuthProvider:
|
|||
return OAuth2TokenAuthProvider(provider_config)
|
||||
elif isinstance(provider_config, GitHubTokenAuthConfig):
|
||||
return GitHubTokenAuthProvider(provider_config)
|
||||
elif isinstance(provider_config, KubernetesAuthProviderConfig):
|
||||
return KubernetesAuthProvider(provider_config)
|
||||
else:
|
||||
raise ValueError(f"Unknown authentication provider config type: {type(provider_config)}")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue