- Change auth config from provider_type + config dict to discriminated union types
- Add GitHub token authentication provider
- Improve auth error messages with provider-specific guidance
- Extract auth datatypes to separate module
- Update tests to use new auth config structure
- Remove unused OAuth2LocalJWTConfig

## Test Plan
- Unit tests pass for all auth providers
- Integration tests verify auth flow works correctly
- GitHub token auth tested with valid/invalid tokens

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
Eric Huang 2025-07-03 10:26:51 -07:00
parent 3c43a2f529
commit c2d16c713e
7 changed files with 480 additions and 161 deletions

View file

@ -73,9 +73,12 @@ jobs:
server:
port: 8321
EOF
yq eval '.server.auth = {"provider_type": "${{ matrix.auth-provider }}"}' -i $run_dir/run.yaml
yq eval '.server.auth.config = {"tls_cafile": "${{ env.KUBERNETES_CA_CERT_PATH }}", "issuer": "${{ env.KUBERNETES_ISSUER }}", "audience": "${{ env.KUBERNETES_AUDIENCE }}"}' -i $run_dir/run.yaml
yq eval '.server.auth.config.jwks = {"uri": "${{ env.KUBERNETES_API_SERVER_URL }}", "token": "${{ env.TOKEN }}"}' -i $run_dir/run.yaml
yq eval '.server.auth.provider_config.type = "${{ matrix.auth-provider }}"' -i $run_dir/run.yaml
yq eval '.server.auth.provider_config.tls_cafile = "${{ env.KUBERNETES_CA_CERT_PATH }}"' -i $run_dir/run.yaml
yq eval '.server.auth.provider_config.issuer = "${{ env.KUBERNETES_ISSUER }}"' -i $run_dir/run.yaml
yq eval '.server.auth.provider_config.audience = "${{ env.KUBERNETES_AUDIENCE }}"' -i $run_dir/run.yaml
yq eval '.server.auth.provider_config.jwks.uri = "${{ env.KUBERNETES_API_SERVER_URL }}"' -i $run_dir/run.yaml
yq eval '.server.auth.provider_config.jwks.token = "${{ env.TOKEN }}"' -i $run_dir/run.yaml
cat $run_dir/run.yaml
nohup uv run llama stack run $run_dir/run.yaml --image-type venv > server.log 2>&1 &

View file

@ -6,9 +6,9 @@
from enum import StrEnum
from pathlib import Path
from typing import Annotated, Any
from typing import Annotated, Any, Literal, Self
from pydantic import BaseModel, Field, field_validator
from pydantic import BaseModel, Field, field_validator, model_validator
from llama_stack.apis.benchmarks import Benchmark, BenchmarkInput
from llama_stack.apis.datasetio import DatasetIO
@ -161,23 +161,114 @@ class LoggingConfig(BaseModel):
)
class OAuth2JWKSConfig(BaseModel):
# The JWKS URI for collecting public keys
uri: str
token: str | None = Field(default=None, description="token to authorise access to jwks")
key_recheck_period: int = Field(default=3600, description="The period to recheck the JWKS URI for key updates")
class OAuth2IntrospectionConfig(BaseModel):
url: str
client_id: str
client_secret: str
send_secret_in_body: bool = False
class AuthProviderType(StrEnum):
"""Supported authentication provider types."""
OAUTH2_TOKEN = "oauth2_token"
GITHUB_TOKEN = "github_token"
CUSTOM = "custom"
class OAuth2TokenAuthConfig(BaseModel):
"""Configuration for OAuth2 token authentication."""
type: Literal[AuthProviderType.OAUTH2_TOKEN] = AuthProviderType.OAUTH2_TOKEN
audience: str = Field(default="llama-stack")
verify_tls: bool = Field(default=True)
tls_cafile: Path | None = Field(default=None)
issuer: str | None = Field(default=None, description="The OIDC issuer URL.")
claims_mapping: dict[str, str] = Field(
default_factory=lambda: {
"sub": "roles",
"username": "roles",
"groups": "teams",
"team": "teams",
"project": "projects",
"tenant": "namespaces",
"namespace": "namespaces",
},
)
jwks: OAuth2JWKSConfig | None = Field(default=None, description="JWKS configuration")
introspection: OAuth2IntrospectionConfig | None = Field(
default=None, description="OAuth2 introspection configuration"
)
@classmethod
@field_validator("claims_mapping")
def validate_claims_mapping(cls, v):
for key, value in v.items():
if not value:
raise ValueError(f"claims_mapping value cannot be empty: {key}")
return v
@model_validator(mode="after")
def validate_mode(self) -> Self:
if not self.jwks and not self.introspection:
raise ValueError("One of jwks or introspection must be configured")
if self.jwks and self.introspection:
raise ValueError("At present only one of jwks or introspection should be configured")
return self
class CustomAuthConfig(BaseModel):
"""Configuration for custom authentication."""
type: Literal[AuthProviderType.CUSTOM] = AuthProviderType.CUSTOM
endpoint: str = Field(
...,
description="Custom authentication endpoint URL",
)
class GitHubTokenAuthConfig(BaseModel):
"""Configuration for GitHub token authentication."""
type: Literal[AuthProviderType.GITHUB_TOKEN] = AuthProviderType.GITHUB_TOKEN
github_api_base_url: str = Field(
default="https://api.github.com",
description="Base URL for GitHub API (use https://api.github.com for public GitHub)",
)
claims_mapping: dict[str, str] = Field(
default_factory=lambda: {
"login": "username",
"id": "user_id",
"organizations": "teams",
},
description="Mapping from GitHub user fields to access attributes",
)
AuthProviderConfig = Annotated[
OAuth2TokenAuthConfig | GitHubTokenAuthConfig | CustomAuthConfig,
Field(discriminator="type"),
]
class AuthenticationConfig(BaseModel):
provider_type: AuthProviderType = Field(
"""Top-level authentication configuration."""
provider_config: AuthProviderConfig = Field(
...,
description="Type of authentication provider",
description="Authentication provider configuration",
)
config: dict[str, Any] = Field(
...,
description="Provider-specific configuration",
access_policy: list[AccessRule] = Field(
default=[],
description="Rules for determining access to resources",
)
access_policy: list[AccessRule] = Field(default=[], description="Rules for determining access to resources")
class AuthenticationRequiredError(Exception):

View file

@ -87,7 +87,11 @@ class AuthenticationMiddleware:
headers = dict(scope.get("headers", []))
auth_header = headers.get(b"authorization", b"").decode()
if not auth_header or not auth_header.startswith("Bearer "):
if not auth_header:
error_msg = self.auth_provider.get_auth_error_message(scope)
return await self._send_auth_error(send, error_msg)
if not auth_header.startswith("Bearer "):
return await self._send_auth_error(send, "Missing or invalid Authorization header")
token = auth_header.split("Bearer ", 1)[1]

View file

@ -8,15 +8,19 @@ import ssl
import time
from abc import ABC, abstractmethod
from asyncio import Lock
from pathlib import Path
from typing import Self
from urllib.parse import parse_qs
from urllib.parse import parse_qs, urlparse
import httpx
from jose import jwt
from pydantic import BaseModel, Field, field_validator, model_validator
from pydantic import BaseModel, Field
from llama_stack.distribution.datatypes import AuthenticationConfig, AuthProviderType, User
from llama_stack.distribution.datatypes import (
AuthenticationConfig,
CustomAuthConfig,
GitHubTokenAuthConfig,
OAuth2TokenAuthConfig,
User,
)
from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="auth")
@ -38,9 +42,7 @@ class AuthRequestContext(BaseModel):
headers: dict[str, str] = Field(description="HTTP headers from the original request (excluding Authorization)")
params: dict[str, list[str]] = Field(
description="Query parameters from the original request, parsed as dictionary of lists"
)
params: dict[str, list[str]] = Field(default_factory=dict, description="Query parameters from the original request")
class AuthRequest(BaseModel):
@ -62,6 +64,10 @@ class AuthProvider(ABC):
"""Clean up any resources."""
pass
def get_auth_error_message(self, scope: dict | None = None) -> str:
"""Return provider-specific authentication error message."""
return "Authentication required"
def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str]) -> dict[str, list[str]]:
attributes: dict[str, list[str]] = {}
@ -81,56 +87,6 @@ def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str])
return attributes
class OAuth2JWKSConfig(BaseModel):
# The JWKS URI for collecting public keys
uri: str
token: str | None = Field(default=None, description="token to authorise access to jwks")
key_recheck_period: int = Field(default=3600, description="The period to recheck the JWKS URI for key updates")
class OAuth2IntrospectionConfig(BaseModel):
url: str
client_id: str
client_secret: str
send_secret_in_body: bool = False
class OAuth2TokenAuthProviderConfig(BaseModel):
audience: str = "llama-stack"
verify_tls: bool = True
tls_cafile: Path | None = None
issuer: str | None = Field(default=None, description="The OIDC issuer URL.")
claims_mapping: dict[str, str] = Field(
default_factory=lambda: {
"sub": "roles",
"username": "roles",
"groups": "teams",
"team": "teams",
"project": "projects",
"tenant": "namespaces",
"namespace": "namespaces",
},
)
jwks: OAuth2JWKSConfig | None
introspection: OAuth2IntrospectionConfig | None = None
@classmethod
@field_validator("claims_mapping")
def validate_claims_mapping(cls, v):
for key, value in v.items():
if not value:
raise ValueError(f"claims_mapping value cannot be empty: {key}")
return v
@model_validator(mode="after")
def validate_mode(self) -> Self:
if not self.jwks and not self.introspection:
raise ValueError("One of jwks or introspection must be configured")
if self.jwks and self.introspection:
raise ValueError("At present only one of jwks or introspection should be configured")
return self
class OAuth2TokenAuthProvider(AuthProvider):
"""
JWT token authentication provider that validates a JWT token and extracts access attributes.
@ -138,7 +94,7 @@ class OAuth2TokenAuthProvider(AuthProvider):
This should be the standard authentication provider for most use cases.
"""
def __init__(self, config: OAuth2TokenAuthProviderConfig):
def __init__(self, config: OAuth2TokenAuthConfig):
self.config = config
self._jwks_at: float = 0.0
self._jwks: dict[str, str] = {}
@ -170,7 +126,7 @@ class OAuth2TokenAuthProvider(AuthProvider):
issuer=self.config.issuer,
)
except Exception as exc:
raise ValueError(f"Invalid JWT token: {token}") from exc
raise ValueError("Invalid JWT token") from exc
# There are other standard claims, the most relevant of which is `scope`.
# We should incorporate these into the access attributes.
@ -232,6 +188,17 @@ class OAuth2TokenAuthProvider(AuthProvider):
async def close(self):
pass
def get_auth_error_message(self, scope: dict | None = None) -> str:
"""Return OAuth2-specific authentication error message."""
if self.config.issuer:
return f"Authentication required. Please provide a valid OAuth2 Bearer token from {self.config.issuer}"
elif self.config.introspection:
# Extract domain from introspection URL for a cleaner message
domain = urlparse(self.config.introspection.url).netloc
return f"Authentication required. Please provide a valid OAuth2 Bearer token validated by {domain}"
else:
return "Authentication required. Please provide a valid OAuth2 Bearer token in the Authorization header"
async def _refresh_jwks(self) -> None:
"""
Refresh the JWKS cache.
@ -264,14 +231,10 @@ class OAuth2TokenAuthProvider(AuthProvider):
self._jwks_at = time.time()
class CustomAuthProviderConfig(BaseModel):
endpoint: str
class CustomAuthProvider(AuthProvider):
"""Custom authentication provider that uses an external endpoint."""
def __init__(self, config: CustomAuthProviderConfig):
def __init__(self, config: CustomAuthConfig):
self.config = config
self._client = None
@ -317,7 +280,7 @@ class CustomAuthProvider(AuthProvider):
try:
response_data = response.json()
auth_response = AuthResponse(**response_data)
return User(auth_response.principal, auth_response.attributes)
return User(principal=auth_response.principal, attributes=auth_response.attributes)
except Exception as e:
logger.exception("Error parsing authentication response")
raise ValueError("Invalid authentication response format") from e
@ -338,15 +301,85 @@ class CustomAuthProvider(AuthProvider):
await self._client.aclose()
self._client = None
def get_auth_error_message(self, scope: dict | None = None) -> str:
"""Return custom auth provider-specific authentication error message."""
domain = urlparse(self.config.endpoint).netloc
if domain:
return f"Authentication required. Please provide your API key as a Bearer token (validated by {domain})"
else:
return "Authentication required. Please provide your API key as a Bearer token in the Authorization header"
class GitHubTokenAuthProvider(AuthProvider):
"""
GitHub token authentication provider that validates GitHub access tokens directly.
This provider accepts GitHub personal access tokens or OAuth tokens and verifies
them against the GitHub API to get user information.
"""
def __init__(self, config: GitHubTokenAuthConfig):
self.config = config
async def validate_token(self, token: str, scope: dict | None = None) -> User:
"""Validate a GitHub token by calling the GitHub API."""
try:
user_info = await self._get_github_user_info(token)
principal = user_info["user"]["login"]
github_data = {
"login": user_info["user"]["login"],
"id": str(user_info["user"]["id"]),
"organizations": user_info.get("organizations", []),
}
access_attributes = get_attributes_from_claims(github_data, self.config.claims_mapping)
return User(
principal=principal,
attributes=access_attributes,
)
except Exception as e:
logger.warning(f"GitHub token validation failed: {e}")
raise ValueError("Invalid GitHub token") from e
async def _get_github_user_info(self, access_token: str) -> dict:
"""Fetch user info and organizations from GitHub API."""
headers = {
"Authorization": f"Bearer {access_token}",
"Accept": "application/vnd.github.v3+json",
"User-Agent": "llama-stack",
}
async with httpx.AsyncClient() as client:
user_response = await client.get(f"{self.config.github_api_base_url}/user", headers=headers, timeout=10.0)
user_response.raise_for_status()
user_data = user_response.json()
return {
"user": user_data,
}
async def close(self):
"""Clean up any resources."""
pass
def get_auth_error_message(self, scope: dict | None = None) -> str:
"""Return GitHub-specific authentication error message."""
return "Authentication required. Please provide a valid GitHub access token (https://docs.github.com/en/authentication/keeping-your-account-and-data-secure/managing-your-personal-access-tokens) in the Authorization header (Bearer <token>)"
def create_auth_provider(config: AuthenticationConfig) -> AuthProvider:
"""Factory function to create the appropriate auth provider."""
provider_type = config.provider_type.lower()
provider_config = config.provider_config
if provider_type == "custom":
return CustomAuthProvider(CustomAuthProviderConfig.model_validate(config.config))
elif provider_type == "oauth2_token":
return OAuth2TokenAuthProvider(OAuth2TokenAuthProviderConfig.model_validate(config.config))
if isinstance(provider_config, CustomAuthConfig):
return CustomAuthProvider(provider_config)
elif isinstance(provider_config, OAuth2TokenAuthConfig):
return OAuth2TokenAuthProvider(provider_config)
elif isinstance(provider_config, GitHubTokenAuthConfig):
return GitHubTokenAuthProvider(provider_config)
else:
supported_providers = ", ".join([t.value for t in AuthProviderType])
raise ValueError(f"Unsupported auth provider type: {provider_type}. Supported types are: {supported_providers}")
raise ValueError(f"Unknown authentication provider config type: {type(provider_config)}")

View file

@ -31,7 +31,11 @@ from openai import BadRequestError
from pydantic import BaseModel, ValidationError
from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.distribution.datatypes import AuthenticationRequiredError, LoggingConfig, StackRunConfig
from llama_stack.distribution.datatypes import (
AuthenticationRequiredError,
LoggingConfig,
StackRunConfig,
)
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
from llama_stack.distribution.request_headers import PROVIDER_DATA_VAR, User, request_provider_data_context
from llama_stack.distribution.resolver import InvalidProviderError
@ -215,7 +219,7 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
# Get auth attributes from the request scope
user_attributes = request.scope.get("user_attributes", {})
principal = request.scope.get("principal", "")
user = User(principal, user_attributes)
user = User(principal=principal, attributes=user_attributes)
await log_request_pre_validation(request)
@ -450,7 +454,7 @@ def main(args: argparse.Namespace | None = None):
# Add authentication middleware if configured
if config.server.auth:
logger.info(f"Enabling authentication with provider: {config.server.auth.provider_type.value}")
logger.info(f"Enabling authentication with provider: {config.server.auth.provider_config.type.value}")
app.add_middleware(AuthenticationMiddleware, auth_config=config.server.auth)
else:
if config.server.quota:

View file

@ -11,10 +11,16 @@ import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from llama_stack.distribution.datatypes import AuthenticationConfig
from llama_stack.distribution.datatypes import (
AuthenticationConfig,
AuthProviderType,
CustomAuthConfig,
OAuth2IntrospectionConfig,
OAuth2JWKSConfig,
OAuth2TokenAuthConfig,
)
from llama_stack.distribution.server.auth import AuthenticationMiddleware
from llama_stack.distribution.server.auth_providers import (
AuthProviderType,
get_attributes_from_claims,
)
@ -61,24 +67,11 @@ def invalid_token():
def http_app(mock_auth_endpoint):
app = FastAPI()
auth_config = AuthenticationConfig(
provider_type=AuthProviderType.CUSTOM,
config={"endpoint": mock_auth_endpoint},
)
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
@app.get("/test")
def test_endpoint():
return {"message": "Authentication successful"}
return app
@pytest.fixture
def k8s_app():
app = FastAPI()
auth_config = AuthenticationConfig(
provider_type=AuthProviderType.KUBERNETES,
config={"api_server_url": "https://kubernetes.default.svc"},
provider_config=CustomAuthConfig(
type=AuthProviderType.CUSTOM,
endpoint=mock_auth_endpoint,
),
access_policy=[],
)
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
@ -94,11 +87,6 @@ def http_client(http_app):
return TestClient(http_app)
@pytest.fixture
def k8s_client(k8s_app):
return TestClient(k8s_app)
@pytest.fixture
def mock_scope():
return {
@ -117,18 +105,11 @@ def mock_scope():
def mock_http_middleware(mock_auth_endpoint):
mock_app = AsyncMock()
auth_config = AuthenticationConfig(
provider_type=AuthProviderType.CUSTOM,
config={"endpoint": mock_auth_endpoint},
)
return AuthenticationMiddleware(mock_app, auth_config), mock_app
@pytest.fixture
def mock_k8s_middleware():
mock_app = AsyncMock()
auth_config = AuthenticationConfig(
provider_type=AuthProviderType.KUBERNETES,
config={"api_server_url": "https://kubernetes.default.svc"},
provider_config=CustomAuthConfig(
type=AuthProviderType.CUSTOM,
endpoint=mock_auth_endpoint,
),
access_policy=[],
)
return AuthenticationMiddleware(mock_app, auth_config), mock_app
@ -161,7 +142,8 @@ async def mock_post_exception(*args, **kwargs):
def test_missing_auth_header(http_client):
response = http_client.get("/test")
assert response.status_code == 401
assert "Missing or invalid Authorization header" in response.json()["error"]["message"]
assert "Authentication required" in response.json()["error"]["message"]
assert "validated by mock-auth-service" in response.json()["error"]["message"]
def test_invalid_auth_header_format(http_client):
@ -262,14 +244,14 @@ async def test_http_middleware_with_access_attributes(mock_http_middleware, mock
def oauth2_app():
app = FastAPI()
auth_config = AuthenticationConfig(
provider_type=AuthProviderType.OAUTH2_TOKEN,
config={
"jwks": {
"uri": "http://mock-authz-service/token/introspect",
"key_recheck_period": "3600",
},
"audience": "llama-stack",
},
provider_config=OAuth2TokenAuthConfig(
type=AuthProviderType.OAUTH2_TOKEN,
jwks=OAuth2JWKSConfig(
uri="http://mock-authz-service/token/introspect",
),
audience="llama-stack",
),
access_policy=[],
)
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
@ -288,7 +270,8 @@ def oauth2_client(oauth2_app):
def test_missing_auth_header_oauth2(oauth2_client):
response = oauth2_client.get("/test")
assert response.status_code == 401
assert "Missing or invalid Authorization header" in response.json()["error"]["message"]
assert "Authentication required" in response.json()["error"]["message"]
assert "OAuth2 Bearer token" in response.json()["error"]["message"]
def test_invalid_auth_header_format_oauth2(oauth2_client):
@ -358,15 +341,16 @@ async def mock_auth_jwks_response(*args, **kwargs):
def oauth2_app_with_jwks_token():
app = FastAPI()
auth_config = AuthenticationConfig(
provider_type=AuthProviderType.OAUTH2_TOKEN,
config={
"jwks": {
"uri": "http://mock-authz-service/token/introspect",
"key_recheck_period": "3600",
"token": "my-jwks-token",
},
"audience": "llama-stack",
},
provider_config=OAuth2TokenAuthConfig(
type=AuthProviderType.OAUTH2_TOKEN,
jwks=OAuth2JWKSConfig(
uri="http://mock-authz-service/token/introspect",
key_recheck_period=3600,
token="my-jwks-token",
),
audience="llama-stack",
),
access_policy=[],
)
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
@ -449,11 +433,15 @@ def mock_introspection_endpoint():
def introspection_app(mock_introspection_endpoint):
app = FastAPI()
auth_config = AuthenticationConfig(
provider_type=AuthProviderType.OAUTH2_TOKEN,
config={
"jwks": None,
"introspection": {"url": mock_introspection_endpoint, "client_id": "myclient", "client_secret": "abcdefg"},
},
provider_config=OAuth2TokenAuthConfig(
type=AuthProviderType.OAUTH2_TOKEN,
introspection=OAuth2IntrospectionConfig(
url=mock_introspection_endpoint,
client_id="myclient",
client_secret="abcdefg",
),
),
access_policy=[],
)
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
@ -468,22 +456,22 @@ def introspection_app(mock_introspection_endpoint):
def introspection_app_with_custom_mapping(mock_introspection_endpoint):
app = FastAPI()
auth_config = AuthenticationConfig(
provider_type=AuthProviderType.OAUTH2_TOKEN,
config={
"jwks": None,
"introspection": {
"url": mock_introspection_endpoint,
"client_id": "myclient",
"client_secret": "abcdefg",
"send_secret_in_body": "true",
},
"claims_mapping": {
provider_config=OAuth2TokenAuthConfig(
type=AuthProviderType.OAUTH2_TOKEN,
introspection=OAuth2IntrospectionConfig(
url=mock_introspection_endpoint,
client_id="myclient",
client_secret="abcdefg",
send_secret_in_body=True,
),
claims_mapping={
"sub": "roles",
"scope": "roles",
"groups": "teams",
"aud": "namespaces",
},
},
),
access_policy=[],
)
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
@ -507,7 +495,8 @@ def introspection_client_with_custom_mapping(introspection_app_with_custom_mappi
def test_missing_auth_header_introspection(introspection_client):
response = introspection_client.get("/test")
assert response.status_code == 401
assert "Missing or invalid Authorization header" in response.json()["error"]["message"]
assert "Authentication required" in response.json()["error"]["message"]
assert "OAuth2 Bearer token" in response.json()["error"]["message"]
def test_invalid_auth_header_format_introspection(introspection_client):

View file

@ -0,0 +1,195 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from unittest.mock import AsyncMock, patch
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from llama_stack.distribution.datatypes import AuthenticationConfig, AuthProviderType, GitHubTokenAuthConfig
from llama_stack.distribution.server.auth import AuthenticationMiddleware
class MockResponse:
def __init__(self, status_code, json_data):
self.status_code = status_code
self._json_data = json_data
def json(self):
return self._json_data
def raise_for_status(self):
if self.status_code != 200:
raise Exception(f"HTTP error: {self.status_code}")
@pytest.fixture
def github_token_app():
app = FastAPI()
# Configure GitHub token auth
auth_config = AuthenticationConfig(
provider_config=GitHubTokenAuthConfig(
type=AuthProviderType.GITHUB_TOKEN,
github_api_base_url="https://api.github.com",
claims_mapping={
"login": "username",
"id": "user_id",
"organizations": "teams",
},
),
access_policy=[],
)
# Add auth middleware
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
@app.get("/test")
def test_endpoint():
return {"message": "Authentication successful"}
return app
@pytest.fixture
def github_token_client(github_token_app):
return TestClient(github_token_app)
def test_authenticated_endpoint_without_token(github_token_client):
"""Test accessing protected endpoint without token"""
response = github_token_client.get("/test")
assert response.status_code == 401
assert "Authentication required" in response.json()["error"]["message"]
assert "GitHub access token" in response.json()["error"]["message"]
def test_authenticated_endpoint_with_invalid_bearer_format(github_token_client):
"""Test accessing protected endpoint with invalid bearer format"""
response = github_token_client.get("/test", headers={"Authorization": "InvalidFormat token123"})
assert response.status_code == 401
assert "Missing or invalid Authorization header" in response.json()["error"]["message"]
@patch("llama_stack.distribution.server.auth_providers.httpx.AsyncClient")
def test_authenticated_endpoint_with_valid_github_token(mock_client_class, github_token_client):
"""Test accessing protected endpoint with valid GitHub token"""
# Mock the GitHub API responses
mock_client = AsyncMock()
mock_client_class.return_value.__aenter__.return_value = mock_client
# Mock successful user API response
mock_client.get.side_effect = [
MockResponse(
200,
{
"login": "testuser",
"id": 12345,
"email": "test@example.com",
"name": "Test User",
},
),
MockResponse(
200,
[
{"login": "test-org-1"},
{"login": "test-org-2"},
],
),
]
response = github_token_client.get("/test", headers={"Authorization": "Bearer github_token_123"})
assert response.status_code == 200
assert response.json()["message"] == "Authentication successful"
# Verify the GitHub API was called correctly
assert mock_client.get.call_count == 1
calls = mock_client.get.call_args_list
assert calls[0][0][0] == "https://api.github.com/user"
# Check authorization header was passed
assert calls[0][1]["headers"]["Authorization"] == "Bearer github_token_123"
@patch("llama_stack.distribution.server.auth_providers.httpx.AsyncClient")
def test_authenticated_endpoint_with_invalid_github_token(mock_client_class, github_token_client):
"""Test accessing protected endpoint with invalid GitHub token"""
# Mock the GitHub API to return 401 Unauthorized
mock_client = AsyncMock()
mock_client_class.return_value.__aenter__.return_value = mock_client
# Mock failed user API response
mock_client.get.return_value = MockResponse(401, {"message": "Bad credentials"})
response = github_token_client.get("/test", headers={"Authorization": "Bearer invalid_token"})
assert response.status_code == 401
assert "Invalid GitHub token" in response.json()["error"]["message"]
@patch("llama_stack.distribution.server.auth_providers.httpx.AsyncClient")
def test_github_enterprise_support(mock_client_class):
"""Test GitHub Enterprise support with custom API base URL"""
app = FastAPI()
# Configure GitHub token auth with enterprise URL
auth_config = AuthenticationConfig(
provider_config=GitHubTokenAuthConfig(
type=AuthProviderType.GITHUB_TOKEN,
github_api_base_url="https://github.enterprise.com/api/v3",
),
access_policy=[],
)
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
@app.get("/test")
def test_endpoint():
return {"message": "Authentication successful"}
client = TestClient(app)
# Mock the GitHub Enterprise API responses
mock_client = AsyncMock()
mock_client_class.return_value.__aenter__.return_value = mock_client
# Mock successful user API response
mock_client.get.side_effect = [
MockResponse(
200,
{
"login": "enterprise_user",
"id": 99999,
"email": "user@enterprise.com",
},
),
MockResponse(
200,
[
{"login": "enterprise-org"},
],
),
]
response = client.get("/test", headers={"Authorization": "Bearer enterprise_token"})
assert response.status_code == 200
# Verify the correct GitHub Enterprise URLs were called
assert mock_client.get.call_count == 1
calls = mock_client.get.call_args_list
assert calls[0][0][0] == "https://github.enterprise.com/api/v3/user"
def test_github_token_auth_error_message_format(github_token_client):
"""Test that the error message for missing auth is properly formatted"""
response = github_token_client.get("/test")
assert response.status_code == 401
error_data = response.json()
assert "error" in error_data
assert "message" in error_data["error"]
assert "Authentication required" in error_data["error"]["message"]
assert "https://docs.github.com" in error_data["error"]["message"] # Contains link to GitHub docs