mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
Merge 94282d3482
into 40fdce79b3
This commit is contained in:
commit
12265d4222
9 changed files with 962 additions and 43 deletions
2
.github/workflows/integration-auth-tests.yml
vendored
2
.github/workflows/integration-auth-tests.yml
vendored
|
@ -73,7 +73,7 @@ jobs:
|
|||
server:
|
||||
port: 8321
|
||||
EOF
|
||||
yq eval '.server.auth = {"provider_type": "${{ matrix.auth-provider }}"}' -i $run_dir/run.yaml
|
||||
yq eval '.server.auth = {"type": "${{ matrix.auth-provider }}"}' -i $run_dir/run.yaml
|
||||
yq eval '.server.auth.config = {"tls_cafile": "${{ env.KUBERNETES_CA_CERT_PATH }}", "issuer": "${{ env.KUBERNETES_ISSUER }}", "audience": "${{ env.KUBERNETES_AUDIENCE }}"}' -i $run_dir/run.yaml
|
||||
yq eval '.server.auth.config.jwks = {"uri": "${{ env.KUBERNETES_API_SERVER_URL }}", "token": "${{ env.TOKEN }}"}' -i $run_dir/run.yaml
|
||||
cat $run_dir/run.yaml
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
from enum import StrEnum
|
||||
from pathlib import Path
|
||||
from typing import Annotated, Any
|
||||
from typing import Annotated, Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
|
@ -166,20 +166,90 @@ class AuthProviderType(StrEnum):
|
|||
|
||||
OAUTH2_TOKEN = "oauth2_token"
|
||||
CUSTOM = "custom"
|
||||
GITHUB_OAUTH = "github_oauth"
|
||||
|
||||
|
||||
class AuthenticationConfig(BaseModel):
|
||||
provider_type: AuthProviderType = Field(
|
||||
...,
|
||||
description="Type of authentication provider",
|
||||
)
|
||||
class OAuth2TokenAuthConfig(BaseModel):
|
||||
"""Configuration for OAuth2 token authentication."""
|
||||
|
||||
type: Literal[AuthProviderType.OAUTH2_TOKEN] = AuthProviderType.OAUTH2_TOKEN
|
||||
config: dict[str, Any] = Field(
|
||||
...,
|
||||
description="Provider-specific configuration",
|
||||
description="OAuth2 token validation configuration",
|
||||
)
|
||||
access_policy: list[AccessRule] = Field(default=[], description="Rules for determining access to resources")
|
||||
|
||||
|
||||
class CustomAuthConfig(BaseModel):
|
||||
"""Configuration for custom authentication."""
|
||||
|
||||
type: Literal[AuthProviderType.CUSTOM] = AuthProviderType.CUSTOM
|
||||
config: dict[str, Any] = Field(
|
||||
...,
|
||||
description="Custom authentication endpoint configuration",
|
||||
)
|
||||
access_policy: list[AccessRule] = Field(default=[], description="Rules for determining access to resources")
|
||||
|
||||
|
||||
class GitHubAuthConfig(BaseModel):
|
||||
"""Configuration for GitHub OAuth authentication.
|
||||
|
||||
This provider implements GitHub OAuth authentication with the following endpoints:
|
||||
|
||||
- /auth/github/login: Initiates OAuth flow
|
||||
- Accepts optional redirect_url parameter (must be in allowed_redirect_urls)
|
||||
- Redirects to GitHub for authorization
|
||||
|
||||
- /auth/github/callback: Handles OAuth callback (should be used as the callback URL for your GitHub OAuth App)
|
||||
- Exchanges authorization code for access token
|
||||
- Creates JWT with user info
|
||||
- With redirect_url: Redirects to specified URL with JWT in fragment (#token=...)
|
||||
- Without redirect_url: Returns JSON response with JWT
|
||||
|
||||
Ensure your GitHub OAuth App's callback URL matches your server's URL + /auth/github/callback.
|
||||
"""
|
||||
|
||||
type: Literal[AuthProviderType.GITHUB_OAUTH] = AuthProviderType.GITHUB_OAUTH
|
||||
|
||||
# GitHub OAuth settings
|
||||
github_client_id: str = Field(description="GitHub OAuth App Client ID")
|
||||
github_client_secret: str = Field(description="GitHub OAuth App Client Secret")
|
||||
allowed_redirect_urls: list[str] = Field(
|
||||
default=[],
|
||||
description="Allowed redirect URLs for OAuth callback, e.g. frontend URL",
|
||||
)
|
||||
|
||||
# JWT configuration for token generation
|
||||
jwt_secret: str = Field(description="Secret for signing JWT tokens")
|
||||
jwt_algorithm: str = Field(default="HS256", description="JWT signing algorithm")
|
||||
jwt_audience: str = Field(default="llama-stack", description="JWT audience")
|
||||
jwt_issuer: str = Field(default="llama-stack-github", description="JWT issuer")
|
||||
token_expiry: int = Field(default=86400, description="JWT token expiry in seconds")
|
||||
|
||||
# OAuth2 token validation config (for validating the JWTs we issue)
|
||||
oauth2_config: dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Additional OAuth2 token validation configuration",
|
||||
)
|
||||
|
||||
# Access control
|
||||
access_policy: list[AccessRule] = Field(default=[], description="Rules for determining access to resources")
|
||||
|
||||
# Claims mapping for GitHub attributes
|
||||
claims_mapping: dict[str, str] = Field(
|
||||
default_factory=lambda: {
|
||||
"sub": "roles", # GitHub username as role
|
||||
},
|
||||
description="Mapping from JWT claims to Llama Stack attributes",
|
||||
)
|
||||
|
||||
|
||||
# Discriminated union for authentication configurations
|
||||
AuthenticationConfig = Annotated[
|
||||
OAuth2TokenAuthConfig | CustomAuthConfig | GitHubAuthConfig, Field(discriminator="type")
|
||||
]
|
||||
|
||||
|
||||
class AuthenticationRequiredError(Exception):
|
||||
pass
|
||||
|
||||
|
|
|
@ -81,13 +81,23 @@ class AuthenticationMiddleware:
|
|||
def __init__(self, app, auth_config: AuthenticationConfig):
|
||||
self.app = app
|
||||
self.auth_provider = create_auth_provider(auth_config)
|
||||
self.public_paths = self.auth_provider.get_public_paths()
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
if scope["type"] == "http":
|
||||
# Skip authentication for public paths defined by the provider
|
||||
path = scope.get("path", "")
|
||||
if any(path.startswith(public_path) for public_path in self.public_paths):
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
headers = dict(scope.get("headers", []))
|
||||
auth_header = headers.get(b"authorization", b"").decode()
|
||||
|
||||
if not auth_header or not auth_header.startswith("Bearer "):
|
||||
if not auth_header:
|
||||
error_msg = self.auth_provider.get_auth_error_message(scope)
|
||||
return await self._send_auth_error(send, error_msg)
|
||||
|
||||
if not auth_header.startswith("Bearer "):
|
||||
return await self._send_auth_error(send, "Missing or invalid Authorization header")
|
||||
|
||||
token = auth_header.split("Bearer ", 1)[1]
|
||||
|
|
|
@ -10,13 +10,19 @@ from abc import ABC, abstractmethod
|
|||
from asyncio import Lock
|
||||
from pathlib import Path
|
||||
from typing import Self
|
||||
from urllib.parse import parse_qs
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
import httpx
|
||||
from jose import jwt
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
|
||||
from llama_stack.distribution.datatypes import AuthenticationConfig, AuthProviderType, User
|
||||
from llama_stack.distribution.datatypes import (
|
||||
AuthenticationConfig,
|
||||
CustomAuthConfig,
|
||||
GitHubAuthConfig,
|
||||
OAuth2TokenAuthConfig,
|
||||
User,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
logger = get_logger(name=__name__, category="auth")
|
||||
|
@ -62,6 +68,24 @@ class AuthProvider(ABC):
|
|||
"""Clean up any resources."""
|
||||
pass
|
||||
|
||||
def setup_routes(self, app):
|
||||
"""Setup any provider-specific routes (e.g., OAuth callbacks).
|
||||
|
||||
This is optional - providers that don't need special routes can skip this.
|
||||
"""
|
||||
return
|
||||
|
||||
def get_public_paths(self) -> list[str]:
|
||||
"""Return a list of path prefixes that should bypass authentication.
|
||||
|
||||
This is optional - providers that don't have public paths return empty list.
|
||||
"""
|
||||
return []
|
||||
|
||||
def get_auth_error_message(self, scope: dict | None = None) -> str:
|
||||
"""Return provider-specific authentication error message."""
|
||||
return "Authentication required"
|
||||
|
||||
|
||||
def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str]) -> dict[str, list[str]]:
|
||||
attributes: dict[str, list[str]] = {}
|
||||
|
@ -232,6 +256,17 @@ class OAuth2TokenAuthProvider(AuthProvider):
|
|||
async def close(self):
|
||||
pass
|
||||
|
||||
def get_auth_error_message(self, scope: dict | None = None) -> str:
|
||||
"""Return OAuth2-specific authentication error message."""
|
||||
if self.config.issuer:
|
||||
return f"Authentication required. Please provide a valid OAuth2 Bearer token from {self.config.issuer}"
|
||||
elif self.config.introspection:
|
||||
# Extract domain from introspection URL for a cleaner message
|
||||
domain = urlparse(self.config.introspection.url).netloc
|
||||
return f"Authentication required. Please provide a valid OAuth2 Bearer token validated by {domain}"
|
||||
else:
|
||||
return "Authentication required. Please provide a valid OAuth2 Bearer token in the Authorization header"
|
||||
|
||||
async def _refresh_jwks(self) -> None:
|
||||
"""
|
||||
Refresh the JWKS cache.
|
||||
|
@ -338,15 +373,25 @@ class CustomAuthProvider(AuthProvider):
|
|||
await self._client.aclose()
|
||||
self._client = None
|
||||
|
||||
def get_auth_error_message(self, scope: dict | None = None) -> str:
|
||||
"""Return custom auth provider-specific authentication error message."""
|
||||
# Extract domain from endpoint URL for a cleaner message
|
||||
domain = urlparse(self.config.endpoint).netloc
|
||||
if domain:
|
||||
return f"Authentication required. Please provide your API key as a Bearer token (validated by {domain})"
|
||||
else:
|
||||
return "Authentication required. Please provide your API key as a Bearer token in the Authorization header"
|
||||
|
||||
|
||||
def create_auth_provider(config: AuthenticationConfig) -> AuthProvider:
|
||||
"""Factory function to create the appropriate auth provider."""
|
||||
provider_type = config.provider_type.lower()
|
||||
|
||||
if provider_type == "custom":
|
||||
if isinstance(config, CustomAuthConfig):
|
||||
return CustomAuthProvider(CustomAuthProviderConfig.model_validate(config.config))
|
||||
elif provider_type == "oauth2_token":
|
||||
elif isinstance(config, OAuth2TokenAuthConfig):
|
||||
return OAuth2TokenAuthProvider(OAuth2TokenAuthProviderConfig.model_validate(config.config))
|
||||
elif isinstance(config, GitHubAuthConfig):
|
||||
from .github_oauth_auth_provider import GitHubAuthProvider
|
||||
|
||||
return GitHubAuthProvider(config)
|
||||
else:
|
||||
supported_providers = ", ".join([t.value for t in AuthProviderType])
|
||||
raise ValueError(f"Unsupported auth provider type: {provider_type}. Supported types are: {supported_providers}")
|
||||
raise ValueError(f"Unknown authentication config type: {type(config)}")
|
||||
|
|
109
llama_stack/distribution/server/auth_routes.py
Normal file
109
llama_stack/distribution/server/auth_routes.py
Normal file
|
@ -0,0 +1,109 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import secrets
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from fastapi.responses import JSONResponse, RedirectResponse
|
||||
|
||||
from llama_stack.distribution.datatypes import GitHubAuthConfig
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from .github_oauth_auth_provider import (
|
||||
GITHUB_CALLBACK_PATH,
|
||||
GITHUB_LOGIN_PATH,
|
||||
GitHubAuthProvider,
|
||||
)
|
||||
|
||||
logger = get_logger(name=__name__, category="auth_routes")
|
||||
|
||||
|
||||
def create_github_auth_router(config: GitHubAuthConfig) -> APIRouter:
|
||||
"""Create and configure GitHub authentication router."""
|
||||
auth_provider = GitHubAuthProvider(config)
|
||||
router = APIRouter()
|
||||
oauth_states: dict[str, dict] = {}
|
||||
|
||||
def cleanup_expired_states() -> None:
|
||||
"""Remove expired OAuth states."""
|
||||
now = datetime.now(UTC)
|
||||
expired_states = [
|
||||
state
|
||||
for state, data in oauth_states.items()
|
||||
if (now - data["created_at"]).seconds > 300 # 5 minutes
|
||||
]
|
||||
for state in expired_states:
|
||||
oauth_states.pop(state, None)
|
||||
|
||||
@router.get(GITHUB_LOGIN_PATH)
|
||||
async def github_login(request: Request, redirect_url: str | None = None):
|
||||
"""Initiate GitHub OAuth flow."""
|
||||
cleanup_expired_states()
|
||||
|
||||
# Validate redirect URL if provided
|
||||
if redirect_url:
|
||||
if not auth_provider.config.allowed_redirect_urls:
|
||||
raise HTTPException(status_code=400, detail="Redirect URLs not configured")
|
||||
if redirect_url not in auth_provider.config.allowed_redirect_urls:
|
||||
raise HTTPException(status_code=400, detail="Invalid redirect URL")
|
||||
|
||||
state = secrets.token_urlsafe(32)
|
||||
|
||||
oauth_states[state] = {
|
||||
"created_at": datetime.now(UTC),
|
||||
"redirect_url": redirect_url,
|
||||
}
|
||||
|
||||
# Get authorization URL
|
||||
auth_url = auth_provider.get_authorization_url(state, request)
|
||||
|
||||
logger.debug(f"Redirecting to GitHub OAuth: {auth_url}")
|
||||
return RedirectResponse(url=auth_url)
|
||||
|
||||
@router.get(GITHUB_CALLBACK_PATH)
|
||||
async def github_callback(code: str, state: str, request: Request):
|
||||
"""Handle GitHub OAuth callback."""
|
||||
# Validate state parameter
|
||||
if state not in oauth_states:
|
||||
logger.warning(f"Invalid OAuth state received: {state}")
|
||||
raise HTTPException(status_code=400, detail="Invalid state parameter")
|
||||
|
||||
state_data = oauth_states.pop(state)
|
||||
|
||||
# Check state expiry
|
||||
if (datetime.now(UTC) - state_data["created_at"]).seconds > 300:
|
||||
logger.warning("OAuth state expired")
|
||||
raise HTTPException(status_code=400, detail="State expired")
|
||||
|
||||
try:
|
||||
# Complete OAuth flow
|
||||
access_token = await auth_provider.complete_oauth_flow(code, request)
|
||||
|
||||
logger.info("GitHub OAuth successful")
|
||||
|
||||
redirect_url = state_data.get("redirect_url")
|
||||
if redirect_url:
|
||||
# Validate redirect URL from state
|
||||
if redirect_url not in auth_provider.config.allowed_redirect_urls:
|
||||
raise HTTPException(status_code=400, detail="Invalid redirect URL in state")
|
||||
|
||||
import urllib.parse
|
||||
|
||||
params = urllib.parse.urlencode({"token": access_token})
|
||||
return RedirectResponse(url=f"{redirect_url}#{params}")
|
||||
else:
|
||||
# For API calls, only return the access token
|
||||
return JSONResponse(content={"access_token": access_token})
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"GitHub OAuth error: {e}")
|
||||
raise HTTPException(status_code=400, detail=str(e)) from e
|
||||
except Exception as e:
|
||||
logger.exception("Unexpected error during GitHub OAuth")
|
||||
raise HTTPException(status_code=500, detail="Authentication failed") from e
|
||||
|
||||
return router
|
203
llama_stack/distribution/server/github_oauth_auth_provider.py
Normal file
203
llama_stack/distribution/server/github_oauth_auth_provider.py
Normal file
|
@ -0,0 +1,203 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any
|
||||
from urllib.parse import urlencode
|
||||
|
||||
import httpx
|
||||
from fastapi import Request
|
||||
from jose import jwt
|
||||
|
||||
from llama_stack.distribution.datatypes import GitHubAuthConfig, User
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from .auth_providers import AuthProvider, get_attributes_from_claims
|
||||
|
||||
logger = get_logger(name=__name__, category="github_auth")
|
||||
|
||||
# GitHub API constants
|
||||
GITHUB_API_BASE_URL = "https://api.github.com"
|
||||
GITHUB_OAUTH_BASE_URL = "https://github.com"
|
||||
|
||||
# GitHub OAuth route paths
|
||||
GITHUB_LOGIN_PATH = "/auth/github/login"
|
||||
GITHUB_CALLBACK_PATH = "/auth/github/callback"
|
||||
|
||||
|
||||
class GitHubAuthProvider(AuthProvider):
|
||||
"""Authentication provider for GitHub OAuth flow and JWT validation."""
|
||||
|
||||
def __init__(self, config: GitHubAuthConfig):
|
||||
self.config = config
|
||||
|
||||
async def validate_token(self, token: str, scope: dict | None = None) -> User:
|
||||
"""Validate a GitHub-issued JWT token."""
|
||||
try:
|
||||
claims = self.verify_jwt(token)
|
||||
|
||||
principal = claims["sub"]
|
||||
attributes = get_attributes_from_claims(claims, self.config.claims_mapping)
|
||||
|
||||
return User(principal=principal, attributes=attributes)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Error validating GitHub JWT")
|
||||
raise ValueError(f"Invalid GitHub JWT token: {str(e)}") from e
|
||||
|
||||
async def close(self):
|
||||
"""Clean up any resources."""
|
||||
pass
|
||||
|
||||
def setup_routes(self, app):
|
||||
"""Setup GitHub OAuth routes."""
|
||||
from .auth_routes import create_github_auth_router
|
||||
|
||||
github_router = create_github_auth_router(self.config)
|
||||
app.include_router(github_router)
|
||||
|
||||
def get_public_paths(self) -> list[str]:
|
||||
"""GitHub OAuth paths that don't require authentication."""
|
||||
return ["/auth/github/"]
|
||||
|
||||
def get_auth_error_message(self, scope: dict | None = None) -> str:
|
||||
"""Return GitHub-specific authentication error message."""
|
||||
if scope:
|
||||
headers = dict(scope.get("headers", []))
|
||||
host = headers.get(b"host", b"").decode()
|
||||
scheme = scope.get("scheme", "http")
|
||||
|
||||
if host:
|
||||
auth_url = f"{scheme}://{host}{GITHUB_LOGIN_PATH}"
|
||||
return f"Authentication required. Please authenticate via GitHub at {auth_url}"
|
||||
|
||||
return f"Authentication required. Please authenticate by visiting {GITHUB_LOGIN_PATH} to start the authentication flow."
|
||||
|
||||
# OAuth flow methods
|
||||
def get_authorization_url(self, state: str, request: Request) -> str:
|
||||
"""Generate GitHub OAuth authorization URL."""
|
||||
params = {
|
||||
"client_id": self.config.github_client_id,
|
||||
"redirect_uri": _build_callback_url(request),
|
||||
"scope": "read:user read:org",
|
||||
"state": state,
|
||||
}
|
||||
return f"{GITHUB_OAUTH_BASE_URL}/login/oauth/authorize?" + urlencode(params)
|
||||
|
||||
async def complete_oauth_flow(self, code: str, request: Request) -> Any:
|
||||
"""Complete the GitHub OAuth flow and return JWT access token."""
|
||||
# Exchange code for token
|
||||
logger.debug("Exchanging code for GitHub access token")
|
||||
token_data = await self._exchange_code_for_token(code, request)
|
||||
|
||||
if "error" in token_data:
|
||||
raise ValueError(f"GitHub OAuth error: {token_data.get('error_description', token_data['error'])}")
|
||||
|
||||
access_token = token_data["access_token"]
|
||||
|
||||
# Get user info
|
||||
logger.debug("Fetching GitHub user info")
|
||||
github_info = await self._get_user_info(access_token)
|
||||
|
||||
# Create JWT
|
||||
logger.debug(f"Creating JWT for user: {github_info['user']['login']}")
|
||||
jwt_token = self._create_jwt_token(github_info)
|
||||
|
||||
return jwt_token
|
||||
|
||||
def verify_jwt(self, token: str) -> Any:
|
||||
"""Verify and decode a GitHub-issued JWT token."""
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
self.config.jwt_secret,
|
||||
algorithms=[self.config.jwt_algorithm],
|
||||
audience=self.config.jwt_audience,
|
||||
issuer=self.config.jwt_issuer,
|
||||
)
|
||||
return payload
|
||||
except jwt.JWTError as e:
|
||||
raise ValueError(f"Invalid JWT token: {e}") from e
|
||||
|
||||
# Private helper methods
|
||||
async def _exchange_code_for_token(self, code: str, request: Request) -> Any:
|
||||
"""Exchange authorization code for GitHub access token."""
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
f"{GITHUB_OAUTH_BASE_URL}/login/oauth/access_token",
|
||||
json={
|
||||
"client_id": self.config.github_client_id,
|
||||
"client_secret": self.config.github_client_secret,
|
||||
"code": code,
|
||||
"redirect_uri": _build_callback_url(request),
|
||||
},
|
||||
headers={"Accept": "application/json"},
|
||||
timeout=10.0,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
async def _get_user_info(self, access_token: str) -> dict:
|
||||
"""Fetch user info and organizations from GitHub."""
|
||||
headers = {
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"Accept": "application/vnd.github.v3+json",
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
# Fetch user and orgs in parallel
|
||||
user_task = client.get(f"{GITHUB_API_BASE_URL}/user", headers=headers, timeout=10.0)
|
||||
orgs_task = client.get(f"{GITHUB_API_BASE_URL}/user/orgs", headers=headers, timeout=10.0)
|
||||
|
||||
user_response, orgs_response = await asyncio.gather(user_task, orgs_task)
|
||||
|
||||
user_response.raise_for_status()
|
||||
orgs_response.raise_for_status()
|
||||
|
||||
user_data = user_response.json()
|
||||
orgs_data = orgs_response.json()
|
||||
|
||||
# Extract organization names
|
||||
organizations = [org["login"] for org in orgs_data]
|
||||
|
||||
return {
|
||||
"user": user_data,
|
||||
"organizations": organizations,
|
||||
}
|
||||
|
||||
def _create_jwt_token(self, github_info: dict) -> Any:
|
||||
"""Create JWT token with GitHub user information."""
|
||||
user = github_info["user"]
|
||||
orgs = github_info["organizations"]
|
||||
teams = github_info.get("teams", [])
|
||||
|
||||
# Create JWT claims that map to Llama Stack attributes
|
||||
now = datetime.now(UTC)
|
||||
claims = {
|
||||
"sub": user["login"],
|
||||
"aud": self.config.jwt_audience,
|
||||
"iss": self.config.jwt_issuer,
|
||||
"exp": now + timedelta(seconds=self.config.token_expiry),
|
||||
"iat": now,
|
||||
"nbf": now,
|
||||
# Custom claims that will be mapped by claims_mapping
|
||||
"github_username": user["login"],
|
||||
"github_user_id": str(user["id"]),
|
||||
"github_orgs": orgs,
|
||||
"github_teams": teams,
|
||||
"email": user.get("email"),
|
||||
"name": user.get("name"),
|
||||
"avatar_url": user.get("avatar_url"),
|
||||
}
|
||||
|
||||
return jwt.encode(claims, self.config.jwt_secret, algorithm=self.config.jwt_algorithm)
|
||||
|
||||
|
||||
def _build_callback_url(request: Request) -> str:
|
||||
"""Build the GitHub OAuth callback URL from the current request."""
|
||||
callback_path = str(request.url_for("github_callback"))
|
||||
return callback_path
|
|
@ -31,7 +31,11 @@ from openai import BadRequestError
|
|||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from llama_stack.apis.common.responses import PaginatedResponse
|
||||
from llama_stack.distribution.datatypes import AuthenticationRequiredError, LoggingConfig, StackRunConfig
|
||||
from llama_stack.distribution.datatypes import (
|
||||
AuthenticationRequiredError,
|
||||
LoggingConfig,
|
||||
StackRunConfig,
|
||||
)
|
||||
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
||||
from llama_stack.distribution.request_headers import PROVIDER_DATA_VAR, User, request_provider_data_context
|
||||
from llama_stack.distribution.resolver import InvalidProviderError
|
||||
|
@ -61,6 +65,7 @@ from llama_stack.providers.utils.telemetry.tracing import (
|
|||
)
|
||||
|
||||
from .auth import AuthenticationMiddleware
|
||||
from .auth_providers import create_auth_provider
|
||||
from .quota import QuotaMiddleware
|
||||
|
||||
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
||||
|
@ -448,9 +453,12 @@ def main(args: argparse.Namespace | None = None):
|
|||
if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"):
|
||||
app.add_middleware(ClientVersionMiddleware)
|
||||
|
||||
# Add authentication middleware if configured
|
||||
if config.server.auth:
|
||||
logger.info(f"Enabling authentication with provider: {config.server.auth.provider_type.value}")
|
||||
logger.info(f"Enabling authentication with provider: {config.server.auth.type.value}")
|
||||
|
||||
auth_provider = create_auth_provider(config.server.auth)
|
||||
auth_provider.setup_routes(app)
|
||||
|
||||
app.add_middleware(AuthenticationMiddleware, auth_config=config.server.auth)
|
||||
else:
|
||||
if config.server.quota:
|
||||
|
|
|
@ -11,10 +11,13 @@ import pytest
|
|||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from llama_stack.distribution.datatypes import AuthenticationConfig
|
||||
from llama_stack.distribution.datatypes import (
|
||||
AuthProviderType,
|
||||
CustomAuthConfig,
|
||||
OAuth2TokenAuthConfig,
|
||||
)
|
||||
from llama_stack.distribution.server.auth import AuthenticationMiddleware
|
||||
from llama_stack.distribution.server.auth_providers import (
|
||||
AuthProviderType,
|
||||
get_attributes_from_claims,
|
||||
)
|
||||
|
||||
|
@ -60,8 +63,8 @@ def invalid_token():
|
|||
@pytest.fixture
|
||||
def http_app(mock_auth_endpoint):
|
||||
app = FastAPI()
|
||||
auth_config = AuthenticationConfig(
|
||||
provider_type=AuthProviderType.CUSTOM,
|
||||
auth_config = CustomAuthConfig(
|
||||
type=AuthProviderType.CUSTOM,
|
||||
config={"endpoint": mock_auth_endpoint},
|
||||
)
|
||||
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
|
||||
|
@ -76,9 +79,9 @@ def http_app(mock_auth_endpoint):
|
|||
@pytest.fixture
|
||||
def k8s_app():
|
||||
app = FastAPI()
|
||||
auth_config = AuthenticationConfig(
|
||||
provider_type=AuthProviderType.KUBERNETES,
|
||||
config={"api_server_url": "https://kubernetes.default.svc"},
|
||||
auth_config = CustomAuthConfig(
|
||||
type=AuthProviderType.CUSTOM,
|
||||
config={"provider_type": "kubernetes", "api_server_url": "https://kubernetes.default.svc"},
|
||||
)
|
||||
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
|
||||
|
||||
|
@ -116,8 +119,8 @@ def mock_scope():
|
|||
@pytest.fixture
|
||||
def mock_http_middleware(mock_auth_endpoint):
|
||||
mock_app = AsyncMock()
|
||||
auth_config = AuthenticationConfig(
|
||||
provider_type=AuthProviderType.CUSTOM,
|
||||
auth_config = CustomAuthConfig(
|
||||
type=AuthProviderType.CUSTOM,
|
||||
config={"endpoint": mock_auth_endpoint},
|
||||
)
|
||||
return AuthenticationMiddleware(mock_app, auth_config), mock_app
|
||||
|
@ -126,9 +129,9 @@ def mock_http_middleware(mock_auth_endpoint):
|
|||
@pytest.fixture
|
||||
def mock_k8s_middleware():
|
||||
mock_app = AsyncMock()
|
||||
auth_config = AuthenticationConfig(
|
||||
provider_type=AuthProviderType.KUBERNETES,
|
||||
config={"api_server_url": "https://kubernetes.default.svc"},
|
||||
auth_config = CustomAuthConfig(
|
||||
type=AuthProviderType.CUSTOM,
|
||||
config={"provider_type": "kubernetes", "api_server_url": "https://kubernetes.default.svc"},
|
||||
)
|
||||
return AuthenticationMiddleware(mock_app, auth_config), mock_app
|
||||
|
||||
|
@ -161,7 +164,8 @@ async def mock_post_exception(*args, **kwargs):
|
|||
def test_missing_auth_header(http_client):
|
||||
response = http_client.get("/test")
|
||||
assert response.status_code == 401
|
||||
assert "Missing or invalid Authorization header" in response.json()["error"]["message"]
|
||||
assert "Authentication required" in response.json()["error"]["message"]
|
||||
assert "validated by mock-auth-service" in response.json()["error"]["message"]
|
||||
|
||||
|
||||
def test_invalid_auth_header_format(http_client):
|
||||
|
@ -261,8 +265,8 @@ async def test_http_middleware_with_access_attributes(mock_http_middleware, mock
|
|||
@pytest.fixture
|
||||
def oauth2_app():
|
||||
app = FastAPI()
|
||||
auth_config = AuthenticationConfig(
|
||||
provider_type=AuthProviderType.OAUTH2_TOKEN,
|
||||
auth_config = OAuth2TokenAuthConfig(
|
||||
type=AuthProviderType.OAUTH2_TOKEN,
|
||||
config={
|
||||
"jwks": {
|
||||
"uri": "http://mock-authz-service/token/introspect",
|
||||
|
@ -288,7 +292,8 @@ def oauth2_client(oauth2_app):
|
|||
def test_missing_auth_header_oauth2(oauth2_client):
|
||||
response = oauth2_client.get("/test")
|
||||
assert response.status_code == 401
|
||||
assert "Missing or invalid Authorization header" in response.json()["error"]["message"]
|
||||
assert "Authentication required" in response.json()["error"]["message"]
|
||||
assert "OAuth2 Bearer token" in response.json()["error"]["message"]
|
||||
|
||||
|
||||
def test_invalid_auth_header_format_oauth2(oauth2_client):
|
||||
|
@ -357,8 +362,8 @@ async def mock_auth_jwks_response(*args, **kwargs):
|
|||
@pytest.fixture
|
||||
def oauth2_app_with_jwks_token():
|
||||
app = FastAPI()
|
||||
auth_config = AuthenticationConfig(
|
||||
provider_type=AuthProviderType.OAUTH2_TOKEN,
|
||||
auth_config = OAuth2TokenAuthConfig(
|
||||
type=AuthProviderType.OAUTH2_TOKEN,
|
||||
config={
|
||||
"jwks": {
|
||||
"uri": "http://mock-authz-service/token/introspect",
|
||||
|
@ -448,8 +453,8 @@ def mock_introspection_endpoint():
|
|||
@pytest.fixture
|
||||
def introspection_app(mock_introspection_endpoint):
|
||||
app = FastAPI()
|
||||
auth_config = AuthenticationConfig(
|
||||
provider_type=AuthProviderType.OAUTH2_TOKEN,
|
||||
auth_config = OAuth2TokenAuthConfig(
|
||||
type=AuthProviderType.OAUTH2_TOKEN,
|
||||
config={
|
||||
"jwks": None,
|
||||
"introspection": {"url": mock_introspection_endpoint, "client_id": "myclient", "client_secret": "abcdefg"},
|
||||
|
@ -467,8 +472,8 @@ def introspection_app(mock_introspection_endpoint):
|
|||
@pytest.fixture
|
||||
def introspection_app_with_custom_mapping(mock_introspection_endpoint):
|
||||
app = FastAPI()
|
||||
auth_config = AuthenticationConfig(
|
||||
provider_type=AuthProviderType.OAUTH2_TOKEN,
|
||||
auth_config = OAuth2TokenAuthConfig(
|
||||
type=AuthProviderType.OAUTH2_TOKEN,
|
||||
config={
|
||||
"jwks": None,
|
||||
"introspection": {
|
||||
|
@ -507,7 +512,8 @@ def introspection_client_with_custom_mapping(introspection_app_with_custom_mappi
|
|||
def test_missing_auth_header_introspection(introspection_client):
|
||||
response = introspection_client.get("/test")
|
||||
assert response.status_code == 401
|
||||
assert "Missing or invalid Authorization header" in response.json()["error"]["message"]
|
||||
assert "Authentication required" in response.json()["error"]["message"]
|
||||
assert "OAuth2 Bearer token" in response.json()["error"]["message"]
|
||||
|
||||
|
||||
def test_invalid_auth_header_format_introspection(introspection_client):
|
||||
|
|
468
tests/unit/server/test_auth_github.py
Normal file
468
tests/unit/server/test_auth_github.py
Normal file
|
@ -0,0 +1,468 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
from jose import jwt
|
||||
|
||||
from llama_stack.distribution.datatypes import GitHubAuthConfig
|
||||
from llama_stack.distribution.server.auth import AuthenticationMiddleware
|
||||
from llama_stack.distribution.server.auth_providers import get_attributes_from_claims
|
||||
from llama_stack.distribution.server.auth_routes import create_github_auth_router
|
||||
from llama_stack.distribution.server.github_oauth_auth_provider import GitHubAuthProvider
|
||||
|
||||
|
||||
class MockResponse:
|
||||
def __init__(self, status_code, json_data):
|
||||
self.status_code = status_code
|
||||
self._json_data = json_data
|
||||
|
||||
def json(self):
|
||||
return self._json_data
|
||||
|
||||
def raise_for_status(self):
|
||||
if self.status_code != 200:
|
||||
raise Exception(f"HTTP error: {self.status_code}")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def github_oauth_app():
|
||||
app = FastAPI()
|
||||
auth_config = {
|
||||
"type": "github_oauth",
|
||||
"github_client_id": "test_client_id",
|
||||
"github_client_secret": "test_client_secret",
|
||||
"jwt_secret": "test_jwt_secret",
|
||||
"jwt_algorithm": "HS256",
|
||||
"jwt_audience": "llama-stack",
|
||||
"jwt_issuer": "llama-stack-github",
|
||||
"token_expiry": 86400,
|
||||
}
|
||||
|
||||
github_config = GitHubAuthConfig(**auth_config)
|
||||
|
||||
# Add auth routes BEFORE middleware so they're not protected
|
||||
auth_router = create_github_auth_router(github_config)
|
||||
app.include_router(auth_router)
|
||||
|
||||
# Then add auth middleware for other routes
|
||||
app.add_middleware(AuthenticationMiddleware, auth_config=github_config)
|
||||
|
||||
@app.get("/test")
|
||||
def test_endpoint():
|
||||
return {"message": "Authentication successful"}
|
||||
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def github_oauth_client(github_oauth_app):
|
||||
return TestClient(github_oauth_app)
|
||||
|
||||
|
||||
def test_github_login_redirect(github_oauth_client):
|
||||
"""Test that GitHub login endpoint returns redirect to GitHub"""
|
||||
response = github_oauth_client.get("/auth/github/login", follow_redirects=False)
|
||||
assert response.status_code == 307 # Temporary redirect
|
||||
assert "github.com/login/oauth/authorize" in response.headers["location"]
|
||||
assert "client_id=test_client_id" in response.headers["location"]
|
||||
assert "state=" in response.headers["location"]
|
||||
|
||||
|
||||
async def mock_github_token_exchange_success(*args, **kwargs):
|
||||
"""Mock successful GitHub token exchange"""
|
||||
return MockResponse(
|
||||
200,
|
||||
{
|
||||
"access_token": "github_access_token_123",
|
||||
"token_type": "bearer",
|
||||
"scope": "read:user,read:org",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class MockAsyncClient:
|
||||
"""Mock httpx.AsyncClient for testing"""
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *args):
|
||||
pass
|
||||
|
||||
async def post(self, url, **kwargs):
|
||||
if "login/oauth/access_token" in url:
|
||||
return await mock_github_token_exchange_success()
|
||||
return MockResponse(404, {})
|
||||
|
||||
async def get(self, url, **kwargs):
|
||||
if url.endswith("/user"):
|
||||
return MockResponse(
|
||||
200,
|
||||
{
|
||||
"login": "test-user",
|
||||
"id": 12345,
|
||||
"email": "test@example.com",
|
||||
"name": "Test User",
|
||||
"avatar_url": "https://avatars.githubusercontent.com/u/12345",
|
||||
},
|
||||
)
|
||||
elif url.endswith("/user/orgs"):
|
||||
return MockResponse(
|
||||
200,
|
||||
[
|
||||
{"login": "test-org-1"},
|
||||
{"login": "test-org-2"},
|
||||
],
|
||||
)
|
||||
return MockResponse(404, {})
|
||||
|
||||
|
||||
@patch("httpx.AsyncClient", MockAsyncClient)
|
||||
def test_github_callback_success(github_oauth_app):
|
||||
"""Test successful GitHub OAuth callback"""
|
||||
# Get a fresh client for this test
|
||||
client = TestClient(github_oauth_app)
|
||||
|
||||
# First, make a login request to generate a state
|
||||
login_response = client.get("/auth/github/login", follow_redirects=False)
|
||||
assert login_response.status_code == 307
|
||||
|
||||
# Extract the state from the redirect URL
|
||||
location = login_response.headers["location"]
|
||||
import re
|
||||
|
||||
state_match = re.search(r"state=([^&]+)", location)
|
||||
assert state_match
|
||||
state = state_match.group(1)
|
||||
|
||||
# Now use that state in the callback
|
||||
response = client.get(f"/auth/github/callback?code=test_code&state={state}")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "access_token" in data
|
||||
# Verify the JWT token contains the expected claims
|
||||
token = data["access_token"]
|
||||
# Decode without verification since this is a test
|
||||
claims = jwt.decode(token, "test_jwt_secret", algorithms=["HS256"], audience="llama-stack")
|
||||
assert claims["github_username"] == "test-user"
|
||||
assert claims["email"] == "test@example.com"
|
||||
|
||||
|
||||
def test_github_callback_invalid_state(github_oauth_client):
|
||||
"""Test GitHub callback with invalid state"""
|
||||
response = github_oauth_client.get("/auth/github/callback?code=test_code&state=invalid_state")
|
||||
assert response.status_code == 400
|
||||
assert "Invalid state parameter" in response.json()["detail"]
|
||||
|
||||
|
||||
async def mock_github_token_exchange_error(*args, **kwargs):
|
||||
"""Mock GitHub token exchange error"""
|
||||
return MockResponse(
|
||||
200, {"error": "bad_verification_code", "error_description": "The code passed is incorrect or expired."}
|
||||
)
|
||||
|
||||
|
||||
class MockAsyncClientError:
|
||||
"""Mock httpx.AsyncClient that returns error for token exchange"""
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *args):
|
||||
pass
|
||||
|
||||
async def post(self, url, **kwargs):
|
||||
if "login/oauth/access_token" in url:
|
||||
return await mock_github_token_exchange_error()
|
||||
return MockResponse(404, {})
|
||||
|
||||
async def get(self, url, **kwargs):
|
||||
return MockResponse(404, {})
|
||||
|
||||
|
||||
@patch("httpx.AsyncClient", MockAsyncClientError)
|
||||
def test_github_callback_token_exchange_error(github_oauth_app):
|
||||
"""Test GitHub callback with token exchange error"""
|
||||
client = TestClient(github_oauth_app)
|
||||
|
||||
# First, make a login request to generate a state
|
||||
login_response = client.get("/auth/github/login", follow_redirects=False)
|
||||
assert login_response.status_code == 307
|
||||
|
||||
# Extract the state from the redirect URL
|
||||
location = login_response.headers["location"]
|
||||
import re
|
||||
|
||||
state_match = re.search(r"state=([^&]+)", location)
|
||||
assert state_match
|
||||
state = state_match.group(1)
|
||||
|
||||
# Now use that state in the callback with bad code
|
||||
response = client.get(f"/auth/github/callback?code=bad_code&state={state}")
|
||||
assert response.status_code == 400
|
||||
assert "The code passed is incorrect or expired" in response.json()["detail"]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def github_jwt_token():
|
||||
"""Create a valid GitHub JWT token for testing"""
|
||||
claims = {
|
||||
"sub": "test-user",
|
||||
"aud": "llama-stack",
|
||||
"iss": "llama-stack-github",
|
||||
"exp": datetime.now(UTC) + timedelta(hours=1),
|
||||
"iat": datetime.now(UTC),
|
||||
"nbf": datetime.now(UTC),
|
||||
"github_username": "test-user",
|
||||
"github_user_id": "12345",
|
||||
"github_orgs": ["test-org-1", "test-org-2"],
|
||||
"email": "test@example.com",
|
||||
"name": "Test User",
|
||||
}
|
||||
|
||||
return jwt.encode(claims, "test_jwt_secret", algorithm="HS256")
|
||||
|
||||
|
||||
def test_github_jwt_authentication_success(github_oauth_client, github_jwt_token):
|
||||
"""Test API authentication with GitHub-issued JWT"""
|
||||
response = github_oauth_client.get("/test", headers={"Authorization": f"Bearer {github_jwt_token}"})
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"message": "Authentication successful"}
|
||||
|
||||
|
||||
def test_github_jwt_authentication_invalid_token(github_oauth_client):
|
||||
"""Test API authentication with invalid JWT"""
|
||||
response = github_oauth_client.get("/test", headers={"Authorization": "Bearer invalid.jwt.token"})
|
||||
assert response.status_code == 401
|
||||
assert "Invalid GitHub JWT token" in response.json()["error"]["message"]
|
||||
|
||||
|
||||
def test_github_jwt_authentication_expired_token(github_oauth_client):
|
||||
"""Test API authentication with expired JWT"""
|
||||
# Create an expired token
|
||||
claims = {
|
||||
"sub": "test-user",
|
||||
"aud": "llama-stack",
|
||||
"iss": "llama-stack-github",
|
||||
"exp": datetime.now(UTC) - timedelta(hours=1), # Expired
|
||||
"iat": datetime.now(UTC) - timedelta(hours=2),
|
||||
"nbf": datetime.now(UTC) - timedelta(hours=2),
|
||||
"github_username": "test-user",
|
||||
}
|
||||
|
||||
expired_token = jwt.encode(claims, "test_jwt_secret", algorithm="HS256")
|
||||
|
||||
response = github_oauth_client.get("/test", headers={"Authorization": f"Bearer {expired_token}"})
|
||||
assert response.status_code == 401
|
||||
assert "Invalid GitHub JWT token" in response.json()["error"]["message"]
|
||||
|
||||
|
||||
def test_github_jwt_authentication_wrong_audience(github_oauth_client):
|
||||
"""Test API authentication with JWT having wrong audience"""
|
||||
claims = {
|
||||
"sub": "test-user",
|
||||
"aud": "wrong-audience", # Wrong audience
|
||||
"iss": "llama-stack-github",
|
||||
"exp": datetime.now(UTC) + timedelta(hours=1),
|
||||
"iat": datetime.now(UTC),
|
||||
"nbf": datetime.now(UTC),
|
||||
"github_username": "test-user",
|
||||
}
|
||||
|
||||
wrong_audience_token = jwt.encode(claims, "test_jwt_secret", algorithm="HS256")
|
||||
|
||||
response = github_oauth_client.get("/test", headers={"Authorization": f"Bearer {wrong_audience_token}"})
|
||||
assert response.status_code == 401
|
||||
assert "Invalid GitHub JWT token" in response.json()["error"]["message"]
|
||||
|
||||
|
||||
def test_github_claims_mapping():
|
||||
"""Test GitHub claims are properly mapped to attributes"""
|
||||
config = GitHubAuthConfig(
|
||||
type="github_oauth",
|
||||
github_client_id="test",
|
||||
github_client_secret="test",
|
||||
jwt_secret="test",
|
||||
)
|
||||
|
||||
claims = {
|
||||
"sub": "test-user",
|
||||
"github_username": "test-user",
|
||||
"github_orgs": ["org1", "org2"],
|
||||
"github_teams": ["team1", "team2"],
|
||||
"github_user_id": "12345",
|
||||
}
|
||||
|
||||
# Default mapping only maps "sub" to "roles"
|
||||
attributes = get_attributes_from_claims(claims, config.claims_mapping)
|
||||
|
||||
assert "test-user" in attributes["roles"]
|
||||
# No other mappings by default
|
||||
assert len(attributes) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_github_auth_provider_validate_token():
|
||||
"""Test GitHubAuthProvider token validation"""
|
||||
config = GitHubAuthConfig(
|
||||
type="github_oauth",
|
||||
github_client_id="test",
|
||||
github_client_secret="test",
|
||||
jwt_secret="test_secret",
|
||||
jwt_algorithm="HS256",
|
||||
jwt_audience="test-audience",
|
||||
jwt_issuer="test-issuer",
|
||||
)
|
||||
|
||||
provider = GitHubAuthProvider(config)
|
||||
|
||||
# Create a valid token
|
||||
claims = {
|
||||
"sub": "test-user",
|
||||
"aud": "test-audience",
|
||||
"iss": "test-issuer",
|
||||
"exp": datetime.now(UTC) + timedelta(hours=1),
|
||||
"iat": datetime.now(UTC),
|
||||
"nbf": datetime.now(UTC),
|
||||
"github_username": "test-user",
|
||||
"github_orgs": ["org1"],
|
||||
"github_user_id": "12345",
|
||||
}
|
||||
|
||||
token = jwt.encode(claims, "test_secret", algorithm="HS256")
|
||||
|
||||
user = await provider.validate_token(token)
|
||||
assert user.principal == "test-user"
|
||||
# Default mapping only maps "sub" to "roles"
|
||||
assert "test-user" in user.attributes["roles"]
|
||||
# No other mappings by default
|
||||
assert len(user.attributes) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_github_auth_provider_custom_claims_mapping():
|
||||
"""Test GitHubAuthProvider with custom claims mapping"""
|
||||
config = GitHubAuthConfig(
|
||||
type="github_oauth",
|
||||
github_client_id="test",
|
||||
github_client_secret="test",
|
||||
jwt_secret="test_secret",
|
||||
jwt_algorithm="HS256",
|
||||
jwt_audience="test-audience",
|
||||
jwt_issuer="test-issuer",
|
||||
claims_mapping={
|
||||
"sub": "roles",
|
||||
"github_orgs": "teams",
|
||||
"github_teams": "teams",
|
||||
"github_user_id": "namespaces",
|
||||
},
|
||||
)
|
||||
|
||||
provider = GitHubAuthProvider(config)
|
||||
|
||||
# Create a valid token
|
||||
claims = {
|
||||
"sub": "test-user",
|
||||
"aud": "test-audience",
|
||||
"iss": "test-issuer",
|
||||
"exp": datetime.now(UTC) + timedelta(hours=1),
|
||||
"iat": datetime.now(UTC),
|
||||
"nbf": datetime.now(UTC),
|
||||
"github_username": "test-user",
|
||||
"github_orgs": ["org1", "org2"],
|
||||
"github_teams": ["team1", "team2"],
|
||||
"github_user_id": "12345",
|
||||
}
|
||||
|
||||
token = jwt.encode(claims, "test_secret", algorithm="HS256")
|
||||
|
||||
user = await provider.validate_token(token)
|
||||
assert user.principal == "test-user"
|
||||
assert "test-user" in user.attributes["roles"]
|
||||
assert set(user.attributes["teams"]) == {"org1", "org2", "team1", "team2"}
|
||||
assert "12345" in user.attributes["namespaces"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_github_auth_provider_invalid_token():
|
||||
"""Test GitHubAuthProvider with invalid token"""
|
||||
config = GitHubAuthConfig(
|
||||
type="github_oauth",
|
||||
github_client_id="test",
|
||||
github_client_secret="test",
|
||||
jwt_secret="test_secret",
|
||||
)
|
||||
|
||||
provider = GitHubAuthProvider(config)
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid GitHub JWT token"):
|
||||
await provider.validate_token("invalid.token.here")
|
||||
|
||||
|
||||
def test_github_auth_provider_authorization_url():
|
||||
"""Test GitHubAuthProvider generates correct authorization URL"""
|
||||
from unittest.mock import Mock
|
||||
|
||||
config = GitHubAuthConfig(
|
||||
type="github_oauth",
|
||||
github_client_id="test_client_id",
|
||||
github_client_secret="test_secret",
|
||||
jwt_secret="test",
|
||||
)
|
||||
|
||||
provider = GitHubAuthProvider(config)
|
||||
|
||||
# Mock request object
|
||||
mock_request = Mock()
|
||||
mock_request.url_for.return_value = "http://localhost:8321/auth/github/callback"
|
||||
|
||||
url = provider.get_authorization_url("test_state_123", mock_request)
|
||||
|
||||
assert "https://github.com/login/oauth/authorize" in url
|
||||
assert "client_id=test_client_id" in url
|
||||
assert "redirect_uri=http%3A%2F%2Flocalhost%3A8321%2Fauth%2Fgithub%2Fcallback" in url
|
||||
assert "state=test_state_123" in url
|
||||
assert "scope=read%3Auser+read%3Aorg" in url
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_github_auth_provider_complete_flow():
|
||||
"""Test complete OAuth flow in GitHubAuthProvider"""
|
||||
from unittest.mock import Mock
|
||||
|
||||
config = GitHubAuthConfig(
|
||||
type="github_oauth",
|
||||
github_client_id="test_client_id",
|
||||
github_client_secret="test_secret",
|
||||
jwt_secret="test_jwt_secret",
|
||||
jwt_algorithm="HS256",
|
||||
jwt_audience="llama-stack",
|
||||
jwt_issuer="llama-stack-github",
|
||||
)
|
||||
|
||||
provider = GitHubAuthProvider(config)
|
||||
|
||||
# Mock request object
|
||||
mock_request = Mock()
|
||||
mock_request.url_for.return_value = "/auth/github/callback"
|
||||
mock_request.url.scheme = "http"
|
||||
mock_request.url.netloc = "localhost:8321"
|
||||
|
||||
with patch("httpx.AsyncClient", MockAsyncClient):
|
||||
# Now returns just the JWT token
|
||||
token = await provider.complete_oauth_flow("test_code", mock_request)
|
||||
|
||||
# Verify it's a JWT token by decoding it
|
||||
claims = jwt.decode(token, "test_jwt_secret", algorithms=["HS256"], audience="llama-stack")
|
||||
assert claims["github_username"] == "test-user"
|
||||
assert claims["github_orgs"] == ["test-org-1", "test-org-2"]
|
||||
assert claims["sub"] == "test-user"
|
Loading…
Add table
Add a link
Reference in a new issue