feat(auth): support github tokens (#2509)

# What does this PR do?

This PR adds GitHub OAuth authentication support to Llama Stack,
allowing users to
  authenticate using their GitHub credentials (#2508) . 

1. support verifying github acesss tokens
2. support provider-specific auth error messages
3. opportunistic reorganized the auth configs for better ergonomics

## Test Plan
Added unit tests.

Also tested e2e manually:
```
server:
  port: 8321
  auth:
    provider_config:
      type: github_token
```
```
~/projects/llama-stack/llama_stack/ui
❯ curl -v http://localhost:8321/v1/models
* Host localhost:8321 was resolved.
* IPv6: ::1
* IPv4: 127.0.0.1
*   Trying [::1]:8321...
* Connected to localhost (::1) port 8321
> GET /v1/models HTTP/1.1
> Host: localhost:8321
> User-Agent: curl/8.7.1
> Accept: */*
>
* Request completely sent off
< HTTP/1.1 401 Unauthorized
< date: Fri, 27 Jun 2025 21:51:25 GMT
< server: uvicorn
< content-type: application/json
< x-trace-id: 5390c6c0654086c55d87c86d7cbf2f6a
< Transfer-Encoding: chunked
<
* Connection #0 to host localhost left intact
{"error": {"message": "Authentication required. Please provide a valid GitHub access token (https://docs.github.com/en/authentication/keeping-your-account-and-data-secure/managing-your-personal-access-tokens) in the Authorization header (Bearer <token>)"}}
~/projects/llama-stack/llama_stack/ui
❯ ./scripts/unit-tests.sh


~/projects/llama-stack/llama_stack/ui
❯ curl "http://localhost:8321/v1/models" \
-H "Authorization: Bearer <token_obtained_from_github>" \

{"data":[{"identifier":"accounts/fireworks/models/llama-guard-3-11b-vision","provider_resource_id":"accounts/fireworks/models/llama-guard-3-11b-vision","provider_id":"fireworks","type":"model","metadata":{},"model_type":"llm"},{"identifier":"accounts/fireworks/models/llama-guard-3-8b","provider_resource_id":"accounts/fireworks/models/llama-guard-3-8b","provider_id":"fireworks","type":"model","metadata":{},"model_type":"llm"},{"identifier":"accounts/fireworks/models/llama-v3p1-405b-instruct","provider_resource_id":"accounts/f
```

---------

Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
ehhuang 2025-07-08 11:02:36 -07:00 committed by GitHub
parent 83c89265e0
commit c8bac888af
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 513 additions and 173 deletions

View file

@ -11,10 +11,16 @@ import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from llama_stack.distribution.datatypes import AuthenticationConfig
from llama_stack.distribution.datatypes import (
AuthenticationConfig,
AuthProviderType,
CustomAuthConfig,
OAuth2IntrospectionConfig,
OAuth2JWKSConfig,
OAuth2TokenAuthConfig,
)
from llama_stack.distribution.server.auth import AuthenticationMiddleware
from llama_stack.distribution.server.auth_providers import (
AuthProviderType,
get_attributes_from_claims,
)
@ -61,24 +67,11 @@ def invalid_token():
def http_app(mock_auth_endpoint):
app = FastAPI()
auth_config = AuthenticationConfig(
provider_type=AuthProviderType.CUSTOM,
config={"endpoint": mock_auth_endpoint},
)
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
@app.get("/test")
def test_endpoint():
return {"message": "Authentication successful"}
return app
@pytest.fixture
def k8s_app():
app = FastAPI()
auth_config = AuthenticationConfig(
provider_type=AuthProviderType.KUBERNETES,
config={"api_server_url": "https://kubernetes.default.svc"},
provider_config=CustomAuthConfig(
type=AuthProviderType.CUSTOM,
endpoint=mock_auth_endpoint,
),
access_policy=[],
)
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
@ -94,11 +87,6 @@ def http_client(http_app):
return TestClient(http_app)
@pytest.fixture
def k8s_client(k8s_app):
return TestClient(k8s_app)
@pytest.fixture
def mock_scope():
return {
@ -117,18 +105,11 @@ def mock_scope():
def mock_http_middleware(mock_auth_endpoint):
mock_app = AsyncMock()
auth_config = AuthenticationConfig(
provider_type=AuthProviderType.CUSTOM,
config={"endpoint": mock_auth_endpoint},
)
return AuthenticationMiddleware(mock_app, auth_config), mock_app
@pytest.fixture
def mock_k8s_middleware():
mock_app = AsyncMock()
auth_config = AuthenticationConfig(
provider_type=AuthProviderType.KUBERNETES,
config={"api_server_url": "https://kubernetes.default.svc"},
provider_config=CustomAuthConfig(
type=AuthProviderType.CUSTOM,
endpoint=mock_auth_endpoint,
),
access_policy=[],
)
return AuthenticationMiddleware(mock_app, auth_config), mock_app
@ -161,13 +142,14 @@ async def mock_post_exception(*args, **kwargs):
def test_missing_auth_header(http_client):
response = http_client.get("/test")
assert response.status_code == 401
assert "Missing or invalid Authorization header" in response.json()["error"]["message"]
assert "Authentication required" in response.json()["error"]["message"]
assert "validated by mock-auth-service" in response.json()["error"]["message"]
def test_invalid_auth_header_format(http_client):
response = http_client.get("/test", headers={"Authorization": "InvalidFormat token123"})
assert response.status_code == 401
assert "Missing or invalid Authorization header" in response.json()["error"]["message"]
assert "Invalid Authorization header format" in response.json()["error"]["message"]
@patch("httpx.AsyncClient.post", new=mock_post_success)
@ -262,14 +244,14 @@ async def test_http_middleware_with_access_attributes(mock_http_middleware, mock
def oauth2_app():
app = FastAPI()
auth_config = AuthenticationConfig(
provider_type=AuthProviderType.OAUTH2_TOKEN,
config={
"jwks": {
"uri": "http://mock-authz-service/token/introspect",
"key_recheck_period": "3600",
},
"audience": "llama-stack",
},
provider_config=OAuth2TokenAuthConfig(
type=AuthProviderType.OAUTH2_TOKEN,
jwks=OAuth2JWKSConfig(
uri="http://mock-authz-service/token/introspect",
),
audience="llama-stack",
),
access_policy=[],
)
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
@ -288,13 +270,14 @@ def oauth2_client(oauth2_app):
def test_missing_auth_header_oauth2(oauth2_client):
response = oauth2_client.get("/test")
assert response.status_code == 401
assert "Missing or invalid Authorization header" in response.json()["error"]["message"]
assert "Authentication required" in response.json()["error"]["message"]
assert "OAuth2 Bearer token" in response.json()["error"]["message"]
def test_invalid_auth_header_format_oauth2(oauth2_client):
response = oauth2_client.get("/test", headers={"Authorization": "InvalidFormat token123"})
assert response.status_code == 401
assert "Missing or invalid Authorization header" in response.json()["error"]["message"]
assert "Invalid Authorization header format" in response.json()["error"]["message"]
async def mock_jwks_response(*args, **kwargs):
@ -358,15 +341,16 @@ async def mock_auth_jwks_response(*args, **kwargs):
def oauth2_app_with_jwks_token():
app = FastAPI()
auth_config = AuthenticationConfig(
provider_type=AuthProviderType.OAUTH2_TOKEN,
config={
"jwks": {
"uri": "http://mock-authz-service/token/introspect",
"key_recheck_period": "3600",
"token": "my-jwks-token",
},
"audience": "llama-stack",
},
provider_config=OAuth2TokenAuthConfig(
type=AuthProviderType.OAUTH2_TOKEN,
jwks=OAuth2JWKSConfig(
uri="http://mock-authz-service/token/introspect",
key_recheck_period=3600,
token="my-jwks-token",
),
audience="llama-stack",
),
access_policy=[],
)
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
@ -449,11 +433,15 @@ def mock_introspection_endpoint():
def introspection_app(mock_introspection_endpoint):
app = FastAPI()
auth_config = AuthenticationConfig(
provider_type=AuthProviderType.OAUTH2_TOKEN,
config={
"jwks": None,
"introspection": {"url": mock_introspection_endpoint, "client_id": "myclient", "client_secret": "abcdefg"},
},
provider_config=OAuth2TokenAuthConfig(
type=AuthProviderType.OAUTH2_TOKEN,
introspection=OAuth2IntrospectionConfig(
url=mock_introspection_endpoint,
client_id="myclient",
client_secret="abcdefg",
),
),
access_policy=[],
)
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
@ -468,22 +456,22 @@ def introspection_app(mock_introspection_endpoint):
def introspection_app_with_custom_mapping(mock_introspection_endpoint):
app = FastAPI()
auth_config = AuthenticationConfig(
provider_type=AuthProviderType.OAUTH2_TOKEN,
config={
"jwks": None,
"introspection": {
"url": mock_introspection_endpoint,
"client_id": "myclient",
"client_secret": "abcdefg",
"send_secret_in_body": "true",
},
"claims_mapping": {
provider_config=OAuth2TokenAuthConfig(
type=AuthProviderType.OAUTH2_TOKEN,
introspection=OAuth2IntrospectionConfig(
url=mock_introspection_endpoint,
client_id="myclient",
client_secret="abcdefg",
send_secret_in_body=True,
),
claims_mapping={
"sub": "roles",
"scope": "roles",
"groups": "teams",
"aud": "namespaces",
},
},
),
access_policy=[],
)
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
@ -507,13 +495,14 @@ def introspection_client_with_custom_mapping(introspection_app_with_custom_mappi
def test_missing_auth_header_introspection(introspection_client):
response = introspection_client.get("/test")
assert response.status_code == 401
assert "Missing or invalid Authorization header" in response.json()["error"]["message"]
assert "Authentication required" in response.json()["error"]["message"]
assert "OAuth2 Bearer token" in response.json()["error"]["message"]
def test_invalid_auth_header_format_introspection(introspection_client):
response = introspection_client.get("/test", headers={"Authorization": "InvalidFormat token123"})
assert response.status_code == 401
assert "Missing or invalid Authorization header" in response.json()["error"]["message"]
assert "Invalid Authorization header format" in response.json()["error"]["message"]
async def mock_introspection_active(*args, **kwargs):

View file

@ -0,0 +1,200 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from unittest.mock import AsyncMock, patch
import httpx
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from llama_stack.distribution.datatypes import AuthenticationConfig, AuthProviderType, GitHubTokenAuthConfig
from llama_stack.distribution.server.auth import AuthenticationMiddleware
class MockResponse:
def __init__(self, status_code, json_data):
self.status_code = status_code
self._json_data = json_data
def json(self):
return self._json_data
def raise_for_status(self):
if self.status_code != 200:
# Create a mock request for the HTTPStatusError
mock_request = httpx.Request("GET", "https://api.github.com/user")
raise httpx.HTTPStatusError(f"HTTP error: {self.status_code}", request=mock_request, response=self)
@pytest.fixture
def github_token_app():
app = FastAPI()
# Configure GitHub token auth
auth_config = AuthenticationConfig(
provider_config=GitHubTokenAuthConfig(
type=AuthProviderType.GITHUB_TOKEN,
github_api_base_url="https://api.github.com",
claims_mapping={
"login": "username",
"id": "user_id",
"organizations": "teams",
},
),
access_policy=[],
)
# Add auth middleware
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
@app.get("/test")
def test_endpoint():
return {"message": "Authentication successful"}
return app
@pytest.fixture
def github_token_client(github_token_app):
return TestClient(github_token_app)
def test_authenticated_endpoint_without_token(github_token_client):
"""Test accessing protected endpoint without token"""
response = github_token_client.get("/test")
assert response.status_code == 401
assert "Authentication required" in response.json()["error"]["message"]
assert "GitHub access token" in response.json()["error"]["message"]
def test_authenticated_endpoint_with_invalid_bearer_format(github_token_client):
"""Test accessing protected endpoint with invalid bearer format"""
response = github_token_client.get("/test", headers={"Authorization": "InvalidFormat token123"})
assert response.status_code == 401
assert "Invalid Authorization header format" in response.json()["error"]["message"]
@patch("llama_stack.distribution.server.auth_providers.httpx.AsyncClient")
def test_authenticated_endpoint_with_valid_github_token(mock_client_class, github_token_client):
"""Test accessing protected endpoint with valid GitHub token"""
# Mock the GitHub API responses
mock_client = AsyncMock()
mock_client_class.return_value.__aenter__.return_value = mock_client
# Mock successful user API response
mock_client.get.side_effect = [
MockResponse(
200,
{
"login": "testuser",
"id": 12345,
"email": "test@example.com",
"name": "Test User",
},
),
MockResponse(
200,
[
{"login": "test-org-1"},
{"login": "test-org-2"},
],
),
]
response = github_token_client.get("/test", headers={"Authorization": "Bearer github_token_123"})
assert response.status_code == 200
assert response.json()["message"] == "Authentication successful"
# Verify the GitHub API was called correctly
assert mock_client.get.call_count == 1
calls = mock_client.get.call_args_list
assert calls[0][0][0] == "https://api.github.com/user"
# Check authorization header was passed
assert calls[0][1]["headers"]["Authorization"] == "Bearer github_token_123"
@patch("llama_stack.distribution.server.auth_providers.httpx.AsyncClient")
def test_authenticated_endpoint_with_invalid_github_token(mock_client_class, github_token_client):
"""Test accessing protected endpoint with invalid GitHub token"""
# Mock the GitHub API to return 401 Unauthorized
mock_client = AsyncMock()
mock_client_class.return_value.__aenter__.return_value = mock_client
# Mock failed user API response
mock_client.get.return_value = MockResponse(401, {"message": "Bad credentials"})
response = github_token_client.get("/test", headers={"Authorization": "Bearer invalid_token"})
assert response.status_code == 401
assert (
"GitHub token validation failed. Please check your token and try again." in response.json()["error"]["message"]
)
@patch("llama_stack.distribution.server.auth_providers.httpx.AsyncClient")
def test_github_enterprise_support(mock_client_class):
"""Test GitHub Enterprise support with custom API base URL"""
app = FastAPI()
# Configure GitHub token auth with enterprise URL
auth_config = AuthenticationConfig(
provider_config=GitHubTokenAuthConfig(
type=AuthProviderType.GITHUB_TOKEN,
github_api_base_url="https://github.enterprise.com/api/v3",
),
access_policy=[],
)
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
@app.get("/test")
def test_endpoint():
return {"message": "Authentication successful"}
client = TestClient(app)
# Mock the GitHub Enterprise API responses
mock_client = AsyncMock()
mock_client_class.return_value.__aenter__.return_value = mock_client
# Mock successful user API response
mock_client.get.side_effect = [
MockResponse(
200,
{
"login": "enterprise_user",
"id": 99999,
"email": "user@enterprise.com",
},
),
MockResponse(
200,
[
{"login": "enterprise-org"},
],
),
]
response = client.get("/test", headers={"Authorization": "Bearer enterprise_token"})
assert response.status_code == 200
# Verify the correct GitHub Enterprise URLs were called
assert mock_client.get.call_count == 1
calls = mock_client.get.call_args_list
assert calls[0][0][0] == "https://github.enterprise.com/api/v3/user"
def test_github_token_auth_error_message_format(github_token_client):
"""Test that the error message for missing auth is properly formatted"""
response = github_token_client.get("/test")
assert response.status_code == 401
error_data = response.json()
assert "error" in error_data
assert "message" in error_data["error"]
assert "Authentication required" in error_data["error"]["message"]
assert "https://docs.github.com" in error_data["error"]["message"] # Contains link to GitHub docs