From 94282d3482961a99a5b5c1c13b489e6e535d5675 Mon Sep 17 00:00:00 2001 From: Eric Huang Date: Tue, 24 Jun 2025 16:28:17 -0700 Subject: [PATCH] gh auth, commit # What does this PR do? ## Test Plan # What does this PR do? ## Test Plan --- .github/workflows/integration-auth-tests.yml | 2 +- llama_stack/distribution/datatypes.py | 84 +++- llama_stack/distribution/server/auth.py | 12 +- .../distribution/server/auth_providers.py | 61 ++- .../distribution/server/auth_routes.py | 109 ++++ .../server/github_oauth_auth_provider.py | 203 ++++++++ llama_stack/distribution/server/server.py | 14 +- tests/unit/server/test_auth.py | 52 +- tests/unit/server/test_auth_github.py | 468 ++++++++++++++++++ 9 files changed, 962 insertions(+), 43 deletions(-) create mode 100644 llama_stack/distribution/server/auth_routes.py create mode 100644 llama_stack/distribution/server/github_oauth_auth_provider.py create mode 100644 tests/unit/server/test_auth_github.py diff --git a/.github/workflows/integration-auth-tests.yml b/.github/workflows/integration-auth-tests.yml index 4139d09ca..adc9fe847 100644 --- a/.github/workflows/integration-auth-tests.yml +++ b/.github/workflows/integration-auth-tests.yml @@ -73,7 +73,7 @@ jobs: server: port: 8321 EOF - yq eval '.server.auth = {"provider_type": "${{ matrix.auth-provider }}"}' -i $run_dir/run.yaml + yq eval '.server.auth = {"type": "${{ matrix.auth-provider }}"}' -i $run_dir/run.yaml yq eval '.server.auth.config = {"tls_cafile": "${{ env.KUBERNETES_CA_CERT_PATH }}", "issuer": "${{ env.KUBERNETES_ISSUER }}", "audience": "${{ env.KUBERNETES_AUDIENCE }}"}' -i $run_dir/run.yaml yq eval '.server.auth.config.jwks = {"uri": "${{ env.KUBERNETES_API_SERVER_URL }}", "token": "${{ env.TOKEN }}"}' -i $run_dir/run.yaml cat $run_dir/run.yaml diff --git a/llama_stack/distribution/datatypes.py b/llama_stack/distribution/datatypes.py index abc3f0065..8724b3360 100644 --- a/llama_stack/distribution/datatypes.py +++ b/llama_stack/distribution/datatypes.py @@ -6,7 +6,7 @@ from enum import Enum from pathlib import Path -from typing import Annotated, Any +from typing import Annotated, Any, Literal from pydantic import BaseModel, Field, field_validator @@ -164,20 +164,90 @@ class AuthProviderType(str, Enum): OAUTH2_TOKEN = "oauth2_token" CUSTOM = "custom" + GITHUB_OAUTH = "github_oauth" -class AuthenticationConfig(BaseModel): - provider_type: AuthProviderType = Field( - ..., - description="Type of authentication provider", - ) +class OAuth2TokenAuthConfig(BaseModel): + """Configuration for OAuth2 token authentication.""" + + type: Literal[AuthProviderType.OAUTH2_TOKEN] = AuthProviderType.OAUTH2_TOKEN config: dict[str, Any] = Field( ..., - description="Provider-specific configuration", + description="OAuth2 token validation configuration", ) access_policy: list[AccessRule] = Field(default=[], description="Rules for determining access to resources") +class CustomAuthConfig(BaseModel): + """Configuration for custom authentication.""" + + type: Literal[AuthProviderType.CUSTOM] = AuthProviderType.CUSTOM + config: dict[str, Any] = Field( + ..., + description="Custom authentication endpoint configuration", + ) + access_policy: list[AccessRule] = Field(default=[], description="Rules for determining access to resources") + + +class GitHubAuthConfig(BaseModel): + """Configuration for GitHub OAuth authentication. + + This provider implements GitHub OAuth authentication with the following endpoints: + + - /auth/github/login: Initiates OAuth flow + - Accepts optional redirect_url parameter (must be in allowed_redirect_urls) + - Redirects to GitHub for authorization + + - /auth/github/callback: Handles OAuth callback (should be used as the callback URL for your GitHub OAuth App) + - Exchanges authorization code for access token + - Creates JWT with user info + - With redirect_url: Redirects to specified URL with JWT in fragment (#token=...) + - Without redirect_url: Returns JSON response with JWT + + Ensure your GitHub OAuth App's callback URL matches your server's URL + /auth/github/callback. + """ + + type: Literal[AuthProviderType.GITHUB_OAUTH] = AuthProviderType.GITHUB_OAUTH + + # GitHub OAuth settings + github_client_id: str = Field(description="GitHub OAuth App Client ID") + github_client_secret: str = Field(description="GitHub OAuth App Client Secret") + allowed_redirect_urls: list[str] = Field( + default=[], + description="Allowed redirect URLs for OAuth callback, e.g. frontend URL", + ) + + # JWT configuration for token generation + jwt_secret: str = Field(description="Secret for signing JWT tokens") + jwt_algorithm: str = Field(default="HS256", description="JWT signing algorithm") + jwt_audience: str = Field(default="llama-stack", description="JWT audience") + jwt_issuer: str = Field(default="llama-stack-github", description="JWT issuer") + token_expiry: int = Field(default=86400, description="JWT token expiry in seconds") + + # OAuth2 token validation config (for validating the JWTs we issue) + oauth2_config: dict[str, Any] = Field( + default_factory=dict, + description="Additional OAuth2 token validation configuration", + ) + + # Access control + access_policy: list[AccessRule] = Field(default=[], description="Rules for determining access to resources") + + # Claims mapping for GitHub attributes + claims_mapping: dict[str, str] = Field( + default_factory=lambda: { + "sub": "roles", # GitHub username as role + }, + description="Mapping from JWT claims to Llama Stack attributes", + ) + + +# Discriminated union for authentication configurations +AuthenticationConfig = Annotated[ + OAuth2TokenAuthConfig | CustomAuthConfig | GitHubAuthConfig, Field(discriminator="type") +] + + class AuthenticationRequiredError(Exception): pass diff --git a/llama_stack/distribution/server/auth.py b/llama_stack/distribution/server/auth.py index 81b1ffd37..3b322cb4c 100644 --- a/llama_stack/distribution/server/auth.py +++ b/llama_stack/distribution/server/auth.py @@ -81,13 +81,23 @@ class AuthenticationMiddleware: def __init__(self, app, auth_config: AuthenticationConfig): self.app = app self.auth_provider = create_auth_provider(auth_config) + self.public_paths = self.auth_provider.get_public_paths() async def __call__(self, scope, receive, send): if scope["type"] == "http": + # Skip authentication for public paths defined by the provider + path = scope.get("path", "") + if any(path.startswith(public_path) for public_path in self.public_paths): + return await self.app(scope, receive, send) + headers = dict(scope.get("headers", [])) auth_header = headers.get(b"authorization", b"").decode() - if not auth_header or not auth_header.startswith("Bearer "): + if not auth_header: + error_msg = self.auth_provider.get_auth_error_message(scope) + return await self._send_auth_error(send, error_msg) + + if not auth_header.startswith("Bearer "): return await self._send_auth_error(send, "Missing or invalid Authorization header") token = auth_header.split("Bearer ", 1)[1] diff --git a/llama_stack/distribution/server/auth_providers.py b/llama_stack/distribution/server/auth_providers.py index 173434652..557e64547 100644 --- a/llama_stack/distribution/server/auth_providers.py +++ b/llama_stack/distribution/server/auth_providers.py @@ -10,13 +10,19 @@ from abc import ABC, abstractmethod from asyncio import Lock from pathlib import Path from typing import Self -from urllib.parse import parse_qs +from urllib.parse import parse_qs, urlparse import httpx from jose import jwt from pydantic import BaseModel, Field, field_validator, model_validator -from llama_stack.distribution.datatypes import AuthenticationConfig, AuthProviderType, User +from llama_stack.distribution.datatypes import ( + AuthenticationConfig, + CustomAuthConfig, + GitHubAuthConfig, + OAuth2TokenAuthConfig, + User, +) from llama_stack.log import get_logger logger = get_logger(name=__name__, category="auth") @@ -62,6 +68,24 @@ class AuthProvider(ABC): """Clean up any resources.""" pass + def setup_routes(self, app): + """Setup any provider-specific routes (e.g., OAuth callbacks). + + This is optional - providers that don't need special routes can skip this. + """ + return + + def get_public_paths(self) -> list[str]: + """Return a list of path prefixes that should bypass authentication. + + This is optional - providers that don't have public paths return empty list. + """ + return [] + + def get_auth_error_message(self, scope: dict | None = None) -> str: + """Return provider-specific authentication error message.""" + return "Authentication required" + def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str]) -> dict[str, list[str]]: attributes: dict[str, list[str]] = {} @@ -232,6 +256,17 @@ class OAuth2TokenAuthProvider(AuthProvider): async def close(self): pass + def get_auth_error_message(self, scope: dict | None = None) -> str: + """Return OAuth2-specific authentication error message.""" + if self.config.issuer: + return f"Authentication required. Please provide a valid OAuth2 Bearer token from {self.config.issuer}" + elif self.config.introspection: + # Extract domain from introspection URL for a cleaner message + domain = urlparse(self.config.introspection.url).netloc + return f"Authentication required. Please provide a valid OAuth2 Bearer token validated by {domain}" + else: + return "Authentication required. Please provide a valid OAuth2 Bearer token in the Authorization header" + async def _refresh_jwks(self) -> None: """ Refresh the JWKS cache. @@ -338,15 +373,25 @@ class CustomAuthProvider(AuthProvider): await self._client.aclose() self._client = None + def get_auth_error_message(self, scope: dict | None = None) -> str: + """Return custom auth provider-specific authentication error message.""" + # Extract domain from endpoint URL for a cleaner message + domain = urlparse(self.config.endpoint).netloc + if domain: + return f"Authentication required. Please provide your API key as a Bearer token (validated by {domain})" + else: + return "Authentication required. Please provide your API key as a Bearer token in the Authorization header" + def create_auth_provider(config: AuthenticationConfig) -> AuthProvider: """Factory function to create the appropriate auth provider.""" - provider_type = config.provider_type.lower() - - if provider_type == "custom": + if isinstance(config, CustomAuthConfig): return CustomAuthProvider(CustomAuthProviderConfig.model_validate(config.config)) - elif provider_type == "oauth2_token": + elif isinstance(config, OAuth2TokenAuthConfig): return OAuth2TokenAuthProvider(OAuth2TokenAuthProviderConfig.model_validate(config.config)) + elif isinstance(config, GitHubAuthConfig): + from .github_oauth_auth_provider import GitHubAuthProvider + + return GitHubAuthProvider(config) else: - supported_providers = ", ".join([t.value for t in AuthProviderType]) - raise ValueError(f"Unsupported auth provider type: {provider_type}. Supported types are: {supported_providers}") + raise ValueError(f"Unknown authentication config type: {type(config)}") diff --git a/llama_stack/distribution/server/auth_routes.py b/llama_stack/distribution/server/auth_routes.py new file mode 100644 index 000000000..86f988bb7 --- /dev/null +++ b/llama_stack/distribution/server/auth_routes.py @@ -0,0 +1,109 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import secrets +from datetime import UTC, datetime + +from fastapi import APIRouter, HTTPException, Request +from fastapi.responses import JSONResponse, RedirectResponse + +from llama_stack.distribution.datatypes import GitHubAuthConfig +from llama_stack.log import get_logger + +from .github_oauth_auth_provider import ( + GITHUB_CALLBACK_PATH, + GITHUB_LOGIN_PATH, + GitHubAuthProvider, +) + +logger = get_logger(name=__name__, category="auth_routes") + + +def create_github_auth_router(config: GitHubAuthConfig) -> APIRouter: + """Create and configure GitHub authentication router.""" + auth_provider = GitHubAuthProvider(config) + router = APIRouter() + oauth_states: dict[str, dict] = {} + + def cleanup_expired_states() -> None: + """Remove expired OAuth states.""" + now = datetime.now(UTC) + expired_states = [ + state + for state, data in oauth_states.items() + if (now - data["created_at"]).seconds > 300 # 5 minutes + ] + for state in expired_states: + oauth_states.pop(state, None) + + @router.get(GITHUB_LOGIN_PATH) + async def github_login(request: Request, redirect_url: str | None = None): + """Initiate GitHub OAuth flow.""" + cleanup_expired_states() + + # Validate redirect URL if provided + if redirect_url: + if not auth_provider.config.allowed_redirect_urls: + raise HTTPException(status_code=400, detail="Redirect URLs not configured") + if redirect_url not in auth_provider.config.allowed_redirect_urls: + raise HTTPException(status_code=400, detail="Invalid redirect URL") + + state = secrets.token_urlsafe(32) + + oauth_states[state] = { + "created_at": datetime.now(UTC), + "redirect_url": redirect_url, + } + + # Get authorization URL + auth_url = auth_provider.get_authorization_url(state, request) + + logger.debug(f"Redirecting to GitHub OAuth: {auth_url}") + return RedirectResponse(url=auth_url) + + @router.get(GITHUB_CALLBACK_PATH) + async def github_callback(code: str, state: str, request: Request): + """Handle GitHub OAuth callback.""" + # Validate state parameter + if state not in oauth_states: + logger.warning(f"Invalid OAuth state received: {state}") + raise HTTPException(status_code=400, detail="Invalid state parameter") + + state_data = oauth_states.pop(state) + + # Check state expiry + if (datetime.now(UTC) - state_data["created_at"]).seconds > 300: + logger.warning("OAuth state expired") + raise HTTPException(status_code=400, detail="State expired") + + try: + # Complete OAuth flow + access_token = await auth_provider.complete_oauth_flow(code, request) + + logger.info("GitHub OAuth successful") + + redirect_url = state_data.get("redirect_url") + if redirect_url: + # Validate redirect URL from state + if redirect_url not in auth_provider.config.allowed_redirect_urls: + raise HTTPException(status_code=400, detail="Invalid redirect URL in state") + + import urllib.parse + + params = urllib.parse.urlencode({"token": access_token}) + return RedirectResponse(url=f"{redirect_url}#{params}") + else: + # For API calls, only return the access token + return JSONResponse(content={"access_token": access_token}) + + except ValueError as e: + logger.error(f"GitHub OAuth error: {e}") + raise HTTPException(status_code=400, detail=str(e)) from e + except Exception as e: + logger.exception("Unexpected error during GitHub OAuth") + raise HTTPException(status_code=500, detail="Authentication failed") from e + + return router diff --git a/llama_stack/distribution/server/github_oauth_auth_provider.py b/llama_stack/distribution/server/github_oauth_auth_provider.py new file mode 100644 index 000000000..580bef261 --- /dev/null +++ b/llama_stack/distribution/server/github_oauth_auth_provider.py @@ -0,0 +1,203 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import asyncio +from datetime import UTC, datetime, timedelta +from typing import Any +from urllib.parse import urlencode + +import httpx +from fastapi import Request +from jose import jwt + +from llama_stack.distribution.datatypes import GitHubAuthConfig, User +from llama_stack.log import get_logger + +from .auth_providers import AuthProvider, get_attributes_from_claims + +logger = get_logger(name=__name__, category="github_auth") + +# GitHub API constants +GITHUB_API_BASE_URL = "https://api.github.com" +GITHUB_OAUTH_BASE_URL = "https://github.com" + +# GitHub OAuth route paths +GITHUB_LOGIN_PATH = "/auth/github/login" +GITHUB_CALLBACK_PATH = "/auth/github/callback" + + +class GitHubAuthProvider(AuthProvider): + """Authentication provider for GitHub OAuth flow and JWT validation.""" + + def __init__(self, config: GitHubAuthConfig): + self.config = config + + async def validate_token(self, token: str, scope: dict | None = None) -> User: + """Validate a GitHub-issued JWT token.""" + try: + claims = self.verify_jwt(token) + + principal = claims["sub"] + attributes = get_attributes_from_claims(claims, self.config.claims_mapping) + + return User(principal=principal, attributes=attributes) + + except Exception as e: + logger.exception("Error validating GitHub JWT") + raise ValueError(f"Invalid GitHub JWT token: {str(e)}") from e + + async def close(self): + """Clean up any resources.""" + pass + + def setup_routes(self, app): + """Setup GitHub OAuth routes.""" + from .auth_routes import create_github_auth_router + + github_router = create_github_auth_router(self.config) + app.include_router(github_router) + + def get_public_paths(self) -> list[str]: + """GitHub OAuth paths that don't require authentication.""" + return ["/auth/github/"] + + def get_auth_error_message(self, scope: dict | None = None) -> str: + """Return GitHub-specific authentication error message.""" + if scope: + headers = dict(scope.get("headers", [])) + host = headers.get(b"host", b"").decode() + scheme = scope.get("scheme", "http") + + if host: + auth_url = f"{scheme}://{host}{GITHUB_LOGIN_PATH}" + return f"Authentication required. Please authenticate via GitHub at {auth_url}" + + return f"Authentication required. Please authenticate by visiting {GITHUB_LOGIN_PATH} to start the authentication flow." + + # OAuth flow methods + def get_authorization_url(self, state: str, request: Request) -> str: + """Generate GitHub OAuth authorization URL.""" + params = { + "client_id": self.config.github_client_id, + "redirect_uri": _build_callback_url(request), + "scope": "read:user read:org", + "state": state, + } + return f"{GITHUB_OAUTH_BASE_URL}/login/oauth/authorize?" + urlencode(params) + + async def complete_oauth_flow(self, code: str, request: Request) -> Any: + """Complete the GitHub OAuth flow and return JWT access token.""" + # Exchange code for token + logger.debug("Exchanging code for GitHub access token") + token_data = await self._exchange_code_for_token(code, request) + + if "error" in token_data: + raise ValueError(f"GitHub OAuth error: {token_data.get('error_description', token_data['error'])}") + + access_token = token_data["access_token"] + + # Get user info + logger.debug("Fetching GitHub user info") + github_info = await self._get_user_info(access_token) + + # Create JWT + logger.debug(f"Creating JWT for user: {github_info['user']['login']}") + jwt_token = self._create_jwt_token(github_info) + + return jwt_token + + def verify_jwt(self, token: str) -> Any: + """Verify and decode a GitHub-issued JWT token.""" + try: + payload = jwt.decode( + token, + self.config.jwt_secret, + algorithms=[self.config.jwt_algorithm], + audience=self.config.jwt_audience, + issuer=self.config.jwt_issuer, + ) + return payload + except jwt.JWTError as e: + raise ValueError(f"Invalid JWT token: {e}") from e + + # Private helper methods + async def _exchange_code_for_token(self, code: str, request: Request) -> Any: + """Exchange authorization code for GitHub access token.""" + async with httpx.AsyncClient() as client: + response = await client.post( + f"{GITHUB_OAUTH_BASE_URL}/login/oauth/access_token", + json={ + "client_id": self.config.github_client_id, + "client_secret": self.config.github_client_secret, + "code": code, + "redirect_uri": _build_callback_url(request), + }, + headers={"Accept": "application/json"}, + timeout=10.0, + ) + response.raise_for_status() + return response.json() + + async def _get_user_info(self, access_token: str) -> dict: + """Fetch user info and organizations from GitHub.""" + headers = { + "Authorization": f"Bearer {access_token}", + "Accept": "application/vnd.github.v3+json", + } + + async with httpx.AsyncClient() as client: + # Fetch user and orgs in parallel + user_task = client.get(f"{GITHUB_API_BASE_URL}/user", headers=headers, timeout=10.0) + orgs_task = client.get(f"{GITHUB_API_BASE_URL}/user/orgs", headers=headers, timeout=10.0) + + user_response, orgs_response = await asyncio.gather(user_task, orgs_task) + + user_response.raise_for_status() + orgs_response.raise_for_status() + + user_data = user_response.json() + orgs_data = orgs_response.json() + + # Extract organization names + organizations = [org["login"] for org in orgs_data] + + return { + "user": user_data, + "organizations": organizations, + } + + def _create_jwt_token(self, github_info: dict) -> Any: + """Create JWT token with GitHub user information.""" + user = github_info["user"] + orgs = github_info["organizations"] + teams = github_info.get("teams", []) + + # Create JWT claims that map to Llama Stack attributes + now = datetime.now(UTC) + claims = { + "sub": user["login"], + "aud": self.config.jwt_audience, + "iss": self.config.jwt_issuer, + "exp": now + timedelta(seconds=self.config.token_expiry), + "iat": now, + "nbf": now, + # Custom claims that will be mapped by claims_mapping + "github_username": user["login"], + "github_user_id": str(user["id"]), + "github_orgs": orgs, + "github_teams": teams, + "email": user.get("email"), + "name": user.get("name"), + "avatar_url": user.get("avatar_url"), + } + + return jwt.encode(claims, self.config.jwt_secret, algorithm=self.config.jwt_algorithm) + + +def _build_callback_url(request: Request) -> str: + """Build the GitHub OAuth callback URL from the current request.""" + callback_path = str(request.url_for("github_callback")) + return callback_path diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index 83407a25f..381d6a8c1 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -31,7 +31,11 @@ from openai import BadRequestError from pydantic import BaseModel, ValidationError from llama_stack.apis.common.responses import PaginatedResponse -from llama_stack.distribution.datatypes import AuthenticationRequiredError, LoggingConfig, StackRunConfig +from llama_stack.distribution.datatypes import ( + AuthenticationRequiredError, + LoggingConfig, + StackRunConfig, +) from llama_stack.distribution.distribution import builtin_automatically_routed_apis from llama_stack.distribution.request_headers import PROVIDER_DATA_VAR, User, request_provider_data_context from llama_stack.distribution.resolver import InvalidProviderError @@ -61,6 +65,7 @@ from llama_stack.providers.utils.telemetry.tracing import ( ) from .auth import AuthenticationMiddleware +from .auth_providers import create_auth_provider from .quota import QuotaMiddleware REPO_ROOT = Path(__file__).parent.parent.parent.parent @@ -448,9 +453,12 @@ def main(args: argparse.Namespace | None = None): if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"): app.add_middleware(ClientVersionMiddleware) - # Add authentication middleware if configured if config.server.auth: - logger.info(f"Enabling authentication with provider: {config.server.auth.provider_type.value}") + logger.info(f"Enabling authentication with provider: {config.server.auth.type.value}") + + auth_provider = create_auth_provider(config.server.auth) + auth_provider.setup_routes(app) + app.add_middleware(AuthenticationMiddleware, auth_config=config.server.auth) else: if config.server.quota: diff --git a/tests/unit/server/test_auth.py b/tests/unit/server/test_auth.py index 4410048c5..6f08e7382 100644 --- a/tests/unit/server/test_auth.py +++ b/tests/unit/server/test_auth.py @@ -11,10 +11,13 @@ import pytest from fastapi import FastAPI from fastapi.testclient import TestClient -from llama_stack.distribution.datatypes import AuthenticationConfig +from llama_stack.distribution.datatypes import ( + AuthProviderType, + CustomAuthConfig, + OAuth2TokenAuthConfig, +) from llama_stack.distribution.server.auth import AuthenticationMiddleware from llama_stack.distribution.server.auth_providers import ( - AuthProviderType, get_attributes_from_claims, ) @@ -60,8 +63,8 @@ def invalid_token(): @pytest.fixture def http_app(mock_auth_endpoint): app = FastAPI() - auth_config = AuthenticationConfig( - provider_type=AuthProviderType.CUSTOM, + auth_config = CustomAuthConfig( + type=AuthProviderType.CUSTOM, config={"endpoint": mock_auth_endpoint}, ) app.add_middleware(AuthenticationMiddleware, auth_config=auth_config) @@ -76,9 +79,9 @@ def http_app(mock_auth_endpoint): @pytest.fixture def k8s_app(): app = FastAPI() - auth_config = AuthenticationConfig( - provider_type=AuthProviderType.KUBERNETES, - config={"api_server_url": "https://kubernetes.default.svc"}, + auth_config = CustomAuthConfig( + type=AuthProviderType.CUSTOM, + config={"provider_type": "kubernetes", "api_server_url": "https://kubernetes.default.svc"}, ) app.add_middleware(AuthenticationMiddleware, auth_config=auth_config) @@ -116,8 +119,8 @@ def mock_scope(): @pytest.fixture def mock_http_middleware(mock_auth_endpoint): mock_app = AsyncMock() - auth_config = AuthenticationConfig( - provider_type=AuthProviderType.CUSTOM, + auth_config = CustomAuthConfig( + type=AuthProviderType.CUSTOM, config={"endpoint": mock_auth_endpoint}, ) return AuthenticationMiddleware(mock_app, auth_config), mock_app @@ -126,9 +129,9 @@ def mock_http_middleware(mock_auth_endpoint): @pytest.fixture def mock_k8s_middleware(): mock_app = AsyncMock() - auth_config = AuthenticationConfig( - provider_type=AuthProviderType.KUBERNETES, - config={"api_server_url": "https://kubernetes.default.svc"}, + auth_config = CustomAuthConfig( + type=AuthProviderType.CUSTOM, + config={"provider_type": "kubernetes", "api_server_url": "https://kubernetes.default.svc"}, ) return AuthenticationMiddleware(mock_app, auth_config), mock_app @@ -161,7 +164,8 @@ async def mock_post_exception(*args, **kwargs): def test_missing_auth_header(http_client): response = http_client.get("/test") assert response.status_code == 401 - assert "Missing or invalid Authorization header" in response.json()["error"]["message"] + assert "Authentication required" in response.json()["error"]["message"] + assert "validated by mock-auth-service" in response.json()["error"]["message"] def test_invalid_auth_header_format(http_client): @@ -261,8 +265,8 @@ async def test_http_middleware_with_access_attributes(mock_http_middleware, mock @pytest.fixture def oauth2_app(): app = FastAPI() - auth_config = AuthenticationConfig( - provider_type=AuthProviderType.OAUTH2_TOKEN, + auth_config = OAuth2TokenAuthConfig( + type=AuthProviderType.OAUTH2_TOKEN, config={ "jwks": { "uri": "http://mock-authz-service/token/introspect", @@ -288,7 +292,8 @@ def oauth2_client(oauth2_app): def test_missing_auth_header_oauth2(oauth2_client): response = oauth2_client.get("/test") assert response.status_code == 401 - assert "Missing or invalid Authorization header" in response.json()["error"]["message"] + assert "Authentication required" in response.json()["error"]["message"] + assert "OAuth2 Bearer token" in response.json()["error"]["message"] def test_invalid_auth_header_format_oauth2(oauth2_client): @@ -357,8 +362,8 @@ async def mock_auth_jwks_response(*args, **kwargs): @pytest.fixture def oauth2_app_with_jwks_token(): app = FastAPI() - auth_config = AuthenticationConfig( - provider_type=AuthProviderType.OAUTH2_TOKEN, + auth_config = OAuth2TokenAuthConfig( + type=AuthProviderType.OAUTH2_TOKEN, config={ "jwks": { "uri": "http://mock-authz-service/token/introspect", @@ -448,8 +453,8 @@ def mock_introspection_endpoint(): @pytest.fixture def introspection_app(mock_introspection_endpoint): app = FastAPI() - auth_config = AuthenticationConfig( - provider_type=AuthProviderType.OAUTH2_TOKEN, + auth_config = OAuth2TokenAuthConfig( + type=AuthProviderType.OAUTH2_TOKEN, config={ "jwks": None, "introspection": {"url": mock_introspection_endpoint, "client_id": "myclient", "client_secret": "abcdefg"}, @@ -467,8 +472,8 @@ def introspection_app(mock_introspection_endpoint): @pytest.fixture def introspection_app_with_custom_mapping(mock_introspection_endpoint): app = FastAPI() - auth_config = AuthenticationConfig( - provider_type=AuthProviderType.OAUTH2_TOKEN, + auth_config = OAuth2TokenAuthConfig( + type=AuthProviderType.OAUTH2_TOKEN, config={ "jwks": None, "introspection": { @@ -507,7 +512,8 @@ def introspection_client_with_custom_mapping(introspection_app_with_custom_mappi def test_missing_auth_header_introspection(introspection_client): response = introspection_client.get("/test") assert response.status_code == 401 - assert "Missing or invalid Authorization header" in response.json()["error"]["message"] + assert "Authentication required" in response.json()["error"]["message"] + assert "OAuth2 Bearer token" in response.json()["error"]["message"] def test_invalid_auth_header_format_introspection(introspection_client): diff --git a/tests/unit/server/test_auth_github.py b/tests/unit/server/test_auth_github.py new file mode 100644 index 000000000..1361fd392 --- /dev/null +++ b/tests/unit/server/test_auth_github.py @@ -0,0 +1,468 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from datetime import UTC, datetime, timedelta +from unittest.mock import patch + +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient +from jose import jwt + +from llama_stack.distribution.datatypes import GitHubAuthConfig +from llama_stack.distribution.server.auth import AuthenticationMiddleware +from llama_stack.distribution.server.auth_providers import get_attributes_from_claims +from llama_stack.distribution.server.auth_routes import create_github_auth_router +from llama_stack.distribution.server.github_oauth_auth_provider import GitHubAuthProvider + + +class MockResponse: + def __init__(self, status_code, json_data): + self.status_code = status_code + self._json_data = json_data + + def json(self): + return self._json_data + + def raise_for_status(self): + if self.status_code != 200: + raise Exception(f"HTTP error: {self.status_code}") + + +@pytest.fixture +def github_oauth_app(): + app = FastAPI() + auth_config = { + "type": "github_oauth", + "github_client_id": "test_client_id", + "github_client_secret": "test_client_secret", + "jwt_secret": "test_jwt_secret", + "jwt_algorithm": "HS256", + "jwt_audience": "llama-stack", + "jwt_issuer": "llama-stack-github", + "token_expiry": 86400, + } + + github_config = GitHubAuthConfig(**auth_config) + + # Add auth routes BEFORE middleware so they're not protected + auth_router = create_github_auth_router(github_config) + app.include_router(auth_router) + + # Then add auth middleware for other routes + app.add_middleware(AuthenticationMiddleware, auth_config=github_config) + + @app.get("/test") + def test_endpoint(): + return {"message": "Authentication successful"} + + return app + + +@pytest.fixture +def github_oauth_client(github_oauth_app): + return TestClient(github_oauth_app) + + +def test_github_login_redirect(github_oauth_client): + """Test that GitHub login endpoint returns redirect to GitHub""" + response = github_oauth_client.get("/auth/github/login", follow_redirects=False) + assert response.status_code == 307 # Temporary redirect + assert "github.com/login/oauth/authorize" in response.headers["location"] + assert "client_id=test_client_id" in response.headers["location"] + assert "state=" in response.headers["location"] + + +async def mock_github_token_exchange_success(*args, **kwargs): + """Mock successful GitHub token exchange""" + return MockResponse( + 200, + { + "access_token": "github_access_token_123", + "token_type": "bearer", + "scope": "read:user,read:org", + }, + ) + + +class MockAsyncClient: + """Mock httpx.AsyncClient for testing""" + + async def __aenter__(self): + return self + + async def __aexit__(self, *args): + pass + + async def post(self, url, **kwargs): + if "login/oauth/access_token" in url: + return await mock_github_token_exchange_success() + return MockResponse(404, {}) + + async def get(self, url, **kwargs): + if url.endswith("/user"): + return MockResponse( + 200, + { + "login": "test-user", + "id": 12345, + "email": "test@example.com", + "name": "Test User", + "avatar_url": "https://avatars.githubusercontent.com/u/12345", + }, + ) + elif url.endswith("/user/orgs"): + return MockResponse( + 200, + [ + {"login": "test-org-1"}, + {"login": "test-org-2"}, + ], + ) + return MockResponse(404, {}) + + +@patch("httpx.AsyncClient", MockAsyncClient) +def test_github_callback_success(github_oauth_app): + """Test successful GitHub OAuth callback""" + # Get a fresh client for this test + client = TestClient(github_oauth_app) + + # First, make a login request to generate a state + login_response = client.get("/auth/github/login", follow_redirects=False) + assert login_response.status_code == 307 + + # Extract the state from the redirect URL + location = login_response.headers["location"] + import re + + state_match = re.search(r"state=([^&]+)", location) + assert state_match + state = state_match.group(1) + + # Now use that state in the callback + response = client.get(f"/auth/github/callback?code=test_code&state={state}") + + assert response.status_code == 200 + data = response.json() + assert "access_token" in data + # Verify the JWT token contains the expected claims + token = data["access_token"] + # Decode without verification since this is a test + claims = jwt.decode(token, "test_jwt_secret", algorithms=["HS256"], audience="llama-stack") + assert claims["github_username"] == "test-user" + assert claims["email"] == "test@example.com" + + +def test_github_callback_invalid_state(github_oauth_client): + """Test GitHub callback with invalid state""" + response = github_oauth_client.get("/auth/github/callback?code=test_code&state=invalid_state") + assert response.status_code == 400 + assert "Invalid state parameter" in response.json()["detail"] + + +async def mock_github_token_exchange_error(*args, **kwargs): + """Mock GitHub token exchange error""" + return MockResponse( + 200, {"error": "bad_verification_code", "error_description": "The code passed is incorrect or expired."} + ) + + +class MockAsyncClientError: + """Mock httpx.AsyncClient that returns error for token exchange""" + + async def __aenter__(self): + return self + + async def __aexit__(self, *args): + pass + + async def post(self, url, **kwargs): + if "login/oauth/access_token" in url: + return await mock_github_token_exchange_error() + return MockResponse(404, {}) + + async def get(self, url, **kwargs): + return MockResponse(404, {}) + + +@patch("httpx.AsyncClient", MockAsyncClientError) +def test_github_callback_token_exchange_error(github_oauth_app): + """Test GitHub callback with token exchange error""" + client = TestClient(github_oauth_app) + + # First, make a login request to generate a state + login_response = client.get("/auth/github/login", follow_redirects=False) + assert login_response.status_code == 307 + + # Extract the state from the redirect URL + location = login_response.headers["location"] + import re + + state_match = re.search(r"state=([^&]+)", location) + assert state_match + state = state_match.group(1) + + # Now use that state in the callback with bad code + response = client.get(f"/auth/github/callback?code=bad_code&state={state}") + assert response.status_code == 400 + assert "The code passed is incorrect or expired" in response.json()["detail"] + + +@pytest.fixture +def github_jwt_token(): + """Create a valid GitHub JWT token for testing""" + claims = { + "sub": "test-user", + "aud": "llama-stack", + "iss": "llama-stack-github", + "exp": datetime.now(UTC) + timedelta(hours=1), + "iat": datetime.now(UTC), + "nbf": datetime.now(UTC), + "github_username": "test-user", + "github_user_id": "12345", + "github_orgs": ["test-org-1", "test-org-2"], + "email": "test@example.com", + "name": "Test User", + } + + return jwt.encode(claims, "test_jwt_secret", algorithm="HS256") + + +def test_github_jwt_authentication_success(github_oauth_client, github_jwt_token): + """Test API authentication with GitHub-issued JWT""" + response = github_oauth_client.get("/test", headers={"Authorization": f"Bearer {github_jwt_token}"}) + assert response.status_code == 200 + assert response.json() == {"message": "Authentication successful"} + + +def test_github_jwt_authentication_invalid_token(github_oauth_client): + """Test API authentication with invalid JWT""" + response = github_oauth_client.get("/test", headers={"Authorization": "Bearer invalid.jwt.token"}) + assert response.status_code == 401 + assert "Invalid GitHub JWT token" in response.json()["error"]["message"] + + +def test_github_jwt_authentication_expired_token(github_oauth_client): + """Test API authentication with expired JWT""" + # Create an expired token + claims = { + "sub": "test-user", + "aud": "llama-stack", + "iss": "llama-stack-github", + "exp": datetime.now(UTC) - timedelta(hours=1), # Expired + "iat": datetime.now(UTC) - timedelta(hours=2), + "nbf": datetime.now(UTC) - timedelta(hours=2), + "github_username": "test-user", + } + + expired_token = jwt.encode(claims, "test_jwt_secret", algorithm="HS256") + + response = github_oauth_client.get("/test", headers={"Authorization": f"Bearer {expired_token}"}) + assert response.status_code == 401 + assert "Invalid GitHub JWT token" in response.json()["error"]["message"] + + +def test_github_jwt_authentication_wrong_audience(github_oauth_client): + """Test API authentication with JWT having wrong audience""" + claims = { + "sub": "test-user", + "aud": "wrong-audience", # Wrong audience + "iss": "llama-stack-github", + "exp": datetime.now(UTC) + timedelta(hours=1), + "iat": datetime.now(UTC), + "nbf": datetime.now(UTC), + "github_username": "test-user", + } + + wrong_audience_token = jwt.encode(claims, "test_jwt_secret", algorithm="HS256") + + response = github_oauth_client.get("/test", headers={"Authorization": f"Bearer {wrong_audience_token}"}) + assert response.status_code == 401 + assert "Invalid GitHub JWT token" in response.json()["error"]["message"] + + +def test_github_claims_mapping(): + """Test GitHub claims are properly mapped to attributes""" + config = GitHubAuthConfig( + type="github_oauth", + github_client_id="test", + github_client_secret="test", + jwt_secret="test", + ) + + claims = { + "sub": "test-user", + "github_username": "test-user", + "github_orgs": ["org1", "org2"], + "github_teams": ["team1", "team2"], + "github_user_id": "12345", + } + + # Default mapping only maps "sub" to "roles" + attributes = get_attributes_from_claims(claims, config.claims_mapping) + + assert "test-user" in attributes["roles"] + # No other mappings by default + assert len(attributes) == 1 + + +@pytest.mark.asyncio +async def test_github_auth_provider_validate_token(): + """Test GitHubAuthProvider token validation""" + config = GitHubAuthConfig( + type="github_oauth", + github_client_id="test", + github_client_secret="test", + jwt_secret="test_secret", + jwt_algorithm="HS256", + jwt_audience="test-audience", + jwt_issuer="test-issuer", + ) + + provider = GitHubAuthProvider(config) + + # Create a valid token + claims = { + "sub": "test-user", + "aud": "test-audience", + "iss": "test-issuer", + "exp": datetime.now(UTC) + timedelta(hours=1), + "iat": datetime.now(UTC), + "nbf": datetime.now(UTC), + "github_username": "test-user", + "github_orgs": ["org1"], + "github_user_id": "12345", + } + + token = jwt.encode(claims, "test_secret", algorithm="HS256") + + user = await provider.validate_token(token) + assert user.principal == "test-user" + # Default mapping only maps "sub" to "roles" + assert "test-user" in user.attributes["roles"] + # No other mappings by default + assert len(user.attributes) == 1 + + +@pytest.mark.asyncio +async def test_github_auth_provider_custom_claims_mapping(): + """Test GitHubAuthProvider with custom claims mapping""" + config = GitHubAuthConfig( + type="github_oauth", + github_client_id="test", + github_client_secret="test", + jwt_secret="test_secret", + jwt_algorithm="HS256", + jwt_audience="test-audience", + jwt_issuer="test-issuer", + claims_mapping={ + "sub": "roles", + "github_orgs": "teams", + "github_teams": "teams", + "github_user_id": "namespaces", + }, + ) + + provider = GitHubAuthProvider(config) + + # Create a valid token + claims = { + "sub": "test-user", + "aud": "test-audience", + "iss": "test-issuer", + "exp": datetime.now(UTC) + timedelta(hours=1), + "iat": datetime.now(UTC), + "nbf": datetime.now(UTC), + "github_username": "test-user", + "github_orgs": ["org1", "org2"], + "github_teams": ["team1", "team2"], + "github_user_id": "12345", + } + + token = jwt.encode(claims, "test_secret", algorithm="HS256") + + user = await provider.validate_token(token) + assert user.principal == "test-user" + assert "test-user" in user.attributes["roles"] + assert set(user.attributes["teams"]) == {"org1", "org2", "team1", "team2"} + assert "12345" in user.attributes["namespaces"] + + +@pytest.mark.asyncio +async def test_github_auth_provider_invalid_token(): + """Test GitHubAuthProvider with invalid token""" + config = GitHubAuthConfig( + type="github_oauth", + github_client_id="test", + github_client_secret="test", + jwt_secret="test_secret", + ) + + provider = GitHubAuthProvider(config) + + with pytest.raises(ValueError, match="Invalid GitHub JWT token"): + await provider.validate_token("invalid.token.here") + + +def test_github_auth_provider_authorization_url(): + """Test GitHubAuthProvider generates correct authorization URL""" + from unittest.mock import Mock + + config = GitHubAuthConfig( + type="github_oauth", + github_client_id="test_client_id", + github_client_secret="test_secret", + jwt_secret="test", + ) + + provider = GitHubAuthProvider(config) + + # Mock request object + mock_request = Mock() + mock_request.url_for.return_value = "http://localhost:8321/auth/github/callback" + + url = provider.get_authorization_url("test_state_123", mock_request) + + assert "https://github.com/login/oauth/authorize" in url + assert "client_id=test_client_id" in url + assert "redirect_uri=http%3A%2F%2Flocalhost%3A8321%2Fauth%2Fgithub%2Fcallback" in url + assert "state=test_state_123" in url + assert "scope=read%3Auser+read%3Aorg" in url + + +@pytest.mark.asyncio +async def test_github_auth_provider_complete_flow(): + """Test complete OAuth flow in GitHubAuthProvider""" + from unittest.mock import Mock + + config = GitHubAuthConfig( + type="github_oauth", + github_client_id="test_client_id", + github_client_secret="test_secret", + jwt_secret="test_jwt_secret", + jwt_algorithm="HS256", + jwt_audience="llama-stack", + jwt_issuer="llama-stack-github", + ) + + provider = GitHubAuthProvider(config) + + # Mock request object + mock_request = Mock() + mock_request.url_for.return_value = "/auth/github/callback" + mock_request.url.scheme = "http" + mock_request.url.netloc = "localhost:8321" + + with patch("httpx.AsyncClient", MockAsyncClient): + # Now returns just the JWT token + token = await provider.complete_oauth_flow("test_code", mock_request) + + # Verify it's a JWT token by decoding it + claims = jwt.decode(token, "test_jwt_secret", algorithms=["HS256"], audience="llama-stack") + assert claims["github_username"] == "test-user" + assert claims["github_orgs"] == ["test-org-1", "test-org-2"] + assert claims["sub"] == "test-user"