fix(server): add middleware for provider data and test context (#4367)

# What does this PR do?
Consolidates provider data context handling into middleware, eliminating
duplication between FastAPI router routes and legacy @webmethod routes.

Closes #4366 

## Test Plan

Added unit test suite `test_test_context_middleware`, specifically
`test_middleware_extracts_test_id_from_header` to validate the expected
behavior.
```
❯ ./scripts/unit-tests.sh tests/unit/
```

Integration of the middleware test context with the `files` FastAPI
router migration from
[pull/4339](https://github.com/llamastack/llama-stack/pull/4339).
```
❯ git switch migrate-files-api
Switched to branch 'migrate-files-api'
❯ git rebase fix-test-ctx-middleware
Successfully rebased and updated refs/heads/migrate-files-api.
❯ ./scripts/integration-tests.sh --inference-mode replay --suite base --setup ollama --stack-config server:starter --subdirs files
```

Signed-off-by: Matthew F Leader <mleader@redhat.com>
This commit is contained in:
Matt Leader 2025-12-16 15:00:48 -05:00 committed by GitHub
parent 5abb7df41a
commit 722d9c53e7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 168 additions and 44 deletions

View file

@ -235,35 +235,18 @@ async def log_request_pre_validation(request: Request):
def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable: def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
@functools.wraps(func) @functools.wraps(func)
async def route_handler(request: Request, **kwargs): async def route_handler(request: Request, **kwargs):
# Get auth attributes from the request scope
user = user_from_scope(request.scope)
await log_request_pre_validation(request) await log_request_pre_validation(request)
test_context_token = None
test_context_var = None
reset_test_context_fn = None
# Use context manager with both provider data and auth attributes
with request_provider_data_context(request.headers, user):
if os.environ.get("LLAMA_STACK_TEST_INFERENCE_MODE"):
from llama_stack.core.testing_context import (
TEST_CONTEXT,
reset_test_context,
sync_test_context_from_provider_data,
)
test_context_token = sync_test_context_from_provider_data()
test_context_var = TEST_CONTEXT
reset_test_context_fn = reset_test_context
is_streaming = is_streaming_request(func.__name__, request, **kwargs) is_streaming = is_streaming_request(func.__name__, request, **kwargs)
try: try:
if is_streaming: if is_streaming:
# Preserve context vars across async generator boundaries
context_vars = [PROVIDER_DATA_VAR] context_vars = [PROVIDER_DATA_VAR]
if test_context_var is not None: if os.environ.get("LLAMA_STACK_TEST_INFERENCE_MODE"):
context_vars.append(test_context_var) from llama_stack.core.testing_context import TEST_CONTEXT
context_vars.append(TEST_CONTEXT)
gen = preserve_contexts_async_generator(sse_generator(func(**kwargs)), context_vars) gen = preserve_contexts_async_generator(sse_generator(func(**kwargs)), context_vars)
return StreamingResponse(gen, media_type="text/event-stream") return StreamingResponse(gen, media_type="text/event-stream")
else: else:
@ -282,9 +265,6 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
else: else:
logger.error(f"Error executing endpoint {route=} {method=}: {str(e)}") logger.error(f"Error executing endpoint {route=} {method=}: {str(e)}")
raise translate_exception(e) from e raise translate_exception(e) from e
finally:
if test_context_token is not None and reset_test_context_fn is not None:
reset_test_context_fn(test_context_token)
sig = inspect.signature(func) sig = inspect.signature(func)
@ -356,6 +336,42 @@ class ClientVersionMiddleware:
return await self.app(scope, receive, send) return await self.app(scope, receive, send)
class ProviderDataMiddleware:
"""Middleware to set up request context for all routes.
Sets up provider data context from X-LlamaStack-Provider-Data header
and auth attributes. Also handles test context propagation when
running in test mode for deterministic ID generation.
"""
def __init__(self, app):
self.app = app
async def __call__(self, scope, receive, send):
if scope["type"] == "http":
headers = {k.decode(): v.decode() for k, v in scope.get("headers", [])}
user = user_from_scope(scope)
with request_provider_data_context(headers, user):
test_context_token = None
reset_fn = None
if os.environ.get("LLAMA_STACK_TEST_INFERENCE_MODE"):
from llama_stack.core.testing_context import (
reset_test_context,
sync_test_context_from_provider_data,
)
test_context_token = sync_test_context_from_provider_data()
reset_fn = reset_test_context
try:
return await self.app(scope, receive, send)
finally:
if test_context_token and reset_fn:
reset_fn(test_context_token)
return await self.app(scope, receive, send)
def create_app() -> StackApp: def create_app() -> StackApp:
"""Create and configure the FastAPI application. """Create and configure the FastAPI application.
@ -395,6 +411,8 @@ def create_app() -> StackApp:
if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"): if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"):
app.add_middleware(ClientVersionMiddleware) app.add_middleware(ClientVersionMiddleware)
app.add_middleware(ProviderDataMiddleware)
impls = app.stack.impls impls = app.stack.impls
if config.server.auth: if config.server.auth:

View file

@ -0,0 +1,106 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import os
import pytest
from fastapi import APIRouter, FastAPI
from starlette.testclient import TestClient
from llama_stack.core.server.server import ProviderDataMiddleware
from llama_stack.core.testing_context import get_test_context
@pytest.fixture
def app_with_middleware():
"""Create a minimal FastAPI app with ProviderDataMiddleware."""
app = FastAPI()
router = APIRouter()
@router.get("/test-context")
def get_current_test_context():
return {"test_id": get_test_context()}
app.include_router(router)
app.add_middleware(ProviderDataMiddleware)
return app
@pytest.fixture
def test_mode_env(monkeypatch):
"""Set environment variables required for test context extraction."""
monkeypatch.setenv("LLAMA_STACK_TEST_INFERENCE_MODE", "replay")
monkeypatch.setenv("LLAMA_STACK_TEST_STACK_CONFIG_TYPE", "server")
def test_middleware_returns_none_without_header(app_with_middleware, test_mode_env):
"""Without the provider data header, test context should be None."""
client = TestClient(app_with_middleware)
response = client.get("/test-context")
assert response.status_code == 200
assert response.json()["test_id"] is None
def test_middleware_extracts_test_id_from_header(app_with_middleware, test_mode_env):
"""With the provider data header containing __test_id, it should be extracted."""
client = TestClient(app_with_middleware)
provider_data = json.dumps({"__test_id": "test-abc-123"})
response = client.get(
"/test-context",
headers={"X-LlamaStack-Provider-Data": provider_data},
)
assert response.status_code == 200
assert response.json()["test_id"] == "test-abc-123"
def test_middleware_handles_empty_provider_data(app_with_middleware, test_mode_env):
"""Empty provider data should result in None test context."""
client = TestClient(app_with_middleware)
response = client.get(
"/test-context",
headers={"X-LlamaStack-Provider-Data": "{}"},
)
assert response.status_code == 200
assert response.json()["test_id"] is None
def test_middleware_handles_invalid_json(app_with_middleware, test_mode_env):
"""Invalid JSON in header should not crash, test context should be None."""
client = TestClient(app_with_middleware)
response = client.get(
"/test-context",
headers={"X-LlamaStack-Provider-Data": "not-valid-json"},
)
assert response.status_code == 200
assert response.json()["test_id"] is None
def test_middleware_noop_without_test_mode(app_with_middleware):
"""Without test mode env vars, middleware should not extract test context."""
# Ensure env vars are not set
os.environ.pop("LLAMA_STACK_TEST_INFERENCE_MODE", None)
os.environ.pop("LLAMA_STACK_TEST_STACK_CONFIG_TYPE", None)
client = TestClient(app_with_middleware)
provider_data = json.dumps({"__test_id": "test-abc-123"})
response = client.get(
"/test-context",
headers={"X-LlamaStack-Provider-Data": provider_data},
)
assert response.status_code == 200
assert response.json()["test_id"] is None