diff --git a/src/llama_stack/core/server/server.py b/src/llama_stack/core/server/server.py index 7fd0a6182..c5459bd9e 100644 --- a/src/llama_stack/core/server/server.py +++ b/src/llama_stack/core/server/server.py @@ -235,56 +235,36 @@ async def log_request_pre_validation(request: Request): def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable: @functools.wraps(func) async def route_handler(request: Request, **kwargs): - # Get auth attributes from the request scope - user = user_from_scope(request.scope) - await log_request_pre_validation(request) - test_context_token = None - test_context_var = None - reset_test_context_fn = None + is_streaming = is_streaming_request(func.__name__, request, **kwargs) - # Use context manager with both provider data and auth attributes - with request_provider_data_context(request.headers, user): - if os.environ.get("LLAMA_STACK_TEST_INFERENCE_MODE"): - from llama_stack.core.testing_context import ( - TEST_CONTEXT, - reset_test_context, - sync_test_context_from_provider_data, - ) + try: + if is_streaming: + # Preserve context vars across async generator boundaries + context_vars = [PROVIDER_DATA_VAR] + if os.environ.get("LLAMA_STACK_TEST_INFERENCE_MODE"): + from llama_stack.core.testing_context import TEST_CONTEXT - test_context_token = sync_test_context_from_provider_data() - test_context_var = TEST_CONTEXT - reset_test_context_fn = reset_test_context + context_vars.append(TEST_CONTEXT) + gen = preserve_contexts_async_generator(sse_generator(func(**kwargs)), context_vars) + return StreamingResponse(gen, media_type="text/event-stream") + else: + value = func(**kwargs) + result = await maybe_await(value) + if isinstance(result, PaginatedResponse) and result.url is None: + result.url = route - is_streaming = is_streaming_request(func.__name__, request, **kwargs) + if method.upper() == "DELETE" and result is None: + return Response(status_code=httpx.codes.NO_CONTENT) - try: - if is_streaming: - context_vars = [PROVIDER_DATA_VAR] - if test_context_var is not None: - context_vars.append(test_context_var) - gen = preserve_contexts_async_generator(sse_generator(func(**kwargs)), context_vars) - return StreamingResponse(gen, media_type="text/event-stream") - else: - value = func(**kwargs) - result = await maybe_await(value) - if isinstance(result, PaginatedResponse) and result.url is None: - result.url = route - - if method.upper() == "DELETE" and result is None: - return Response(status_code=httpx.codes.NO_CONTENT) - - return result - except Exception as e: - if logger.isEnabledFor(logging.INFO): - logger.exception(f"Error executing endpoint {route=} {method=}") - else: - logger.error(f"Error executing endpoint {route=} {method=}: {str(e)}") - raise translate_exception(e) from e - finally: - if test_context_token is not None and reset_test_context_fn is not None: - reset_test_context_fn(test_context_token) + return result + except Exception as e: + if logger.isEnabledFor(logging.INFO): + logger.exception(f"Error executing endpoint {route=} {method=}") + else: + logger.error(f"Error executing endpoint {route=} {method=}: {str(e)}") + raise translate_exception(e) from e sig = inspect.signature(func) @@ -356,6 +336,42 @@ class ClientVersionMiddleware: return await self.app(scope, receive, send) +class ProviderDataMiddleware: + """Middleware to set up request context for all routes. + + Sets up provider data context from X-LlamaStack-Provider-Data header + and auth attributes. Also handles test context propagation when + running in test mode for deterministic ID generation. + """ + + def __init__(self, app): + self.app = app + + async def __call__(self, scope, receive, send): + if scope["type"] == "http": + headers = {k.decode(): v.decode() for k, v in scope.get("headers", [])} + user = user_from_scope(scope) + + with request_provider_data_context(headers, user): + test_context_token = None + reset_fn = None + if os.environ.get("LLAMA_STACK_TEST_INFERENCE_MODE"): + from llama_stack.core.testing_context import ( + reset_test_context, + sync_test_context_from_provider_data, + ) + + test_context_token = sync_test_context_from_provider_data() + reset_fn = reset_test_context + try: + return await self.app(scope, receive, send) + finally: + if test_context_token and reset_fn: + reset_fn(test_context_token) + + return await self.app(scope, receive, send) + + def create_app() -> StackApp: """Create and configure the FastAPI application. @@ -395,6 +411,8 @@ def create_app() -> StackApp: if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"): app.add_middleware(ClientVersionMiddleware) + app.add_middleware(ProviderDataMiddleware) + impls = app.stack.impls if config.server.auth: diff --git a/tests/unit/server/test_test_context_middleware.py b/tests/unit/server/test_test_context_middleware.py new file mode 100644 index 000000000..4ad61d1ca --- /dev/null +++ b/tests/unit/server/test_test_context_middleware.py @@ -0,0 +1,106 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import json +import os + +import pytest +from fastapi import APIRouter, FastAPI +from starlette.testclient import TestClient + +from llama_stack.core.server.server import ProviderDataMiddleware +from llama_stack.core.testing_context import get_test_context + + +@pytest.fixture +def app_with_middleware(): + """Create a minimal FastAPI app with ProviderDataMiddleware.""" + app = FastAPI() + + router = APIRouter() + + @router.get("/test-context") + def get_current_test_context(): + return {"test_id": get_test_context()} + + app.include_router(router) + app.add_middleware(ProviderDataMiddleware) + + return app + + +@pytest.fixture +def test_mode_env(monkeypatch): + """Set environment variables required for test context extraction.""" + monkeypatch.setenv("LLAMA_STACK_TEST_INFERENCE_MODE", "replay") + monkeypatch.setenv("LLAMA_STACK_TEST_STACK_CONFIG_TYPE", "server") + + +def test_middleware_returns_none_without_header(app_with_middleware, test_mode_env): + """Without the provider data header, test context should be None.""" + client = TestClient(app_with_middleware) + response = client.get("/test-context") + + assert response.status_code == 200 + assert response.json()["test_id"] is None + + +def test_middleware_extracts_test_id_from_header(app_with_middleware, test_mode_env): + """With the provider data header containing __test_id, it should be extracted.""" + client = TestClient(app_with_middleware) + + provider_data = json.dumps({"__test_id": "test-abc-123"}) + response = client.get( + "/test-context", + headers={"X-LlamaStack-Provider-Data": provider_data}, + ) + + assert response.status_code == 200 + assert response.json()["test_id"] == "test-abc-123" + + +def test_middleware_handles_empty_provider_data(app_with_middleware, test_mode_env): + """Empty provider data should result in None test context.""" + client = TestClient(app_with_middleware) + + response = client.get( + "/test-context", + headers={"X-LlamaStack-Provider-Data": "{}"}, + ) + + assert response.status_code == 200 + assert response.json()["test_id"] is None + + +def test_middleware_handles_invalid_json(app_with_middleware, test_mode_env): + """Invalid JSON in header should not crash, test context should be None.""" + client = TestClient(app_with_middleware) + + response = client.get( + "/test-context", + headers={"X-LlamaStack-Provider-Data": "not-valid-json"}, + ) + + assert response.status_code == 200 + assert response.json()["test_id"] is None + + +def test_middleware_noop_without_test_mode(app_with_middleware): + """Without test mode env vars, middleware should not extract test context.""" + # Ensure env vars are not set + os.environ.pop("LLAMA_STACK_TEST_INFERENCE_MODE", None) + os.environ.pop("LLAMA_STACK_TEST_STACK_CONFIG_TYPE", None) + + client = TestClient(app_with_middleware) + + provider_data = json.dumps({"__test_id": "test-abc-123"}) + response = client.get( + "/test-context", + headers={"X-LlamaStack-Provider-Data": provider_data}, + ) + + assert response.status_code == 200 + assert response.json()["test_id"] is None