mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-29 03:14:19 +00:00
gh auth, commit
# What does this PR do? ## Test Plan # What does this PR do? ## Test Plan
This commit is contained in:
parent
114946ae88
commit
94282d3482
9 changed files with 962 additions and 43 deletions
|
@ -11,10 +11,13 @@ import pytest
|
|||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from llama_stack.distribution.datatypes import AuthenticationConfig
|
||||
from llama_stack.distribution.datatypes import (
|
||||
AuthProviderType,
|
||||
CustomAuthConfig,
|
||||
OAuth2TokenAuthConfig,
|
||||
)
|
||||
from llama_stack.distribution.server.auth import AuthenticationMiddleware
|
||||
from llama_stack.distribution.server.auth_providers import (
|
||||
AuthProviderType,
|
||||
get_attributes_from_claims,
|
||||
)
|
||||
|
||||
|
@ -60,8 +63,8 @@ def invalid_token():
|
|||
@pytest.fixture
|
||||
def http_app(mock_auth_endpoint):
|
||||
app = FastAPI()
|
||||
auth_config = AuthenticationConfig(
|
||||
provider_type=AuthProviderType.CUSTOM,
|
||||
auth_config = CustomAuthConfig(
|
||||
type=AuthProviderType.CUSTOM,
|
||||
config={"endpoint": mock_auth_endpoint},
|
||||
)
|
||||
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
|
||||
|
@ -76,9 +79,9 @@ def http_app(mock_auth_endpoint):
|
|||
@pytest.fixture
|
||||
def k8s_app():
|
||||
app = FastAPI()
|
||||
auth_config = AuthenticationConfig(
|
||||
provider_type=AuthProviderType.KUBERNETES,
|
||||
config={"api_server_url": "https://kubernetes.default.svc"},
|
||||
auth_config = CustomAuthConfig(
|
||||
type=AuthProviderType.CUSTOM,
|
||||
config={"provider_type": "kubernetes", "api_server_url": "https://kubernetes.default.svc"},
|
||||
)
|
||||
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
|
||||
|
||||
|
@ -116,8 +119,8 @@ def mock_scope():
|
|||
@pytest.fixture
|
||||
def mock_http_middleware(mock_auth_endpoint):
|
||||
mock_app = AsyncMock()
|
||||
auth_config = AuthenticationConfig(
|
||||
provider_type=AuthProviderType.CUSTOM,
|
||||
auth_config = CustomAuthConfig(
|
||||
type=AuthProviderType.CUSTOM,
|
||||
config={"endpoint": mock_auth_endpoint},
|
||||
)
|
||||
return AuthenticationMiddleware(mock_app, auth_config), mock_app
|
||||
|
@ -126,9 +129,9 @@ def mock_http_middleware(mock_auth_endpoint):
|
|||
@pytest.fixture
|
||||
def mock_k8s_middleware():
|
||||
mock_app = AsyncMock()
|
||||
auth_config = AuthenticationConfig(
|
||||
provider_type=AuthProviderType.KUBERNETES,
|
||||
config={"api_server_url": "https://kubernetes.default.svc"},
|
||||
auth_config = CustomAuthConfig(
|
||||
type=AuthProviderType.CUSTOM,
|
||||
config={"provider_type": "kubernetes", "api_server_url": "https://kubernetes.default.svc"},
|
||||
)
|
||||
return AuthenticationMiddleware(mock_app, auth_config), mock_app
|
||||
|
||||
|
@ -161,7 +164,8 @@ async def mock_post_exception(*args, **kwargs):
|
|||
def test_missing_auth_header(http_client):
|
||||
response = http_client.get("/test")
|
||||
assert response.status_code == 401
|
||||
assert "Missing or invalid Authorization header" in response.json()["error"]["message"]
|
||||
assert "Authentication required" in response.json()["error"]["message"]
|
||||
assert "validated by mock-auth-service" in response.json()["error"]["message"]
|
||||
|
||||
|
||||
def test_invalid_auth_header_format(http_client):
|
||||
|
@ -261,8 +265,8 @@ async def test_http_middleware_with_access_attributes(mock_http_middleware, mock
|
|||
@pytest.fixture
|
||||
def oauth2_app():
|
||||
app = FastAPI()
|
||||
auth_config = AuthenticationConfig(
|
||||
provider_type=AuthProviderType.OAUTH2_TOKEN,
|
||||
auth_config = OAuth2TokenAuthConfig(
|
||||
type=AuthProviderType.OAUTH2_TOKEN,
|
||||
config={
|
||||
"jwks": {
|
||||
"uri": "http://mock-authz-service/token/introspect",
|
||||
|
@ -288,7 +292,8 @@ def oauth2_client(oauth2_app):
|
|||
def test_missing_auth_header_oauth2(oauth2_client):
|
||||
response = oauth2_client.get("/test")
|
||||
assert response.status_code == 401
|
||||
assert "Missing or invalid Authorization header" in response.json()["error"]["message"]
|
||||
assert "Authentication required" in response.json()["error"]["message"]
|
||||
assert "OAuth2 Bearer token" in response.json()["error"]["message"]
|
||||
|
||||
|
||||
def test_invalid_auth_header_format_oauth2(oauth2_client):
|
||||
|
@ -357,8 +362,8 @@ async def mock_auth_jwks_response(*args, **kwargs):
|
|||
@pytest.fixture
|
||||
def oauth2_app_with_jwks_token():
|
||||
app = FastAPI()
|
||||
auth_config = AuthenticationConfig(
|
||||
provider_type=AuthProviderType.OAUTH2_TOKEN,
|
||||
auth_config = OAuth2TokenAuthConfig(
|
||||
type=AuthProviderType.OAUTH2_TOKEN,
|
||||
config={
|
||||
"jwks": {
|
||||
"uri": "http://mock-authz-service/token/introspect",
|
||||
|
@ -448,8 +453,8 @@ def mock_introspection_endpoint():
|
|||
@pytest.fixture
|
||||
def introspection_app(mock_introspection_endpoint):
|
||||
app = FastAPI()
|
||||
auth_config = AuthenticationConfig(
|
||||
provider_type=AuthProviderType.OAUTH2_TOKEN,
|
||||
auth_config = OAuth2TokenAuthConfig(
|
||||
type=AuthProviderType.OAUTH2_TOKEN,
|
||||
config={
|
||||
"jwks": None,
|
||||
"introspection": {"url": mock_introspection_endpoint, "client_id": "myclient", "client_secret": "abcdefg"},
|
||||
|
@ -467,8 +472,8 @@ def introspection_app(mock_introspection_endpoint):
|
|||
@pytest.fixture
|
||||
def introspection_app_with_custom_mapping(mock_introspection_endpoint):
|
||||
app = FastAPI()
|
||||
auth_config = AuthenticationConfig(
|
||||
provider_type=AuthProviderType.OAUTH2_TOKEN,
|
||||
auth_config = OAuth2TokenAuthConfig(
|
||||
type=AuthProviderType.OAUTH2_TOKEN,
|
||||
config={
|
||||
"jwks": None,
|
||||
"introspection": {
|
||||
|
@ -507,7 +512,8 @@ def introspection_client_with_custom_mapping(introspection_app_with_custom_mappi
|
|||
def test_missing_auth_header_introspection(introspection_client):
|
||||
response = introspection_client.get("/test")
|
||||
assert response.status_code == 401
|
||||
assert "Missing or invalid Authorization header" in response.json()["error"]["message"]
|
||||
assert "Authentication required" in response.json()["error"]["message"]
|
||||
assert "OAuth2 Bearer token" in response.json()["error"]["message"]
|
||||
|
||||
|
||||
def test_invalid_auth_header_format_introspection(introspection_client):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue