diff --git a/.github/workflows/integration-auth-tests.yml b/.github/workflows/integration-auth-tests.yml index 4139d09ca..7822e4216 100644 --- a/.github/workflows/integration-auth-tests.yml +++ b/.github/workflows/integration-auth-tests.yml @@ -73,9 +73,12 @@ jobs: server: port: 8321 EOF - yq eval '.server.auth = {"provider_type": "${{ matrix.auth-provider }}"}' -i $run_dir/run.yaml - yq eval '.server.auth.config = {"tls_cafile": "${{ env.KUBERNETES_CA_CERT_PATH }}", "issuer": "${{ env.KUBERNETES_ISSUER }}", "audience": "${{ env.KUBERNETES_AUDIENCE }}"}' -i $run_dir/run.yaml - yq eval '.server.auth.config.jwks = {"uri": "${{ env.KUBERNETES_API_SERVER_URL }}", "token": "${{ env.TOKEN }}"}' -i $run_dir/run.yaml + yq eval '.server.auth.provider_config.type = "${{ matrix.auth-provider }}"' -i $run_dir/run.yaml + yq eval '.server.auth.provider_config.tls_cafile = "${{ env.KUBERNETES_CA_CERT_PATH }}"' -i $run_dir/run.yaml + yq eval '.server.auth.provider_config.issuer = "${{ env.KUBERNETES_ISSUER }}"' -i $run_dir/run.yaml + yq eval '.server.auth.provider_config.audience = "${{ env.KUBERNETES_AUDIENCE }}"' -i $run_dir/run.yaml + yq eval '.server.auth.provider_config.jwks.uri = "${{ env.KUBERNETES_API_SERVER_URL }}"' -i $run_dir/run.yaml + yq eval '.server.auth.provider_config.jwks.token = "${{ env.TOKEN }}"' -i $run_dir/run.yaml cat $run_dir/run.yaml nohup uv run llama stack run $run_dir/run.yaml --image-type venv > server.log 2>&1 & diff --git a/docs/source/distributions/configuration.md b/docs/source/distributions/configuration.md index 1bba6677e..4709cb8c6 100644 --- a/docs/source/distributions/configuration.md +++ b/docs/source/distributions/configuration.md @@ -56,8 +56,8 @@ shields: [] server: port: 8321 auth: - provider_type: "oauth2_token" - config: + provider_config: + type: "oauth2_token" jwks: uri: "https://my-token-issuing-svc.com/jwks" ``` @@ -226,6 +226,8 @@ server: ### Authentication Configuration +> **Breaking Change (v0.2.14)**: The authentication configuration structure has changed. The previous format with `provider_type` and `config` fields has been replaced with a unified `provider_config` field that includes the `type` field. Update your configuration files accordingly. + The `auth` section configures authentication for the server. When configured, all API requests must include a valid Bearer token in the Authorization header: ``` @@ -240,8 +242,8 @@ The server can be configured to use service account tokens for authorization, va ```yaml server: auth: - provider_type: "oauth2_token" - config: + provider_config: + type: "oauth2_token" jwks: uri: "https://kubernetes.default.svc:8443/openid/v1/jwks" token: "${env.TOKEN:+}" @@ -325,13 +327,25 @@ You can easily validate a request by running: curl -s -L -H "Authorization: Bearer $(cat llama-stack-auth-token)" http://127.0.0.1:8321/v1/providers ``` +#### GitHub Token Provider +Validates GitHub personal access tokens or OAuth tokens directly: +```yaml +server: + auth: + provider_config: + type: "github_token" + github_api_base_url: "https://api.github.com" # Or GitHub Enterprise URL +``` + +The provider fetches user information from GitHub and maps it to access attributes based on the `claims_mapping` configuration. + #### Custom Provider Validates tokens against a custom authentication endpoint: ```yaml server: auth: - provider_type: "custom" - config: + provider_config: + type: "custom" endpoint: "https://auth.example.com/validate" # URL of the auth endpoint ``` @@ -416,8 +430,8 @@ clients. server: port: 8321 auth: - provider_type: custom - config: + provider_config: + type: custom endpoint: https://auth.example.com/validate quota: kvstore: diff --git a/llama_stack/distribution/datatypes.py b/llama_stack/distribution/datatypes.py index 5e48ac0ad..ead1331f3 100644 --- a/llama_stack/distribution/datatypes.py +++ b/llama_stack/distribution/datatypes.py @@ -6,9 +6,9 @@ from enum import StrEnum from pathlib import Path -from typing import Annotated, Any +from typing import Annotated, Any, Literal, Self -from pydantic import BaseModel, Field, field_validator +from pydantic import BaseModel, Field, field_validator, model_validator from llama_stack.apis.benchmarks import Benchmark, BenchmarkInput from llama_stack.apis.datasetio import DatasetIO @@ -161,23 +161,113 @@ class LoggingConfig(BaseModel): ) +class OAuth2JWKSConfig(BaseModel): + # The JWKS URI for collecting public keys + uri: str + token: str | None = Field(default=None, description="token to authorise access to jwks") + key_recheck_period: int = Field(default=3600, description="The period to recheck the JWKS URI for key updates") + + +class OAuth2IntrospectionConfig(BaseModel): + url: str + client_id: str + client_secret: str + send_secret_in_body: bool = False + + class AuthProviderType(StrEnum): """Supported authentication provider types.""" OAUTH2_TOKEN = "oauth2_token" + GITHUB_TOKEN = "github_token" CUSTOM = "custom" +class OAuth2TokenAuthConfig(BaseModel): + """Configuration for OAuth2 token authentication.""" + + type: Literal[AuthProviderType.OAUTH2_TOKEN] = AuthProviderType.OAUTH2_TOKEN + audience: str = Field(default="llama-stack") + verify_tls: bool = Field(default=True) + tls_cafile: Path | None = Field(default=None) + issuer: str | None = Field(default=None, description="The OIDC issuer URL.") + claims_mapping: dict[str, str] = Field( + default_factory=lambda: { + "sub": "roles", + "username": "roles", + "groups": "teams", + "team": "teams", + "project": "projects", + "tenant": "namespaces", + "namespace": "namespaces", + }, + ) + jwks: OAuth2JWKSConfig | None = Field(default=None, description="JWKS configuration") + introspection: OAuth2IntrospectionConfig | None = Field( + default=None, description="OAuth2 introspection configuration" + ) + + @classmethod + @field_validator("claims_mapping") + def validate_claims_mapping(cls, v): + for key, value in v.items(): + if not value: + raise ValueError(f"claims_mapping value cannot be empty: {key}") + return v + + @model_validator(mode="after") + def validate_mode(self) -> Self: + if not self.jwks and not self.introspection: + raise ValueError("One of jwks or introspection must be configured") + if self.jwks and self.introspection: + raise ValueError("At present only one of jwks or introspection should be configured") + return self + + +class CustomAuthConfig(BaseModel): + """Configuration for custom authentication.""" + + type: Literal[AuthProviderType.CUSTOM] = AuthProviderType.CUSTOM + endpoint: str = Field( + ..., + description="Custom authentication endpoint URL", + ) + + +class GitHubTokenAuthConfig(BaseModel): + """Configuration for GitHub token authentication.""" + + type: Literal[AuthProviderType.GITHUB_TOKEN] = AuthProviderType.GITHUB_TOKEN + github_api_base_url: str = Field( + default="https://api.github.com", + description="Base URL for GitHub API (use https://api.github.com for public GitHub)", + ) + claims_mapping: dict[str, str] = Field( + default_factory=lambda: { + "login": "roles", + "organizations": "teams", + }, + description="Mapping from GitHub user fields to access attributes", + ) + + +AuthProviderConfig = Annotated[ + OAuth2TokenAuthConfig | GitHubTokenAuthConfig | CustomAuthConfig, + Field(discriminator="type"), +] + + class AuthenticationConfig(BaseModel): - provider_type: AuthProviderType = Field( + """Top-level authentication configuration.""" + + provider_config: AuthProviderConfig = Field( ..., - description="Type of authentication provider", + description="Authentication provider configuration", ) - config: dict[str, Any] = Field( - ..., - description="Provider-specific configuration", + access_policy: list[AccessRule] = Field( + default=[], + description="Rules for determining access to resources", ) - access_policy: list[AccessRule] = Field(default=[], description="Rules for determining access to resources") class AuthenticationRequiredError(Exception): diff --git a/llama_stack/distribution/server/auth.py b/llama_stack/distribution/server/auth.py index 81b1ffd37..fadbf7b49 100644 --- a/llama_stack/distribution/server/auth.py +++ b/llama_stack/distribution/server/auth.py @@ -87,8 +87,12 @@ class AuthenticationMiddleware: headers = dict(scope.get("headers", [])) auth_header = headers.get(b"authorization", b"").decode() - if not auth_header or not auth_header.startswith("Bearer "): - return await self._send_auth_error(send, "Missing or invalid Authorization header") + if not auth_header: + error_msg = self.auth_provider.get_auth_error_message(scope) + return await self._send_auth_error(send, error_msg) + + if not auth_header.startswith("Bearer "): + return await self._send_auth_error(send, "Invalid Authorization header format") token = auth_header.split("Bearer ", 1)[1] diff --git a/llama_stack/distribution/server/auth_providers.py b/llama_stack/distribution/server/auth_providers.py index 173434652..9b0e182f5 100644 --- a/llama_stack/distribution/server/auth_providers.py +++ b/llama_stack/distribution/server/auth_providers.py @@ -8,15 +8,19 @@ import ssl import time from abc import ABC, abstractmethod from asyncio import Lock -from pathlib import Path -from typing import Self -from urllib.parse import parse_qs +from urllib.parse import parse_qs, urlparse import httpx from jose import jwt -from pydantic import BaseModel, Field, field_validator, model_validator +from pydantic import BaseModel, Field -from llama_stack.distribution.datatypes import AuthenticationConfig, AuthProviderType, User +from llama_stack.distribution.datatypes import ( + AuthenticationConfig, + CustomAuthConfig, + GitHubTokenAuthConfig, + OAuth2TokenAuthConfig, + User, +) from llama_stack.log import get_logger logger = get_logger(name=__name__, category="auth") @@ -38,9 +42,7 @@ class AuthRequestContext(BaseModel): headers: dict[str, str] = Field(description="HTTP headers from the original request (excluding Authorization)") - params: dict[str, list[str]] = Field( - description="Query parameters from the original request, parsed as dictionary of lists" - ) + params: dict[str, list[str]] = Field(default_factory=dict, description="Query parameters from the original request") class AuthRequest(BaseModel): @@ -62,6 +64,10 @@ class AuthProvider(ABC): """Clean up any resources.""" pass + def get_auth_error_message(self, scope: dict | None = None) -> str: + """Return provider-specific authentication error message.""" + return "Authentication required" + def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str]) -> dict[str, list[str]]: attributes: dict[str, list[str]] = {} @@ -81,56 +87,6 @@ def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str]) return attributes -class OAuth2JWKSConfig(BaseModel): - # The JWKS URI for collecting public keys - uri: str - token: str | None = Field(default=None, description="token to authorise access to jwks") - key_recheck_period: int = Field(default=3600, description="The period to recheck the JWKS URI for key updates") - - -class OAuth2IntrospectionConfig(BaseModel): - url: str - client_id: str - client_secret: str - send_secret_in_body: bool = False - - -class OAuth2TokenAuthProviderConfig(BaseModel): - audience: str = "llama-stack" - verify_tls: bool = True - tls_cafile: Path | None = None - issuer: str | None = Field(default=None, description="The OIDC issuer URL.") - claims_mapping: dict[str, str] = Field( - default_factory=lambda: { - "sub": "roles", - "username": "roles", - "groups": "teams", - "team": "teams", - "project": "projects", - "tenant": "namespaces", - "namespace": "namespaces", - }, - ) - jwks: OAuth2JWKSConfig | None - introspection: OAuth2IntrospectionConfig | None = None - - @classmethod - @field_validator("claims_mapping") - def validate_claims_mapping(cls, v): - for key, value in v.items(): - if not value: - raise ValueError(f"claims_mapping value cannot be empty: {key}") - return v - - @model_validator(mode="after") - def validate_mode(self) -> Self: - if not self.jwks and not self.introspection: - raise ValueError("One of jwks or introspection must be configured") - if self.jwks and self.introspection: - raise ValueError("At present only one of jwks or introspection should be configured") - return self - - class OAuth2TokenAuthProvider(AuthProvider): """ JWT token authentication provider that validates a JWT token and extracts access attributes. @@ -138,7 +94,7 @@ class OAuth2TokenAuthProvider(AuthProvider): This should be the standard authentication provider for most use cases. """ - def __init__(self, config: OAuth2TokenAuthProviderConfig): + def __init__(self, config: OAuth2TokenAuthConfig): self.config = config self._jwks_at: float = 0.0 self._jwks: dict[str, str] = {} @@ -170,7 +126,7 @@ class OAuth2TokenAuthProvider(AuthProvider): issuer=self.config.issuer, ) except Exception as exc: - raise ValueError(f"Invalid JWT token: {token}") from exc + raise ValueError("Invalid JWT token") from exc # There are other standard claims, the most relevant of which is `scope`. # We should incorporate these into the access attributes. @@ -232,6 +188,17 @@ class OAuth2TokenAuthProvider(AuthProvider): async def close(self): pass + def get_auth_error_message(self, scope: dict | None = None) -> str: + """Return OAuth2-specific authentication error message.""" + if self.config.issuer: + return f"Authentication required. Please provide a valid OAuth2 Bearer token from {self.config.issuer}" + elif self.config.introspection: + # Extract domain from introspection URL for a cleaner message + domain = urlparse(self.config.introspection.url).netloc + return f"Authentication required. Please provide a valid OAuth2 Bearer token validated by {domain}" + else: + return "Authentication required. Please provide a valid OAuth2 Bearer token in the Authorization header" + async def _refresh_jwks(self) -> None: """ Refresh the JWKS cache. @@ -264,14 +231,10 @@ class OAuth2TokenAuthProvider(AuthProvider): self._jwks_at = time.time() -class CustomAuthProviderConfig(BaseModel): - endpoint: str - - class CustomAuthProvider(AuthProvider): """Custom authentication provider that uses an external endpoint.""" - def __init__(self, config: CustomAuthProviderConfig): + def __init__(self, config: CustomAuthConfig): self.config = config self._client = None @@ -317,7 +280,7 @@ class CustomAuthProvider(AuthProvider): try: response_data = response.json() auth_response = AuthResponse(**response_data) - return User(auth_response.principal, auth_response.attributes) + return User(principal=auth_response.principal, attributes=auth_response.attributes) except Exception as e: logger.exception("Error parsing authentication response") raise ValueError("Invalid authentication response format") from e @@ -338,15 +301,88 @@ class CustomAuthProvider(AuthProvider): await self._client.aclose() self._client = None + def get_auth_error_message(self, scope: dict | None = None) -> str: + """Return custom auth provider-specific authentication error message.""" + domain = urlparse(self.config.endpoint).netloc + if domain: + return f"Authentication required. Please provide your API key as a Bearer token (validated by {domain})" + else: + return "Authentication required. Please provide your API key as a Bearer token in the Authorization header" + + +class GitHubTokenAuthProvider(AuthProvider): + """ + GitHub token authentication provider that validates GitHub access tokens directly. + + This provider accepts GitHub personal access tokens or OAuth tokens and verifies + them against the GitHub API to get user information. + """ + + def __init__(self, config: GitHubTokenAuthConfig): + self.config = config + + async def validate_token(self, token: str, scope: dict | None = None) -> User: + """Validate a GitHub token by calling the GitHub API. + + This validates tokens issued by GitHub (personal access tokens or OAuth tokens). + """ + try: + user_info = await _get_github_user_info(token, self.config.github_api_base_url) + except httpx.HTTPStatusError as e: + logger.warning(f"GitHub token validation failed: {e}") + raise ValueError("GitHub token validation failed. Please check your token and try again.") from e + + principal = user_info["user"]["login"] + + github_data = { + "login": user_info["user"]["login"], + "id": str(user_info["user"]["id"]), + "organizations": user_info.get("organizations", []), + } + + access_attributes = get_attributes_from_claims(github_data, self.config.claims_mapping) + + return User( + principal=principal, + attributes=access_attributes, + ) + + async def close(self): + """Clean up any resources.""" + pass + + def get_auth_error_message(self, scope: dict | None = None) -> str: + """Return GitHub-specific authentication error message.""" + return "Authentication required. Please provide a valid GitHub access token (https://docs.github.com/en/authentication/keeping-your-account-and-data-secure/managing-your-personal-access-tokens) in the Authorization header (Bearer )" + + +async def _get_github_user_info(access_token: str, github_api_base_url: str) -> dict: + """Fetch user info and organizations from GitHub API.""" + headers = { + "Authorization": f"Bearer {access_token}", + "Accept": "application/vnd.github.v3+json", + "User-Agent": "llama-stack", + } + + async with httpx.AsyncClient() as client: + user_response = await client.get(f"{github_api_base_url}/user", headers=headers, timeout=10.0) + user_response.raise_for_status() + user_data = user_response.json() + + return { + "user": user_data, + } + def create_auth_provider(config: AuthenticationConfig) -> AuthProvider: """Factory function to create the appropriate auth provider.""" - provider_type = config.provider_type.lower() + provider_config = config.provider_config - if provider_type == "custom": - return CustomAuthProvider(CustomAuthProviderConfig.model_validate(config.config)) - elif provider_type == "oauth2_token": - return OAuth2TokenAuthProvider(OAuth2TokenAuthProviderConfig.model_validate(config.config)) + if isinstance(provider_config, CustomAuthConfig): + return CustomAuthProvider(provider_config) + elif isinstance(provider_config, OAuth2TokenAuthConfig): + return OAuth2TokenAuthProvider(provider_config) + elif isinstance(provider_config, GitHubTokenAuthConfig): + return GitHubTokenAuthProvider(provider_config) else: - supported_providers = ", ".join([t.value for t in AuthProviderType]) - raise ValueError(f"Unsupported auth provider type: {provider_type}. Supported types are: {supported_providers}") + raise ValueError(f"Unknown authentication provider config type: {type(provider_config)}") diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index c19354794..a7e860a36 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -33,7 +33,11 @@ from pydantic import BaseModel, ValidationError from llama_stack.apis.common.responses import PaginatedResponse from llama_stack.distribution.access_control.access_control import AccessDeniedError -from llama_stack.distribution.datatypes import AuthenticationRequiredError, LoggingConfig, StackRunConfig +from llama_stack.distribution.datatypes import ( + AuthenticationRequiredError, + LoggingConfig, + StackRunConfig, +) from llama_stack.distribution.distribution import builtin_automatically_routed_apis from llama_stack.distribution.request_headers import PROVIDER_DATA_VAR, User, request_provider_data_context from llama_stack.distribution.resolver import InvalidProviderError @@ -217,7 +221,7 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable: # Get auth attributes from the request scope user_attributes = request.scope.get("user_attributes", {}) principal = request.scope.get("principal", "") - user = User(principal, user_attributes) + user = User(principal=principal, attributes=user_attributes) await log_request_pre_validation(request) @@ -455,7 +459,7 @@ def main(args: argparse.Namespace | None = None): # Add authentication middleware if configured if config.server.auth: - logger.info(f"Enabling authentication with provider: {config.server.auth.provider_type.value}") + logger.info(f"Enabling authentication with provider: {config.server.auth.provider_config.type.value}") app.add_middleware(AuthenticationMiddleware, auth_config=config.server.auth) else: if config.server.quota: diff --git a/tests/unit/server/test_auth.py b/tests/unit/server/test_auth.py index 4410048c5..39d6af1c8 100644 --- a/tests/unit/server/test_auth.py +++ b/tests/unit/server/test_auth.py @@ -11,10 +11,16 @@ import pytest from fastapi import FastAPI from fastapi.testclient import TestClient -from llama_stack.distribution.datatypes import AuthenticationConfig +from llama_stack.distribution.datatypes import ( + AuthenticationConfig, + AuthProviderType, + CustomAuthConfig, + OAuth2IntrospectionConfig, + OAuth2JWKSConfig, + OAuth2TokenAuthConfig, +) from llama_stack.distribution.server.auth import AuthenticationMiddleware from llama_stack.distribution.server.auth_providers import ( - AuthProviderType, get_attributes_from_claims, ) @@ -61,24 +67,11 @@ def invalid_token(): def http_app(mock_auth_endpoint): app = FastAPI() auth_config = AuthenticationConfig( - provider_type=AuthProviderType.CUSTOM, - config={"endpoint": mock_auth_endpoint}, - ) - app.add_middleware(AuthenticationMiddleware, auth_config=auth_config) - - @app.get("/test") - def test_endpoint(): - return {"message": "Authentication successful"} - - return app - - -@pytest.fixture -def k8s_app(): - app = FastAPI() - auth_config = AuthenticationConfig( - provider_type=AuthProviderType.KUBERNETES, - config={"api_server_url": "https://kubernetes.default.svc"}, + provider_config=CustomAuthConfig( + type=AuthProviderType.CUSTOM, + endpoint=mock_auth_endpoint, + ), + access_policy=[], ) app.add_middleware(AuthenticationMiddleware, auth_config=auth_config) @@ -94,11 +87,6 @@ def http_client(http_app): return TestClient(http_app) -@pytest.fixture -def k8s_client(k8s_app): - return TestClient(k8s_app) - - @pytest.fixture def mock_scope(): return { @@ -117,18 +105,11 @@ def mock_scope(): def mock_http_middleware(mock_auth_endpoint): mock_app = AsyncMock() auth_config = AuthenticationConfig( - provider_type=AuthProviderType.CUSTOM, - config={"endpoint": mock_auth_endpoint}, - ) - return AuthenticationMiddleware(mock_app, auth_config), mock_app - - -@pytest.fixture -def mock_k8s_middleware(): - mock_app = AsyncMock() - auth_config = AuthenticationConfig( - provider_type=AuthProviderType.KUBERNETES, - config={"api_server_url": "https://kubernetes.default.svc"}, + provider_config=CustomAuthConfig( + type=AuthProviderType.CUSTOM, + endpoint=mock_auth_endpoint, + ), + access_policy=[], ) return AuthenticationMiddleware(mock_app, auth_config), mock_app @@ -161,13 +142,14 @@ async def mock_post_exception(*args, **kwargs): def test_missing_auth_header(http_client): response = http_client.get("/test") assert response.status_code == 401 - assert "Missing or invalid Authorization header" in response.json()["error"]["message"] + assert "Authentication required" in response.json()["error"]["message"] + assert "validated by mock-auth-service" in response.json()["error"]["message"] def test_invalid_auth_header_format(http_client): response = http_client.get("/test", headers={"Authorization": "InvalidFormat token123"}) assert response.status_code == 401 - assert "Missing or invalid Authorization header" in response.json()["error"]["message"] + assert "Invalid Authorization header format" in response.json()["error"]["message"] @patch("httpx.AsyncClient.post", new=mock_post_success) @@ -262,14 +244,14 @@ async def test_http_middleware_with_access_attributes(mock_http_middleware, mock def oauth2_app(): app = FastAPI() auth_config = AuthenticationConfig( - provider_type=AuthProviderType.OAUTH2_TOKEN, - config={ - "jwks": { - "uri": "http://mock-authz-service/token/introspect", - "key_recheck_period": "3600", - }, - "audience": "llama-stack", - }, + provider_config=OAuth2TokenAuthConfig( + type=AuthProviderType.OAUTH2_TOKEN, + jwks=OAuth2JWKSConfig( + uri="http://mock-authz-service/token/introspect", + ), + audience="llama-stack", + ), + access_policy=[], ) app.add_middleware(AuthenticationMiddleware, auth_config=auth_config) @@ -288,13 +270,14 @@ def oauth2_client(oauth2_app): def test_missing_auth_header_oauth2(oauth2_client): response = oauth2_client.get("/test") assert response.status_code == 401 - assert "Missing or invalid Authorization header" in response.json()["error"]["message"] + assert "Authentication required" in response.json()["error"]["message"] + assert "OAuth2 Bearer token" in response.json()["error"]["message"] def test_invalid_auth_header_format_oauth2(oauth2_client): response = oauth2_client.get("/test", headers={"Authorization": "InvalidFormat token123"}) assert response.status_code == 401 - assert "Missing or invalid Authorization header" in response.json()["error"]["message"] + assert "Invalid Authorization header format" in response.json()["error"]["message"] async def mock_jwks_response(*args, **kwargs): @@ -358,15 +341,16 @@ async def mock_auth_jwks_response(*args, **kwargs): def oauth2_app_with_jwks_token(): app = FastAPI() auth_config = AuthenticationConfig( - provider_type=AuthProviderType.OAUTH2_TOKEN, - config={ - "jwks": { - "uri": "http://mock-authz-service/token/introspect", - "key_recheck_period": "3600", - "token": "my-jwks-token", - }, - "audience": "llama-stack", - }, + provider_config=OAuth2TokenAuthConfig( + type=AuthProviderType.OAUTH2_TOKEN, + jwks=OAuth2JWKSConfig( + uri="http://mock-authz-service/token/introspect", + key_recheck_period=3600, + token="my-jwks-token", + ), + audience="llama-stack", + ), + access_policy=[], ) app.add_middleware(AuthenticationMiddleware, auth_config=auth_config) @@ -449,11 +433,15 @@ def mock_introspection_endpoint(): def introspection_app(mock_introspection_endpoint): app = FastAPI() auth_config = AuthenticationConfig( - provider_type=AuthProviderType.OAUTH2_TOKEN, - config={ - "jwks": None, - "introspection": {"url": mock_introspection_endpoint, "client_id": "myclient", "client_secret": "abcdefg"}, - }, + provider_config=OAuth2TokenAuthConfig( + type=AuthProviderType.OAUTH2_TOKEN, + introspection=OAuth2IntrospectionConfig( + url=mock_introspection_endpoint, + client_id="myclient", + client_secret="abcdefg", + ), + ), + access_policy=[], ) app.add_middleware(AuthenticationMiddleware, auth_config=auth_config) @@ -468,22 +456,22 @@ def introspection_app(mock_introspection_endpoint): def introspection_app_with_custom_mapping(mock_introspection_endpoint): app = FastAPI() auth_config = AuthenticationConfig( - provider_type=AuthProviderType.OAUTH2_TOKEN, - config={ - "jwks": None, - "introspection": { - "url": mock_introspection_endpoint, - "client_id": "myclient", - "client_secret": "abcdefg", - "send_secret_in_body": "true", - }, - "claims_mapping": { + provider_config=OAuth2TokenAuthConfig( + type=AuthProviderType.OAUTH2_TOKEN, + introspection=OAuth2IntrospectionConfig( + url=mock_introspection_endpoint, + client_id="myclient", + client_secret="abcdefg", + send_secret_in_body=True, + ), + claims_mapping={ "sub": "roles", "scope": "roles", "groups": "teams", "aud": "namespaces", }, - }, + ), + access_policy=[], ) app.add_middleware(AuthenticationMiddleware, auth_config=auth_config) @@ -507,13 +495,14 @@ def introspection_client_with_custom_mapping(introspection_app_with_custom_mappi def test_missing_auth_header_introspection(introspection_client): response = introspection_client.get("/test") assert response.status_code == 401 - assert "Missing or invalid Authorization header" in response.json()["error"]["message"] + assert "Authentication required" in response.json()["error"]["message"] + assert "OAuth2 Bearer token" in response.json()["error"]["message"] def test_invalid_auth_header_format_introspection(introspection_client): response = introspection_client.get("/test", headers={"Authorization": "InvalidFormat token123"}) assert response.status_code == 401 - assert "Missing or invalid Authorization header" in response.json()["error"]["message"] + assert "Invalid Authorization header format" in response.json()["error"]["message"] async def mock_introspection_active(*args, **kwargs): diff --git a/tests/unit/server/test_auth_github.py b/tests/unit/server/test_auth_github.py new file mode 100644 index 000000000..24e60f60f --- /dev/null +++ b/tests/unit/server/test_auth_github.py @@ -0,0 +1,200 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from unittest.mock import AsyncMock, patch + +import httpx +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient + +from llama_stack.distribution.datatypes import AuthenticationConfig, AuthProviderType, GitHubTokenAuthConfig +from llama_stack.distribution.server.auth import AuthenticationMiddleware + + +class MockResponse: + def __init__(self, status_code, json_data): + self.status_code = status_code + self._json_data = json_data + + def json(self): + return self._json_data + + def raise_for_status(self): + if self.status_code != 200: + # Create a mock request for the HTTPStatusError + mock_request = httpx.Request("GET", "https://api.github.com/user") + raise httpx.HTTPStatusError(f"HTTP error: {self.status_code}", request=mock_request, response=self) + + +@pytest.fixture +def github_token_app(): + app = FastAPI() + + # Configure GitHub token auth + auth_config = AuthenticationConfig( + provider_config=GitHubTokenAuthConfig( + type=AuthProviderType.GITHUB_TOKEN, + github_api_base_url="https://api.github.com", + claims_mapping={ + "login": "username", + "id": "user_id", + "organizations": "teams", + }, + ), + access_policy=[], + ) + + # Add auth middleware + app.add_middleware(AuthenticationMiddleware, auth_config=auth_config) + + @app.get("/test") + def test_endpoint(): + return {"message": "Authentication successful"} + + return app + + +@pytest.fixture +def github_token_client(github_token_app): + return TestClient(github_token_app) + + +def test_authenticated_endpoint_without_token(github_token_client): + """Test accessing protected endpoint without token""" + response = github_token_client.get("/test") + assert response.status_code == 401 + assert "Authentication required" in response.json()["error"]["message"] + assert "GitHub access token" in response.json()["error"]["message"] + + +def test_authenticated_endpoint_with_invalid_bearer_format(github_token_client): + """Test accessing protected endpoint with invalid bearer format""" + response = github_token_client.get("/test", headers={"Authorization": "InvalidFormat token123"}) + assert response.status_code == 401 + assert "Invalid Authorization header format" in response.json()["error"]["message"] + + +@patch("llama_stack.distribution.server.auth_providers.httpx.AsyncClient") +def test_authenticated_endpoint_with_valid_github_token(mock_client_class, github_token_client): + """Test accessing protected endpoint with valid GitHub token""" + # Mock the GitHub API responses + mock_client = AsyncMock() + mock_client_class.return_value.__aenter__.return_value = mock_client + + # Mock successful user API response + mock_client.get.side_effect = [ + MockResponse( + 200, + { + "login": "testuser", + "id": 12345, + "email": "test@example.com", + "name": "Test User", + }, + ), + MockResponse( + 200, + [ + {"login": "test-org-1"}, + {"login": "test-org-2"}, + ], + ), + ] + + response = github_token_client.get("/test", headers={"Authorization": "Bearer github_token_123"}) + assert response.status_code == 200 + assert response.json()["message"] == "Authentication successful" + + # Verify the GitHub API was called correctly + assert mock_client.get.call_count == 1 + calls = mock_client.get.call_args_list + assert calls[0][0][0] == "https://api.github.com/user" + + # Check authorization header was passed + assert calls[0][1]["headers"]["Authorization"] == "Bearer github_token_123" + + +@patch("llama_stack.distribution.server.auth_providers.httpx.AsyncClient") +def test_authenticated_endpoint_with_invalid_github_token(mock_client_class, github_token_client): + """Test accessing protected endpoint with invalid GitHub token""" + # Mock the GitHub API to return 401 Unauthorized + mock_client = AsyncMock() + mock_client_class.return_value.__aenter__.return_value = mock_client + + # Mock failed user API response + mock_client.get.return_value = MockResponse(401, {"message": "Bad credentials"}) + + response = github_token_client.get("/test", headers={"Authorization": "Bearer invalid_token"}) + assert response.status_code == 401 + assert ( + "GitHub token validation failed. Please check your token and try again." in response.json()["error"]["message"] + ) + + +@patch("llama_stack.distribution.server.auth_providers.httpx.AsyncClient") +def test_github_enterprise_support(mock_client_class): + """Test GitHub Enterprise support with custom API base URL""" + app = FastAPI() + + # Configure GitHub token auth with enterprise URL + auth_config = AuthenticationConfig( + provider_config=GitHubTokenAuthConfig( + type=AuthProviderType.GITHUB_TOKEN, + github_api_base_url="https://github.enterprise.com/api/v3", + ), + access_policy=[], + ) + + app.add_middleware(AuthenticationMiddleware, auth_config=auth_config) + + @app.get("/test") + def test_endpoint(): + return {"message": "Authentication successful"} + + client = TestClient(app) + + # Mock the GitHub Enterprise API responses + mock_client = AsyncMock() + mock_client_class.return_value.__aenter__.return_value = mock_client + + # Mock successful user API response + mock_client.get.side_effect = [ + MockResponse( + 200, + { + "login": "enterprise_user", + "id": 99999, + "email": "user@enterprise.com", + }, + ), + MockResponse( + 200, + [ + {"login": "enterprise-org"}, + ], + ), + ] + + response = client.get("/test", headers={"Authorization": "Bearer enterprise_token"}) + assert response.status_code == 200 + + # Verify the correct GitHub Enterprise URLs were called + assert mock_client.get.call_count == 1 + calls = mock_client.get.call_args_list + assert calls[0][0][0] == "https://github.enterprise.com/api/v3/user" + + +def test_github_token_auth_error_message_format(github_token_client): + """Test that the error message for missing auth is properly formatted""" + response = github_token_client.get("/test") + assert response.status_code == 401 + + error_data = response.json() + assert "error" in error_data + assert "message" in error_data["error"] + assert "Authentication required" in error_data["error"]["message"] + assert "https://docs.github.com" in error_data["error"]["message"] # Contains link to GitHub docs