mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-13 16:46:09 +00:00
feat(auth): support github tokens (#2509)
# What does this PR do? This PR adds GitHub OAuth authentication support to Llama Stack, allowing users to authenticate using their GitHub credentials (#2508) . 1. support verifying github acesss tokens 2. support provider-specific auth error messages 3. opportunistic reorganized the auth configs for better ergonomics ## Test Plan Added unit tests. Also tested e2e manually: ``` server: port: 8321 auth: provider_config: type: github_token ``` ``` ~/projects/llama-stack/llama_stack/ui ❯ curl -v http://localhost:8321/v1/models * Host localhost:8321 was resolved. * IPv6: ::1 * IPv4: 127.0.0.1 * Trying [::1]:8321... * Connected to localhost (::1) port 8321 > GET /v1/models HTTP/1.1 > Host: localhost:8321 > User-Agent: curl/8.7.1 > Accept: */* > * Request completely sent off < HTTP/1.1 401 Unauthorized < date: Fri, 27 Jun 2025 21:51:25 GMT < server: uvicorn < content-type: application/json < x-trace-id: 5390c6c0654086c55d87c86d7cbf2f6a < Transfer-Encoding: chunked < * Connection #0 to host localhost left intact {"error": {"message": "Authentication required. Please provide a valid GitHub access token (https://docs.github.com/en/authentication/keeping-your-account-and-data-secure/managing-your-personal-access-tokens) in the Authorization header (Bearer <token>)"}} ~/projects/llama-stack/llama_stack/ui ❯ ./scripts/unit-tests.sh ~/projects/llama-stack/llama_stack/ui ❯ curl "http://localhost:8321/v1/models" \ -H "Authorization: Bearer <token_obtained_from_github>" \ {"data":[{"identifier":"accounts/fireworks/models/llama-guard-3-11b-vision","provider_resource_id":"accounts/fireworks/models/llama-guard-3-11b-vision","provider_id":"fireworks","type":"model","metadata":{},"model_type":"llm"},{"identifier":"accounts/fireworks/models/llama-guard-3-8b","provider_resource_id":"accounts/fireworks/models/llama-guard-3-8b","provider_id":"fireworks","type":"model","metadata":{},"model_type":"llm"},{"identifier":"accounts/fireworks/models/llama-v3p1-405b-instruct","provider_resource_id":"accounts/f ``` --------- Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
parent
83c89265e0
commit
c8bac888af
8 changed files with 513 additions and 173 deletions
9
.github/workflows/integration-auth-tests.yml
vendored
9
.github/workflows/integration-auth-tests.yml
vendored
|
@ -73,9 +73,12 @@ jobs:
|
|||
server:
|
||||
port: 8321
|
||||
EOF
|
||||
yq eval '.server.auth = {"provider_type": "${{ matrix.auth-provider }}"}' -i $run_dir/run.yaml
|
||||
yq eval '.server.auth.config = {"tls_cafile": "${{ env.KUBERNETES_CA_CERT_PATH }}", "issuer": "${{ env.KUBERNETES_ISSUER }}", "audience": "${{ env.KUBERNETES_AUDIENCE }}"}' -i $run_dir/run.yaml
|
||||
yq eval '.server.auth.config.jwks = {"uri": "${{ env.KUBERNETES_API_SERVER_URL }}", "token": "${{ env.TOKEN }}"}' -i $run_dir/run.yaml
|
||||
yq eval '.server.auth.provider_config.type = "${{ matrix.auth-provider }}"' -i $run_dir/run.yaml
|
||||
yq eval '.server.auth.provider_config.tls_cafile = "${{ env.KUBERNETES_CA_CERT_PATH }}"' -i $run_dir/run.yaml
|
||||
yq eval '.server.auth.provider_config.issuer = "${{ env.KUBERNETES_ISSUER }}"' -i $run_dir/run.yaml
|
||||
yq eval '.server.auth.provider_config.audience = "${{ env.KUBERNETES_AUDIENCE }}"' -i $run_dir/run.yaml
|
||||
yq eval '.server.auth.provider_config.jwks.uri = "${{ env.KUBERNETES_API_SERVER_URL }}"' -i $run_dir/run.yaml
|
||||
yq eval '.server.auth.provider_config.jwks.token = "${{ env.TOKEN }}"' -i $run_dir/run.yaml
|
||||
cat $run_dir/run.yaml
|
||||
|
||||
nohup uv run llama stack run $run_dir/run.yaml --image-type venv > server.log 2>&1 &
|
||||
|
|
|
@ -56,8 +56,8 @@ shields: []
|
|||
server:
|
||||
port: 8321
|
||||
auth:
|
||||
provider_type: "oauth2_token"
|
||||
config:
|
||||
provider_config:
|
||||
type: "oauth2_token"
|
||||
jwks:
|
||||
uri: "https://my-token-issuing-svc.com/jwks"
|
||||
```
|
||||
|
@ -226,6 +226,8 @@ server:
|
|||
|
||||
### Authentication Configuration
|
||||
|
||||
> **Breaking Change (v0.2.14)**: The authentication configuration structure has changed. The previous format with `provider_type` and `config` fields has been replaced with a unified `provider_config` field that includes the `type` field. Update your configuration files accordingly.
|
||||
|
||||
The `auth` section configures authentication for the server. When configured, all API requests must include a valid Bearer token in the Authorization header:
|
||||
|
||||
```
|
||||
|
@ -240,8 +242,8 @@ The server can be configured to use service account tokens for authorization, va
|
|||
```yaml
|
||||
server:
|
||||
auth:
|
||||
provider_type: "oauth2_token"
|
||||
config:
|
||||
provider_config:
|
||||
type: "oauth2_token"
|
||||
jwks:
|
||||
uri: "https://kubernetes.default.svc:8443/openid/v1/jwks"
|
||||
token: "${env.TOKEN:+}"
|
||||
|
@ -325,13 +327,25 @@ You can easily validate a request by running:
|
|||
curl -s -L -H "Authorization: Bearer $(cat llama-stack-auth-token)" http://127.0.0.1:8321/v1/providers
|
||||
```
|
||||
|
||||
#### GitHub Token Provider
|
||||
Validates GitHub personal access tokens or OAuth tokens directly:
|
||||
```yaml
|
||||
server:
|
||||
auth:
|
||||
provider_config:
|
||||
type: "github_token"
|
||||
github_api_base_url: "https://api.github.com" # Or GitHub Enterprise URL
|
||||
```
|
||||
|
||||
The provider fetches user information from GitHub and maps it to access attributes based on the `claims_mapping` configuration.
|
||||
|
||||
#### Custom Provider
|
||||
Validates tokens against a custom authentication endpoint:
|
||||
```yaml
|
||||
server:
|
||||
auth:
|
||||
provider_type: "custom"
|
||||
config:
|
||||
provider_config:
|
||||
type: "custom"
|
||||
endpoint: "https://auth.example.com/validate" # URL of the auth endpoint
|
||||
```
|
||||
|
||||
|
@ -416,8 +430,8 @@ clients.
|
|||
server:
|
||||
port: 8321
|
||||
auth:
|
||||
provider_type: custom
|
||||
config:
|
||||
provider_config:
|
||||
type: custom
|
||||
endpoint: https://auth.example.com/validate
|
||||
quota:
|
||||
kvstore:
|
||||
|
|
|
@ -6,9 +6,9 @@
|
|||
|
||||
from enum import StrEnum
|
||||
from pathlib import Path
|
||||
from typing import Annotated, Any
|
||||
from typing import Annotated, Any, Literal, Self
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
|
||||
from llama_stack.apis.benchmarks import Benchmark, BenchmarkInput
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
|
@ -161,23 +161,113 @@ class LoggingConfig(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
class OAuth2JWKSConfig(BaseModel):
|
||||
# The JWKS URI for collecting public keys
|
||||
uri: str
|
||||
token: str | None = Field(default=None, description="token to authorise access to jwks")
|
||||
key_recheck_period: int = Field(default=3600, description="The period to recheck the JWKS URI for key updates")
|
||||
|
||||
|
||||
class OAuth2IntrospectionConfig(BaseModel):
|
||||
url: str
|
||||
client_id: str
|
||||
client_secret: str
|
||||
send_secret_in_body: bool = False
|
||||
|
||||
|
||||
class AuthProviderType(StrEnum):
|
||||
"""Supported authentication provider types."""
|
||||
|
||||
OAUTH2_TOKEN = "oauth2_token"
|
||||
GITHUB_TOKEN = "github_token"
|
||||
CUSTOM = "custom"
|
||||
|
||||
|
||||
class OAuth2TokenAuthConfig(BaseModel):
|
||||
"""Configuration for OAuth2 token authentication."""
|
||||
|
||||
type: Literal[AuthProviderType.OAUTH2_TOKEN] = AuthProviderType.OAUTH2_TOKEN
|
||||
audience: str = Field(default="llama-stack")
|
||||
verify_tls: bool = Field(default=True)
|
||||
tls_cafile: Path | None = Field(default=None)
|
||||
issuer: str | None = Field(default=None, description="The OIDC issuer URL.")
|
||||
claims_mapping: dict[str, str] = Field(
|
||||
default_factory=lambda: {
|
||||
"sub": "roles",
|
||||
"username": "roles",
|
||||
"groups": "teams",
|
||||
"team": "teams",
|
||||
"project": "projects",
|
||||
"tenant": "namespaces",
|
||||
"namespace": "namespaces",
|
||||
},
|
||||
)
|
||||
jwks: OAuth2JWKSConfig | None = Field(default=None, description="JWKS configuration")
|
||||
introspection: OAuth2IntrospectionConfig | None = Field(
|
||||
default=None, description="OAuth2 introspection configuration"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@field_validator("claims_mapping")
|
||||
def validate_claims_mapping(cls, v):
|
||||
for key, value in v.items():
|
||||
if not value:
|
||||
raise ValueError(f"claims_mapping value cannot be empty: {key}")
|
||||
return v
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_mode(self) -> Self:
|
||||
if not self.jwks and not self.introspection:
|
||||
raise ValueError("One of jwks or introspection must be configured")
|
||||
if self.jwks and self.introspection:
|
||||
raise ValueError("At present only one of jwks or introspection should be configured")
|
||||
return self
|
||||
|
||||
|
||||
class CustomAuthConfig(BaseModel):
|
||||
"""Configuration for custom authentication."""
|
||||
|
||||
type: Literal[AuthProviderType.CUSTOM] = AuthProviderType.CUSTOM
|
||||
endpoint: str = Field(
|
||||
...,
|
||||
description="Custom authentication endpoint URL",
|
||||
)
|
||||
|
||||
|
||||
class GitHubTokenAuthConfig(BaseModel):
|
||||
"""Configuration for GitHub token authentication."""
|
||||
|
||||
type: Literal[AuthProviderType.GITHUB_TOKEN] = AuthProviderType.GITHUB_TOKEN
|
||||
github_api_base_url: str = Field(
|
||||
default="https://api.github.com",
|
||||
description="Base URL for GitHub API (use https://api.github.com for public GitHub)",
|
||||
)
|
||||
claims_mapping: dict[str, str] = Field(
|
||||
default_factory=lambda: {
|
||||
"login": "roles",
|
||||
"organizations": "teams",
|
||||
},
|
||||
description="Mapping from GitHub user fields to access attributes",
|
||||
)
|
||||
|
||||
|
||||
AuthProviderConfig = Annotated[
|
||||
OAuth2TokenAuthConfig | GitHubTokenAuthConfig | CustomAuthConfig,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
||||
class AuthenticationConfig(BaseModel):
|
||||
provider_type: AuthProviderType = Field(
|
||||
"""Top-level authentication configuration."""
|
||||
|
||||
provider_config: AuthProviderConfig = Field(
|
||||
...,
|
||||
description="Type of authentication provider",
|
||||
description="Authentication provider configuration",
|
||||
)
|
||||
config: dict[str, Any] = Field(
|
||||
...,
|
||||
description="Provider-specific configuration",
|
||||
access_policy: list[AccessRule] = Field(
|
||||
default=[],
|
||||
description="Rules for determining access to resources",
|
||||
)
|
||||
access_policy: list[AccessRule] = Field(default=[], description="Rules for determining access to resources")
|
||||
|
||||
|
||||
class AuthenticationRequiredError(Exception):
|
||||
|
|
|
@ -87,8 +87,12 @@ class AuthenticationMiddleware:
|
|||
headers = dict(scope.get("headers", []))
|
||||
auth_header = headers.get(b"authorization", b"").decode()
|
||||
|
||||
if not auth_header or not auth_header.startswith("Bearer "):
|
||||
return await self._send_auth_error(send, "Missing or invalid Authorization header")
|
||||
if not auth_header:
|
||||
error_msg = self.auth_provider.get_auth_error_message(scope)
|
||||
return await self._send_auth_error(send, error_msg)
|
||||
|
||||
if not auth_header.startswith("Bearer "):
|
||||
return await self._send_auth_error(send, "Invalid Authorization header format")
|
||||
|
||||
token = auth_header.split("Bearer ", 1)[1]
|
||||
|
||||
|
|
|
@ -8,15 +8,19 @@ import ssl
|
|||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from asyncio import Lock
|
||||
from pathlib import Path
|
||||
from typing import Self
|
||||
from urllib.parse import parse_qs
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
import httpx
|
||||
from jose import jwt
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.distribution.datatypes import AuthenticationConfig, AuthProviderType, User
|
||||
from llama_stack.distribution.datatypes import (
|
||||
AuthenticationConfig,
|
||||
CustomAuthConfig,
|
||||
GitHubTokenAuthConfig,
|
||||
OAuth2TokenAuthConfig,
|
||||
User,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
logger = get_logger(name=__name__, category="auth")
|
||||
|
@ -38,9 +42,7 @@ class AuthRequestContext(BaseModel):
|
|||
|
||||
headers: dict[str, str] = Field(description="HTTP headers from the original request (excluding Authorization)")
|
||||
|
||||
params: dict[str, list[str]] = Field(
|
||||
description="Query parameters from the original request, parsed as dictionary of lists"
|
||||
)
|
||||
params: dict[str, list[str]] = Field(default_factory=dict, description="Query parameters from the original request")
|
||||
|
||||
|
||||
class AuthRequest(BaseModel):
|
||||
|
@ -62,6 +64,10 @@ class AuthProvider(ABC):
|
|||
"""Clean up any resources."""
|
||||
pass
|
||||
|
||||
def get_auth_error_message(self, scope: dict | None = None) -> str:
|
||||
"""Return provider-specific authentication error message."""
|
||||
return "Authentication required"
|
||||
|
||||
|
||||
def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str]) -> dict[str, list[str]]:
|
||||
attributes: dict[str, list[str]] = {}
|
||||
|
@ -81,56 +87,6 @@ def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str])
|
|||
return attributes
|
||||
|
||||
|
||||
class OAuth2JWKSConfig(BaseModel):
|
||||
# The JWKS URI for collecting public keys
|
||||
uri: str
|
||||
token: str | None = Field(default=None, description="token to authorise access to jwks")
|
||||
key_recheck_period: int = Field(default=3600, description="The period to recheck the JWKS URI for key updates")
|
||||
|
||||
|
||||
class OAuth2IntrospectionConfig(BaseModel):
|
||||
url: str
|
||||
client_id: str
|
||||
client_secret: str
|
||||
send_secret_in_body: bool = False
|
||||
|
||||
|
||||
class OAuth2TokenAuthProviderConfig(BaseModel):
|
||||
audience: str = "llama-stack"
|
||||
verify_tls: bool = True
|
||||
tls_cafile: Path | None = None
|
||||
issuer: str | None = Field(default=None, description="The OIDC issuer URL.")
|
||||
claims_mapping: dict[str, str] = Field(
|
||||
default_factory=lambda: {
|
||||
"sub": "roles",
|
||||
"username": "roles",
|
||||
"groups": "teams",
|
||||
"team": "teams",
|
||||
"project": "projects",
|
||||
"tenant": "namespaces",
|
||||
"namespace": "namespaces",
|
||||
},
|
||||
)
|
||||
jwks: OAuth2JWKSConfig | None
|
||||
introspection: OAuth2IntrospectionConfig | None = None
|
||||
|
||||
@classmethod
|
||||
@field_validator("claims_mapping")
|
||||
def validate_claims_mapping(cls, v):
|
||||
for key, value in v.items():
|
||||
if not value:
|
||||
raise ValueError(f"claims_mapping value cannot be empty: {key}")
|
||||
return v
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_mode(self) -> Self:
|
||||
if not self.jwks and not self.introspection:
|
||||
raise ValueError("One of jwks or introspection must be configured")
|
||||
if self.jwks and self.introspection:
|
||||
raise ValueError("At present only one of jwks or introspection should be configured")
|
||||
return self
|
||||
|
||||
|
||||
class OAuth2TokenAuthProvider(AuthProvider):
|
||||
"""
|
||||
JWT token authentication provider that validates a JWT token and extracts access attributes.
|
||||
|
@ -138,7 +94,7 @@ class OAuth2TokenAuthProvider(AuthProvider):
|
|||
This should be the standard authentication provider for most use cases.
|
||||
"""
|
||||
|
||||
def __init__(self, config: OAuth2TokenAuthProviderConfig):
|
||||
def __init__(self, config: OAuth2TokenAuthConfig):
|
||||
self.config = config
|
||||
self._jwks_at: float = 0.0
|
||||
self._jwks: dict[str, str] = {}
|
||||
|
@ -170,7 +126,7 @@ class OAuth2TokenAuthProvider(AuthProvider):
|
|||
issuer=self.config.issuer,
|
||||
)
|
||||
except Exception as exc:
|
||||
raise ValueError(f"Invalid JWT token: {token}") from exc
|
||||
raise ValueError("Invalid JWT token") from exc
|
||||
|
||||
# There are other standard claims, the most relevant of which is `scope`.
|
||||
# We should incorporate these into the access attributes.
|
||||
|
@ -232,6 +188,17 @@ class OAuth2TokenAuthProvider(AuthProvider):
|
|||
async def close(self):
|
||||
pass
|
||||
|
||||
def get_auth_error_message(self, scope: dict | None = None) -> str:
|
||||
"""Return OAuth2-specific authentication error message."""
|
||||
if self.config.issuer:
|
||||
return f"Authentication required. Please provide a valid OAuth2 Bearer token from {self.config.issuer}"
|
||||
elif self.config.introspection:
|
||||
# Extract domain from introspection URL for a cleaner message
|
||||
domain = urlparse(self.config.introspection.url).netloc
|
||||
return f"Authentication required. Please provide a valid OAuth2 Bearer token validated by {domain}"
|
||||
else:
|
||||
return "Authentication required. Please provide a valid OAuth2 Bearer token in the Authorization header"
|
||||
|
||||
async def _refresh_jwks(self) -> None:
|
||||
"""
|
||||
Refresh the JWKS cache.
|
||||
|
@ -264,14 +231,10 @@ class OAuth2TokenAuthProvider(AuthProvider):
|
|||
self._jwks_at = time.time()
|
||||
|
||||
|
||||
class CustomAuthProviderConfig(BaseModel):
|
||||
endpoint: str
|
||||
|
||||
|
||||
class CustomAuthProvider(AuthProvider):
|
||||
"""Custom authentication provider that uses an external endpoint."""
|
||||
|
||||
def __init__(self, config: CustomAuthProviderConfig):
|
||||
def __init__(self, config: CustomAuthConfig):
|
||||
self.config = config
|
||||
self._client = None
|
||||
|
||||
|
@ -317,7 +280,7 @@ class CustomAuthProvider(AuthProvider):
|
|||
try:
|
||||
response_data = response.json()
|
||||
auth_response = AuthResponse(**response_data)
|
||||
return User(auth_response.principal, auth_response.attributes)
|
||||
return User(principal=auth_response.principal, attributes=auth_response.attributes)
|
||||
except Exception as e:
|
||||
logger.exception("Error parsing authentication response")
|
||||
raise ValueError("Invalid authentication response format") from e
|
||||
|
@ -338,15 +301,88 @@ class CustomAuthProvider(AuthProvider):
|
|||
await self._client.aclose()
|
||||
self._client = None
|
||||
|
||||
def get_auth_error_message(self, scope: dict | None = None) -> str:
|
||||
"""Return custom auth provider-specific authentication error message."""
|
||||
domain = urlparse(self.config.endpoint).netloc
|
||||
if domain:
|
||||
return f"Authentication required. Please provide your API key as a Bearer token (validated by {domain})"
|
||||
else:
|
||||
return "Authentication required. Please provide your API key as a Bearer token in the Authorization header"
|
||||
|
||||
|
||||
class GitHubTokenAuthProvider(AuthProvider):
|
||||
"""
|
||||
GitHub token authentication provider that validates GitHub access tokens directly.
|
||||
|
||||
This provider accepts GitHub personal access tokens or OAuth tokens and verifies
|
||||
them against the GitHub API to get user information.
|
||||
"""
|
||||
|
||||
def __init__(self, config: GitHubTokenAuthConfig):
|
||||
self.config = config
|
||||
|
||||
async def validate_token(self, token: str, scope: dict | None = None) -> User:
|
||||
"""Validate a GitHub token by calling the GitHub API.
|
||||
|
||||
This validates tokens issued by GitHub (personal access tokens or OAuth tokens).
|
||||
"""
|
||||
try:
|
||||
user_info = await _get_github_user_info(token, self.config.github_api_base_url)
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.warning(f"GitHub token validation failed: {e}")
|
||||
raise ValueError("GitHub token validation failed. Please check your token and try again.") from e
|
||||
|
||||
principal = user_info["user"]["login"]
|
||||
|
||||
github_data = {
|
||||
"login": user_info["user"]["login"],
|
||||
"id": str(user_info["user"]["id"]),
|
||||
"organizations": user_info.get("organizations", []),
|
||||
}
|
||||
|
||||
access_attributes = get_attributes_from_claims(github_data, self.config.claims_mapping)
|
||||
|
||||
return User(
|
||||
principal=principal,
|
||||
attributes=access_attributes,
|
||||
)
|
||||
|
||||
async def close(self):
|
||||
"""Clean up any resources."""
|
||||
pass
|
||||
|
||||
def get_auth_error_message(self, scope: dict | None = None) -> str:
|
||||
"""Return GitHub-specific authentication error message."""
|
||||
return "Authentication required. Please provide a valid GitHub access token (https://docs.github.com/en/authentication/keeping-your-account-and-data-secure/managing-your-personal-access-tokens) in the Authorization header (Bearer <token>)"
|
||||
|
||||
|
||||
async def _get_github_user_info(access_token: str, github_api_base_url: str) -> dict:
|
||||
"""Fetch user info and organizations from GitHub API."""
|
||||
headers = {
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"Accept": "application/vnd.github.v3+json",
|
||||
"User-Agent": "llama-stack",
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
user_response = await client.get(f"{github_api_base_url}/user", headers=headers, timeout=10.0)
|
||||
user_response.raise_for_status()
|
||||
user_data = user_response.json()
|
||||
|
||||
return {
|
||||
"user": user_data,
|
||||
}
|
||||
|
||||
|
||||
def create_auth_provider(config: AuthenticationConfig) -> AuthProvider:
|
||||
"""Factory function to create the appropriate auth provider."""
|
||||
provider_type = config.provider_type.lower()
|
||||
provider_config = config.provider_config
|
||||
|
||||
if provider_type == "custom":
|
||||
return CustomAuthProvider(CustomAuthProviderConfig.model_validate(config.config))
|
||||
elif provider_type == "oauth2_token":
|
||||
return OAuth2TokenAuthProvider(OAuth2TokenAuthProviderConfig.model_validate(config.config))
|
||||
if isinstance(provider_config, CustomAuthConfig):
|
||||
return CustomAuthProvider(provider_config)
|
||||
elif isinstance(provider_config, OAuth2TokenAuthConfig):
|
||||
return OAuth2TokenAuthProvider(provider_config)
|
||||
elif isinstance(provider_config, GitHubTokenAuthConfig):
|
||||
return GitHubTokenAuthProvider(provider_config)
|
||||
else:
|
||||
supported_providers = ", ".join([t.value for t in AuthProviderType])
|
||||
raise ValueError(f"Unsupported auth provider type: {provider_type}. Supported types are: {supported_providers}")
|
||||
raise ValueError(f"Unknown authentication provider config type: {type(provider_config)}")
|
||||
|
|
|
@ -33,7 +33,11 @@ from pydantic import BaseModel, ValidationError
|
|||
|
||||
from llama_stack.apis.common.responses import PaginatedResponse
|
||||
from llama_stack.distribution.access_control.access_control import AccessDeniedError
|
||||
from llama_stack.distribution.datatypes import AuthenticationRequiredError, LoggingConfig, StackRunConfig
|
||||
from llama_stack.distribution.datatypes import (
|
||||
AuthenticationRequiredError,
|
||||
LoggingConfig,
|
||||
StackRunConfig,
|
||||
)
|
||||
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
||||
from llama_stack.distribution.request_headers import PROVIDER_DATA_VAR, User, request_provider_data_context
|
||||
from llama_stack.distribution.resolver import InvalidProviderError
|
||||
|
@ -217,7 +221,7 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
|
|||
# Get auth attributes from the request scope
|
||||
user_attributes = request.scope.get("user_attributes", {})
|
||||
principal = request.scope.get("principal", "")
|
||||
user = User(principal, user_attributes)
|
||||
user = User(principal=principal, attributes=user_attributes)
|
||||
|
||||
await log_request_pre_validation(request)
|
||||
|
||||
|
@ -455,7 +459,7 @@ def main(args: argparse.Namespace | None = None):
|
|||
|
||||
# Add authentication middleware if configured
|
||||
if config.server.auth:
|
||||
logger.info(f"Enabling authentication with provider: {config.server.auth.provider_type.value}")
|
||||
logger.info(f"Enabling authentication with provider: {config.server.auth.provider_config.type.value}")
|
||||
app.add_middleware(AuthenticationMiddleware, auth_config=config.server.auth)
|
||||
else:
|
||||
if config.server.quota:
|
||||
|
|
|
@ -11,10 +11,16 @@ import pytest
|
|||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from llama_stack.distribution.datatypes import AuthenticationConfig
|
||||
from llama_stack.distribution.datatypes import (
|
||||
AuthenticationConfig,
|
||||
AuthProviderType,
|
||||
CustomAuthConfig,
|
||||
OAuth2IntrospectionConfig,
|
||||
OAuth2JWKSConfig,
|
||||
OAuth2TokenAuthConfig,
|
||||
)
|
||||
from llama_stack.distribution.server.auth import AuthenticationMiddleware
|
||||
from llama_stack.distribution.server.auth_providers import (
|
||||
AuthProviderType,
|
||||
get_attributes_from_claims,
|
||||
)
|
||||
|
||||
|
@ -61,24 +67,11 @@ def invalid_token():
|
|||
def http_app(mock_auth_endpoint):
|
||||
app = FastAPI()
|
||||
auth_config = AuthenticationConfig(
|
||||
provider_type=AuthProviderType.CUSTOM,
|
||||
config={"endpoint": mock_auth_endpoint},
|
||||
)
|
||||
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
|
||||
|
||||
@app.get("/test")
|
||||
def test_endpoint():
|
||||
return {"message": "Authentication successful"}
|
||||
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def k8s_app():
|
||||
app = FastAPI()
|
||||
auth_config = AuthenticationConfig(
|
||||
provider_type=AuthProviderType.KUBERNETES,
|
||||
config={"api_server_url": "https://kubernetes.default.svc"},
|
||||
provider_config=CustomAuthConfig(
|
||||
type=AuthProviderType.CUSTOM,
|
||||
endpoint=mock_auth_endpoint,
|
||||
),
|
||||
access_policy=[],
|
||||
)
|
||||
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
|
||||
|
||||
|
@ -94,11 +87,6 @@ def http_client(http_app):
|
|||
return TestClient(http_app)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def k8s_client(k8s_app):
|
||||
return TestClient(k8s_app)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_scope():
|
||||
return {
|
||||
|
@ -117,18 +105,11 @@ def mock_scope():
|
|||
def mock_http_middleware(mock_auth_endpoint):
|
||||
mock_app = AsyncMock()
|
||||
auth_config = AuthenticationConfig(
|
||||
provider_type=AuthProviderType.CUSTOM,
|
||||
config={"endpoint": mock_auth_endpoint},
|
||||
)
|
||||
return AuthenticationMiddleware(mock_app, auth_config), mock_app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_k8s_middleware():
|
||||
mock_app = AsyncMock()
|
||||
auth_config = AuthenticationConfig(
|
||||
provider_type=AuthProviderType.KUBERNETES,
|
||||
config={"api_server_url": "https://kubernetes.default.svc"},
|
||||
provider_config=CustomAuthConfig(
|
||||
type=AuthProviderType.CUSTOM,
|
||||
endpoint=mock_auth_endpoint,
|
||||
),
|
||||
access_policy=[],
|
||||
)
|
||||
return AuthenticationMiddleware(mock_app, auth_config), mock_app
|
||||
|
||||
|
@ -161,13 +142,14 @@ async def mock_post_exception(*args, **kwargs):
|
|||
def test_missing_auth_header(http_client):
|
||||
response = http_client.get("/test")
|
||||
assert response.status_code == 401
|
||||
assert "Missing or invalid Authorization header" in response.json()["error"]["message"]
|
||||
assert "Authentication required" in response.json()["error"]["message"]
|
||||
assert "validated by mock-auth-service" in response.json()["error"]["message"]
|
||||
|
||||
|
||||
def test_invalid_auth_header_format(http_client):
|
||||
response = http_client.get("/test", headers={"Authorization": "InvalidFormat token123"})
|
||||
assert response.status_code == 401
|
||||
assert "Missing or invalid Authorization header" in response.json()["error"]["message"]
|
||||
assert "Invalid Authorization header format" in response.json()["error"]["message"]
|
||||
|
||||
|
||||
@patch("httpx.AsyncClient.post", new=mock_post_success)
|
||||
|
@ -262,14 +244,14 @@ async def test_http_middleware_with_access_attributes(mock_http_middleware, mock
|
|||
def oauth2_app():
|
||||
app = FastAPI()
|
||||
auth_config = AuthenticationConfig(
|
||||
provider_type=AuthProviderType.OAUTH2_TOKEN,
|
||||
config={
|
||||
"jwks": {
|
||||
"uri": "http://mock-authz-service/token/introspect",
|
||||
"key_recheck_period": "3600",
|
||||
},
|
||||
"audience": "llama-stack",
|
||||
},
|
||||
provider_config=OAuth2TokenAuthConfig(
|
||||
type=AuthProviderType.OAUTH2_TOKEN,
|
||||
jwks=OAuth2JWKSConfig(
|
||||
uri="http://mock-authz-service/token/introspect",
|
||||
),
|
||||
audience="llama-stack",
|
||||
),
|
||||
access_policy=[],
|
||||
)
|
||||
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
|
||||
|
||||
|
@ -288,13 +270,14 @@ def oauth2_client(oauth2_app):
|
|||
def test_missing_auth_header_oauth2(oauth2_client):
|
||||
response = oauth2_client.get("/test")
|
||||
assert response.status_code == 401
|
||||
assert "Missing or invalid Authorization header" in response.json()["error"]["message"]
|
||||
assert "Authentication required" in response.json()["error"]["message"]
|
||||
assert "OAuth2 Bearer token" in response.json()["error"]["message"]
|
||||
|
||||
|
||||
def test_invalid_auth_header_format_oauth2(oauth2_client):
|
||||
response = oauth2_client.get("/test", headers={"Authorization": "InvalidFormat token123"})
|
||||
assert response.status_code == 401
|
||||
assert "Missing or invalid Authorization header" in response.json()["error"]["message"]
|
||||
assert "Invalid Authorization header format" in response.json()["error"]["message"]
|
||||
|
||||
|
||||
async def mock_jwks_response(*args, **kwargs):
|
||||
|
@ -358,15 +341,16 @@ async def mock_auth_jwks_response(*args, **kwargs):
|
|||
def oauth2_app_with_jwks_token():
|
||||
app = FastAPI()
|
||||
auth_config = AuthenticationConfig(
|
||||
provider_type=AuthProviderType.OAUTH2_TOKEN,
|
||||
config={
|
||||
"jwks": {
|
||||
"uri": "http://mock-authz-service/token/introspect",
|
||||
"key_recheck_period": "3600",
|
||||
"token": "my-jwks-token",
|
||||
},
|
||||
"audience": "llama-stack",
|
||||
},
|
||||
provider_config=OAuth2TokenAuthConfig(
|
||||
type=AuthProviderType.OAUTH2_TOKEN,
|
||||
jwks=OAuth2JWKSConfig(
|
||||
uri="http://mock-authz-service/token/introspect",
|
||||
key_recheck_period=3600,
|
||||
token="my-jwks-token",
|
||||
),
|
||||
audience="llama-stack",
|
||||
),
|
||||
access_policy=[],
|
||||
)
|
||||
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
|
||||
|
||||
|
@ -449,11 +433,15 @@ def mock_introspection_endpoint():
|
|||
def introspection_app(mock_introspection_endpoint):
|
||||
app = FastAPI()
|
||||
auth_config = AuthenticationConfig(
|
||||
provider_type=AuthProviderType.OAUTH2_TOKEN,
|
||||
config={
|
||||
"jwks": None,
|
||||
"introspection": {"url": mock_introspection_endpoint, "client_id": "myclient", "client_secret": "abcdefg"},
|
||||
},
|
||||
provider_config=OAuth2TokenAuthConfig(
|
||||
type=AuthProviderType.OAUTH2_TOKEN,
|
||||
introspection=OAuth2IntrospectionConfig(
|
||||
url=mock_introspection_endpoint,
|
||||
client_id="myclient",
|
||||
client_secret="abcdefg",
|
||||
),
|
||||
),
|
||||
access_policy=[],
|
||||
)
|
||||
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
|
||||
|
||||
|
@ -468,22 +456,22 @@ def introspection_app(mock_introspection_endpoint):
|
|||
def introspection_app_with_custom_mapping(mock_introspection_endpoint):
|
||||
app = FastAPI()
|
||||
auth_config = AuthenticationConfig(
|
||||
provider_type=AuthProviderType.OAUTH2_TOKEN,
|
||||
config={
|
||||
"jwks": None,
|
||||
"introspection": {
|
||||
"url": mock_introspection_endpoint,
|
||||
"client_id": "myclient",
|
||||
"client_secret": "abcdefg",
|
||||
"send_secret_in_body": "true",
|
||||
},
|
||||
"claims_mapping": {
|
||||
provider_config=OAuth2TokenAuthConfig(
|
||||
type=AuthProviderType.OAUTH2_TOKEN,
|
||||
introspection=OAuth2IntrospectionConfig(
|
||||
url=mock_introspection_endpoint,
|
||||
client_id="myclient",
|
||||
client_secret="abcdefg",
|
||||
send_secret_in_body=True,
|
||||
),
|
||||
claims_mapping={
|
||||
"sub": "roles",
|
||||
"scope": "roles",
|
||||
"groups": "teams",
|
||||
"aud": "namespaces",
|
||||
},
|
||||
},
|
||||
),
|
||||
access_policy=[],
|
||||
)
|
||||
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
|
||||
|
||||
|
@ -507,13 +495,14 @@ def introspection_client_with_custom_mapping(introspection_app_with_custom_mappi
|
|||
def test_missing_auth_header_introspection(introspection_client):
|
||||
response = introspection_client.get("/test")
|
||||
assert response.status_code == 401
|
||||
assert "Missing or invalid Authorization header" in response.json()["error"]["message"]
|
||||
assert "Authentication required" in response.json()["error"]["message"]
|
||||
assert "OAuth2 Bearer token" in response.json()["error"]["message"]
|
||||
|
||||
|
||||
def test_invalid_auth_header_format_introspection(introspection_client):
|
||||
response = introspection_client.get("/test", headers={"Authorization": "InvalidFormat token123"})
|
||||
assert response.status_code == 401
|
||||
assert "Missing or invalid Authorization header" in response.json()["error"]["message"]
|
||||
assert "Invalid Authorization header format" in response.json()["error"]["message"]
|
||||
|
||||
|
||||
async def mock_introspection_active(*args, **kwargs):
|
||||
|
|
200
tests/unit/server/test_auth_github.py
Normal file
200
tests/unit/server/test_auth_github.py
Normal file
|
@ -0,0 +1,200 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from llama_stack.distribution.datatypes import AuthenticationConfig, AuthProviderType, GitHubTokenAuthConfig
|
||||
from llama_stack.distribution.server.auth import AuthenticationMiddleware
|
||||
|
||||
|
||||
class MockResponse:
|
||||
def __init__(self, status_code, json_data):
|
||||
self.status_code = status_code
|
||||
self._json_data = json_data
|
||||
|
||||
def json(self):
|
||||
return self._json_data
|
||||
|
||||
def raise_for_status(self):
|
||||
if self.status_code != 200:
|
||||
# Create a mock request for the HTTPStatusError
|
||||
mock_request = httpx.Request("GET", "https://api.github.com/user")
|
||||
raise httpx.HTTPStatusError(f"HTTP error: {self.status_code}", request=mock_request, response=self)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def github_token_app():
|
||||
app = FastAPI()
|
||||
|
||||
# Configure GitHub token auth
|
||||
auth_config = AuthenticationConfig(
|
||||
provider_config=GitHubTokenAuthConfig(
|
||||
type=AuthProviderType.GITHUB_TOKEN,
|
||||
github_api_base_url="https://api.github.com",
|
||||
claims_mapping={
|
||||
"login": "username",
|
||||
"id": "user_id",
|
||||
"organizations": "teams",
|
||||
},
|
||||
),
|
||||
access_policy=[],
|
||||
)
|
||||
|
||||
# Add auth middleware
|
||||
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
|
||||
|
||||
@app.get("/test")
|
||||
def test_endpoint():
|
||||
return {"message": "Authentication successful"}
|
||||
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def github_token_client(github_token_app):
|
||||
return TestClient(github_token_app)
|
||||
|
||||
|
||||
def test_authenticated_endpoint_without_token(github_token_client):
|
||||
"""Test accessing protected endpoint without token"""
|
||||
response = github_token_client.get("/test")
|
||||
assert response.status_code == 401
|
||||
assert "Authentication required" in response.json()["error"]["message"]
|
||||
assert "GitHub access token" in response.json()["error"]["message"]
|
||||
|
||||
|
||||
def test_authenticated_endpoint_with_invalid_bearer_format(github_token_client):
|
||||
"""Test accessing protected endpoint with invalid bearer format"""
|
||||
response = github_token_client.get("/test", headers={"Authorization": "InvalidFormat token123"})
|
||||
assert response.status_code == 401
|
||||
assert "Invalid Authorization header format" in response.json()["error"]["message"]
|
||||
|
||||
|
||||
@patch("llama_stack.distribution.server.auth_providers.httpx.AsyncClient")
|
||||
def test_authenticated_endpoint_with_valid_github_token(mock_client_class, github_token_client):
|
||||
"""Test accessing protected endpoint with valid GitHub token"""
|
||||
# Mock the GitHub API responses
|
||||
mock_client = AsyncMock()
|
||||
mock_client_class.return_value.__aenter__.return_value = mock_client
|
||||
|
||||
# Mock successful user API response
|
||||
mock_client.get.side_effect = [
|
||||
MockResponse(
|
||||
200,
|
||||
{
|
||||
"login": "testuser",
|
||||
"id": 12345,
|
||||
"email": "test@example.com",
|
||||
"name": "Test User",
|
||||
},
|
||||
),
|
||||
MockResponse(
|
||||
200,
|
||||
[
|
||||
{"login": "test-org-1"},
|
||||
{"login": "test-org-2"},
|
||||
],
|
||||
),
|
||||
]
|
||||
|
||||
response = github_token_client.get("/test", headers={"Authorization": "Bearer github_token_123"})
|
||||
assert response.status_code == 200
|
||||
assert response.json()["message"] == "Authentication successful"
|
||||
|
||||
# Verify the GitHub API was called correctly
|
||||
assert mock_client.get.call_count == 1
|
||||
calls = mock_client.get.call_args_list
|
||||
assert calls[0][0][0] == "https://api.github.com/user"
|
||||
|
||||
# Check authorization header was passed
|
||||
assert calls[0][1]["headers"]["Authorization"] == "Bearer github_token_123"
|
||||
|
||||
|
||||
@patch("llama_stack.distribution.server.auth_providers.httpx.AsyncClient")
|
||||
def test_authenticated_endpoint_with_invalid_github_token(mock_client_class, github_token_client):
|
||||
"""Test accessing protected endpoint with invalid GitHub token"""
|
||||
# Mock the GitHub API to return 401 Unauthorized
|
||||
mock_client = AsyncMock()
|
||||
mock_client_class.return_value.__aenter__.return_value = mock_client
|
||||
|
||||
# Mock failed user API response
|
||||
mock_client.get.return_value = MockResponse(401, {"message": "Bad credentials"})
|
||||
|
||||
response = github_token_client.get("/test", headers={"Authorization": "Bearer invalid_token"})
|
||||
assert response.status_code == 401
|
||||
assert (
|
||||
"GitHub token validation failed. Please check your token and try again." in response.json()["error"]["message"]
|
||||
)
|
||||
|
||||
|
||||
@patch("llama_stack.distribution.server.auth_providers.httpx.AsyncClient")
|
||||
def test_github_enterprise_support(mock_client_class):
|
||||
"""Test GitHub Enterprise support with custom API base URL"""
|
||||
app = FastAPI()
|
||||
|
||||
# Configure GitHub token auth with enterprise URL
|
||||
auth_config = AuthenticationConfig(
|
||||
provider_config=GitHubTokenAuthConfig(
|
||||
type=AuthProviderType.GITHUB_TOKEN,
|
||||
github_api_base_url="https://github.enterprise.com/api/v3",
|
||||
),
|
||||
access_policy=[],
|
||||
)
|
||||
|
||||
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
|
||||
|
||||
@app.get("/test")
|
||||
def test_endpoint():
|
||||
return {"message": "Authentication successful"}
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
# Mock the GitHub Enterprise API responses
|
||||
mock_client = AsyncMock()
|
||||
mock_client_class.return_value.__aenter__.return_value = mock_client
|
||||
|
||||
# Mock successful user API response
|
||||
mock_client.get.side_effect = [
|
||||
MockResponse(
|
||||
200,
|
||||
{
|
||||
"login": "enterprise_user",
|
||||
"id": 99999,
|
||||
"email": "user@enterprise.com",
|
||||
},
|
||||
),
|
||||
MockResponse(
|
||||
200,
|
||||
[
|
||||
{"login": "enterprise-org"},
|
||||
],
|
||||
),
|
||||
]
|
||||
|
||||
response = client.get("/test", headers={"Authorization": "Bearer enterprise_token"})
|
||||
assert response.status_code == 200
|
||||
|
||||
# Verify the correct GitHub Enterprise URLs were called
|
||||
assert mock_client.get.call_count == 1
|
||||
calls = mock_client.get.call_args_list
|
||||
assert calls[0][0][0] == "https://github.enterprise.com/api/v3/user"
|
||||
|
||||
|
||||
def test_github_token_auth_error_message_format(github_token_client):
|
||||
"""Test that the error message for missing auth is properly formatted"""
|
||||
response = github_token_client.get("/test")
|
||||
assert response.status_code == 401
|
||||
|
||||
error_data = response.json()
|
||||
assert "error" in error_data
|
||||
assert "message" in error_data["error"]
|
||||
assert "Authentication required" in error_data["error"]["message"]
|
||||
assert "https://docs.github.com" in error_data["error"]["message"] # Contains link to GitHub docs
|
Loading…
Add table
Add a link
Reference in a new issue