mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
Merge 94282d3482
into 40fdce79b3
This commit is contained in:
commit
12265d4222
9 changed files with 962 additions and 43 deletions
2
.github/workflows/integration-auth-tests.yml
vendored
2
.github/workflows/integration-auth-tests.yml
vendored
|
@ -73,7 +73,7 @@ jobs:
|
||||||
server:
|
server:
|
||||||
port: 8321
|
port: 8321
|
||||||
EOF
|
EOF
|
||||||
yq eval '.server.auth = {"provider_type": "${{ matrix.auth-provider }}"}' -i $run_dir/run.yaml
|
yq eval '.server.auth = {"type": "${{ matrix.auth-provider }}"}' -i $run_dir/run.yaml
|
||||||
yq eval '.server.auth.config = {"tls_cafile": "${{ env.KUBERNETES_CA_CERT_PATH }}", "issuer": "${{ env.KUBERNETES_ISSUER }}", "audience": "${{ env.KUBERNETES_AUDIENCE }}"}' -i $run_dir/run.yaml
|
yq eval '.server.auth.config = {"tls_cafile": "${{ env.KUBERNETES_CA_CERT_PATH }}", "issuer": "${{ env.KUBERNETES_ISSUER }}", "audience": "${{ env.KUBERNETES_AUDIENCE }}"}' -i $run_dir/run.yaml
|
||||||
yq eval '.server.auth.config.jwks = {"uri": "${{ env.KUBERNETES_API_SERVER_URL }}", "token": "${{ env.TOKEN }}"}' -i $run_dir/run.yaml
|
yq eval '.server.auth.config.jwks = {"uri": "${{ env.KUBERNETES_API_SERVER_URL }}", "token": "${{ env.TOKEN }}"}' -i $run_dir/run.yaml
|
||||||
cat $run_dir/run.yaml
|
cat $run_dir/run.yaml
|
||||||
|
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Annotated, Any
|
from typing import Annotated, Any, Literal
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
|
@ -166,20 +166,90 @@ class AuthProviderType(StrEnum):
|
||||||
|
|
||||||
OAUTH2_TOKEN = "oauth2_token"
|
OAUTH2_TOKEN = "oauth2_token"
|
||||||
CUSTOM = "custom"
|
CUSTOM = "custom"
|
||||||
|
GITHUB_OAUTH = "github_oauth"
|
||||||
|
|
||||||
|
|
||||||
class AuthenticationConfig(BaseModel):
|
class OAuth2TokenAuthConfig(BaseModel):
|
||||||
provider_type: AuthProviderType = Field(
|
"""Configuration for OAuth2 token authentication."""
|
||||||
...,
|
|
||||||
description="Type of authentication provider",
|
type: Literal[AuthProviderType.OAUTH2_TOKEN] = AuthProviderType.OAUTH2_TOKEN
|
||||||
)
|
|
||||||
config: dict[str, Any] = Field(
|
config: dict[str, Any] = Field(
|
||||||
...,
|
...,
|
||||||
description="Provider-specific configuration",
|
description="OAuth2 token validation configuration",
|
||||||
)
|
)
|
||||||
access_policy: list[AccessRule] = Field(default=[], description="Rules for determining access to resources")
|
access_policy: list[AccessRule] = Field(default=[], description="Rules for determining access to resources")
|
||||||
|
|
||||||
|
|
||||||
|
class CustomAuthConfig(BaseModel):
|
||||||
|
"""Configuration for custom authentication."""
|
||||||
|
|
||||||
|
type: Literal[AuthProviderType.CUSTOM] = AuthProviderType.CUSTOM
|
||||||
|
config: dict[str, Any] = Field(
|
||||||
|
...,
|
||||||
|
description="Custom authentication endpoint configuration",
|
||||||
|
)
|
||||||
|
access_policy: list[AccessRule] = Field(default=[], description="Rules for determining access to resources")
|
||||||
|
|
||||||
|
|
||||||
|
class GitHubAuthConfig(BaseModel):
|
||||||
|
"""Configuration for GitHub OAuth authentication.
|
||||||
|
|
||||||
|
This provider implements GitHub OAuth authentication with the following endpoints:
|
||||||
|
|
||||||
|
- /auth/github/login: Initiates OAuth flow
|
||||||
|
- Accepts optional redirect_url parameter (must be in allowed_redirect_urls)
|
||||||
|
- Redirects to GitHub for authorization
|
||||||
|
|
||||||
|
- /auth/github/callback: Handles OAuth callback (should be used as the callback URL for your GitHub OAuth App)
|
||||||
|
- Exchanges authorization code for access token
|
||||||
|
- Creates JWT with user info
|
||||||
|
- With redirect_url: Redirects to specified URL with JWT in fragment (#token=...)
|
||||||
|
- Without redirect_url: Returns JSON response with JWT
|
||||||
|
|
||||||
|
Ensure your GitHub OAuth App's callback URL matches your server's URL + /auth/github/callback.
|
||||||
|
"""
|
||||||
|
|
||||||
|
type: Literal[AuthProviderType.GITHUB_OAUTH] = AuthProviderType.GITHUB_OAUTH
|
||||||
|
|
||||||
|
# GitHub OAuth settings
|
||||||
|
github_client_id: str = Field(description="GitHub OAuth App Client ID")
|
||||||
|
github_client_secret: str = Field(description="GitHub OAuth App Client Secret")
|
||||||
|
allowed_redirect_urls: list[str] = Field(
|
||||||
|
default=[],
|
||||||
|
description="Allowed redirect URLs for OAuth callback, e.g. frontend URL",
|
||||||
|
)
|
||||||
|
|
||||||
|
# JWT configuration for token generation
|
||||||
|
jwt_secret: str = Field(description="Secret for signing JWT tokens")
|
||||||
|
jwt_algorithm: str = Field(default="HS256", description="JWT signing algorithm")
|
||||||
|
jwt_audience: str = Field(default="llama-stack", description="JWT audience")
|
||||||
|
jwt_issuer: str = Field(default="llama-stack-github", description="JWT issuer")
|
||||||
|
token_expiry: int = Field(default=86400, description="JWT token expiry in seconds")
|
||||||
|
|
||||||
|
# OAuth2 token validation config (for validating the JWTs we issue)
|
||||||
|
oauth2_config: dict[str, Any] = Field(
|
||||||
|
default_factory=dict,
|
||||||
|
description="Additional OAuth2 token validation configuration",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Access control
|
||||||
|
access_policy: list[AccessRule] = Field(default=[], description="Rules for determining access to resources")
|
||||||
|
|
||||||
|
# Claims mapping for GitHub attributes
|
||||||
|
claims_mapping: dict[str, str] = Field(
|
||||||
|
default_factory=lambda: {
|
||||||
|
"sub": "roles", # GitHub username as role
|
||||||
|
},
|
||||||
|
description="Mapping from JWT claims to Llama Stack attributes",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Discriminated union for authentication configurations
|
||||||
|
AuthenticationConfig = Annotated[
|
||||||
|
OAuth2TokenAuthConfig | CustomAuthConfig | GitHubAuthConfig, Field(discriminator="type")
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class AuthenticationRequiredError(Exception):
|
class AuthenticationRequiredError(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
|
@ -81,13 +81,23 @@ class AuthenticationMiddleware:
|
||||||
def __init__(self, app, auth_config: AuthenticationConfig):
|
def __init__(self, app, auth_config: AuthenticationConfig):
|
||||||
self.app = app
|
self.app = app
|
||||||
self.auth_provider = create_auth_provider(auth_config)
|
self.auth_provider = create_auth_provider(auth_config)
|
||||||
|
self.public_paths = self.auth_provider.get_public_paths()
|
||||||
|
|
||||||
async def __call__(self, scope, receive, send):
|
async def __call__(self, scope, receive, send):
|
||||||
if scope["type"] == "http":
|
if scope["type"] == "http":
|
||||||
|
# Skip authentication for public paths defined by the provider
|
||||||
|
path = scope.get("path", "")
|
||||||
|
if any(path.startswith(public_path) for public_path in self.public_paths):
|
||||||
|
return await self.app(scope, receive, send)
|
||||||
|
|
||||||
headers = dict(scope.get("headers", []))
|
headers = dict(scope.get("headers", []))
|
||||||
auth_header = headers.get(b"authorization", b"").decode()
|
auth_header = headers.get(b"authorization", b"").decode()
|
||||||
|
|
||||||
if not auth_header or not auth_header.startswith("Bearer "):
|
if not auth_header:
|
||||||
|
error_msg = self.auth_provider.get_auth_error_message(scope)
|
||||||
|
return await self._send_auth_error(send, error_msg)
|
||||||
|
|
||||||
|
if not auth_header.startswith("Bearer "):
|
||||||
return await self._send_auth_error(send, "Missing or invalid Authorization header")
|
return await self._send_auth_error(send, "Missing or invalid Authorization header")
|
||||||
|
|
||||||
token = auth_header.split("Bearer ", 1)[1]
|
token = auth_header.split("Bearer ", 1)[1]
|
||||||
|
|
|
@ -10,13 +10,19 @@ from abc import ABC, abstractmethod
|
||||||
from asyncio import Lock
|
from asyncio import Lock
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Self
|
from typing import Self
|
||||||
from urllib.parse import parse_qs
|
from urllib.parse import parse_qs, urlparse
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from jose import jwt
|
from jose import jwt
|
||||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import AuthenticationConfig, AuthProviderType, User
|
from llama_stack.distribution.datatypes import (
|
||||||
|
AuthenticationConfig,
|
||||||
|
CustomAuthConfig,
|
||||||
|
GitHubAuthConfig,
|
||||||
|
OAuth2TokenAuthConfig,
|
||||||
|
User,
|
||||||
|
)
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="auth")
|
logger = get_logger(name=__name__, category="auth")
|
||||||
|
@ -62,6 +68,24 @@ class AuthProvider(ABC):
|
||||||
"""Clean up any resources."""
|
"""Clean up any resources."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def setup_routes(self, app):
|
||||||
|
"""Setup any provider-specific routes (e.g., OAuth callbacks).
|
||||||
|
|
||||||
|
This is optional - providers that don't need special routes can skip this.
|
||||||
|
"""
|
||||||
|
return
|
||||||
|
|
||||||
|
def get_public_paths(self) -> list[str]:
|
||||||
|
"""Return a list of path prefixes that should bypass authentication.
|
||||||
|
|
||||||
|
This is optional - providers that don't have public paths return empty list.
|
||||||
|
"""
|
||||||
|
return []
|
||||||
|
|
||||||
|
def get_auth_error_message(self, scope: dict | None = None) -> str:
|
||||||
|
"""Return provider-specific authentication error message."""
|
||||||
|
return "Authentication required"
|
||||||
|
|
||||||
|
|
||||||
def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str]) -> dict[str, list[str]]:
|
def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str]) -> dict[str, list[str]]:
|
||||||
attributes: dict[str, list[str]] = {}
|
attributes: dict[str, list[str]] = {}
|
||||||
|
@ -232,6 +256,17 @@ class OAuth2TokenAuthProvider(AuthProvider):
|
||||||
async def close(self):
|
async def close(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def get_auth_error_message(self, scope: dict | None = None) -> str:
|
||||||
|
"""Return OAuth2-specific authentication error message."""
|
||||||
|
if self.config.issuer:
|
||||||
|
return f"Authentication required. Please provide a valid OAuth2 Bearer token from {self.config.issuer}"
|
||||||
|
elif self.config.introspection:
|
||||||
|
# Extract domain from introspection URL for a cleaner message
|
||||||
|
domain = urlparse(self.config.introspection.url).netloc
|
||||||
|
return f"Authentication required. Please provide a valid OAuth2 Bearer token validated by {domain}"
|
||||||
|
else:
|
||||||
|
return "Authentication required. Please provide a valid OAuth2 Bearer token in the Authorization header"
|
||||||
|
|
||||||
async def _refresh_jwks(self) -> None:
|
async def _refresh_jwks(self) -> None:
|
||||||
"""
|
"""
|
||||||
Refresh the JWKS cache.
|
Refresh the JWKS cache.
|
||||||
|
@ -338,15 +373,25 @@ class CustomAuthProvider(AuthProvider):
|
||||||
await self._client.aclose()
|
await self._client.aclose()
|
||||||
self._client = None
|
self._client = None
|
||||||
|
|
||||||
|
def get_auth_error_message(self, scope: dict | None = None) -> str:
|
||||||
|
"""Return custom auth provider-specific authentication error message."""
|
||||||
|
# Extract domain from endpoint URL for a cleaner message
|
||||||
|
domain = urlparse(self.config.endpoint).netloc
|
||||||
|
if domain:
|
||||||
|
return f"Authentication required. Please provide your API key as a Bearer token (validated by {domain})"
|
||||||
|
else:
|
||||||
|
return "Authentication required. Please provide your API key as a Bearer token in the Authorization header"
|
||||||
|
|
||||||
|
|
||||||
def create_auth_provider(config: AuthenticationConfig) -> AuthProvider:
|
def create_auth_provider(config: AuthenticationConfig) -> AuthProvider:
|
||||||
"""Factory function to create the appropriate auth provider."""
|
"""Factory function to create the appropriate auth provider."""
|
||||||
provider_type = config.provider_type.lower()
|
if isinstance(config, CustomAuthConfig):
|
||||||
|
|
||||||
if provider_type == "custom":
|
|
||||||
return CustomAuthProvider(CustomAuthProviderConfig.model_validate(config.config))
|
return CustomAuthProvider(CustomAuthProviderConfig.model_validate(config.config))
|
||||||
elif provider_type == "oauth2_token":
|
elif isinstance(config, OAuth2TokenAuthConfig):
|
||||||
return OAuth2TokenAuthProvider(OAuth2TokenAuthProviderConfig.model_validate(config.config))
|
return OAuth2TokenAuthProvider(OAuth2TokenAuthProviderConfig.model_validate(config.config))
|
||||||
|
elif isinstance(config, GitHubAuthConfig):
|
||||||
|
from .github_oauth_auth_provider import GitHubAuthProvider
|
||||||
|
|
||||||
|
return GitHubAuthProvider(config)
|
||||||
else:
|
else:
|
||||||
supported_providers = ", ".join([t.value for t in AuthProviderType])
|
raise ValueError(f"Unknown authentication config type: {type(config)}")
|
||||||
raise ValueError(f"Unsupported auth provider type: {provider_type}. Supported types are: {supported_providers}")
|
|
||||||
|
|
109
llama_stack/distribution/server/auth_routes.py
Normal file
109
llama_stack/distribution/server/auth_routes.py
Normal file
|
@ -0,0 +1,109 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import secrets
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
|
||||||
|
from fastapi import APIRouter, HTTPException, Request
|
||||||
|
from fastapi.responses import JSONResponse, RedirectResponse
|
||||||
|
|
||||||
|
from llama_stack.distribution.datatypes import GitHubAuthConfig
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
|
from .github_oauth_auth_provider import (
|
||||||
|
GITHUB_CALLBACK_PATH,
|
||||||
|
GITHUB_LOGIN_PATH,
|
||||||
|
GitHubAuthProvider,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = get_logger(name=__name__, category="auth_routes")
|
||||||
|
|
||||||
|
|
||||||
|
def create_github_auth_router(config: GitHubAuthConfig) -> APIRouter:
|
||||||
|
"""Create and configure GitHub authentication router."""
|
||||||
|
auth_provider = GitHubAuthProvider(config)
|
||||||
|
router = APIRouter()
|
||||||
|
oauth_states: dict[str, dict] = {}
|
||||||
|
|
||||||
|
def cleanup_expired_states() -> None:
|
||||||
|
"""Remove expired OAuth states."""
|
||||||
|
now = datetime.now(UTC)
|
||||||
|
expired_states = [
|
||||||
|
state
|
||||||
|
for state, data in oauth_states.items()
|
||||||
|
if (now - data["created_at"]).seconds > 300 # 5 minutes
|
||||||
|
]
|
||||||
|
for state in expired_states:
|
||||||
|
oauth_states.pop(state, None)
|
||||||
|
|
||||||
|
@router.get(GITHUB_LOGIN_PATH)
|
||||||
|
async def github_login(request: Request, redirect_url: str | None = None):
|
||||||
|
"""Initiate GitHub OAuth flow."""
|
||||||
|
cleanup_expired_states()
|
||||||
|
|
||||||
|
# Validate redirect URL if provided
|
||||||
|
if redirect_url:
|
||||||
|
if not auth_provider.config.allowed_redirect_urls:
|
||||||
|
raise HTTPException(status_code=400, detail="Redirect URLs not configured")
|
||||||
|
if redirect_url not in auth_provider.config.allowed_redirect_urls:
|
||||||
|
raise HTTPException(status_code=400, detail="Invalid redirect URL")
|
||||||
|
|
||||||
|
state = secrets.token_urlsafe(32)
|
||||||
|
|
||||||
|
oauth_states[state] = {
|
||||||
|
"created_at": datetime.now(UTC),
|
||||||
|
"redirect_url": redirect_url,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Get authorization URL
|
||||||
|
auth_url = auth_provider.get_authorization_url(state, request)
|
||||||
|
|
||||||
|
logger.debug(f"Redirecting to GitHub OAuth: {auth_url}")
|
||||||
|
return RedirectResponse(url=auth_url)
|
||||||
|
|
||||||
|
@router.get(GITHUB_CALLBACK_PATH)
|
||||||
|
async def github_callback(code: str, state: str, request: Request):
|
||||||
|
"""Handle GitHub OAuth callback."""
|
||||||
|
# Validate state parameter
|
||||||
|
if state not in oauth_states:
|
||||||
|
logger.warning(f"Invalid OAuth state received: {state}")
|
||||||
|
raise HTTPException(status_code=400, detail="Invalid state parameter")
|
||||||
|
|
||||||
|
state_data = oauth_states.pop(state)
|
||||||
|
|
||||||
|
# Check state expiry
|
||||||
|
if (datetime.now(UTC) - state_data["created_at"]).seconds > 300:
|
||||||
|
logger.warning("OAuth state expired")
|
||||||
|
raise HTTPException(status_code=400, detail="State expired")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Complete OAuth flow
|
||||||
|
access_token = await auth_provider.complete_oauth_flow(code, request)
|
||||||
|
|
||||||
|
logger.info("GitHub OAuth successful")
|
||||||
|
|
||||||
|
redirect_url = state_data.get("redirect_url")
|
||||||
|
if redirect_url:
|
||||||
|
# Validate redirect URL from state
|
||||||
|
if redirect_url not in auth_provider.config.allowed_redirect_urls:
|
||||||
|
raise HTTPException(status_code=400, detail="Invalid redirect URL in state")
|
||||||
|
|
||||||
|
import urllib.parse
|
||||||
|
|
||||||
|
params = urllib.parse.urlencode({"token": access_token})
|
||||||
|
return RedirectResponse(url=f"{redirect_url}#{params}")
|
||||||
|
else:
|
||||||
|
# For API calls, only return the access token
|
||||||
|
return JSONResponse(content={"access_token": access_token})
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
logger.error(f"GitHub OAuth error: {e}")
|
||||||
|
raise HTTPException(status_code=400, detail=str(e)) from e
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Unexpected error during GitHub OAuth")
|
||||||
|
raise HTTPException(status_code=500, detail="Authentication failed") from e
|
||||||
|
|
||||||
|
return router
|
203
llama_stack/distribution/server/github_oauth_auth_provider.py
Normal file
203
llama_stack/distribution/server/github_oauth_auth_provider.py
Normal file
|
@ -0,0 +1,203 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from datetime import UTC, datetime, timedelta
|
||||||
|
from typing import Any
|
||||||
|
from urllib.parse import urlencode
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from fastapi import Request
|
||||||
|
from jose import jwt
|
||||||
|
|
||||||
|
from llama_stack.distribution.datatypes import GitHubAuthConfig, User
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
|
from .auth_providers import AuthProvider, get_attributes_from_claims
|
||||||
|
|
||||||
|
logger = get_logger(name=__name__, category="github_auth")
|
||||||
|
|
||||||
|
# GitHub API constants
|
||||||
|
GITHUB_API_BASE_URL = "https://api.github.com"
|
||||||
|
GITHUB_OAUTH_BASE_URL = "https://github.com"
|
||||||
|
|
||||||
|
# GitHub OAuth route paths
|
||||||
|
GITHUB_LOGIN_PATH = "/auth/github/login"
|
||||||
|
GITHUB_CALLBACK_PATH = "/auth/github/callback"
|
||||||
|
|
||||||
|
|
||||||
|
class GitHubAuthProvider(AuthProvider):
|
||||||
|
"""Authentication provider for GitHub OAuth flow and JWT validation."""
|
||||||
|
|
||||||
|
def __init__(self, config: GitHubAuthConfig):
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
async def validate_token(self, token: str, scope: dict | None = None) -> User:
|
||||||
|
"""Validate a GitHub-issued JWT token."""
|
||||||
|
try:
|
||||||
|
claims = self.verify_jwt(token)
|
||||||
|
|
||||||
|
principal = claims["sub"]
|
||||||
|
attributes = get_attributes_from_claims(claims, self.config.claims_mapping)
|
||||||
|
|
||||||
|
return User(principal=principal, attributes=attributes)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Error validating GitHub JWT")
|
||||||
|
raise ValueError(f"Invalid GitHub JWT token: {str(e)}") from e
|
||||||
|
|
||||||
|
async def close(self):
|
||||||
|
"""Clean up any resources."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def setup_routes(self, app):
|
||||||
|
"""Setup GitHub OAuth routes."""
|
||||||
|
from .auth_routes import create_github_auth_router
|
||||||
|
|
||||||
|
github_router = create_github_auth_router(self.config)
|
||||||
|
app.include_router(github_router)
|
||||||
|
|
||||||
|
def get_public_paths(self) -> list[str]:
|
||||||
|
"""GitHub OAuth paths that don't require authentication."""
|
||||||
|
return ["/auth/github/"]
|
||||||
|
|
||||||
|
def get_auth_error_message(self, scope: dict | None = None) -> str:
|
||||||
|
"""Return GitHub-specific authentication error message."""
|
||||||
|
if scope:
|
||||||
|
headers = dict(scope.get("headers", []))
|
||||||
|
host = headers.get(b"host", b"").decode()
|
||||||
|
scheme = scope.get("scheme", "http")
|
||||||
|
|
||||||
|
if host:
|
||||||
|
auth_url = f"{scheme}://{host}{GITHUB_LOGIN_PATH}"
|
||||||
|
return f"Authentication required. Please authenticate via GitHub at {auth_url}"
|
||||||
|
|
||||||
|
return f"Authentication required. Please authenticate by visiting {GITHUB_LOGIN_PATH} to start the authentication flow."
|
||||||
|
|
||||||
|
# OAuth flow methods
|
||||||
|
def get_authorization_url(self, state: str, request: Request) -> str:
|
||||||
|
"""Generate GitHub OAuth authorization URL."""
|
||||||
|
params = {
|
||||||
|
"client_id": self.config.github_client_id,
|
||||||
|
"redirect_uri": _build_callback_url(request),
|
||||||
|
"scope": "read:user read:org",
|
||||||
|
"state": state,
|
||||||
|
}
|
||||||
|
return f"{GITHUB_OAUTH_BASE_URL}/login/oauth/authorize?" + urlencode(params)
|
||||||
|
|
||||||
|
async def complete_oauth_flow(self, code: str, request: Request) -> Any:
|
||||||
|
"""Complete the GitHub OAuth flow and return JWT access token."""
|
||||||
|
# Exchange code for token
|
||||||
|
logger.debug("Exchanging code for GitHub access token")
|
||||||
|
token_data = await self._exchange_code_for_token(code, request)
|
||||||
|
|
||||||
|
if "error" in token_data:
|
||||||
|
raise ValueError(f"GitHub OAuth error: {token_data.get('error_description', token_data['error'])}")
|
||||||
|
|
||||||
|
access_token = token_data["access_token"]
|
||||||
|
|
||||||
|
# Get user info
|
||||||
|
logger.debug("Fetching GitHub user info")
|
||||||
|
github_info = await self._get_user_info(access_token)
|
||||||
|
|
||||||
|
# Create JWT
|
||||||
|
logger.debug(f"Creating JWT for user: {github_info['user']['login']}")
|
||||||
|
jwt_token = self._create_jwt_token(github_info)
|
||||||
|
|
||||||
|
return jwt_token
|
||||||
|
|
||||||
|
def verify_jwt(self, token: str) -> Any:
|
||||||
|
"""Verify and decode a GitHub-issued JWT token."""
|
||||||
|
try:
|
||||||
|
payload = jwt.decode(
|
||||||
|
token,
|
||||||
|
self.config.jwt_secret,
|
||||||
|
algorithms=[self.config.jwt_algorithm],
|
||||||
|
audience=self.config.jwt_audience,
|
||||||
|
issuer=self.config.jwt_issuer,
|
||||||
|
)
|
||||||
|
return payload
|
||||||
|
except jwt.JWTError as e:
|
||||||
|
raise ValueError(f"Invalid JWT token: {e}") from e
|
||||||
|
|
||||||
|
# Private helper methods
|
||||||
|
async def _exchange_code_for_token(self, code: str, request: Request) -> Any:
|
||||||
|
"""Exchange authorization code for GitHub access token."""
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.post(
|
||||||
|
f"{GITHUB_OAUTH_BASE_URL}/login/oauth/access_token",
|
||||||
|
json={
|
||||||
|
"client_id": self.config.github_client_id,
|
||||||
|
"client_secret": self.config.github_client_secret,
|
||||||
|
"code": code,
|
||||||
|
"redirect_uri": _build_callback_url(request),
|
||||||
|
},
|
||||||
|
headers={"Accept": "application/json"},
|
||||||
|
timeout=10.0,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
async def _get_user_info(self, access_token: str) -> dict:
|
||||||
|
"""Fetch user info and organizations from GitHub."""
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {access_token}",
|
||||||
|
"Accept": "application/vnd.github.v3+json",
|
||||||
|
}
|
||||||
|
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
# Fetch user and orgs in parallel
|
||||||
|
user_task = client.get(f"{GITHUB_API_BASE_URL}/user", headers=headers, timeout=10.0)
|
||||||
|
orgs_task = client.get(f"{GITHUB_API_BASE_URL}/user/orgs", headers=headers, timeout=10.0)
|
||||||
|
|
||||||
|
user_response, orgs_response = await asyncio.gather(user_task, orgs_task)
|
||||||
|
|
||||||
|
user_response.raise_for_status()
|
||||||
|
orgs_response.raise_for_status()
|
||||||
|
|
||||||
|
user_data = user_response.json()
|
||||||
|
orgs_data = orgs_response.json()
|
||||||
|
|
||||||
|
# Extract organization names
|
||||||
|
organizations = [org["login"] for org in orgs_data]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"user": user_data,
|
||||||
|
"organizations": organizations,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _create_jwt_token(self, github_info: dict) -> Any:
|
||||||
|
"""Create JWT token with GitHub user information."""
|
||||||
|
user = github_info["user"]
|
||||||
|
orgs = github_info["organizations"]
|
||||||
|
teams = github_info.get("teams", [])
|
||||||
|
|
||||||
|
# Create JWT claims that map to Llama Stack attributes
|
||||||
|
now = datetime.now(UTC)
|
||||||
|
claims = {
|
||||||
|
"sub": user["login"],
|
||||||
|
"aud": self.config.jwt_audience,
|
||||||
|
"iss": self.config.jwt_issuer,
|
||||||
|
"exp": now + timedelta(seconds=self.config.token_expiry),
|
||||||
|
"iat": now,
|
||||||
|
"nbf": now,
|
||||||
|
# Custom claims that will be mapped by claims_mapping
|
||||||
|
"github_username": user["login"],
|
||||||
|
"github_user_id": str(user["id"]),
|
||||||
|
"github_orgs": orgs,
|
||||||
|
"github_teams": teams,
|
||||||
|
"email": user.get("email"),
|
||||||
|
"name": user.get("name"),
|
||||||
|
"avatar_url": user.get("avatar_url"),
|
||||||
|
}
|
||||||
|
|
||||||
|
return jwt.encode(claims, self.config.jwt_secret, algorithm=self.config.jwt_algorithm)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_callback_url(request: Request) -> str:
|
||||||
|
"""Build the GitHub OAuth callback URL from the current request."""
|
||||||
|
callback_path = str(request.url_for("github_callback"))
|
||||||
|
return callback_path
|
|
@ -31,7 +31,11 @@ from openai import BadRequestError
|
||||||
from pydantic import BaseModel, ValidationError
|
from pydantic import BaseModel, ValidationError
|
||||||
|
|
||||||
from llama_stack.apis.common.responses import PaginatedResponse
|
from llama_stack.apis.common.responses import PaginatedResponse
|
||||||
from llama_stack.distribution.datatypes import AuthenticationRequiredError, LoggingConfig, StackRunConfig
|
from llama_stack.distribution.datatypes import (
|
||||||
|
AuthenticationRequiredError,
|
||||||
|
LoggingConfig,
|
||||||
|
StackRunConfig,
|
||||||
|
)
|
||||||
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
||||||
from llama_stack.distribution.request_headers import PROVIDER_DATA_VAR, User, request_provider_data_context
|
from llama_stack.distribution.request_headers import PROVIDER_DATA_VAR, User, request_provider_data_context
|
||||||
from llama_stack.distribution.resolver import InvalidProviderError
|
from llama_stack.distribution.resolver import InvalidProviderError
|
||||||
|
@ -61,6 +65,7 @@ from llama_stack.providers.utils.telemetry.tracing import (
|
||||||
)
|
)
|
||||||
|
|
||||||
from .auth import AuthenticationMiddleware
|
from .auth import AuthenticationMiddleware
|
||||||
|
from .auth_providers import create_auth_provider
|
||||||
from .quota import QuotaMiddleware
|
from .quota import QuotaMiddleware
|
||||||
|
|
||||||
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
||||||
|
@ -448,9 +453,12 @@ def main(args: argparse.Namespace | None = None):
|
||||||
if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"):
|
if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"):
|
||||||
app.add_middleware(ClientVersionMiddleware)
|
app.add_middleware(ClientVersionMiddleware)
|
||||||
|
|
||||||
# Add authentication middleware if configured
|
|
||||||
if config.server.auth:
|
if config.server.auth:
|
||||||
logger.info(f"Enabling authentication with provider: {config.server.auth.provider_type.value}")
|
logger.info(f"Enabling authentication with provider: {config.server.auth.type.value}")
|
||||||
|
|
||||||
|
auth_provider = create_auth_provider(config.server.auth)
|
||||||
|
auth_provider.setup_routes(app)
|
||||||
|
|
||||||
app.add_middleware(AuthenticationMiddleware, auth_config=config.server.auth)
|
app.add_middleware(AuthenticationMiddleware, auth_config=config.server.auth)
|
||||||
else:
|
else:
|
||||||
if config.server.quota:
|
if config.server.quota:
|
||||||
|
|
|
@ -11,10 +11,13 @@ import pytest
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import AuthenticationConfig
|
from llama_stack.distribution.datatypes import (
|
||||||
|
AuthProviderType,
|
||||||
|
CustomAuthConfig,
|
||||||
|
OAuth2TokenAuthConfig,
|
||||||
|
)
|
||||||
from llama_stack.distribution.server.auth import AuthenticationMiddleware
|
from llama_stack.distribution.server.auth import AuthenticationMiddleware
|
||||||
from llama_stack.distribution.server.auth_providers import (
|
from llama_stack.distribution.server.auth_providers import (
|
||||||
AuthProviderType,
|
|
||||||
get_attributes_from_claims,
|
get_attributes_from_claims,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -60,8 +63,8 @@ def invalid_token():
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def http_app(mock_auth_endpoint):
|
def http_app(mock_auth_endpoint):
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
auth_config = AuthenticationConfig(
|
auth_config = CustomAuthConfig(
|
||||||
provider_type=AuthProviderType.CUSTOM,
|
type=AuthProviderType.CUSTOM,
|
||||||
config={"endpoint": mock_auth_endpoint},
|
config={"endpoint": mock_auth_endpoint},
|
||||||
)
|
)
|
||||||
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
|
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
|
||||||
|
@ -76,9 +79,9 @@ def http_app(mock_auth_endpoint):
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def k8s_app():
|
def k8s_app():
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
auth_config = AuthenticationConfig(
|
auth_config = CustomAuthConfig(
|
||||||
provider_type=AuthProviderType.KUBERNETES,
|
type=AuthProviderType.CUSTOM,
|
||||||
config={"api_server_url": "https://kubernetes.default.svc"},
|
config={"provider_type": "kubernetes", "api_server_url": "https://kubernetes.default.svc"},
|
||||||
)
|
)
|
||||||
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
|
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
|
||||||
|
|
||||||
|
@ -116,8 +119,8 @@ def mock_scope():
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_http_middleware(mock_auth_endpoint):
|
def mock_http_middleware(mock_auth_endpoint):
|
||||||
mock_app = AsyncMock()
|
mock_app = AsyncMock()
|
||||||
auth_config = AuthenticationConfig(
|
auth_config = CustomAuthConfig(
|
||||||
provider_type=AuthProviderType.CUSTOM,
|
type=AuthProviderType.CUSTOM,
|
||||||
config={"endpoint": mock_auth_endpoint},
|
config={"endpoint": mock_auth_endpoint},
|
||||||
)
|
)
|
||||||
return AuthenticationMiddleware(mock_app, auth_config), mock_app
|
return AuthenticationMiddleware(mock_app, auth_config), mock_app
|
||||||
|
@ -126,9 +129,9 @@ def mock_http_middleware(mock_auth_endpoint):
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_k8s_middleware():
|
def mock_k8s_middleware():
|
||||||
mock_app = AsyncMock()
|
mock_app = AsyncMock()
|
||||||
auth_config = AuthenticationConfig(
|
auth_config = CustomAuthConfig(
|
||||||
provider_type=AuthProviderType.KUBERNETES,
|
type=AuthProviderType.CUSTOM,
|
||||||
config={"api_server_url": "https://kubernetes.default.svc"},
|
config={"provider_type": "kubernetes", "api_server_url": "https://kubernetes.default.svc"},
|
||||||
)
|
)
|
||||||
return AuthenticationMiddleware(mock_app, auth_config), mock_app
|
return AuthenticationMiddleware(mock_app, auth_config), mock_app
|
||||||
|
|
||||||
|
@ -161,7 +164,8 @@ async def mock_post_exception(*args, **kwargs):
|
||||||
def test_missing_auth_header(http_client):
|
def test_missing_auth_header(http_client):
|
||||||
response = http_client.get("/test")
|
response = http_client.get("/test")
|
||||||
assert response.status_code == 401
|
assert response.status_code == 401
|
||||||
assert "Missing or invalid Authorization header" in response.json()["error"]["message"]
|
assert "Authentication required" in response.json()["error"]["message"]
|
||||||
|
assert "validated by mock-auth-service" in response.json()["error"]["message"]
|
||||||
|
|
||||||
|
|
||||||
def test_invalid_auth_header_format(http_client):
|
def test_invalid_auth_header_format(http_client):
|
||||||
|
@ -261,8 +265,8 @@ async def test_http_middleware_with_access_attributes(mock_http_middleware, mock
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def oauth2_app():
|
def oauth2_app():
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
auth_config = AuthenticationConfig(
|
auth_config = OAuth2TokenAuthConfig(
|
||||||
provider_type=AuthProviderType.OAUTH2_TOKEN,
|
type=AuthProviderType.OAUTH2_TOKEN,
|
||||||
config={
|
config={
|
||||||
"jwks": {
|
"jwks": {
|
||||||
"uri": "http://mock-authz-service/token/introspect",
|
"uri": "http://mock-authz-service/token/introspect",
|
||||||
|
@ -288,7 +292,8 @@ def oauth2_client(oauth2_app):
|
||||||
def test_missing_auth_header_oauth2(oauth2_client):
|
def test_missing_auth_header_oauth2(oauth2_client):
|
||||||
response = oauth2_client.get("/test")
|
response = oauth2_client.get("/test")
|
||||||
assert response.status_code == 401
|
assert response.status_code == 401
|
||||||
assert "Missing or invalid Authorization header" in response.json()["error"]["message"]
|
assert "Authentication required" in response.json()["error"]["message"]
|
||||||
|
assert "OAuth2 Bearer token" in response.json()["error"]["message"]
|
||||||
|
|
||||||
|
|
||||||
def test_invalid_auth_header_format_oauth2(oauth2_client):
|
def test_invalid_auth_header_format_oauth2(oauth2_client):
|
||||||
|
@ -357,8 +362,8 @@ async def mock_auth_jwks_response(*args, **kwargs):
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def oauth2_app_with_jwks_token():
|
def oauth2_app_with_jwks_token():
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
auth_config = AuthenticationConfig(
|
auth_config = OAuth2TokenAuthConfig(
|
||||||
provider_type=AuthProviderType.OAUTH2_TOKEN,
|
type=AuthProviderType.OAUTH2_TOKEN,
|
||||||
config={
|
config={
|
||||||
"jwks": {
|
"jwks": {
|
||||||
"uri": "http://mock-authz-service/token/introspect",
|
"uri": "http://mock-authz-service/token/introspect",
|
||||||
|
@ -448,8 +453,8 @@ def mock_introspection_endpoint():
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def introspection_app(mock_introspection_endpoint):
|
def introspection_app(mock_introspection_endpoint):
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
auth_config = AuthenticationConfig(
|
auth_config = OAuth2TokenAuthConfig(
|
||||||
provider_type=AuthProviderType.OAUTH2_TOKEN,
|
type=AuthProviderType.OAUTH2_TOKEN,
|
||||||
config={
|
config={
|
||||||
"jwks": None,
|
"jwks": None,
|
||||||
"introspection": {"url": mock_introspection_endpoint, "client_id": "myclient", "client_secret": "abcdefg"},
|
"introspection": {"url": mock_introspection_endpoint, "client_id": "myclient", "client_secret": "abcdefg"},
|
||||||
|
@ -467,8 +472,8 @@ def introspection_app(mock_introspection_endpoint):
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def introspection_app_with_custom_mapping(mock_introspection_endpoint):
|
def introspection_app_with_custom_mapping(mock_introspection_endpoint):
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
auth_config = AuthenticationConfig(
|
auth_config = OAuth2TokenAuthConfig(
|
||||||
provider_type=AuthProviderType.OAUTH2_TOKEN,
|
type=AuthProviderType.OAUTH2_TOKEN,
|
||||||
config={
|
config={
|
||||||
"jwks": None,
|
"jwks": None,
|
||||||
"introspection": {
|
"introspection": {
|
||||||
|
@ -507,7 +512,8 @@ def introspection_client_with_custom_mapping(introspection_app_with_custom_mappi
|
||||||
def test_missing_auth_header_introspection(introspection_client):
|
def test_missing_auth_header_introspection(introspection_client):
|
||||||
response = introspection_client.get("/test")
|
response = introspection_client.get("/test")
|
||||||
assert response.status_code == 401
|
assert response.status_code == 401
|
||||||
assert "Missing or invalid Authorization header" in response.json()["error"]["message"]
|
assert "Authentication required" in response.json()["error"]["message"]
|
||||||
|
assert "OAuth2 Bearer token" in response.json()["error"]["message"]
|
||||||
|
|
||||||
|
|
||||||
def test_invalid_auth_header_format_introspection(introspection_client):
|
def test_invalid_auth_header_format_introspection(introspection_client):
|
||||||
|
|
468
tests/unit/server/test_auth_github.py
Normal file
468
tests/unit/server/test_auth_github.py
Normal file
|
@ -0,0 +1,468 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from datetime import UTC, datetime, timedelta
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
from jose import jwt
|
||||||
|
|
||||||
|
from llama_stack.distribution.datatypes import GitHubAuthConfig
|
||||||
|
from llama_stack.distribution.server.auth import AuthenticationMiddleware
|
||||||
|
from llama_stack.distribution.server.auth_providers import get_attributes_from_claims
|
||||||
|
from llama_stack.distribution.server.auth_routes import create_github_auth_router
|
||||||
|
from llama_stack.distribution.server.github_oauth_auth_provider import GitHubAuthProvider
|
||||||
|
|
||||||
|
|
||||||
|
class MockResponse:
|
||||||
|
def __init__(self, status_code, json_data):
|
||||||
|
self.status_code = status_code
|
||||||
|
self._json_data = json_data
|
||||||
|
|
||||||
|
def json(self):
|
||||||
|
return self._json_data
|
||||||
|
|
||||||
|
def raise_for_status(self):
|
||||||
|
if self.status_code != 200:
|
||||||
|
raise Exception(f"HTTP error: {self.status_code}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def github_oauth_app():
|
||||||
|
app = FastAPI()
|
||||||
|
auth_config = {
|
||||||
|
"type": "github_oauth",
|
||||||
|
"github_client_id": "test_client_id",
|
||||||
|
"github_client_secret": "test_client_secret",
|
||||||
|
"jwt_secret": "test_jwt_secret",
|
||||||
|
"jwt_algorithm": "HS256",
|
||||||
|
"jwt_audience": "llama-stack",
|
||||||
|
"jwt_issuer": "llama-stack-github",
|
||||||
|
"token_expiry": 86400,
|
||||||
|
}
|
||||||
|
|
||||||
|
github_config = GitHubAuthConfig(**auth_config)
|
||||||
|
|
||||||
|
# Add auth routes BEFORE middleware so they're not protected
|
||||||
|
auth_router = create_github_auth_router(github_config)
|
||||||
|
app.include_router(auth_router)
|
||||||
|
|
||||||
|
# Then add auth middleware for other routes
|
||||||
|
app.add_middleware(AuthenticationMiddleware, auth_config=github_config)
|
||||||
|
|
||||||
|
@app.get("/test")
|
||||||
|
def test_endpoint():
|
||||||
|
return {"message": "Authentication successful"}
|
||||||
|
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def github_oauth_client(github_oauth_app):
|
||||||
|
return TestClient(github_oauth_app)
|
||||||
|
|
||||||
|
|
||||||
|
def test_github_login_redirect(github_oauth_client):
|
||||||
|
"""Test that GitHub login endpoint returns redirect to GitHub"""
|
||||||
|
response = github_oauth_client.get("/auth/github/login", follow_redirects=False)
|
||||||
|
assert response.status_code == 307 # Temporary redirect
|
||||||
|
assert "github.com/login/oauth/authorize" in response.headers["location"]
|
||||||
|
assert "client_id=test_client_id" in response.headers["location"]
|
||||||
|
assert "state=" in response.headers["location"]
|
||||||
|
|
||||||
|
|
||||||
|
async def mock_github_token_exchange_success(*args, **kwargs):
|
||||||
|
"""Mock successful GitHub token exchange"""
|
||||||
|
return MockResponse(
|
||||||
|
200,
|
||||||
|
{
|
||||||
|
"access_token": "github_access_token_123",
|
||||||
|
"token_type": "bearer",
|
||||||
|
"scope": "read:user,read:org",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MockAsyncClient:
|
||||||
|
"""Mock httpx.AsyncClient for testing"""
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, *args):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def post(self, url, **kwargs):
|
||||||
|
if "login/oauth/access_token" in url:
|
||||||
|
return await mock_github_token_exchange_success()
|
||||||
|
return MockResponse(404, {})
|
||||||
|
|
||||||
|
async def get(self, url, **kwargs):
|
||||||
|
if url.endswith("/user"):
|
||||||
|
return MockResponse(
|
||||||
|
200,
|
||||||
|
{
|
||||||
|
"login": "test-user",
|
||||||
|
"id": 12345,
|
||||||
|
"email": "test@example.com",
|
||||||
|
"name": "Test User",
|
||||||
|
"avatar_url": "https://avatars.githubusercontent.com/u/12345",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
elif url.endswith("/user/orgs"):
|
||||||
|
return MockResponse(
|
||||||
|
200,
|
||||||
|
[
|
||||||
|
{"login": "test-org-1"},
|
||||||
|
{"login": "test-org-2"},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
return MockResponse(404, {})
|
||||||
|
|
||||||
|
|
||||||
|
@patch("httpx.AsyncClient", MockAsyncClient)
|
||||||
|
def test_github_callback_success(github_oauth_app):
|
||||||
|
"""Test successful GitHub OAuth callback"""
|
||||||
|
# Get a fresh client for this test
|
||||||
|
client = TestClient(github_oauth_app)
|
||||||
|
|
||||||
|
# First, make a login request to generate a state
|
||||||
|
login_response = client.get("/auth/github/login", follow_redirects=False)
|
||||||
|
assert login_response.status_code == 307
|
||||||
|
|
||||||
|
# Extract the state from the redirect URL
|
||||||
|
location = login_response.headers["location"]
|
||||||
|
import re
|
||||||
|
|
||||||
|
state_match = re.search(r"state=([^&]+)", location)
|
||||||
|
assert state_match
|
||||||
|
state = state_match.group(1)
|
||||||
|
|
||||||
|
# Now use that state in the callback
|
||||||
|
response = client.get(f"/auth/github/callback?code=test_code&state={state}")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert "access_token" in data
|
||||||
|
# Verify the JWT token contains the expected claims
|
||||||
|
token = data["access_token"]
|
||||||
|
# Decode without verification since this is a test
|
||||||
|
claims = jwt.decode(token, "test_jwt_secret", algorithms=["HS256"], audience="llama-stack")
|
||||||
|
assert claims["github_username"] == "test-user"
|
||||||
|
assert claims["email"] == "test@example.com"
|
||||||
|
|
||||||
|
|
||||||
|
def test_github_callback_invalid_state(github_oauth_client):
|
||||||
|
"""Test GitHub callback with invalid state"""
|
||||||
|
response = github_oauth_client.get("/auth/github/callback?code=test_code&state=invalid_state")
|
||||||
|
assert response.status_code == 400
|
||||||
|
assert "Invalid state parameter" in response.json()["detail"]
|
||||||
|
|
||||||
|
|
||||||
|
async def mock_github_token_exchange_error(*args, **kwargs):
|
||||||
|
"""Mock GitHub token exchange error"""
|
||||||
|
return MockResponse(
|
||||||
|
200, {"error": "bad_verification_code", "error_description": "The code passed is incorrect or expired."}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MockAsyncClientError:
|
||||||
|
"""Mock httpx.AsyncClient that returns error for token exchange"""
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, *args):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def post(self, url, **kwargs):
|
||||||
|
if "login/oauth/access_token" in url:
|
||||||
|
return await mock_github_token_exchange_error()
|
||||||
|
return MockResponse(404, {})
|
||||||
|
|
||||||
|
async def get(self, url, **kwargs):
|
||||||
|
return MockResponse(404, {})
|
||||||
|
|
||||||
|
|
||||||
|
@patch("httpx.AsyncClient", MockAsyncClientError)
|
||||||
|
def test_github_callback_token_exchange_error(github_oauth_app):
|
||||||
|
"""Test GitHub callback with token exchange error"""
|
||||||
|
client = TestClient(github_oauth_app)
|
||||||
|
|
||||||
|
# First, make a login request to generate a state
|
||||||
|
login_response = client.get("/auth/github/login", follow_redirects=False)
|
||||||
|
assert login_response.status_code == 307
|
||||||
|
|
||||||
|
# Extract the state from the redirect URL
|
||||||
|
location = login_response.headers["location"]
|
||||||
|
import re
|
||||||
|
|
||||||
|
state_match = re.search(r"state=([^&]+)", location)
|
||||||
|
assert state_match
|
||||||
|
state = state_match.group(1)
|
||||||
|
|
||||||
|
# Now use that state in the callback with bad code
|
||||||
|
response = client.get(f"/auth/github/callback?code=bad_code&state={state}")
|
||||||
|
assert response.status_code == 400
|
||||||
|
assert "The code passed is incorrect or expired" in response.json()["detail"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def github_jwt_token():
|
||||||
|
"""Create a valid GitHub JWT token for testing"""
|
||||||
|
claims = {
|
||||||
|
"sub": "test-user",
|
||||||
|
"aud": "llama-stack",
|
||||||
|
"iss": "llama-stack-github",
|
||||||
|
"exp": datetime.now(UTC) + timedelta(hours=1),
|
||||||
|
"iat": datetime.now(UTC),
|
||||||
|
"nbf": datetime.now(UTC),
|
||||||
|
"github_username": "test-user",
|
||||||
|
"github_user_id": "12345",
|
||||||
|
"github_orgs": ["test-org-1", "test-org-2"],
|
||||||
|
"email": "test@example.com",
|
||||||
|
"name": "Test User",
|
||||||
|
}
|
||||||
|
|
||||||
|
return jwt.encode(claims, "test_jwt_secret", algorithm="HS256")
|
||||||
|
|
||||||
|
|
||||||
|
def test_github_jwt_authentication_success(github_oauth_client, github_jwt_token):
|
||||||
|
"""Test API authentication with GitHub-issued JWT"""
|
||||||
|
response = github_oauth_client.get("/test", headers={"Authorization": f"Bearer {github_jwt_token}"})
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == {"message": "Authentication successful"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_github_jwt_authentication_invalid_token(github_oauth_client):
|
||||||
|
"""Test API authentication with invalid JWT"""
|
||||||
|
response = github_oauth_client.get("/test", headers={"Authorization": "Bearer invalid.jwt.token"})
|
||||||
|
assert response.status_code == 401
|
||||||
|
assert "Invalid GitHub JWT token" in response.json()["error"]["message"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_github_jwt_authentication_expired_token(github_oauth_client):
|
||||||
|
"""Test API authentication with expired JWT"""
|
||||||
|
# Create an expired token
|
||||||
|
claims = {
|
||||||
|
"sub": "test-user",
|
||||||
|
"aud": "llama-stack",
|
||||||
|
"iss": "llama-stack-github",
|
||||||
|
"exp": datetime.now(UTC) - timedelta(hours=1), # Expired
|
||||||
|
"iat": datetime.now(UTC) - timedelta(hours=2),
|
||||||
|
"nbf": datetime.now(UTC) - timedelta(hours=2),
|
||||||
|
"github_username": "test-user",
|
||||||
|
}
|
||||||
|
|
||||||
|
expired_token = jwt.encode(claims, "test_jwt_secret", algorithm="HS256")
|
||||||
|
|
||||||
|
response = github_oauth_client.get("/test", headers={"Authorization": f"Bearer {expired_token}"})
|
||||||
|
assert response.status_code == 401
|
||||||
|
assert "Invalid GitHub JWT token" in response.json()["error"]["message"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_github_jwt_authentication_wrong_audience(github_oauth_client):
|
||||||
|
"""Test API authentication with JWT having wrong audience"""
|
||||||
|
claims = {
|
||||||
|
"sub": "test-user",
|
||||||
|
"aud": "wrong-audience", # Wrong audience
|
||||||
|
"iss": "llama-stack-github",
|
||||||
|
"exp": datetime.now(UTC) + timedelta(hours=1),
|
||||||
|
"iat": datetime.now(UTC),
|
||||||
|
"nbf": datetime.now(UTC),
|
||||||
|
"github_username": "test-user",
|
||||||
|
}
|
||||||
|
|
||||||
|
wrong_audience_token = jwt.encode(claims, "test_jwt_secret", algorithm="HS256")
|
||||||
|
|
||||||
|
response = github_oauth_client.get("/test", headers={"Authorization": f"Bearer {wrong_audience_token}"})
|
||||||
|
assert response.status_code == 401
|
||||||
|
assert "Invalid GitHub JWT token" in response.json()["error"]["message"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_github_claims_mapping():
|
||||||
|
"""Test GitHub claims are properly mapped to attributes"""
|
||||||
|
config = GitHubAuthConfig(
|
||||||
|
type="github_oauth",
|
||||||
|
github_client_id="test",
|
||||||
|
github_client_secret="test",
|
||||||
|
jwt_secret="test",
|
||||||
|
)
|
||||||
|
|
||||||
|
claims = {
|
||||||
|
"sub": "test-user",
|
||||||
|
"github_username": "test-user",
|
||||||
|
"github_orgs": ["org1", "org2"],
|
||||||
|
"github_teams": ["team1", "team2"],
|
||||||
|
"github_user_id": "12345",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Default mapping only maps "sub" to "roles"
|
||||||
|
attributes = get_attributes_from_claims(claims, config.claims_mapping)
|
||||||
|
|
||||||
|
assert "test-user" in attributes["roles"]
|
||||||
|
# No other mappings by default
|
||||||
|
assert len(attributes) == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_github_auth_provider_validate_token():
|
||||||
|
"""Test GitHubAuthProvider token validation"""
|
||||||
|
config = GitHubAuthConfig(
|
||||||
|
type="github_oauth",
|
||||||
|
github_client_id="test",
|
||||||
|
github_client_secret="test",
|
||||||
|
jwt_secret="test_secret",
|
||||||
|
jwt_algorithm="HS256",
|
||||||
|
jwt_audience="test-audience",
|
||||||
|
jwt_issuer="test-issuer",
|
||||||
|
)
|
||||||
|
|
||||||
|
provider = GitHubAuthProvider(config)
|
||||||
|
|
||||||
|
# Create a valid token
|
||||||
|
claims = {
|
||||||
|
"sub": "test-user",
|
||||||
|
"aud": "test-audience",
|
||||||
|
"iss": "test-issuer",
|
||||||
|
"exp": datetime.now(UTC) + timedelta(hours=1),
|
||||||
|
"iat": datetime.now(UTC),
|
||||||
|
"nbf": datetime.now(UTC),
|
||||||
|
"github_username": "test-user",
|
||||||
|
"github_orgs": ["org1"],
|
||||||
|
"github_user_id": "12345",
|
||||||
|
}
|
||||||
|
|
||||||
|
token = jwt.encode(claims, "test_secret", algorithm="HS256")
|
||||||
|
|
||||||
|
user = await provider.validate_token(token)
|
||||||
|
assert user.principal == "test-user"
|
||||||
|
# Default mapping only maps "sub" to "roles"
|
||||||
|
assert "test-user" in user.attributes["roles"]
|
||||||
|
# No other mappings by default
|
||||||
|
assert len(user.attributes) == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_github_auth_provider_custom_claims_mapping():
|
||||||
|
"""Test GitHubAuthProvider with custom claims mapping"""
|
||||||
|
config = GitHubAuthConfig(
|
||||||
|
type="github_oauth",
|
||||||
|
github_client_id="test",
|
||||||
|
github_client_secret="test",
|
||||||
|
jwt_secret="test_secret",
|
||||||
|
jwt_algorithm="HS256",
|
||||||
|
jwt_audience="test-audience",
|
||||||
|
jwt_issuer="test-issuer",
|
||||||
|
claims_mapping={
|
||||||
|
"sub": "roles",
|
||||||
|
"github_orgs": "teams",
|
||||||
|
"github_teams": "teams",
|
||||||
|
"github_user_id": "namespaces",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
provider = GitHubAuthProvider(config)
|
||||||
|
|
||||||
|
# Create a valid token
|
||||||
|
claims = {
|
||||||
|
"sub": "test-user",
|
||||||
|
"aud": "test-audience",
|
||||||
|
"iss": "test-issuer",
|
||||||
|
"exp": datetime.now(UTC) + timedelta(hours=1),
|
||||||
|
"iat": datetime.now(UTC),
|
||||||
|
"nbf": datetime.now(UTC),
|
||||||
|
"github_username": "test-user",
|
||||||
|
"github_orgs": ["org1", "org2"],
|
||||||
|
"github_teams": ["team1", "team2"],
|
||||||
|
"github_user_id": "12345",
|
||||||
|
}
|
||||||
|
|
||||||
|
token = jwt.encode(claims, "test_secret", algorithm="HS256")
|
||||||
|
|
||||||
|
user = await provider.validate_token(token)
|
||||||
|
assert user.principal == "test-user"
|
||||||
|
assert "test-user" in user.attributes["roles"]
|
||||||
|
assert set(user.attributes["teams"]) == {"org1", "org2", "team1", "team2"}
|
||||||
|
assert "12345" in user.attributes["namespaces"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_github_auth_provider_invalid_token():
|
||||||
|
"""Test GitHubAuthProvider with invalid token"""
|
||||||
|
config = GitHubAuthConfig(
|
||||||
|
type="github_oauth",
|
||||||
|
github_client_id="test",
|
||||||
|
github_client_secret="test",
|
||||||
|
jwt_secret="test_secret",
|
||||||
|
)
|
||||||
|
|
||||||
|
provider = GitHubAuthProvider(config)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Invalid GitHub JWT token"):
|
||||||
|
await provider.validate_token("invalid.token.here")
|
||||||
|
|
||||||
|
|
||||||
|
def test_github_auth_provider_authorization_url():
|
||||||
|
"""Test GitHubAuthProvider generates correct authorization URL"""
|
||||||
|
from unittest.mock import Mock
|
||||||
|
|
||||||
|
config = GitHubAuthConfig(
|
||||||
|
type="github_oauth",
|
||||||
|
github_client_id="test_client_id",
|
||||||
|
github_client_secret="test_secret",
|
||||||
|
jwt_secret="test",
|
||||||
|
)
|
||||||
|
|
||||||
|
provider = GitHubAuthProvider(config)
|
||||||
|
|
||||||
|
# Mock request object
|
||||||
|
mock_request = Mock()
|
||||||
|
mock_request.url_for.return_value = "http://localhost:8321/auth/github/callback"
|
||||||
|
|
||||||
|
url = provider.get_authorization_url("test_state_123", mock_request)
|
||||||
|
|
||||||
|
assert "https://github.com/login/oauth/authorize" in url
|
||||||
|
assert "client_id=test_client_id" in url
|
||||||
|
assert "redirect_uri=http%3A%2F%2Flocalhost%3A8321%2Fauth%2Fgithub%2Fcallback" in url
|
||||||
|
assert "state=test_state_123" in url
|
||||||
|
assert "scope=read%3Auser+read%3Aorg" in url
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_github_auth_provider_complete_flow():
|
||||||
|
"""Test complete OAuth flow in GitHubAuthProvider"""
|
||||||
|
from unittest.mock import Mock
|
||||||
|
|
||||||
|
config = GitHubAuthConfig(
|
||||||
|
type="github_oauth",
|
||||||
|
github_client_id="test_client_id",
|
||||||
|
github_client_secret="test_secret",
|
||||||
|
jwt_secret="test_jwt_secret",
|
||||||
|
jwt_algorithm="HS256",
|
||||||
|
jwt_audience="llama-stack",
|
||||||
|
jwt_issuer="llama-stack-github",
|
||||||
|
)
|
||||||
|
|
||||||
|
provider = GitHubAuthProvider(config)
|
||||||
|
|
||||||
|
# Mock request object
|
||||||
|
mock_request = Mock()
|
||||||
|
mock_request.url_for.return_value = "/auth/github/callback"
|
||||||
|
mock_request.url.scheme = "http"
|
||||||
|
mock_request.url.netloc = "localhost:8321"
|
||||||
|
|
||||||
|
with patch("httpx.AsyncClient", MockAsyncClient):
|
||||||
|
# Now returns just the JWT token
|
||||||
|
token = await provider.complete_oauth_flow("test_code", mock_request)
|
||||||
|
|
||||||
|
# Verify it's a JWT token by decoding it
|
||||||
|
claims = jwt.decode(token, "test_jwt_secret", algorithms=["HS256"], audience="llama-stack")
|
||||||
|
assert claims["github_username"] == "test-user"
|
||||||
|
assert claims["github_orgs"] == ["test-org-1", "test-org-2"]
|
||||||
|
assert claims["sub"] == "test-user"
|
Loading…
Add table
Add a link
Reference in a new issue