llama-stack-mirror/tests/unit/server/test_test_context_middleware.py
Matt Leader 722d9c53e7
fix(server): add middleware for provider data and test context (#4367)
# What does this PR do?
Consolidates provider data context handling into middleware, eliminating
duplication between FastAPI router routes and legacy @webmethod routes.

Closes #4366 

## Test Plan

Added unit test suite `test_test_context_middleware`, specifically
`test_middleware_extracts_test_id_from_header` to validate the expected
behavior.
```
❯ ./scripts/unit-tests.sh tests/unit/
```

Integration of the middleware test context with the `files` FastAPI
router migration from
[pull/4339](https://github.com/llamastack/llama-stack/pull/4339).
```
❯ git switch migrate-files-api
Switched to branch 'migrate-files-api'
❯ git rebase fix-test-ctx-middleware
Successfully rebased and updated refs/heads/migrate-files-api.
❯ ./scripts/integration-tests.sh --inference-mode replay --suite base --setup ollama --stack-config server:starter --subdirs files
```

Signed-off-by: Matthew F Leader <mleader@redhat.com>
2025-12-16 15:00:48 -05:00

106 lines
3.3 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import os
import pytest
from fastapi import APIRouter, FastAPI
from starlette.testclient import TestClient
from llama_stack.core.server.server import ProviderDataMiddleware
from llama_stack.core.testing_context import get_test_context
@pytest.fixture
def app_with_middleware():
"""Create a minimal FastAPI app with ProviderDataMiddleware."""
app = FastAPI()
router = APIRouter()
@router.get("/test-context")
def get_current_test_context():
return {"test_id": get_test_context()}
app.include_router(router)
app.add_middleware(ProviderDataMiddleware)
return app
@pytest.fixture
def test_mode_env(monkeypatch):
"""Set environment variables required for test context extraction."""
monkeypatch.setenv("LLAMA_STACK_TEST_INFERENCE_MODE", "replay")
monkeypatch.setenv("LLAMA_STACK_TEST_STACK_CONFIG_TYPE", "server")
def test_middleware_returns_none_without_header(app_with_middleware, test_mode_env):
"""Without the provider data header, test context should be None."""
client = TestClient(app_with_middleware)
response = client.get("/test-context")
assert response.status_code == 200
assert response.json()["test_id"] is None
def test_middleware_extracts_test_id_from_header(app_with_middleware, test_mode_env):
"""With the provider data header containing __test_id, it should be extracted."""
client = TestClient(app_with_middleware)
provider_data = json.dumps({"__test_id": "test-abc-123"})
response = client.get(
"/test-context",
headers={"X-LlamaStack-Provider-Data": provider_data},
)
assert response.status_code == 200
assert response.json()["test_id"] == "test-abc-123"
def test_middleware_handles_empty_provider_data(app_with_middleware, test_mode_env):
"""Empty provider data should result in None test context."""
client = TestClient(app_with_middleware)
response = client.get(
"/test-context",
headers={"X-LlamaStack-Provider-Data": "{}"},
)
assert response.status_code == 200
assert response.json()["test_id"] is None
def test_middleware_handles_invalid_json(app_with_middleware, test_mode_env):
"""Invalid JSON in header should not crash, test context should be None."""
client = TestClient(app_with_middleware)
response = client.get(
"/test-context",
headers={"X-LlamaStack-Provider-Data": "not-valid-json"},
)
assert response.status_code == 200
assert response.json()["test_id"] is None
def test_middleware_noop_without_test_mode(app_with_middleware):
"""Without test mode env vars, middleware should not extract test context."""
# Ensure env vars are not set
os.environ.pop("LLAMA_STACK_TEST_INFERENCE_MODE", None)
os.environ.pop("LLAMA_STACK_TEST_STACK_CONFIG_TYPE", None)
client = TestClient(app_with_middleware)
provider_data = json.dumps({"__test_id": "test-abc-123"})
response = client.get(
"/test-context",
headers={"X-LlamaStack-Provider-Data": provider_data},
)
assert response.status_code == 200
assert response.json()["test_id"] is None