mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-25 14:28:04 +00:00
gh auth
- Change auth config from provider_type + config dict to discriminated union types - Add GitHub token authentication provider - Improve auth error messages with provider-specific guidance - Extract auth datatypes to separate module - Update tests to use new auth config structure - Remove unused OAuth2LocalJWTConfig ## Test Plan - Unit tests pass for all auth providers - Integration tests verify auth flow works correctly - GitHub token auth tested with valid/invalid tokens Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
parent
3c43a2f529
commit
c2d16c713e
7 changed files with 480 additions and 161 deletions
|
|
@ -11,10 +11,16 @@ import pytest
|
|||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from llama_stack.distribution.datatypes import AuthenticationConfig
|
||||
from llama_stack.distribution.datatypes import (
|
||||
AuthenticationConfig,
|
||||
AuthProviderType,
|
||||
CustomAuthConfig,
|
||||
OAuth2IntrospectionConfig,
|
||||
OAuth2JWKSConfig,
|
||||
OAuth2TokenAuthConfig,
|
||||
)
|
||||
from llama_stack.distribution.server.auth import AuthenticationMiddleware
|
||||
from llama_stack.distribution.server.auth_providers import (
|
||||
AuthProviderType,
|
||||
get_attributes_from_claims,
|
||||
)
|
||||
|
||||
|
|
@ -61,24 +67,11 @@ def invalid_token():
|
|||
def http_app(mock_auth_endpoint):
|
||||
app = FastAPI()
|
||||
auth_config = AuthenticationConfig(
|
||||
provider_type=AuthProviderType.CUSTOM,
|
||||
config={"endpoint": mock_auth_endpoint},
|
||||
)
|
||||
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
|
||||
|
||||
@app.get("/test")
|
||||
def test_endpoint():
|
||||
return {"message": "Authentication successful"}
|
||||
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def k8s_app():
|
||||
app = FastAPI()
|
||||
auth_config = AuthenticationConfig(
|
||||
provider_type=AuthProviderType.KUBERNETES,
|
||||
config={"api_server_url": "https://kubernetes.default.svc"},
|
||||
provider_config=CustomAuthConfig(
|
||||
type=AuthProviderType.CUSTOM,
|
||||
endpoint=mock_auth_endpoint,
|
||||
),
|
||||
access_policy=[],
|
||||
)
|
||||
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
|
||||
|
||||
|
|
@ -94,11 +87,6 @@ def http_client(http_app):
|
|||
return TestClient(http_app)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def k8s_client(k8s_app):
|
||||
return TestClient(k8s_app)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_scope():
|
||||
return {
|
||||
|
|
@ -117,18 +105,11 @@ def mock_scope():
|
|||
def mock_http_middleware(mock_auth_endpoint):
|
||||
mock_app = AsyncMock()
|
||||
auth_config = AuthenticationConfig(
|
||||
provider_type=AuthProviderType.CUSTOM,
|
||||
config={"endpoint": mock_auth_endpoint},
|
||||
)
|
||||
return AuthenticationMiddleware(mock_app, auth_config), mock_app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_k8s_middleware():
|
||||
mock_app = AsyncMock()
|
||||
auth_config = AuthenticationConfig(
|
||||
provider_type=AuthProviderType.KUBERNETES,
|
||||
config={"api_server_url": "https://kubernetes.default.svc"},
|
||||
provider_config=CustomAuthConfig(
|
||||
type=AuthProviderType.CUSTOM,
|
||||
endpoint=mock_auth_endpoint,
|
||||
),
|
||||
access_policy=[],
|
||||
)
|
||||
return AuthenticationMiddleware(mock_app, auth_config), mock_app
|
||||
|
||||
|
|
@ -161,7 +142,8 @@ async def mock_post_exception(*args, **kwargs):
|
|||
def test_missing_auth_header(http_client):
|
||||
response = http_client.get("/test")
|
||||
assert response.status_code == 401
|
||||
assert "Missing or invalid Authorization header" in response.json()["error"]["message"]
|
||||
assert "Authentication required" in response.json()["error"]["message"]
|
||||
assert "validated by mock-auth-service" in response.json()["error"]["message"]
|
||||
|
||||
|
||||
def test_invalid_auth_header_format(http_client):
|
||||
|
|
@ -262,14 +244,14 @@ async def test_http_middleware_with_access_attributes(mock_http_middleware, mock
|
|||
def oauth2_app():
|
||||
app = FastAPI()
|
||||
auth_config = AuthenticationConfig(
|
||||
provider_type=AuthProviderType.OAUTH2_TOKEN,
|
||||
config={
|
||||
"jwks": {
|
||||
"uri": "http://mock-authz-service/token/introspect",
|
||||
"key_recheck_period": "3600",
|
||||
},
|
||||
"audience": "llama-stack",
|
||||
},
|
||||
provider_config=OAuth2TokenAuthConfig(
|
||||
type=AuthProviderType.OAUTH2_TOKEN,
|
||||
jwks=OAuth2JWKSConfig(
|
||||
uri="http://mock-authz-service/token/introspect",
|
||||
),
|
||||
audience="llama-stack",
|
||||
),
|
||||
access_policy=[],
|
||||
)
|
||||
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
|
||||
|
||||
|
|
@ -288,7 +270,8 @@ def oauth2_client(oauth2_app):
|
|||
def test_missing_auth_header_oauth2(oauth2_client):
|
||||
response = oauth2_client.get("/test")
|
||||
assert response.status_code == 401
|
||||
assert "Missing or invalid Authorization header" in response.json()["error"]["message"]
|
||||
assert "Authentication required" in response.json()["error"]["message"]
|
||||
assert "OAuth2 Bearer token" in response.json()["error"]["message"]
|
||||
|
||||
|
||||
def test_invalid_auth_header_format_oauth2(oauth2_client):
|
||||
|
|
@ -358,15 +341,16 @@ async def mock_auth_jwks_response(*args, **kwargs):
|
|||
def oauth2_app_with_jwks_token():
|
||||
app = FastAPI()
|
||||
auth_config = AuthenticationConfig(
|
||||
provider_type=AuthProviderType.OAUTH2_TOKEN,
|
||||
config={
|
||||
"jwks": {
|
||||
"uri": "http://mock-authz-service/token/introspect",
|
||||
"key_recheck_period": "3600",
|
||||
"token": "my-jwks-token",
|
||||
},
|
||||
"audience": "llama-stack",
|
||||
},
|
||||
provider_config=OAuth2TokenAuthConfig(
|
||||
type=AuthProviderType.OAUTH2_TOKEN,
|
||||
jwks=OAuth2JWKSConfig(
|
||||
uri="http://mock-authz-service/token/introspect",
|
||||
key_recheck_period=3600,
|
||||
token="my-jwks-token",
|
||||
),
|
||||
audience="llama-stack",
|
||||
),
|
||||
access_policy=[],
|
||||
)
|
||||
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
|
||||
|
||||
|
|
@ -449,11 +433,15 @@ def mock_introspection_endpoint():
|
|||
def introspection_app(mock_introspection_endpoint):
|
||||
app = FastAPI()
|
||||
auth_config = AuthenticationConfig(
|
||||
provider_type=AuthProviderType.OAUTH2_TOKEN,
|
||||
config={
|
||||
"jwks": None,
|
||||
"introspection": {"url": mock_introspection_endpoint, "client_id": "myclient", "client_secret": "abcdefg"},
|
||||
},
|
||||
provider_config=OAuth2TokenAuthConfig(
|
||||
type=AuthProviderType.OAUTH2_TOKEN,
|
||||
introspection=OAuth2IntrospectionConfig(
|
||||
url=mock_introspection_endpoint,
|
||||
client_id="myclient",
|
||||
client_secret="abcdefg",
|
||||
),
|
||||
),
|
||||
access_policy=[],
|
||||
)
|
||||
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
|
||||
|
||||
|
|
@ -468,22 +456,22 @@ def introspection_app(mock_introspection_endpoint):
|
|||
def introspection_app_with_custom_mapping(mock_introspection_endpoint):
|
||||
app = FastAPI()
|
||||
auth_config = AuthenticationConfig(
|
||||
provider_type=AuthProviderType.OAUTH2_TOKEN,
|
||||
config={
|
||||
"jwks": None,
|
||||
"introspection": {
|
||||
"url": mock_introspection_endpoint,
|
||||
"client_id": "myclient",
|
||||
"client_secret": "abcdefg",
|
||||
"send_secret_in_body": "true",
|
||||
},
|
||||
"claims_mapping": {
|
||||
provider_config=OAuth2TokenAuthConfig(
|
||||
type=AuthProviderType.OAUTH2_TOKEN,
|
||||
introspection=OAuth2IntrospectionConfig(
|
||||
url=mock_introspection_endpoint,
|
||||
client_id="myclient",
|
||||
client_secret="abcdefg",
|
||||
send_secret_in_body=True,
|
||||
),
|
||||
claims_mapping={
|
||||
"sub": "roles",
|
||||
"scope": "roles",
|
||||
"groups": "teams",
|
||||
"aud": "namespaces",
|
||||
},
|
||||
},
|
||||
),
|
||||
access_policy=[],
|
||||
)
|
||||
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
|
||||
|
||||
|
|
@ -507,7 +495,8 @@ def introspection_client_with_custom_mapping(introspection_app_with_custom_mappi
|
|||
def test_missing_auth_header_introspection(introspection_client):
|
||||
response = introspection_client.get("/test")
|
||||
assert response.status_code == 401
|
||||
assert "Missing or invalid Authorization header" in response.json()["error"]["message"]
|
||||
assert "Authentication required" in response.json()["error"]["message"]
|
||||
assert "OAuth2 Bearer token" in response.json()["error"]["message"]
|
||||
|
||||
|
||||
def test_invalid_auth_header_format_introspection(introspection_client):
|
||||
|
|
|
|||
195
tests/unit/server/test_auth_github.py
Normal file
195
tests/unit/server/test_auth_github.py
Normal file
|
|
@ -0,0 +1,195 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from llama_stack.distribution.datatypes import AuthenticationConfig, AuthProviderType, GitHubTokenAuthConfig
|
||||
from llama_stack.distribution.server.auth import AuthenticationMiddleware
|
||||
|
||||
|
||||
class MockResponse:
|
||||
def __init__(self, status_code, json_data):
|
||||
self.status_code = status_code
|
||||
self._json_data = json_data
|
||||
|
||||
def json(self):
|
||||
return self._json_data
|
||||
|
||||
def raise_for_status(self):
|
||||
if self.status_code != 200:
|
||||
raise Exception(f"HTTP error: {self.status_code}")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def github_token_app():
|
||||
app = FastAPI()
|
||||
|
||||
# Configure GitHub token auth
|
||||
auth_config = AuthenticationConfig(
|
||||
provider_config=GitHubTokenAuthConfig(
|
||||
type=AuthProviderType.GITHUB_TOKEN,
|
||||
github_api_base_url="https://api.github.com",
|
||||
claims_mapping={
|
||||
"login": "username",
|
||||
"id": "user_id",
|
||||
"organizations": "teams",
|
||||
},
|
||||
),
|
||||
access_policy=[],
|
||||
)
|
||||
|
||||
# Add auth middleware
|
||||
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
|
||||
|
||||
@app.get("/test")
|
||||
def test_endpoint():
|
||||
return {"message": "Authentication successful"}
|
||||
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def github_token_client(github_token_app):
|
||||
return TestClient(github_token_app)
|
||||
|
||||
|
||||
def test_authenticated_endpoint_without_token(github_token_client):
|
||||
"""Test accessing protected endpoint without token"""
|
||||
response = github_token_client.get("/test")
|
||||
assert response.status_code == 401
|
||||
assert "Authentication required" in response.json()["error"]["message"]
|
||||
assert "GitHub access token" in response.json()["error"]["message"]
|
||||
|
||||
|
||||
def test_authenticated_endpoint_with_invalid_bearer_format(github_token_client):
|
||||
"""Test accessing protected endpoint with invalid bearer format"""
|
||||
response = github_token_client.get("/test", headers={"Authorization": "InvalidFormat token123"})
|
||||
assert response.status_code == 401
|
||||
assert "Missing or invalid Authorization header" in response.json()["error"]["message"]
|
||||
|
||||
|
||||
@patch("llama_stack.distribution.server.auth_providers.httpx.AsyncClient")
|
||||
def test_authenticated_endpoint_with_valid_github_token(mock_client_class, github_token_client):
|
||||
"""Test accessing protected endpoint with valid GitHub token"""
|
||||
# Mock the GitHub API responses
|
||||
mock_client = AsyncMock()
|
||||
mock_client_class.return_value.__aenter__.return_value = mock_client
|
||||
|
||||
# Mock successful user API response
|
||||
mock_client.get.side_effect = [
|
||||
MockResponse(
|
||||
200,
|
||||
{
|
||||
"login": "testuser",
|
||||
"id": 12345,
|
||||
"email": "test@example.com",
|
||||
"name": "Test User",
|
||||
},
|
||||
),
|
||||
MockResponse(
|
||||
200,
|
||||
[
|
||||
{"login": "test-org-1"},
|
||||
{"login": "test-org-2"},
|
||||
],
|
||||
),
|
||||
]
|
||||
|
||||
response = github_token_client.get("/test", headers={"Authorization": "Bearer github_token_123"})
|
||||
assert response.status_code == 200
|
||||
assert response.json()["message"] == "Authentication successful"
|
||||
|
||||
# Verify the GitHub API was called correctly
|
||||
assert mock_client.get.call_count == 1
|
||||
calls = mock_client.get.call_args_list
|
||||
assert calls[0][0][0] == "https://api.github.com/user"
|
||||
|
||||
# Check authorization header was passed
|
||||
assert calls[0][1]["headers"]["Authorization"] == "Bearer github_token_123"
|
||||
|
||||
|
||||
@patch("llama_stack.distribution.server.auth_providers.httpx.AsyncClient")
|
||||
def test_authenticated_endpoint_with_invalid_github_token(mock_client_class, github_token_client):
|
||||
"""Test accessing protected endpoint with invalid GitHub token"""
|
||||
# Mock the GitHub API to return 401 Unauthorized
|
||||
mock_client = AsyncMock()
|
||||
mock_client_class.return_value.__aenter__.return_value = mock_client
|
||||
|
||||
# Mock failed user API response
|
||||
mock_client.get.return_value = MockResponse(401, {"message": "Bad credentials"})
|
||||
|
||||
response = github_token_client.get("/test", headers={"Authorization": "Bearer invalid_token"})
|
||||
assert response.status_code == 401
|
||||
assert "Invalid GitHub token" in response.json()["error"]["message"]
|
||||
|
||||
|
||||
@patch("llama_stack.distribution.server.auth_providers.httpx.AsyncClient")
|
||||
def test_github_enterprise_support(mock_client_class):
|
||||
"""Test GitHub Enterprise support with custom API base URL"""
|
||||
app = FastAPI()
|
||||
|
||||
# Configure GitHub token auth with enterprise URL
|
||||
auth_config = AuthenticationConfig(
|
||||
provider_config=GitHubTokenAuthConfig(
|
||||
type=AuthProviderType.GITHUB_TOKEN,
|
||||
github_api_base_url="https://github.enterprise.com/api/v3",
|
||||
),
|
||||
access_policy=[],
|
||||
)
|
||||
|
||||
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
|
||||
|
||||
@app.get("/test")
|
||||
def test_endpoint():
|
||||
return {"message": "Authentication successful"}
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
# Mock the GitHub Enterprise API responses
|
||||
mock_client = AsyncMock()
|
||||
mock_client_class.return_value.__aenter__.return_value = mock_client
|
||||
|
||||
# Mock successful user API response
|
||||
mock_client.get.side_effect = [
|
||||
MockResponse(
|
||||
200,
|
||||
{
|
||||
"login": "enterprise_user",
|
||||
"id": 99999,
|
||||
"email": "user@enterprise.com",
|
||||
},
|
||||
),
|
||||
MockResponse(
|
||||
200,
|
||||
[
|
||||
{"login": "enterprise-org"},
|
||||
],
|
||||
),
|
||||
]
|
||||
|
||||
response = client.get("/test", headers={"Authorization": "Bearer enterprise_token"})
|
||||
assert response.status_code == 200
|
||||
|
||||
# Verify the correct GitHub Enterprise URLs were called
|
||||
assert mock_client.get.call_count == 1
|
||||
calls = mock_client.get.call_args_list
|
||||
assert calls[0][0][0] == "https://github.enterprise.com/api/v3/user"
|
||||
|
||||
|
||||
def test_github_token_auth_error_message_format(github_token_client):
|
||||
"""Test that the error message for missing auth is properly formatted"""
|
||||
response = github_token_client.get("/test")
|
||||
assert response.status_code == 401
|
||||
|
||||
error_data = response.json()
|
||||
assert "error" in error_data
|
||||
assert "message" in error_data["error"]
|
||||
assert "Authentication required" in error_data["error"]["message"]
|
||||
assert "https://docs.github.com" in error_data["error"]["message"] # Contains link to GitHub docs
|
||||
Loading…
Add table
Add a link
Reference in a new issue