fix(tests): ensure test isolation in server mode

Propagate test IDs from client to server via HTTP headers to maintain
proper test isolation when running with server-based stack configs. Without
this, recorded/replayed inference requests in server mode would leak across
tests.

Changes:
- Patch client _prepare_request to inject test ID into provider data header
- Sync test context from provider data on server side before storage operations
- Set LLAMA_STACK_TEST_STACK_CONFIG_TYPE env var based on stack config
- Configure console width for cleaner log output in CI
- Add SQLITE_STORE_DIR temp directory for test data isolation
This commit is contained in:
Ashwin Bharambe 2025-10-08 11:00:49 -07:00
parent bba9957edd
commit d5296a35f6
4 changed files with 219 additions and 106 deletions

View file

@ -128,7 +128,10 @@ def strip_rich_markup(text):
class CustomRichHandler(RichHandler):
def __init__(self, *args, **kwargs):
kwargs["console"] = Console()
# Set a reasonable default width for console output, especially when redirected to files
console_width = int(os.environ.get("LLAMA_STACK_LOG_WIDTH", "120"))
# Don't force terminal codes to avoid ANSI escape codes in log files
kwargs["console"] = Console(width=console_width)
super().__init__(*args, **kwargs)
def emit(self, record):

View file

@ -15,7 +15,7 @@ from enum import StrEnum
from pathlib import Path
from typing import Any, Literal, cast
from openai import NOT_GIVEN
from openai import NOT_GIVEN, OpenAI
from llama_stack.log import get_logger
@ -79,6 +79,96 @@ def normalize_request(method: str, url: str, headers: dict[str, Any], body: dict
return hashlib.sha256(normalized_json.encode()).hexdigest()
def _sync_test_context_from_provider_data():
"""In server mode, sync test ID from provider_data to _test_context.
This ensures that storage operations (which read from _test_context) work correctly
in server mode where the test ID arrives via HTTP header provider_data.
Returns a token to reset _test_context, or None if no sync was needed.
"""
stack_config_type = os.environ.get("LLAMA_STACK_TEST_STACK_CONFIG_TYPE", "library_client")
if stack_config_type != "server":
return None
try:
from llama_stack.core.request_headers import PROVIDER_DATA_VAR
provider_data = PROVIDER_DATA_VAR.get()
if provider_data and "__test_id" in provider_data:
test_id = provider_data["__test_id"]
return _test_context.set(test_id)
except ImportError:
pass
return None
def patch_httpx_for_test_id():
"""Patch client _prepare_request methods to inject test ID into provider data header.
This is needed for server mode where the test ID must be transported from
client to server via HTTP headers. In library_client mode, this patch is a no-op
since everything runs in the same process.
We use the _prepare_request hook that Stainless clients provide for mutating
requests after construction but before sending.
"""
from llama_stack_client import LlamaStackClient
if "llama_stack_client_prepare_request" in _original_methods:
return
_original_methods["llama_stack_client_prepare_request"] = LlamaStackClient._prepare_request
_original_methods["openai_prepare_request"] = OpenAI._prepare_request
def patched_prepare_request(self, request):
# Call original first (it's a sync method that returns None)
# Determine which original to call based on client type
if "llama_stack_client" in self.__class__.__module__:
_original_methods["llama_stack_client_prepare_request"](self, request)
_original_methods["openai_prepare_request"](self, request)
# Only inject test ID in server mode
stack_config_type = os.environ.get("LLAMA_STACK_TEST_STACK_CONFIG_TYPE", "library_client")
test_id = _test_context.get()
if stack_config_type == "server" and test_id:
provider_data_header = request.headers.get("X-LlamaStack-Provider-Data")
if provider_data_header:
provider_data = json.loads(provider_data_header)
else:
provider_data = {}
provider_data["__test_id"] = test_id
request.headers["X-LlamaStack-Provider-Data"] = json.dumps(provider_data)
return None
LlamaStackClient._prepare_request = patched_prepare_request
OpenAI._prepare_request = patched_prepare_request
# currently, unpatch is never called
def unpatch_httpx_for_test_id():
"""Remove client _prepare_request patches for test ID injection."""
if "llama_stack_client_prepare_request" not in _original_methods:
return
from llama_stack_client import LlamaStackClient
LlamaStackClient._prepare_request = _original_methods["llama_stack_client_prepare_request"]
del _original_methods["llama_stack_client_prepare_request"]
# Also restore OpenAI client if it was patched
if "openai_prepare_request" in _original_methods:
OpenAI._prepare_request = _original_methods["openai_prepare_request"]
del _original_methods["openai_prepare_request"]
def get_inference_mode() -> InferenceMode:
return InferenceMode(os.environ.get("LLAMA_STACK_TEST_INFERENCE_MODE", "replay").lower())
@ -244,7 +334,7 @@ class ResponseStorage:
with open(response_path, "w") as f:
json.dump(
{
"test_id": _test_context.get(), # Include for debugging
"test_id": _test_context.get(),
"request": request,
"response": serialized_response,
},
@ -386,6 +476,10 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
else:
return await original_method(self, *args, **kwargs)
# In server mode, sync test ID from provider_data to _test_context for storage operations
test_context_token = _sync_test_context_from_provider_data()
try:
# Get base URL based on client type
if client_type == "openai":
base_url = str(self._client.base_url)
@ -488,6 +582,9 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
else:
raise AssertionError(f"Invalid mode: {mode}")
finally:
if test_context_token:
_test_context.reset(test_context_token)
def patch_inference_clients():

View file

@ -124,12 +124,6 @@ echo ""
echo "Checking llama packages"
uv pip list | grep llama
# Check storage and memory before tests
echo "=== System Resources Before Tests ==="
free -h 2>/dev/null || echo "free command not available"
df -h
echo ""
# Set environment variables
export LLAMA_STACK_CLIENT_TIMEOUT=300
@ -144,6 +138,17 @@ echo "=== Applying Setup Environment Variables ==="
# the server needs this
export LLAMA_STACK_TEST_INFERENCE_MODE="$INFERENCE_MODE"
export SQLITE_STORE_DIR=$(mktemp -d)
echo "Setting SQLITE_STORE_DIR: $SQLITE_STORE_DIR"
# Determine stack config type for api_recorder test isolation
if [[ "$STACK_CONFIG" == server:* ]]; then
export LLAMA_STACK_TEST_STACK_CONFIG_TYPE="server"
echo "Setting stack config type: server"
else
export LLAMA_STACK_TEST_STACK_CONFIG_TYPE="library_client"
echo "Setting stack config type: library_client"
fi
SETUP_ENV=$(PYTHONPATH=$THIS_DIR/.. python "$THIS_DIR/get_setup_env.py" --suite "$TEST_SUITE" --setup "$TEST_SETUP" --format bash)
echo "Setting up environment variables:"
@ -186,6 +191,8 @@ if [[ "$STACK_CONFIG" == *"server:"* ]]; then
echo "Llama Stack Server is already running, skipping start"
else
echo "=== Starting Llama Stack Server ==="
# Set a reasonable log width for better readability in server.log
export LLAMA_STACK_LOG_WIDTH=120
nohup llama stack run ci-tests --image-type venv > server.log 2>&1 &
echo "Waiting for Llama Stack Server to start..."
@ -277,11 +284,5 @@ else
exit 1
fi
# Check storage and memory after tests
echo ""
echo "=== System Resources After Tests ==="
free -h 2>/dev/null || echo "free command not available"
df -h
echo ""
echo "=== Integration Tests Complete ==="

View file

@ -35,6 +35,18 @@ def pytest_sessionstart(session):
if "LLAMA_STACK_TEST_INFERENCE_MODE" not in os.environ:
os.environ["LLAMA_STACK_TEST_INFERENCE_MODE"] = "replay"
stack_config = session.config.getoption("--stack-config", default=None)
if stack_config and stack_config.startswith("server:"):
os.environ["LLAMA_STACK_TEST_STACK_CONFIG_TYPE"] = "server"
logger.info(f"Test stack config type: server (stack_config={stack_config})")
else:
os.environ["LLAMA_STACK_TEST_STACK_CONFIG_TYPE"] = "library_client"
logger.info(f"Test stack config type: library_client (stack_config={stack_config})")
from llama_stack.testing.inference_recorder import patch_httpx_for_test_id
patch_httpx_for_test_id()
@pytest.fixture(autouse=True)
def _track_test_context(request):