This commit is contained in:
ehhuang 2025-06-27 11:39:51 +02:00 committed by GitHub
commit 12265d4222
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 962 additions and 43 deletions

View file

@ -73,7 +73,7 @@ jobs:
server:
port: 8321
EOF
yq eval '.server.auth = {"provider_type": "${{ matrix.auth-provider }}"}' -i $run_dir/run.yaml
yq eval '.server.auth = {"type": "${{ matrix.auth-provider }}"}' -i $run_dir/run.yaml
yq eval '.server.auth.config = {"tls_cafile": "${{ env.KUBERNETES_CA_CERT_PATH }}", "issuer": "${{ env.KUBERNETES_ISSUER }}", "audience": "${{ env.KUBERNETES_AUDIENCE }}"}' -i $run_dir/run.yaml
yq eval '.server.auth.config.jwks = {"uri": "${{ env.KUBERNETES_API_SERVER_URL }}", "token": "${{ env.TOKEN }}"}' -i $run_dir/run.yaml
cat $run_dir/run.yaml

View file

@ -6,7 +6,7 @@
from enum import StrEnum
from pathlib import Path
from typing import Annotated, Any
from typing import Annotated, Any, Literal
from pydantic import BaseModel, Field, field_validator
@ -166,20 +166,90 @@ class AuthProviderType(StrEnum):
OAUTH2_TOKEN = "oauth2_token"
CUSTOM = "custom"
GITHUB_OAUTH = "github_oauth"
class AuthenticationConfig(BaseModel):
provider_type: AuthProviderType = Field(
...,
description="Type of authentication provider",
)
class OAuth2TokenAuthConfig(BaseModel):
"""Configuration for OAuth2 token authentication."""
type: Literal[AuthProviderType.OAUTH2_TOKEN] = AuthProviderType.OAUTH2_TOKEN
config: dict[str, Any] = Field(
...,
description="Provider-specific configuration",
description="OAuth2 token validation configuration",
)
access_policy: list[AccessRule] = Field(default=[], description="Rules for determining access to resources")
class CustomAuthConfig(BaseModel):
"""Configuration for custom authentication."""
type: Literal[AuthProviderType.CUSTOM] = AuthProviderType.CUSTOM
config: dict[str, Any] = Field(
...,
description="Custom authentication endpoint configuration",
)
access_policy: list[AccessRule] = Field(default=[], description="Rules for determining access to resources")
class GitHubAuthConfig(BaseModel):
"""Configuration for GitHub OAuth authentication.
This provider implements GitHub OAuth authentication with the following endpoints:
- /auth/github/login: Initiates OAuth flow
- Accepts optional redirect_url parameter (must be in allowed_redirect_urls)
- Redirects to GitHub for authorization
- /auth/github/callback: Handles OAuth callback (should be used as the callback URL for your GitHub OAuth App)
- Exchanges authorization code for access token
- Creates JWT with user info
- With redirect_url: Redirects to specified URL with JWT in fragment (#token=...)
- Without redirect_url: Returns JSON response with JWT
Ensure your GitHub OAuth App's callback URL matches your server's URL + /auth/github/callback.
"""
type: Literal[AuthProviderType.GITHUB_OAUTH] = AuthProviderType.GITHUB_OAUTH
# GitHub OAuth settings
github_client_id: str = Field(description="GitHub OAuth App Client ID")
github_client_secret: str = Field(description="GitHub OAuth App Client Secret")
allowed_redirect_urls: list[str] = Field(
default=[],
description="Allowed redirect URLs for OAuth callback, e.g. frontend URL",
)
# JWT configuration for token generation
jwt_secret: str = Field(description="Secret for signing JWT tokens")
jwt_algorithm: str = Field(default="HS256", description="JWT signing algorithm")
jwt_audience: str = Field(default="llama-stack", description="JWT audience")
jwt_issuer: str = Field(default="llama-stack-github", description="JWT issuer")
token_expiry: int = Field(default=86400, description="JWT token expiry in seconds")
# OAuth2 token validation config (for validating the JWTs we issue)
oauth2_config: dict[str, Any] = Field(
default_factory=dict,
description="Additional OAuth2 token validation configuration",
)
# Access control
access_policy: list[AccessRule] = Field(default=[], description="Rules for determining access to resources")
# Claims mapping for GitHub attributes
claims_mapping: dict[str, str] = Field(
default_factory=lambda: {
"sub": "roles", # GitHub username as role
},
description="Mapping from JWT claims to Llama Stack attributes",
)
# Discriminated union for authentication configurations
AuthenticationConfig = Annotated[
OAuth2TokenAuthConfig | CustomAuthConfig | GitHubAuthConfig, Field(discriminator="type")
]
class AuthenticationRequiredError(Exception):
pass

View file

@ -81,13 +81,23 @@ class AuthenticationMiddleware:
def __init__(self, app, auth_config: AuthenticationConfig):
self.app = app
self.auth_provider = create_auth_provider(auth_config)
self.public_paths = self.auth_provider.get_public_paths()
async def __call__(self, scope, receive, send):
if scope["type"] == "http":
# Skip authentication for public paths defined by the provider
path = scope.get("path", "")
if any(path.startswith(public_path) for public_path in self.public_paths):
return await self.app(scope, receive, send)
headers = dict(scope.get("headers", []))
auth_header = headers.get(b"authorization", b"").decode()
if not auth_header or not auth_header.startswith("Bearer "):
if not auth_header:
error_msg = self.auth_provider.get_auth_error_message(scope)
return await self._send_auth_error(send, error_msg)
if not auth_header.startswith("Bearer "):
return await self._send_auth_error(send, "Missing or invalid Authorization header")
token = auth_header.split("Bearer ", 1)[1]

View file

@ -10,13 +10,19 @@ from abc import ABC, abstractmethod
from asyncio import Lock
from pathlib import Path
from typing import Self
from urllib.parse import parse_qs
from urllib.parse import parse_qs, urlparse
import httpx
from jose import jwt
from pydantic import BaseModel, Field, field_validator, model_validator
from llama_stack.distribution.datatypes import AuthenticationConfig, AuthProviderType, User
from llama_stack.distribution.datatypes import (
AuthenticationConfig,
CustomAuthConfig,
GitHubAuthConfig,
OAuth2TokenAuthConfig,
User,
)
from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="auth")
@ -62,6 +68,24 @@ class AuthProvider(ABC):
"""Clean up any resources."""
pass
def setup_routes(self, app):
"""Setup any provider-specific routes (e.g., OAuth callbacks).
This is optional - providers that don't need special routes can skip this.
"""
return
def get_public_paths(self) -> list[str]:
"""Return a list of path prefixes that should bypass authentication.
This is optional - providers that don't have public paths return empty list.
"""
return []
def get_auth_error_message(self, scope: dict | None = None) -> str:
"""Return provider-specific authentication error message."""
return "Authentication required"
def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str]) -> dict[str, list[str]]:
attributes: dict[str, list[str]] = {}
@ -232,6 +256,17 @@ class OAuth2TokenAuthProvider(AuthProvider):
async def close(self):
pass
def get_auth_error_message(self, scope: dict | None = None) -> str:
"""Return OAuth2-specific authentication error message."""
if self.config.issuer:
return f"Authentication required. Please provide a valid OAuth2 Bearer token from {self.config.issuer}"
elif self.config.introspection:
# Extract domain from introspection URL for a cleaner message
domain = urlparse(self.config.introspection.url).netloc
return f"Authentication required. Please provide a valid OAuth2 Bearer token validated by {domain}"
else:
return "Authentication required. Please provide a valid OAuth2 Bearer token in the Authorization header"
async def _refresh_jwks(self) -> None:
"""
Refresh the JWKS cache.
@ -338,15 +373,25 @@ class CustomAuthProvider(AuthProvider):
await self._client.aclose()
self._client = None
def get_auth_error_message(self, scope: dict | None = None) -> str:
"""Return custom auth provider-specific authentication error message."""
# Extract domain from endpoint URL for a cleaner message
domain = urlparse(self.config.endpoint).netloc
if domain:
return f"Authentication required. Please provide your API key as a Bearer token (validated by {domain})"
else:
return "Authentication required. Please provide your API key as a Bearer token in the Authorization header"
def create_auth_provider(config: AuthenticationConfig) -> AuthProvider:
"""Factory function to create the appropriate auth provider."""
provider_type = config.provider_type.lower()
if provider_type == "custom":
if isinstance(config, CustomAuthConfig):
return CustomAuthProvider(CustomAuthProviderConfig.model_validate(config.config))
elif provider_type == "oauth2_token":
elif isinstance(config, OAuth2TokenAuthConfig):
return OAuth2TokenAuthProvider(OAuth2TokenAuthProviderConfig.model_validate(config.config))
elif isinstance(config, GitHubAuthConfig):
from .github_oauth_auth_provider import GitHubAuthProvider
return GitHubAuthProvider(config)
else:
supported_providers = ", ".join([t.value for t in AuthProviderType])
raise ValueError(f"Unsupported auth provider type: {provider_type}. Supported types are: {supported_providers}")
raise ValueError(f"Unknown authentication config type: {type(config)}")

View file

@ -0,0 +1,109 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import secrets
from datetime import UTC, datetime
from fastapi import APIRouter, HTTPException, Request
from fastapi.responses import JSONResponse, RedirectResponse
from llama_stack.distribution.datatypes import GitHubAuthConfig
from llama_stack.log import get_logger
from .github_oauth_auth_provider import (
GITHUB_CALLBACK_PATH,
GITHUB_LOGIN_PATH,
GitHubAuthProvider,
)
logger = get_logger(name=__name__, category="auth_routes")
def create_github_auth_router(config: GitHubAuthConfig) -> APIRouter:
"""Create and configure GitHub authentication router."""
auth_provider = GitHubAuthProvider(config)
router = APIRouter()
oauth_states: dict[str, dict] = {}
def cleanup_expired_states() -> None:
"""Remove expired OAuth states."""
now = datetime.now(UTC)
expired_states = [
state
for state, data in oauth_states.items()
if (now - data["created_at"]).seconds > 300 # 5 minutes
]
for state in expired_states:
oauth_states.pop(state, None)
@router.get(GITHUB_LOGIN_PATH)
async def github_login(request: Request, redirect_url: str | None = None):
"""Initiate GitHub OAuth flow."""
cleanup_expired_states()
# Validate redirect URL if provided
if redirect_url:
if not auth_provider.config.allowed_redirect_urls:
raise HTTPException(status_code=400, detail="Redirect URLs not configured")
if redirect_url not in auth_provider.config.allowed_redirect_urls:
raise HTTPException(status_code=400, detail="Invalid redirect URL")
state = secrets.token_urlsafe(32)
oauth_states[state] = {
"created_at": datetime.now(UTC),
"redirect_url": redirect_url,
}
# Get authorization URL
auth_url = auth_provider.get_authorization_url(state, request)
logger.debug(f"Redirecting to GitHub OAuth: {auth_url}")
return RedirectResponse(url=auth_url)
@router.get(GITHUB_CALLBACK_PATH)
async def github_callback(code: str, state: str, request: Request):
"""Handle GitHub OAuth callback."""
# Validate state parameter
if state not in oauth_states:
logger.warning(f"Invalid OAuth state received: {state}")
raise HTTPException(status_code=400, detail="Invalid state parameter")
state_data = oauth_states.pop(state)
# Check state expiry
if (datetime.now(UTC) - state_data["created_at"]).seconds > 300:
logger.warning("OAuth state expired")
raise HTTPException(status_code=400, detail="State expired")
try:
# Complete OAuth flow
access_token = await auth_provider.complete_oauth_flow(code, request)
logger.info("GitHub OAuth successful")
redirect_url = state_data.get("redirect_url")
if redirect_url:
# Validate redirect URL from state
if redirect_url not in auth_provider.config.allowed_redirect_urls:
raise HTTPException(status_code=400, detail="Invalid redirect URL in state")
import urllib.parse
params = urllib.parse.urlencode({"token": access_token})
return RedirectResponse(url=f"{redirect_url}#{params}")
else:
# For API calls, only return the access token
return JSONResponse(content={"access_token": access_token})
except ValueError as e:
logger.error(f"GitHub OAuth error: {e}")
raise HTTPException(status_code=400, detail=str(e)) from e
except Exception as e:
logger.exception("Unexpected error during GitHub OAuth")
raise HTTPException(status_code=500, detail="Authentication failed") from e
return router

View file

@ -0,0 +1,203 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
from datetime import UTC, datetime, timedelta
from typing import Any
from urllib.parse import urlencode
import httpx
from fastapi import Request
from jose import jwt
from llama_stack.distribution.datatypes import GitHubAuthConfig, User
from llama_stack.log import get_logger
from .auth_providers import AuthProvider, get_attributes_from_claims
logger = get_logger(name=__name__, category="github_auth")
# GitHub API constants
GITHUB_API_BASE_URL = "https://api.github.com"
GITHUB_OAUTH_BASE_URL = "https://github.com"
# GitHub OAuth route paths
GITHUB_LOGIN_PATH = "/auth/github/login"
GITHUB_CALLBACK_PATH = "/auth/github/callback"
class GitHubAuthProvider(AuthProvider):
"""Authentication provider for GitHub OAuth flow and JWT validation."""
def __init__(self, config: GitHubAuthConfig):
self.config = config
async def validate_token(self, token: str, scope: dict | None = None) -> User:
"""Validate a GitHub-issued JWT token."""
try:
claims = self.verify_jwt(token)
principal = claims["sub"]
attributes = get_attributes_from_claims(claims, self.config.claims_mapping)
return User(principal=principal, attributes=attributes)
except Exception as e:
logger.exception("Error validating GitHub JWT")
raise ValueError(f"Invalid GitHub JWT token: {str(e)}") from e
async def close(self):
"""Clean up any resources."""
pass
def setup_routes(self, app):
"""Setup GitHub OAuth routes."""
from .auth_routes import create_github_auth_router
github_router = create_github_auth_router(self.config)
app.include_router(github_router)
def get_public_paths(self) -> list[str]:
"""GitHub OAuth paths that don't require authentication."""
return ["/auth/github/"]
def get_auth_error_message(self, scope: dict | None = None) -> str:
"""Return GitHub-specific authentication error message."""
if scope:
headers = dict(scope.get("headers", []))
host = headers.get(b"host", b"").decode()
scheme = scope.get("scheme", "http")
if host:
auth_url = f"{scheme}://{host}{GITHUB_LOGIN_PATH}"
return f"Authentication required. Please authenticate via GitHub at {auth_url}"
return f"Authentication required. Please authenticate by visiting {GITHUB_LOGIN_PATH} to start the authentication flow."
# OAuth flow methods
def get_authorization_url(self, state: str, request: Request) -> str:
"""Generate GitHub OAuth authorization URL."""
params = {
"client_id": self.config.github_client_id,
"redirect_uri": _build_callback_url(request),
"scope": "read:user read:org",
"state": state,
}
return f"{GITHUB_OAUTH_BASE_URL}/login/oauth/authorize?" + urlencode(params)
async def complete_oauth_flow(self, code: str, request: Request) -> Any:
"""Complete the GitHub OAuth flow and return JWT access token."""
# Exchange code for token
logger.debug("Exchanging code for GitHub access token")
token_data = await self._exchange_code_for_token(code, request)
if "error" in token_data:
raise ValueError(f"GitHub OAuth error: {token_data.get('error_description', token_data['error'])}")
access_token = token_data["access_token"]
# Get user info
logger.debug("Fetching GitHub user info")
github_info = await self._get_user_info(access_token)
# Create JWT
logger.debug(f"Creating JWT for user: {github_info['user']['login']}")
jwt_token = self._create_jwt_token(github_info)
return jwt_token
def verify_jwt(self, token: str) -> Any:
"""Verify and decode a GitHub-issued JWT token."""
try:
payload = jwt.decode(
token,
self.config.jwt_secret,
algorithms=[self.config.jwt_algorithm],
audience=self.config.jwt_audience,
issuer=self.config.jwt_issuer,
)
return payload
except jwt.JWTError as e:
raise ValueError(f"Invalid JWT token: {e}") from e
# Private helper methods
async def _exchange_code_for_token(self, code: str, request: Request) -> Any:
"""Exchange authorization code for GitHub access token."""
async with httpx.AsyncClient() as client:
response = await client.post(
f"{GITHUB_OAUTH_BASE_URL}/login/oauth/access_token",
json={
"client_id": self.config.github_client_id,
"client_secret": self.config.github_client_secret,
"code": code,
"redirect_uri": _build_callback_url(request),
},
headers={"Accept": "application/json"},
timeout=10.0,
)
response.raise_for_status()
return response.json()
async def _get_user_info(self, access_token: str) -> dict:
"""Fetch user info and organizations from GitHub."""
headers = {
"Authorization": f"Bearer {access_token}",
"Accept": "application/vnd.github.v3+json",
}
async with httpx.AsyncClient() as client:
# Fetch user and orgs in parallel
user_task = client.get(f"{GITHUB_API_BASE_URL}/user", headers=headers, timeout=10.0)
orgs_task = client.get(f"{GITHUB_API_BASE_URL}/user/orgs", headers=headers, timeout=10.0)
user_response, orgs_response = await asyncio.gather(user_task, orgs_task)
user_response.raise_for_status()
orgs_response.raise_for_status()
user_data = user_response.json()
orgs_data = orgs_response.json()
# Extract organization names
organizations = [org["login"] for org in orgs_data]
return {
"user": user_data,
"organizations": organizations,
}
def _create_jwt_token(self, github_info: dict) -> Any:
"""Create JWT token with GitHub user information."""
user = github_info["user"]
orgs = github_info["organizations"]
teams = github_info.get("teams", [])
# Create JWT claims that map to Llama Stack attributes
now = datetime.now(UTC)
claims = {
"sub": user["login"],
"aud": self.config.jwt_audience,
"iss": self.config.jwt_issuer,
"exp": now + timedelta(seconds=self.config.token_expiry),
"iat": now,
"nbf": now,
# Custom claims that will be mapped by claims_mapping
"github_username": user["login"],
"github_user_id": str(user["id"]),
"github_orgs": orgs,
"github_teams": teams,
"email": user.get("email"),
"name": user.get("name"),
"avatar_url": user.get("avatar_url"),
}
return jwt.encode(claims, self.config.jwt_secret, algorithm=self.config.jwt_algorithm)
def _build_callback_url(request: Request) -> str:
"""Build the GitHub OAuth callback URL from the current request."""
callback_path = str(request.url_for("github_callback"))
return callback_path

View file

@ -31,7 +31,11 @@ from openai import BadRequestError
from pydantic import BaseModel, ValidationError
from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.distribution.datatypes import AuthenticationRequiredError, LoggingConfig, StackRunConfig
from llama_stack.distribution.datatypes import (
AuthenticationRequiredError,
LoggingConfig,
StackRunConfig,
)
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
from llama_stack.distribution.request_headers import PROVIDER_DATA_VAR, User, request_provider_data_context
from llama_stack.distribution.resolver import InvalidProviderError
@ -61,6 +65,7 @@ from llama_stack.providers.utils.telemetry.tracing import (
)
from .auth import AuthenticationMiddleware
from .auth_providers import create_auth_provider
from .quota import QuotaMiddleware
REPO_ROOT = Path(__file__).parent.parent.parent.parent
@ -448,9 +453,12 @@ def main(args: argparse.Namespace | None = None):
if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"):
app.add_middleware(ClientVersionMiddleware)
# Add authentication middleware if configured
if config.server.auth:
logger.info(f"Enabling authentication with provider: {config.server.auth.provider_type.value}")
logger.info(f"Enabling authentication with provider: {config.server.auth.type.value}")
auth_provider = create_auth_provider(config.server.auth)
auth_provider.setup_routes(app)
app.add_middleware(AuthenticationMiddleware, auth_config=config.server.auth)
else:
if config.server.quota:

View file

@ -11,10 +11,13 @@ import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from llama_stack.distribution.datatypes import AuthenticationConfig
from llama_stack.distribution.datatypes import (
AuthProviderType,
CustomAuthConfig,
OAuth2TokenAuthConfig,
)
from llama_stack.distribution.server.auth import AuthenticationMiddleware
from llama_stack.distribution.server.auth_providers import (
AuthProviderType,
get_attributes_from_claims,
)
@ -60,8 +63,8 @@ def invalid_token():
@pytest.fixture
def http_app(mock_auth_endpoint):
app = FastAPI()
auth_config = AuthenticationConfig(
provider_type=AuthProviderType.CUSTOM,
auth_config = CustomAuthConfig(
type=AuthProviderType.CUSTOM,
config={"endpoint": mock_auth_endpoint},
)
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
@ -76,9 +79,9 @@ def http_app(mock_auth_endpoint):
@pytest.fixture
def k8s_app():
app = FastAPI()
auth_config = AuthenticationConfig(
provider_type=AuthProviderType.KUBERNETES,
config={"api_server_url": "https://kubernetes.default.svc"},
auth_config = CustomAuthConfig(
type=AuthProviderType.CUSTOM,
config={"provider_type": "kubernetes", "api_server_url": "https://kubernetes.default.svc"},
)
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
@ -116,8 +119,8 @@ def mock_scope():
@pytest.fixture
def mock_http_middleware(mock_auth_endpoint):
mock_app = AsyncMock()
auth_config = AuthenticationConfig(
provider_type=AuthProviderType.CUSTOM,
auth_config = CustomAuthConfig(
type=AuthProviderType.CUSTOM,
config={"endpoint": mock_auth_endpoint},
)
return AuthenticationMiddleware(mock_app, auth_config), mock_app
@ -126,9 +129,9 @@ def mock_http_middleware(mock_auth_endpoint):
@pytest.fixture
def mock_k8s_middleware():
mock_app = AsyncMock()
auth_config = AuthenticationConfig(
provider_type=AuthProviderType.KUBERNETES,
config={"api_server_url": "https://kubernetes.default.svc"},
auth_config = CustomAuthConfig(
type=AuthProviderType.CUSTOM,
config={"provider_type": "kubernetes", "api_server_url": "https://kubernetes.default.svc"},
)
return AuthenticationMiddleware(mock_app, auth_config), mock_app
@ -161,7 +164,8 @@ async def mock_post_exception(*args, **kwargs):
def test_missing_auth_header(http_client):
response = http_client.get("/test")
assert response.status_code == 401
assert "Missing or invalid Authorization header" in response.json()["error"]["message"]
assert "Authentication required" in response.json()["error"]["message"]
assert "validated by mock-auth-service" in response.json()["error"]["message"]
def test_invalid_auth_header_format(http_client):
@ -261,8 +265,8 @@ async def test_http_middleware_with_access_attributes(mock_http_middleware, mock
@pytest.fixture
def oauth2_app():
app = FastAPI()
auth_config = AuthenticationConfig(
provider_type=AuthProviderType.OAUTH2_TOKEN,
auth_config = OAuth2TokenAuthConfig(
type=AuthProviderType.OAUTH2_TOKEN,
config={
"jwks": {
"uri": "http://mock-authz-service/token/introspect",
@ -288,7 +292,8 @@ def oauth2_client(oauth2_app):
def test_missing_auth_header_oauth2(oauth2_client):
response = oauth2_client.get("/test")
assert response.status_code == 401
assert "Missing or invalid Authorization header" in response.json()["error"]["message"]
assert "Authentication required" in response.json()["error"]["message"]
assert "OAuth2 Bearer token" in response.json()["error"]["message"]
def test_invalid_auth_header_format_oauth2(oauth2_client):
@ -357,8 +362,8 @@ async def mock_auth_jwks_response(*args, **kwargs):
@pytest.fixture
def oauth2_app_with_jwks_token():
app = FastAPI()
auth_config = AuthenticationConfig(
provider_type=AuthProviderType.OAUTH2_TOKEN,
auth_config = OAuth2TokenAuthConfig(
type=AuthProviderType.OAUTH2_TOKEN,
config={
"jwks": {
"uri": "http://mock-authz-service/token/introspect",
@ -448,8 +453,8 @@ def mock_introspection_endpoint():
@pytest.fixture
def introspection_app(mock_introspection_endpoint):
app = FastAPI()
auth_config = AuthenticationConfig(
provider_type=AuthProviderType.OAUTH2_TOKEN,
auth_config = OAuth2TokenAuthConfig(
type=AuthProviderType.OAUTH2_TOKEN,
config={
"jwks": None,
"introspection": {"url": mock_introspection_endpoint, "client_id": "myclient", "client_secret": "abcdefg"},
@ -467,8 +472,8 @@ def introspection_app(mock_introspection_endpoint):
@pytest.fixture
def introspection_app_with_custom_mapping(mock_introspection_endpoint):
app = FastAPI()
auth_config = AuthenticationConfig(
provider_type=AuthProviderType.OAUTH2_TOKEN,
auth_config = OAuth2TokenAuthConfig(
type=AuthProviderType.OAUTH2_TOKEN,
config={
"jwks": None,
"introspection": {
@ -507,7 +512,8 @@ def introspection_client_with_custom_mapping(introspection_app_with_custom_mappi
def test_missing_auth_header_introspection(introspection_client):
response = introspection_client.get("/test")
assert response.status_code == 401
assert "Missing or invalid Authorization header" in response.json()["error"]["message"]
assert "Authentication required" in response.json()["error"]["message"]
assert "OAuth2 Bearer token" in response.json()["error"]["message"]
def test_invalid_auth_header_format_introspection(introspection_client):

View file

@ -0,0 +1,468 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from datetime import UTC, datetime, timedelta
from unittest.mock import patch
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from jose import jwt
from llama_stack.distribution.datatypes import GitHubAuthConfig
from llama_stack.distribution.server.auth import AuthenticationMiddleware
from llama_stack.distribution.server.auth_providers import get_attributes_from_claims
from llama_stack.distribution.server.auth_routes import create_github_auth_router
from llama_stack.distribution.server.github_oauth_auth_provider import GitHubAuthProvider
class MockResponse:
def __init__(self, status_code, json_data):
self.status_code = status_code
self._json_data = json_data
def json(self):
return self._json_data
def raise_for_status(self):
if self.status_code != 200:
raise Exception(f"HTTP error: {self.status_code}")
@pytest.fixture
def github_oauth_app():
app = FastAPI()
auth_config = {
"type": "github_oauth",
"github_client_id": "test_client_id",
"github_client_secret": "test_client_secret",
"jwt_secret": "test_jwt_secret",
"jwt_algorithm": "HS256",
"jwt_audience": "llama-stack",
"jwt_issuer": "llama-stack-github",
"token_expiry": 86400,
}
github_config = GitHubAuthConfig(**auth_config)
# Add auth routes BEFORE middleware so they're not protected
auth_router = create_github_auth_router(github_config)
app.include_router(auth_router)
# Then add auth middleware for other routes
app.add_middleware(AuthenticationMiddleware, auth_config=github_config)
@app.get("/test")
def test_endpoint():
return {"message": "Authentication successful"}
return app
@pytest.fixture
def github_oauth_client(github_oauth_app):
return TestClient(github_oauth_app)
def test_github_login_redirect(github_oauth_client):
"""Test that GitHub login endpoint returns redirect to GitHub"""
response = github_oauth_client.get("/auth/github/login", follow_redirects=False)
assert response.status_code == 307 # Temporary redirect
assert "github.com/login/oauth/authorize" in response.headers["location"]
assert "client_id=test_client_id" in response.headers["location"]
assert "state=" in response.headers["location"]
async def mock_github_token_exchange_success(*args, **kwargs):
"""Mock successful GitHub token exchange"""
return MockResponse(
200,
{
"access_token": "github_access_token_123",
"token_type": "bearer",
"scope": "read:user,read:org",
},
)
class MockAsyncClient:
"""Mock httpx.AsyncClient for testing"""
async def __aenter__(self):
return self
async def __aexit__(self, *args):
pass
async def post(self, url, **kwargs):
if "login/oauth/access_token" in url:
return await mock_github_token_exchange_success()
return MockResponse(404, {})
async def get(self, url, **kwargs):
if url.endswith("/user"):
return MockResponse(
200,
{
"login": "test-user",
"id": 12345,
"email": "test@example.com",
"name": "Test User",
"avatar_url": "https://avatars.githubusercontent.com/u/12345",
},
)
elif url.endswith("/user/orgs"):
return MockResponse(
200,
[
{"login": "test-org-1"},
{"login": "test-org-2"},
],
)
return MockResponse(404, {})
@patch("httpx.AsyncClient", MockAsyncClient)
def test_github_callback_success(github_oauth_app):
"""Test successful GitHub OAuth callback"""
# Get a fresh client for this test
client = TestClient(github_oauth_app)
# First, make a login request to generate a state
login_response = client.get("/auth/github/login", follow_redirects=False)
assert login_response.status_code == 307
# Extract the state from the redirect URL
location = login_response.headers["location"]
import re
state_match = re.search(r"state=([^&]+)", location)
assert state_match
state = state_match.group(1)
# Now use that state in the callback
response = client.get(f"/auth/github/callback?code=test_code&state={state}")
assert response.status_code == 200
data = response.json()
assert "access_token" in data
# Verify the JWT token contains the expected claims
token = data["access_token"]
# Decode without verification since this is a test
claims = jwt.decode(token, "test_jwt_secret", algorithms=["HS256"], audience="llama-stack")
assert claims["github_username"] == "test-user"
assert claims["email"] == "test@example.com"
def test_github_callback_invalid_state(github_oauth_client):
"""Test GitHub callback with invalid state"""
response = github_oauth_client.get("/auth/github/callback?code=test_code&state=invalid_state")
assert response.status_code == 400
assert "Invalid state parameter" in response.json()["detail"]
async def mock_github_token_exchange_error(*args, **kwargs):
"""Mock GitHub token exchange error"""
return MockResponse(
200, {"error": "bad_verification_code", "error_description": "The code passed is incorrect or expired."}
)
class MockAsyncClientError:
"""Mock httpx.AsyncClient that returns error for token exchange"""
async def __aenter__(self):
return self
async def __aexit__(self, *args):
pass
async def post(self, url, **kwargs):
if "login/oauth/access_token" in url:
return await mock_github_token_exchange_error()
return MockResponse(404, {})
async def get(self, url, **kwargs):
return MockResponse(404, {})
@patch("httpx.AsyncClient", MockAsyncClientError)
def test_github_callback_token_exchange_error(github_oauth_app):
"""Test GitHub callback with token exchange error"""
client = TestClient(github_oauth_app)
# First, make a login request to generate a state
login_response = client.get("/auth/github/login", follow_redirects=False)
assert login_response.status_code == 307
# Extract the state from the redirect URL
location = login_response.headers["location"]
import re
state_match = re.search(r"state=([^&]+)", location)
assert state_match
state = state_match.group(1)
# Now use that state in the callback with bad code
response = client.get(f"/auth/github/callback?code=bad_code&state={state}")
assert response.status_code == 400
assert "The code passed is incorrect or expired" in response.json()["detail"]
@pytest.fixture
def github_jwt_token():
"""Create a valid GitHub JWT token for testing"""
claims = {
"sub": "test-user",
"aud": "llama-stack",
"iss": "llama-stack-github",
"exp": datetime.now(UTC) + timedelta(hours=1),
"iat": datetime.now(UTC),
"nbf": datetime.now(UTC),
"github_username": "test-user",
"github_user_id": "12345",
"github_orgs": ["test-org-1", "test-org-2"],
"email": "test@example.com",
"name": "Test User",
}
return jwt.encode(claims, "test_jwt_secret", algorithm="HS256")
def test_github_jwt_authentication_success(github_oauth_client, github_jwt_token):
"""Test API authentication with GitHub-issued JWT"""
response = github_oauth_client.get("/test", headers={"Authorization": f"Bearer {github_jwt_token}"})
assert response.status_code == 200
assert response.json() == {"message": "Authentication successful"}
def test_github_jwt_authentication_invalid_token(github_oauth_client):
"""Test API authentication with invalid JWT"""
response = github_oauth_client.get("/test", headers={"Authorization": "Bearer invalid.jwt.token"})
assert response.status_code == 401
assert "Invalid GitHub JWT token" in response.json()["error"]["message"]
def test_github_jwt_authentication_expired_token(github_oauth_client):
"""Test API authentication with expired JWT"""
# Create an expired token
claims = {
"sub": "test-user",
"aud": "llama-stack",
"iss": "llama-stack-github",
"exp": datetime.now(UTC) - timedelta(hours=1), # Expired
"iat": datetime.now(UTC) - timedelta(hours=2),
"nbf": datetime.now(UTC) - timedelta(hours=2),
"github_username": "test-user",
}
expired_token = jwt.encode(claims, "test_jwt_secret", algorithm="HS256")
response = github_oauth_client.get("/test", headers={"Authorization": f"Bearer {expired_token}"})
assert response.status_code == 401
assert "Invalid GitHub JWT token" in response.json()["error"]["message"]
def test_github_jwt_authentication_wrong_audience(github_oauth_client):
"""Test API authentication with JWT having wrong audience"""
claims = {
"sub": "test-user",
"aud": "wrong-audience", # Wrong audience
"iss": "llama-stack-github",
"exp": datetime.now(UTC) + timedelta(hours=1),
"iat": datetime.now(UTC),
"nbf": datetime.now(UTC),
"github_username": "test-user",
}
wrong_audience_token = jwt.encode(claims, "test_jwt_secret", algorithm="HS256")
response = github_oauth_client.get("/test", headers={"Authorization": f"Bearer {wrong_audience_token}"})
assert response.status_code == 401
assert "Invalid GitHub JWT token" in response.json()["error"]["message"]
def test_github_claims_mapping():
"""Test GitHub claims are properly mapped to attributes"""
config = GitHubAuthConfig(
type="github_oauth",
github_client_id="test",
github_client_secret="test",
jwt_secret="test",
)
claims = {
"sub": "test-user",
"github_username": "test-user",
"github_orgs": ["org1", "org2"],
"github_teams": ["team1", "team2"],
"github_user_id": "12345",
}
# Default mapping only maps "sub" to "roles"
attributes = get_attributes_from_claims(claims, config.claims_mapping)
assert "test-user" in attributes["roles"]
# No other mappings by default
assert len(attributes) == 1
@pytest.mark.asyncio
async def test_github_auth_provider_validate_token():
"""Test GitHubAuthProvider token validation"""
config = GitHubAuthConfig(
type="github_oauth",
github_client_id="test",
github_client_secret="test",
jwt_secret="test_secret",
jwt_algorithm="HS256",
jwt_audience="test-audience",
jwt_issuer="test-issuer",
)
provider = GitHubAuthProvider(config)
# Create a valid token
claims = {
"sub": "test-user",
"aud": "test-audience",
"iss": "test-issuer",
"exp": datetime.now(UTC) + timedelta(hours=1),
"iat": datetime.now(UTC),
"nbf": datetime.now(UTC),
"github_username": "test-user",
"github_orgs": ["org1"],
"github_user_id": "12345",
}
token = jwt.encode(claims, "test_secret", algorithm="HS256")
user = await provider.validate_token(token)
assert user.principal == "test-user"
# Default mapping only maps "sub" to "roles"
assert "test-user" in user.attributes["roles"]
# No other mappings by default
assert len(user.attributes) == 1
@pytest.mark.asyncio
async def test_github_auth_provider_custom_claims_mapping():
"""Test GitHubAuthProvider with custom claims mapping"""
config = GitHubAuthConfig(
type="github_oauth",
github_client_id="test",
github_client_secret="test",
jwt_secret="test_secret",
jwt_algorithm="HS256",
jwt_audience="test-audience",
jwt_issuer="test-issuer",
claims_mapping={
"sub": "roles",
"github_orgs": "teams",
"github_teams": "teams",
"github_user_id": "namespaces",
},
)
provider = GitHubAuthProvider(config)
# Create a valid token
claims = {
"sub": "test-user",
"aud": "test-audience",
"iss": "test-issuer",
"exp": datetime.now(UTC) + timedelta(hours=1),
"iat": datetime.now(UTC),
"nbf": datetime.now(UTC),
"github_username": "test-user",
"github_orgs": ["org1", "org2"],
"github_teams": ["team1", "team2"],
"github_user_id": "12345",
}
token = jwt.encode(claims, "test_secret", algorithm="HS256")
user = await provider.validate_token(token)
assert user.principal == "test-user"
assert "test-user" in user.attributes["roles"]
assert set(user.attributes["teams"]) == {"org1", "org2", "team1", "team2"}
assert "12345" in user.attributes["namespaces"]
@pytest.mark.asyncio
async def test_github_auth_provider_invalid_token():
"""Test GitHubAuthProvider with invalid token"""
config = GitHubAuthConfig(
type="github_oauth",
github_client_id="test",
github_client_secret="test",
jwt_secret="test_secret",
)
provider = GitHubAuthProvider(config)
with pytest.raises(ValueError, match="Invalid GitHub JWT token"):
await provider.validate_token("invalid.token.here")
def test_github_auth_provider_authorization_url():
"""Test GitHubAuthProvider generates correct authorization URL"""
from unittest.mock import Mock
config = GitHubAuthConfig(
type="github_oauth",
github_client_id="test_client_id",
github_client_secret="test_secret",
jwt_secret="test",
)
provider = GitHubAuthProvider(config)
# Mock request object
mock_request = Mock()
mock_request.url_for.return_value = "http://localhost:8321/auth/github/callback"
url = provider.get_authorization_url("test_state_123", mock_request)
assert "https://github.com/login/oauth/authorize" in url
assert "client_id=test_client_id" in url
assert "redirect_uri=http%3A%2F%2Flocalhost%3A8321%2Fauth%2Fgithub%2Fcallback" in url
assert "state=test_state_123" in url
assert "scope=read%3Auser+read%3Aorg" in url
@pytest.mark.asyncio
async def test_github_auth_provider_complete_flow():
"""Test complete OAuth flow in GitHubAuthProvider"""
from unittest.mock import Mock
config = GitHubAuthConfig(
type="github_oauth",
github_client_id="test_client_id",
github_client_secret="test_secret",
jwt_secret="test_jwt_secret",
jwt_algorithm="HS256",
jwt_audience="llama-stack",
jwt_issuer="llama-stack-github",
)
provider = GitHubAuthProvider(config)
# Mock request object
mock_request = Mock()
mock_request.url_for.return_value = "/auth/github/callback"
mock_request.url.scheme = "http"
mock_request.url.netloc = "localhost:8321"
with patch("httpx.AsyncClient", MockAsyncClient):
# Now returns just the JWT token
token = await provider.complete_oauth_flow("test_code", mock_request)
# Verify it's a JWT token by decoding it
claims = jwt.decode(token, "test_jwt_secret", algorithms=["HS256"], audience="llama-stack")
assert claims["github_username"] == "test-user"
assert claims["github_orgs"] == ["test-org-1", "test-org-2"]
assert claims["sub"] == "test-user"