diff --git a/llama_stack/distribution/datatypes.py b/llama_stack/distribution/datatypes.py index 7e1d8c016..e16e047e5 100644 --- a/llama_stack/distribution/datatypes.py +++ b/llama_stack/distribution/datatypes.py @@ -125,6 +125,13 @@ class LoggingConfig(BaseModel): ) +class AuthenticationConfig(BaseModel): + endpoint: str = Field( + ..., + description="Endpoint URL to validate authentication tokens", + ) + + class ServerConfig(BaseModel): port: int = Field( default=8321, @@ -140,6 +147,10 @@ class ServerConfig(BaseModel): default=None, description="Path to TLS key file for HTTPS", ) + auth: Optional[AuthenticationConfig] = Field( + default=None, + description="Authentication configuration for the server", + ) class StackRunConfig(BaseModel): diff --git a/llama_stack/distribution/server/auth.py b/llama_stack/distribution/server/auth.py new file mode 100644 index 000000000..bb577bae5 --- /dev/null +++ b/llama_stack/distribution/server/auth.py @@ -0,0 +1,69 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import json +from urllib.parse import parse_qs + +import httpx + +from llama_stack.log import get_logger + +logger = get_logger(name=__name__, category="auth") + + +class AuthenticationMiddleware: + def __init__(self, app, auth_endpoint): + self.app = app + self.auth_endpoint = auth_endpoint + + async def __call__(self, scope, receive, send): + if scope["type"] == "http": + headers = dict(scope.get("headers", [])) + auth_header = headers.get(b"authorization", b"").decode() + + if not auth_header or not auth_header.startswith("Bearer "): + return await self._send_auth_error(send, "Missing or invalid Authorization header") + + api_key = auth_header.split("Bearer ", 1)[1] + + path = scope.get("path", "") + request_headers = {k.decode(): v.decode() for k, v in headers.items()} + + query_string = scope.get("query_string", b"").decode() + params = parse_qs(query_string) + + auth_data = { + "api_key": api_key, + "request": { + "path": path, + "headers": request_headers, + "params": params, + }, + } + + # Validate with authentication endpoint + try: + async with httpx.AsyncClient() as client: + response = await client.post(self.auth_endpoint, json=auth_data) + if response.status_code != 200: + logger.warning(f"Authentication failed: {response.status_code}") + return await self._send_auth_error(send, "Authentication failed") + except Exception: + logger.exception("Error during authentication") + return await self._send_auth_error(send, "Authentication service error") + + return await self.app(scope, receive, send) + + async def _send_auth_error(self, send, message): + await send( + { + "type": "http.response.start", + "status": 401, + "headers": [[b"content-type", b"application/json"]], + } + ) + error_msg = json.dumps({"error": {"message": message}}).encode() + await send({"type": "http.response.body", "body": error_msg}) diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index b37b3a007..460acbc87 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -52,6 +52,7 @@ from llama_stack.providers.utils.telemetry.tracing import ( start_trace, ) +from .auth import AuthenticationMiddleware from .endpoints import get_all_api_endpoints REPO_ROOT = Path(__file__).parent.parent.parent.parent @@ -351,6 +352,11 @@ def main(): if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"): app.add_middleware(ClientVersionMiddleware) + # Add authentication middleware if configured + if config.server.auth and config.server.auth.endpoint: + logger.info(f"Enabling authentication with endpoint: {config.server.auth.endpoint}") + app.add_middleware(AuthenticationMiddleware, auth_endpoint=config.server.auth.endpoint) + try: impls = asyncio.run(construct_stack(config)) except InvalidProviderError as e: diff --git a/tests/unit/server/test_auth.py b/tests/unit/server/test_auth.py new file mode 100644 index 000000000..70f08dbd6 --- /dev/null +++ b/tests/unit/server/test_auth.py @@ -0,0 +1,124 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from unittest.mock import AsyncMock, patch + +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient + +from llama_stack.distribution.server.auth import AuthenticationMiddleware + + +@pytest.fixture +def mock_auth_endpoint(): + return "http://mock-auth-service/validate" + + +@pytest.fixture +def valid_api_key(): + return "valid_api_key_12345" + + +@pytest.fixture +def invalid_api_key(): + return "invalid_api_key_67890" + + +@pytest.fixture +def app(mock_auth_endpoint): + app = FastAPI() + app.add_middleware(AuthenticationMiddleware, auth_endpoint=mock_auth_endpoint) + + @app.get("/test") + def test_endpoint(): + return {"message": "Authentication successful"} + + return app + + +@pytest.fixture +def client(app): + return TestClient(app) + + +async def mock_post_success(*args, **kwargs): + mock_response = AsyncMock() + mock_response.status_code = 200 + return mock_response + + +async def mock_post_failure(*args, **kwargs): + mock_response = AsyncMock() + mock_response.status_code = 401 + return mock_response + + +async def mock_post_exception(*args, **kwargs): + raise Exception("Connection error") + + +def test_missing_auth_header(client): + response = client.get("/test") + assert response.status_code == 401 + assert "Missing or invalid Authorization header" in response.json()["error"]["message"] + + +def test_invalid_auth_header_format(client): + response = client.get("/test", headers={"Authorization": "InvalidFormat token123"}) + assert response.status_code == 401 + assert "Missing or invalid Authorization header" in response.json()["error"]["message"] + + +@patch("httpx.AsyncClient.post", new=mock_post_success) +def test_valid_authentication(client, valid_api_key): + response = client.get("/test", headers={"Authorization": f"Bearer {valid_api_key}"}) + assert response.status_code == 200 + assert response.json() == {"message": "Authentication successful"} + + +@patch("httpx.AsyncClient.post", new=mock_post_failure) +def test_invalid_authentication(client, invalid_api_key): + response = client.get("/test", headers={"Authorization": f"Bearer {invalid_api_key}"}) + assert response.status_code == 401 + assert "Authentication failed" in response.json()["error"]["message"] + + +@patch("httpx.AsyncClient.post", new=mock_post_exception) +def test_auth_service_error(client, valid_api_key): + response = client.get("/test", headers={"Authorization": f"Bearer {valid_api_key}"}) + assert response.status_code == 401 + assert "Authentication service error" in response.json()["error"]["message"] + + +def test_auth_request_payload(client, valid_api_key, mock_auth_endpoint): + with patch("httpx.AsyncClient.post") as mock_post: + mock_response = AsyncMock() + mock_response.status_code = 200 + mock_post.return_value = mock_response + + client.get( + "/test?param1=value1¶m2=value2", + headers={ + "Authorization": f"Bearer {valid_api_key}", + "User-Agent": "TestClient", + "Content-Type": "application/json", + }, + ) + + # Check that the auth endpoint was called with the correct payload + call_args = mock_post.call_args + assert call_args is not None + + url, kwargs = call_args[0][0], call_args[1] + assert url == mock_auth_endpoint + + payload = kwargs["json"] + assert payload["api_key"] == valid_api_key + assert payload["request"]["path"] == "/test" + assert "authorization" in payload["request"]["headers"] + assert "param1" in payload["request"]["params"] + assert "param2" in payload["request"]["params"]