This commit is contained in:
ehhuang 2025-06-27 11:39:51 +02:00 committed by GitHub
commit 12265d4222
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 962 additions and 43 deletions

View file

@ -6,7 +6,7 @@
from enum import StrEnum
from pathlib import Path
from typing import Annotated, Any
from typing import Annotated, Any, Literal
from pydantic import BaseModel, Field, field_validator
@ -166,20 +166,90 @@ class AuthProviderType(StrEnum):
OAUTH2_TOKEN = "oauth2_token"
CUSTOM = "custom"
GITHUB_OAUTH = "github_oauth"
class AuthenticationConfig(BaseModel):
provider_type: AuthProviderType = Field(
...,
description="Type of authentication provider",
)
class OAuth2TokenAuthConfig(BaseModel):
"""Configuration for OAuth2 token authentication."""
type: Literal[AuthProviderType.OAUTH2_TOKEN] = AuthProviderType.OAUTH2_TOKEN
config: dict[str, Any] = Field(
...,
description="Provider-specific configuration",
description="OAuth2 token validation configuration",
)
access_policy: list[AccessRule] = Field(default=[], description="Rules for determining access to resources")
class CustomAuthConfig(BaseModel):
"""Configuration for custom authentication."""
type: Literal[AuthProviderType.CUSTOM] = AuthProviderType.CUSTOM
config: dict[str, Any] = Field(
...,
description="Custom authentication endpoint configuration",
)
access_policy: list[AccessRule] = Field(default=[], description="Rules for determining access to resources")
class GitHubAuthConfig(BaseModel):
"""Configuration for GitHub OAuth authentication.
This provider implements GitHub OAuth authentication with the following endpoints:
- /auth/github/login: Initiates OAuth flow
- Accepts optional redirect_url parameter (must be in allowed_redirect_urls)
- Redirects to GitHub for authorization
- /auth/github/callback: Handles OAuth callback (should be used as the callback URL for your GitHub OAuth App)
- Exchanges authorization code for access token
- Creates JWT with user info
- With redirect_url: Redirects to specified URL with JWT in fragment (#token=...)
- Without redirect_url: Returns JSON response with JWT
Ensure your GitHub OAuth App's callback URL matches your server's URL + /auth/github/callback.
"""
type: Literal[AuthProviderType.GITHUB_OAUTH] = AuthProviderType.GITHUB_OAUTH
# GitHub OAuth settings
github_client_id: str = Field(description="GitHub OAuth App Client ID")
github_client_secret: str = Field(description="GitHub OAuth App Client Secret")
allowed_redirect_urls: list[str] = Field(
default=[],
description="Allowed redirect URLs for OAuth callback, e.g. frontend URL",
)
# JWT configuration for token generation
jwt_secret: str = Field(description="Secret for signing JWT tokens")
jwt_algorithm: str = Field(default="HS256", description="JWT signing algorithm")
jwt_audience: str = Field(default="llama-stack", description="JWT audience")
jwt_issuer: str = Field(default="llama-stack-github", description="JWT issuer")
token_expiry: int = Field(default=86400, description="JWT token expiry in seconds")
# OAuth2 token validation config (for validating the JWTs we issue)
oauth2_config: dict[str, Any] = Field(
default_factory=dict,
description="Additional OAuth2 token validation configuration",
)
# Access control
access_policy: list[AccessRule] = Field(default=[], description="Rules for determining access to resources")
# Claims mapping for GitHub attributes
claims_mapping: dict[str, str] = Field(
default_factory=lambda: {
"sub": "roles", # GitHub username as role
},
description="Mapping from JWT claims to Llama Stack attributes",
)
# Discriminated union for authentication configurations
AuthenticationConfig = Annotated[
OAuth2TokenAuthConfig | CustomAuthConfig | GitHubAuthConfig, Field(discriminator="type")
]
class AuthenticationRequiredError(Exception):
pass

View file

@ -81,13 +81,23 @@ class AuthenticationMiddleware:
def __init__(self, app, auth_config: AuthenticationConfig):
self.app = app
self.auth_provider = create_auth_provider(auth_config)
self.public_paths = self.auth_provider.get_public_paths()
async def __call__(self, scope, receive, send):
if scope["type"] == "http":
# Skip authentication for public paths defined by the provider
path = scope.get("path", "")
if any(path.startswith(public_path) for public_path in self.public_paths):
return await self.app(scope, receive, send)
headers = dict(scope.get("headers", []))
auth_header = headers.get(b"authorization", b"").decode()
if not auth_header or not auth_header.startswith("Bearer "):
if not auth_header:
error_msg = self.auth_provider.get_auth_error_message(scope)
return await self._send_auth_error(send, error_msg)
if not auth_header.startswith("Bearer "):
return await self._send_auth_error(send, "Missing or invalid Authorization header")
token = auth_header.split("Bearer ", 1)[1]

View file

@ -10,13 +10,19 @@ from abc import ABC, abstractmethod
from asyncio import Lock
from pathlib import Path
from typing import Self
from urllib.parse import parse_qs
from urllib.parse import parse_qs, urlparse
import httpx
from jose import jwt
from pydantic import BaseModel, Field, field_validator, model_validator
from llama_stack.distribution.datatypes import AuthenticationConfig, AuthProviderType, User
from llama_stack.distribution.datatypes import (
AuthenticationConfig,
CustomAuthConfig,
GitHubAuthConfig,
OAuth2TokenAuthConfig,
User,
)
from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="auth")
@ -62,6 +68,24 @@ class AuthProvider(ABC):
"""Clean up any resources."""
pass
def setup_routes(self, app):
"""Setup any provider-specific routes (e.g., OAuth callbacks).
This is optional - providers that don't need special routes can skip this.
"""
return
def get_public_paths(self) -> list[str]:
"""Return a list of path prefixes that should bypass authentication.
This is optional - providers that don't have public paths return empty list.
"""
return []
def get_auth_error_message(self, scope: dict | None = None) -> str:
"""Return provider-specific authentication error message."""
return "Authentication required"
def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str]) -> dict[str, list[str]]:
attributes: dict[str, list[str]] = {}
@ -232,6 +256,17 @@ class OAuth2TokenAuthProvider(AuthProvider):
async def close(self):
pass
def get_auth_error_message(self, scope: dict | None = None) -> str:
"""Return OAuth2-specific authentication error message."""
if self.config.issuer:
return f"Authentication required. Please provide a valid OAuth2 Bearer token from {self.config.issuer}"
elif self.config.introspection:
# Extract domain from introspection URL for a cleaner message
domain = urlparse(self.config.introspection.url).netloc
return f"Authentication required. Please provide a valid OAuth2 Bearer token validated by {domain}"
else:
return "Authentication required. Please provide a valid OAuth2 Bearer token in the Authorization header"
async def _refresh_jwks(self) -> None:
"""
Refresh the JWKS cache.
@ -338,15 +373,25 @@ class CustomAuthProvider(AuthProvider):
await self._client.aclose()
self._client = None
def get_auth_error_message(self, scope: dict | None = None) -> str:
"""Return custom auth provider-specific authentication error message."""
# Extract domain from endpoint URL for a cleaner message
domain = urlparse(self.config.endpoint).netloc
if domain:
return f"Authentication required. Please provide your API key as a Bearer token (validated by {domain})"
else:
return "Authentication required. Please provide your API key as a Bearer token in the Authorization header"
def create_auth_provider(config: AuthenticationConfig) -> AuthProvider:
"""Factory function to create the appropriate auth provider."""
provider_type = config.provider_type.lower()
if provider_type == "custom":
if isinstance(config, CustomAuthConfig):
return CustomAuthProvider(CustomAuthProviderConfig.model_validate(config.config))
elif provider_type == "oauth2_token":
elif isinstance(config, OAuth2TokenAuthConfig):
return OAuth2TokenAuthProvider(OAuth2TokenAuthProviderConfig.model_validate(config.config))
elif isinstance(config, GitHubAuthConfig):
from .github_oauth_auth_provider import GitHubAuthProvider
return GitHubAuthProvider(config)
else:
supported_providers = ", ".join([t.value for t in AuthProviderType])
raise ValueError(f"Unsupported auth provider type: {provider_type}. Supported types are: {supported_providers}")
raise ValueError(f"Unknown authentication config type: {type(config)}")

View file

@ -0,0 +1,109 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import secrets
from datetime import UTC, datetime
from fastapi import APIRouter, HTTPException, Request
from fastapi.responses import JSONResponse, RedirectResponse
from llama_stack.distribution.datatypes import GitHubAuthConfig
from llama_stack.log import get_logger
from .github_oauth_auth_provider import (
GITHUB_CALLBACK_PATH,
GITHUB_LOGIN_PATH,
GitHubAuthProvider,
)
logger = get_logger(name=__name__, category="auth_routes")
def create_github_auth_router(config: GitHubAuthConfig) -> APIRouter:
"""Create and configure GitHub authentication router."""
auth_provider = GitHubAuthProvider(config)
router = APIRouter()
oauth_states: dict[str, dict] = {}
def cleanup_expired_states() -> None:
"""Remove expired OAuth states."""
now = datetime.now(UTC)
expired_states = [
state
for state, data in oauth_states.items()
if (now - data["created_at"]).seconds > 300 # 5 minutes
]
for state in expired_states:
oauth_states.pop(state, None)
@router.get(GITHUB_LOGIN_PATH)
async def github_login(request: Request, redirect_url: str | None = None):
"""Initiate GitHub OAuth flow."""
cleanup_expired_states()
# Validate redirect URL if provided
if redirect_url:
if not auth_provider.config.allowed_redirect_urls:
raise HTTPException(status_code=400, detail="Redirect URLs not configured")
if redirect_url not in auth_provider.config.allowed_redirect_urls:
raise HTTPException(status_code=400, detail="Invalid redirect URL")
state = secrets.token_urlsafe(32)
oauth_states[state] = {
"created_at": datetime.now(UTC),
"redirect_url": redirect_url,
}
# Get authorization URL
auth_url = auth_provider.get_authorization_url(state, request)
logger.debug(f"Redirecting to GitHub OAuth: {auth_url}")
return RedirectResponse(url=auth_url)
@router.get(GITHUB_CALLBACK_PATH)
async def github_callback(code: str, state: str, request: Request):
"""Handle GitHub OAuth callback."""
# Validate state parameter
if state not in oauth_states:
logger.warning(f"Invalid OAuth state received: {state}")
raise HTTPException(status_code=400, detail="Invalid state parameter")
state_data = oauth_states.pop(state)
# Check state expiry
if (datetime.now(UTC) - state_data["created_at"]).seconds > 300:
logger.warning("OAuth state expired")
raise HTTPException(status_code=400, detail="State expired")
try:
# Complete OAuth flow
access_token = await auth_provider.complete_oauth_flow(code, request)
logger.info("GitHub OAuth successful")
redirect_url = state_data.get("redirect_url")
if redirect_url:
# Validate redirect URL from state
if redirect_url not in auth_provider.config.allowed_redirect_urls:
raise HTTPException(status_code=400, detail="Invalid redirect URL in state")
import urllib.parse
params = urllib.parse.urlencode({"token": access_token})
return RedirectResponse(url=f"{redirect_url}#{params}")
else:
# For API calls, only return the access token
return JSONResponse(content={"access_token": access_token})
except ValueError as e:
logger.error(f"GitHub OAuth error: {e}")
raise HTTPException(status_code=400, detail=str(e)) from e
except Exception as e:
logger.exception("Unexpected error during GitHub OAuth")
raise HTTPException(status_code=500, detail="Authentication failed") from e
return router

View file

@ -0,0 +1,203 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
from datetime import UTC, datetime, timedelta
from typing import Any
from urllib.parse import urlencode
import httpx
from fastapi import Request
from jose import jwt
from llama_stack.distribution.datatypes import GitHubAuthConfig, User
from llama_stack.log import get_logger
from .auth_providers import AuthProvider, get_attributes_from_claims
logger = get_logger(name=__name__, category="github_auth")
# GitHub API constants
GITHUB_API_BASE_URL = "https://api.github.com"
GITHUB_OAUTH_BASE_URL = "https://github.com"
# GitHub OAuth route paths
GITHUB_LOGIN_PATH = "/auth/github/login"
GITHUB_CALLBACK_PATH = "/auth/github/callback"
class GitHubAuthProvider(AuthProvider):
"""Authentication provider for GitHub OAuth flow and JWT validation."""
def __init__(self, config: GitHubAuthConfig):
self.config = config
async def validate_token(self, token: str, scope: dict | None = None) -> User:
"""Validate a GitHub-issued JWT token."""
try:
claims = self.verify_jwt(token)
principal = claims["sub"]
attributes = get_attributes_from_claims(claims, self.config.claims_mapping)
return User(principal=principal, attributes=attributes)
except Exception as e:
logger.exception("Error validating GitHub JWT")
raise ValueError(f"Invalid GitHub JWT token: {str(e)}") from e
async def close(self):
"""Clean up any resources."""
pass
def setup_routes(self, app):
"""Setup GitHub OAuth routes."""
from .auth_routes import create_github_auth_router
github_router = create_github_auth_router(self.config)
app.include_router(github_router)
def get_public_paths(self) -> list[str]:
"""GitHub OAuth paths that don't require authentication."""
return ["/auth/github/"]
def get_auth_error_message(self, scope: dict | None = None) -> str:
"""Return GitHub-specific authentication error message."""
if scope:
headers = dict(scope.get("headers", []))
host = headers.get(b"host", b"").decode()
scheme = scope.get("scheme", "http")
if host:
auth_url = f"{scheme}://{host}{GITHUB_LOGIN_PATH}"
return f"Authentication required. Please authenticate via GitHub at {auth_url}"
return f"Authentication required. Please authenticate by visiting {GITHUB_LOGIN_PATH} to start the authentication flow."
# OAuth flow methods
def get_authorization_url(self, state: str, request: Request) -> str:
"""Generate GitHub OAuth authorization URL."""
params = {
"client_id": self.config.github_client_id,
"redirect_uri": _build_callback_url(request),
"scope": "read:user read:org",
"state": state,
}
return f"{GITHUB_OAUTH_BASE_URL}/login/oauth/authorize?" + urlencode(params)
async def complete_oauth_flow(self, code: str, request: Request) -> Any:
"""Complete the GitHub OAuth flow and return JWT access token."""
# Exchange code for token
logger.debug("Exchanging code for GitHub access token")
token_data = await self._exchange_code_for_token(code, request)
if "error" in token_data:
raise ValueError(f"GitHub OAuth error: {token_data.get('error_description', token_data['error'])}")
access_token = token_data["access_token"]
# Get user info
logger.debug("Fetching GitHub user info")
github_info = await self._get_user_info(access_token)
# Create JWT
logger.debug(f"Creating JWT for user: {github_info['user']['login']}")
jwt_token = self._create_jwt_token(github_info)
return jwt_token
def verify_jwt(self, token: str) -> Any:
"""Verify and decode a GitHub-issued JWT token."""
try:
payload = jwt.decode(
token,
self.config.jwt_secret,
algorithms=[self.config.jwt_algorithm],
audience=self.config.jwt_audience,
issuer=self.config.jwt_issuer,
)
return payload
except jwt.JWTError as e:
raise ValueError(f"Invalid JWT token: {e}") from e
# Private helper methods
async def _exchange_code_for_token(self, code: str, request: Request) -> Any:
"""Exchange authorization code for GitHub access token."""
async with httpx.AsyncClient() as client:
response = await client.post(
f"{GITHUB_OAUTH_BASE_URL}/login/oauth/access_token",
json={
"client_id": self.config.github_client_id,
"client_secret": self.config.github_client_secret,
"code": code,
"redirect_uri": _build_callback_url(request),
},
headers={"Accept": "application/json"},
timeout=10.0,
)
response.raise_for_status()
return response.json()
async def _get_user_info(self, access_token: str) -> dict:
"""Fetch user info and organizations from GitHub."""
headers = {
"Authorization": f"Bearer {access_token}",
"Accept": "application/vnd.github.v3+json",
}
async with httpx.AsyncClient() as client:
# Fetch user and orgs in parallel
user_task = client.get(f"{GITHUB_API_BASE_URL}/user", headers=headers, timeout=10.0)
orgs_task = client.get(f"{GITHUB_API_BASE_URL}/user/orgs", headers=headers, timeout=10.0)
user_response, orgs_response = await asyncio.gather(user_task, orgs_task)
user_response.raise_for_status()
orgs_response.raise_for_status()
user_data = user_response.json()
orgs_data = orgs_response.json()
# Extract organization names
organizations = [org["login"] for org in orgs_data]
return {
"user": user_data,
"organizations": organizations,
}
def _create_jwt_token(self, github_info: dict) -> Any:
"""Create JWT token with GitHub user information."""
user = github_info["user"]
orgs = github_info["organizations"]
teams = github_info.get("teams", [])
# Create JWT claims that map to Llama Stack attributes
now = datetime.now(UTC)
claims = {
"sub": user["login"],
"aud": self.config.jwt_audience,
"iss": self.config.jwt_issuer,
"exp": now + timedelta(seconds=self.config.token_expiry),
"iat": now,
"nbf": now,
# Custom claims that will be mapped by claims_mapping
"github_username": user["login"],
"github_user_id": str(user["id"]),
"github_orgs": orgs,
"github_teams": teams,
"email": user.get("email"),
"name": user.get("name"),
"avatar_url": user.get("avatar_url"),
}
return jwt.encode(claims, self.config.jwt_secret, algorithm=self.config.jwt_algorithm)
def _build_callback_url(request: Request) -> str:
"""Build the GitHub OAuth callback URL from the current request."""
callback_path = str(request.url_for("github_callback"))
return callback_path

View file

@ -31,7 +31,11 @@ from openai import BadRequestError
from pydantic import BaseModel, ValidationError
from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.distribution.datatypes import AuthenticationRequiredError, LoggingConfig, StackRunConfig
from llama_stack.distribution.datatypes import (
AuthenticationRequiredError,
LoggingConfig,
StackRunConfig,
)
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
from llama_stack.distribution.request_headers import PROVIDER_DATA_VAR, User, request_provider_data_context
from llama_stack.distribution.resolver import InvalidProviderError
@ -61,6 +65,7 @@ from llama_stack.providers.utils.telemetry.tracing import (
)
from .auth import AuthenticationMiddleware
from .auth_providers import create_auth_provider
from .quota import QuotaMiddleware
REPO_ROOT = Path(__file__).parent.parent.parent.parent
@ -448,9 +453,12 @@ def main(args: argparse.Namespace | None = None):
if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"):
app.add_middleware(ClientVersionMiddleware)
# Add authentication middleware if configured
if config.server.auth:
logger.info(f"Enabling authentication with provider: {config.server.auth.provider_type.value}")
logger.info(f"Enabling authentication with provider: {config.server.auth.type.value}")
auth_provider = create_auth_provider(config.server.auth)
auth_provider.setup_routes(app)
app.add_middleware(AuthenticationMiddleware, auth_config=config.server.auth)
else:
if config.server.quota: