mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-15 01:26:10 +00:00
feat(auth): support github tokens (#2509)
# What does this PR do? This PR adds GitHub OAuth authentication support to Llama Stack, allowing users to authenticate using their GitHub credentials (#2508) . 1. support verifying github acesss tokens 2. support provider-specific auth error messages 3. opportunistic reorganized the auth configs for better ergonomics ## Test Plan Added unit tests. Also tested e2e manually: ``` server: port: 8321 auth: provider_config: type: github_token ``` ``` ~/projects/llama-stack/llama_stack/ui ❯ curl -v http://localhost:8321/v1/models * Host localhost:8321 was resolved. * IPv6: ::1 * IPv4: 127.0.0.1 * Trying [::1]:8321... * Connected to localhost (::1) port 8321 > GET /v1/models HTTP/1.1 > Host: localhost:8321 > User-Agent: curl/8.7.1 > Accept: */* > * Request completely sent off < HTTP/1.1 401 Unauthorized < date: Fri, 27 Jun 2025 21:51:25 GMT < server: uvicorn < content-type: application/json < x-trace-id: 5390c6c0654086c55d87c86d7cbf2f6a < Transfer-Encoding: chunked < * Connection #0 to host localhost left intact {"error": {"message": "Authentication required. Please provide a valid GitHub access token (https://docs.github.com/en/authentication/keeping-your-account-and-data-secure/managing-your-personal-access-tokens) in the Authorization header (Bearer <token>)"}} ~/projects/llama-stack/llama_stack/ui ❯ ./scripts/unit-tests.sh ~/projects/llama-stack/llama_stack/ui ❯ curl "http://localhost:8321/v1/models" \ -H "Authorization: Bearer <token_obtained_from_github>" \ {"data":[{"identifier":"accounts/fireworks/models/llama-guard-3-11b-vision","provider_resource_id":"accounts/fireworks/models/llama-guard-3-11b-vision","provider_id":"fireworks","type":"model","metadata":{},"model_type":"llm"},{"identifier":"accounts/fireworks/models/llama-guard-3-8b","provider_resource_id":"accounts/fireworks/models/llama-guard-3-8b","provider_id":"fireworks","type":"model","metadata":{},"model_type":"llm"},{"identifier":"accounts/fireworks/models/llama-v3p1-405b-instruct","provider_resource_id":"accounts/f ``` --------- Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
parent
83c89265e0
commit
c8bac888af
8 changed files with 513 additions and 173 deletions
|
@ -11,10 +11,16 @@ import pytest
|
|||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from llama_stack.distribution.datatypes import AuthenticationConfig
|
||||
from llama_stack.distribution.datatypes import (
|
||||
AuthenticationConfig,
|
||||
AuthProviderType,
|
||||
CustomAuthConfig,
|
||||
OAuth2IntrospectionConfig,
|
||||
OAuth2JWKSConfig,
|
||||
OAuth2TokenAuthConfig,
|
||||
)
|
||||
from llama_stack.distribution.server.auth import AuthenticationMiddleware
|
||||
from llama_stack.distribution.server.auth_providers import (
|
||||
AuthProviderType,
|
||||
get_attributes_from_claims,
|
||||
)
|
||||
|
||||
|
@ -61,24 +67,11 @@ def invalid_token():
|
|||
def http_app(mock_auth_endpoint):
|
||||
app = FastAPI()
|
||||
auth_config = AuthenticationConfig(
|
||||
provider_type=AuthProviderType.CUSTOM,
|
||||
config={"endpoint": mock_auth_endpoint},
|
||||
)
|
||||
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
|
||||
|
||||
@app.get("/test")
|
||||
def test_endpoint():
|
||||
return {"message": "Authentication successful"}
|
||||
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def k8s_app():
|
||||
app = FastAPI()
|
||||
auth_config = AuthenticationConfig(
|
||||
provider_type=AuthProviderType.KUBERNETES,
|
||||
config={"api_server_url": "https://kubernetes.default.svc"},
|
||||
provider_config=CustomAuthConfig(
|
||||
type=AuthProviderType.CUSTOM,
|
||||
endpoint=mock_auth_endpoint,
|
||||
),
|
||||
access_policy=[],
|
||||
)
|
||||
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
|
||||
|
||||
|
@ -94,11 +87,6 @@ def http_client(http_app):
|
|||
return TestClient(http_app)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def k8s_client(k8s_app):
|
||||
return TestClient(k8s_app)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_scope():
|
||||
return {
|
||||
|
@ -117,18 +105,11 @@ def mock_scope():
|
|||
def mock_http_middleware(mock_auth_endpoint):
|
||||
mock_app = AsyncMock()
|
||||
auth_config = AuthenticationConfig(
|
||||
provider_type=AuthProviderType.CUSTOM,
|
||||
config={"endpoint": mock_auth_endpoint},
|
||||
)
|
||||
return AuthenticationMiddleware(mock_app, auth_config), mock_app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_k8s_middleware():
|
||||
mock_app = AsyncMock()
|
||||
auth_config = AuthenticationConfig(
|
||||
provider_type=AuthProviderType.KUBERNETES,
|
||||
config={"api_server_url": "https://kubernetes.default.svc"},
|
||||
provider_config=CustomAuthConfig(
|
||||
type=AuthProviderType.CUSTOM,
|
||||
endpoint=mock_auth_endpoint,
|
||||
),
|
||||
access_policy=[],
|
||||
)
|
||||
return AuthenticationMiddleware(mock_app, auth_config), mock_app
|
||||
|
||||
|
@ -161,13 +142,14 @@ async def mock_post_exception(*args, **kwargs):
|
|||
def test_missing_auth_header(http_client):
|
||||
response = http_client.get("/test")
|
||||
assert response.status_code == 401
|
||||
assert "Missing or invalid Authorization header" in response.json()["error"]["message"]
|
||||
assert "Authentication required" in response.json()["error"]["message"]
|
||||
assert "validated by mock-auth-service" in response.json()["error"]["message"]
|
||||
|
||||
|
||||
def test_invalid_auth_header_format(http_client):
|
||||
response = http_client.get("/test", headers={"Authorization": "InvalidFormat token123"})
|
||||
assert response.status_code == 401
|
||||
assert "Missing or invalid Authorization header" in response.json()["error"]["message"]
|
||||
assert "Invalid Authorization header format" in response.json()["error"]["message"]
|
||||
|
||||
|
||||
@patch("httpx.AsyncClient.post", new=mock_post_success)
|
||||
|
@ -262,14 +244,14 @@ async def test_http_middleware_with_access_attributes(mock_http_middleware, mock
|
|||
def oauth2_app():
|
||||
app = FastAPI()
|
||||
auth_config = AuthenticationConfig(
|
||||
provider_type=AuthProviderType.OAUTH2_TOKEN,
|
||||
config={
|
||||
"jwks": {
|
||||
"uri": "http://mock-authz-service/token/introspect",
|
||||
"key_recheck_period": "3600",
|
||||
},
|
||||
"audience": "llama-stack",
|
||||
},
|
||||
provider_config=OAuth2TokenAuthConfig(
|
||||
type=AuthProviderType.OAUTH2_TOKEN,
|
||||
jwks=OAuth2JWKSConfig(
|
||||
uri="http://mock-authz-service/token/introspect",
|
||||
),
|
||||
audience="llama-stack",
|
||||
),
|
||||
access_policy=[],
|
||||
)
|
||||
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
|
||||
|
||||
|
@ -288,13 +270,14 @@ def oauth2_client(oauth2_app):
|
|||
def test_missing_auth_header_oauth2(oauth2_client):
|
||||
response = oauth2_client.get("/test")
|
||||
assert response.status_code == 401
|
||||
assert "Missing or invalid Authorization header" in response.json()["error"]["message"]
|
||||
assert "Authentication required" in response.json()["error"]["message"]
|
||||
assert "OAuth2 Bearer token" in response.json()["error"]["message"]
|
||||
|
||||
|
||||
def test_invalid_auth_header_format_oauth2(oauth2_client):
|
||||
response = oauth2_client.get("/test", headers={"Authorization": "InvalidFormat token123"})
|
||||
assert response.status_code == 401
|
||||
assert "Missing or invalid Authorization header" in response.json()["error"]["message"]
|
||||
assert "Invalid Authorization header format" in response.json()["error"]["message"]
|
||||
|
||||
|
||||
async def mock_jwks_response(*args, **kwargs):
|
||||
|
@ -358,15 +341,16 @@ async def mock_auth_jwks_response(*args, **kwargs):
|
|||
def oauth2_app_with_jwks_token():
|
||||
app = FastAPI()
|
||||
auth_config = AuthenticationConfig(
|
||||
provider_type=AuthProviderType.OAUTH2_TOKEN,
|
||||
config={
|
||||
"jwks": {
|
||||
"uri": "http://mock-authz-service/token/introspect",
|
||||
"key_recheck_period": "3600",
|
||||
"token": "my-jwks-token",
|
||||
},
|
||||
"audience": "llama-stack",
|
||||
},
|
||||
provider_config=OAuth2TokenAuthConfig(
|
||||
type=AuthProviderType.OAUTH2_TOKEN,
|
||||
jwks=OAuth2JWKSConfig(
|
||||
uri="http://mock-authz-service/token/introspect",
|
||||
key_recheck_period=3600,
|
||||
token="my-jwks-token",
|
||||
),
|
||||
audience="llama-stack",
|
||||
),
|
||||
access_policy=[],
|
||||
)
|
||||
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
|
||||
|
||||
|
@ -449,11 +433,15 @@ def mock_introspection_endpoint():
|
|||
def introspection_app(mock_introspection_endpoint):
|
||||
app = FastAPI()
|
||||
auth_config = AuthenticationConfig(
|
||||
provider_type=AuthProviderType.OAUTH2_TOKEN,
|
||||
config={
|
||||
"jwks": None,
|
||||
"introspection": {"url": mock_introspection_endpoint, "client_id": "myclient", "client_secret": "abcdefg"},
|
||||
},
|
||||
provider_config=OAuth2TokenAuthConfig(
|
||||
type=AuthProviderType.OAUTH2_TOKEN,
|
||||
introspection=OAuth2IntrospectionConfig(
|
||||
url=mock_introspection_endpoint,
|
||||
client_id="myclient",
|
||||
client_secret="abcdefg",
|
||||
),
|
||||
),
|
||||
access_policy=[],
|
||||
)
|
||||
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
|
||||
|
||||
|
@ -468,22 +456,22 @@ def introspection_app(mock_introspection_endpoint):
|
|||
def introspection_app_with_custom_mapping(mock_introspection_endpoint):
|
||||
app = FastAPI()
|
||||
auth_config = AuthenticationConfig(
|
||||
provider_type=AuthProviderType.OAUTH2_TOKEN,
|
||||
config={
|
||||
"jwks": None,
|
||||
"introspection": {
|
||||
"url": mock_introspection_endpoint,
|
||||
"client_id": "myclient",
|
||||
"client_secret": "abcdefg",
|
||||
"send_secret_in_body": "true",
|
||||
},
|
||||
"claims_mapping": {
|
||||
provider_config=OAuth2TokenAuthConfig(
|
||||
type=AuthProviderType.OAUTH2_TOKEN,
|
||||
introspection=OAuth2IntrospectionConfig(
|
||||
url=mock_introspection_endpoint,
|
||||
client_id="myclient",
|
||||
client_secret="abcdefg",
|
||||
send_secret_in_body=True,
|
||||
),
|
||||
claims_mapping={
|
||||
"sub": "roles",
|
||||
"scope": "roles",
|
||||
"groups": "teams",
|
||||
"aud": "namespaces",
|
||||
},
|
||||
},
|
||||
),
|
||||
access_policy=[],
|
||||
)
|
||||
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
|
||||
|
||||
|
@ -507,13 +495,14 @@ def introspection_client_with_custom_mapping(introspection_app_with_custom_mappi
|
|||
def test_missing_auth_header_introspection(introspection_client):
|
||||
response = introspection_client.get("/test")
|
||||
assert response.status_code == 401
|
||||
assert "Missing or invalid Authorization header" in response.json()["error"]["message"]
|
||||
assert "Authentication required" in response.json()["error"]["message"]
|
||||
assert "OAuth2 Bearer token" in response.json()["error"]["message"]
|
||||
|
||||
|
||||
def test_invalid_auth_header_format_introspection(introspection_client):
|
||||
response = introspection_client.get("/test", headers={"Authorization": "InvalidFormat token123"})
|
||||
assert response.status_code == 401
|
||||
assert "Missing or invalid Authorization header" in response.json()["error"]["message"]
|
||||
assert "Invalid Authorization header format" in response.json()["error"]["message"]
|
||||
|
||||
|
||||
async def mock_introspection_active(*args, **kwargs):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue