feat(tests): implement test isolation for inference recordings (#3681)

Uses test_id in request hashes and test-scoped subdirectories to prevent
cross-test contamination. Model list endpoints exclude test_id to enable
merging recordings from different servers.

Additionally, this PR adds a `record-if-missing` mode (which we will use
instead of `record` which records everything) which is very useful.

🤖 Co-authored with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
Ashwin Bharambe 2025-10-04 11:34:18 -07:00 committed by GitHub
parent f176196fba
commit 045a0c1d57
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
428 changed files with 85345 additions and 104330 deletions

View file

@ -22,10 +22,18 @@ from llama_stack.log import get_logger
logger = get_logger(__name__, category="testing")
# Global state for the recording system
# Note: Using module globals instead of ContextVars because the session-scoped
# client initialization happens in one async context, but tests run in different
# contexts, and we need the mode/storage to persist across all contexts.
_current_mode: str | None = None
_current_storage: ResponseStorage | None = None
_original_methods: dict[str, Any] = {}
# Test context uses ContextVar since it changes per-test and needs async isolation
from contextvars import ContextVar
_test_context: ContextVar[str | None] = ContextVar("_test_context", default=None)
from openai.types.completion_choice import CompletionChoice
# update the "finish_reason" field, since its type definition is wrong (no None is accepted)
@ -33,22 +41,38 @@ CompletionChoice.model_fields["finish_reason"].annotation = Literal["stop", "len
CompletionChoice.model_rebuild()
REPO_ROOT = Path(__file__).parent.parent.parent
DEFAULT_STORAGE_DIR = REPO_ROOT / "tests/integration/recordings"
DEFAULT_STORAGE_DIR = REPO_ROOT / "tests/integration/common"
class InferenceMode(StrEnum):
LIVE = "live"
RECORD = "record"
REPLAY = "replay"
RECORD_IF_MISSING = "record-if-missing"
def normalize_request(method: str, url: str, headers: dict[str, Any], body: dict[str, Any]) -> str:
"""Create a normalized hash of the request for consistent matching."""
"""Create a normalized hash of the request for consistent matching.
Includes test_id from context to ensure test isolation - identical requests
from different tests will have different hashes.
Exception: Model list endpoints (/v1/models, /api/tags) exclude test_id since
they are infrastructure/shared and need to work across session setup and tests.
"""
# Extract just the endpoint path
from urllib.parse import urlparse
parsed = urlparse(url)
normalized = {"method": method.upper(), "endpoint": parsed.path, "body": body}
normalized: dict[str, Any] = {
"method": method.upper(),
"endpoint": parsed.path,
"body": body,
}
# Include test_id for isolation, except for shared infrastructure endpoints
if parsed.path not in ("/api/tags", "/v1/models"):
normalized["test_id"] = _test_context.get()
# Create hash - sort_keys=True ensures deterministic ordering
normalized_json = json.dumps(normalized, sort_keys=True)
@ -67,7 +91,11 @@ def setup_inference_recording():
Currently, this is only supported for OpenAI and Ollama clients. These should cover the vast majority of use cases.
Two environment variables are supported:
- LLAMA_STACK_TEST_INFERENCE_MODE: The mode to run in. Must be 'live', 'record', or 'replay'. Default is 'replay'.
- LLAMA_STACK_TEST_INFERENCE_MODE: The mode to run in. Must be 'live', 'record', 'replay', or 'record-if-missing'. Default is 'replay'.
- 'live': Make all requests live without recording
- 'record': Record all requests (overwrites existing recordings)
- 'replay': Use only recorded responses (fails if recording not found)
- 'record-if-missing': Use recorded responses when available, record new ones when not found
- LLAMA_STACK_TEST_RECORDING_DIR: The directory to store the recordings in. Default is 'tests/integration/recordings'.
The recordings are stored as JSON files.
@ -154,21 +182,43 @@ def _deserialize_response(data: dict[str, Any]) -> Any:
class ResponseStorage:
"""Handles SQLite index + JSON file storage/retrieval for inference recordings."""
def __init__(self, test_dir: Path):
self.test_dir = test_dir
self.responses_dir = self.test_dir / "responses"
def __init__(self, base_dir: Path):
self.base_dir = base_dir
# Don't create responses_dir here - determine it per-test at runtime
self._ensure_directories()
def _get_test_dir(self) -> Path:
"""Get the recordings directory in the test file's parent directory.
For test at "tests/integration/inference/test_foo.py::test_bar",
returns "tests/integration/inference/recordings/".
"""
test_id = _test_context.get()
if test_id:
# Extract the directory path from the test nodeid
# e.g., "tests/integration/inference/test_basic.py::test_foo[params]"
# -> get "tests/integration/inference"
test_file = test_id.split("::")[0] # Remove test function part
test_dir = Path(test_file).parent # Get parent directory
# Put recordings in a "recordings" subdirectory of the test's parent dir
# e.g., "tests/integration/inference" -> "tests/integration/inference/recordings"
return test_dir / "recordings"
else:
# Fallback for non-test contexts
return self.base_dir / "recordings"
def _ensure_directories(self):
self.test_dir.mkdir(parents=True, exist_ok=True)
self.responses_dir.mkdir(exist_ok=True)
"""Ensure test-specific directories exist."""
test_dir = self._get_test_dir()
test_dir.mkdir(parents=True, exist_ok=True)
return test_dir
def store_recording(self, request_hash: str, request: dict[str, Any], response: dict[str, Any]):
"""Store a request/response pair."""
# Generate unique response filename
short_hash = request_hash[:12]
response_file = f"{short_hash}.json"
responses_dir = self._ensure_directories()
# Use FULL hash (not truncated)
response_file = f"{request_hash}.json"
# Serialize response body if needed
serialized_response = dict(response)
@ -182,35 +232,71 @@ class ResponseStorage:
# Handle single response
serialized_response["body"] = _serialize_response(serialized_response["body"], request_hash)
# If this is an Ollama /api/tags recording, include models digest in filename to distinguish variants
# For model-list endpoints, include digest in filename to distinguish different model sets
endpoint = request.get("endpoint")
if endpoint in ("/api/tags", "/v1/models"):
digest = _model_identifiers_digest(endpoint, response)
response_file = f"models-{short_hash}-{digest}.json"
response_file = f"models-{request_hash}-{digest}.json"
response_path = self.responses_dir / response_file
response_path = responses_dir / response_file
# Save response to JSON file
# Save response to JSON file with metadata
with open(response_path, "w") as f:
json.dump({"request": request, "response": serialized_response}, f, indent=2)
json.dump(
{
"test_id": _test_context.get(), # Include for debugging
"request": request,
"response": serialized_response,
},
f,
indent=2,
)
f.write("\n")
f.flush()
def find_recording(self, request_hash: str) -> dict[str, Any] | None:
"""Find a recorded response by request hash."""
response_file = f"{request_hash[:12]}.json"
response_path = self.responses_dir / response_file
"""Find a recorded response by request hash.
if not response_path.exists():
return None
Uses fallback: first checks test-specific dir, then falls back to base recordings dir.
This handles cases where recordings happen during session setup (no test context) but
are requested during tests (with test context).
"""
response_file = f"{request_hash}.json"
return _recording_from_file(response_path)
# Try test-specific directory first
test_dir = self._get_test_dir()
response_path = test_dir / response_file
def _model_list_responses(self, short_hash: str) -> list[dict[str, Any]]:
if response_path.exists():
return _recording_from_file(response_path)
# Fallback to base recordings directory (for session-level recordings)
fallback_dir = self.base_dir / "recordings"
fallback_path = fallback_dir / response_file
if fallback_path.exists():
return _recording_from_file(fallback_path)
return None
def _model_list_responses(self, request_hash: str) -> list[dict[str, Any]]:
"""Find all model-list recordings with the given hash (different digests)."""
results: list[dict[str, Any]] = []
for path in self.responses_dir.glob(f"models-{short_hash}-*.json"):
data = _recording_from_file(path)
results.append(data)
# Check test-specific directory first
test_dir = self._get_test_dir()
if test_dir.exists():
for path in test_dir.glob(f"models-{request_hash}-*.json"):
data = _recording_from_file(path)
results.append(data)
# Also check fallback directory
fallback_dir = self.base_dir / "recordings"
if fallback_dir.exists():
for path in fallback_dir.glob(f"models-{request_hash}-*.json"):
data = _recording_from_file(path)
results.append(data)
return results
@ -231,6 +317,8 @@ def _recording_from_file(response_path) -> dict[str, Any]:
def _model_identifiers_digest(endpoint: str, response: dict[str, Any]) -> str:
"""Generate a digest from model identifiers for distinguishing different model sets."""
def _extract_model_identifiers():
"""Extract a stable set of identifiers for model-list endpoints.
@ -253,7 +341,14 @@ def _model_identifiers_digest(endpoint: str, response: dict[str, Any]) -> str:
def _combine_model_list_responses(endpoint: str, records: list[dict[str, Any]]) -> dict[str, Any] | None:
"""Return a single, unioned recording for supported model-list endpoints."""
"""Return a single, unioned recording for supported model-list endpoints.
Merges multiple recordings with different model sets (from different servers) into
a single response containing all models.
"""
if not records:
return None
seen: dict[str, dict[str, Any]] = {}
for rec in records:
body = rec["response"]["body"]
@ -282,7 +377,10 @@ def _combine_model_list_responses(endpoint: str, records: list[dict[str, Any]])
async def _patched_inference_method(original_method, self, client_type, endpoint, *args, **kwargs):
global _current_mode, _current_storage
if _current_mode == InferenceMode.LIVE or _current_storage is None:
mode = _current_mode
storage = _current_storage
if mode == InferenceMode.LIVE or storage is None:
if endpoint == "/v1/models":
return original_method(self, *args, **kwargs)
else:
@ -313,13 +411,16 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
request_hash = normalize_request(method, url, headers, body)
if _current_mode == InferenceMode.REPLAY:
# Special handling for model-list endpoints: return union of all responses
# Try to find existing recording for REPLAY or RECORD_IF_MISSING modes
recording = None
if mode == InferenceMode.REPLAY or mode == InferenceMode.RECORD_IF_MISSING:
# Special handling for model-list endpoints: merge all recordings with this hash
if endpoint in ("/api/tags", "/v1/models"):
records = _current_storage._model_list_responses(request_hash[:12])
records = storage._model_list_responses(request_hash)
recording = _combine_model_list_responses(endpoint, records)
else:
recording = _current_storage.find_recording(request_hash)
recording = storage.find_recording(request_hash)
if recording:
response_body = recording["response"]["body"]
@ -332,7 +433,8 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
return replay_stream()
else:
return response_body
else:
elif mode == InferenceMode.REPLAY:
# REPLAY mode requires recording to exist
raise RuntimeError(
f"No recorded response found for request hash: {request_hash}\n"
f"Request: {method} {url} {body}\n"
@ -340,7 +442,7 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
f"To record this response, run with LLAMA_STACK_TEST_INFERENCE_MODE=record"
)
elif _current_mode == InferenceMode.RECORD:
if mode == InferenceMode.RECORD or (mode == InferenceMode.RECORD_IF_MISSING and not recording):
if endpoint == "/v1/models":
response = original_method(self, *args, **kwargs)
else:
@ -371,7 +473,7 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
# Store the recording immediately
response_data = {"body": chunks, "is_streaming": True}
_current_storage.store_recording(request_hash, request_data, response_data)
storage.store_recording(request_hash, request_data, response_data)
# Return a generator that replays the stored chunks
async def replay_recorded_stream():
@ -381,11 +483,11 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
return replay_recorded_stream()
else:
response_data = {"body": response, "is_streaming": False}
_current_storage.store_recording(request_hash, request_data, response_data)
storage.store_recording(request_hash, request_data, response_data)
return response
else:
raise AssertionError(f"Invalid mode: {_current_mode}")
raise AssertionError(f"Invalid mode: {mode}")
def patch_inference_clients():
@ -526,9 +628,9 @@ def inference_recording(mode: str, storage_dir: str | Path | None = None) -> Gen
try:
_current_mode = mode
if mode in ["record", "replay"]:
if mode in ["record", "replay", "record-if-missing"]:
if storage_dir is None:
raise ValueError("storage_dir is required for record and replay modes")
raise ValueError("storage_dir is required for record, replay, and record-if-missing modes")
_current_storage = ResponseStorage(Path(storage_dir))
patch_inference_clients()
@ -536,7 +638,7 @@ def inference_recording(mode: str, storage_dir: str | Path | None = None) -> Gen
finally:
# Restore previous state
if mode in ["record", "replay"]:
if mode in ["record", "replay", "record-if-missing"]:
unpatch_inference_clients()
_current_mode = prev_mode