mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-06 04:34:57 +00:00
feat(tests): implement test isolation for inference recordings (#3681)
Uses test_id in request hashes and test-scoped subdirectories to prevent cross-test contamination. Model list endpoints exclude test_id to enable merging recordings from different servers. Additionally, this PR adds a `record-if-missing` mode (which we will use instead of `record` which records everything) which is very useful. 🤖 Co-authored with [Claude Code](https://claude.com/claude-code) --------- Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
parent
f176196fba
commit
045a0c1d57
428 changed files with 85345 additions and 104330 deletions
|
@ -22,10 +22,18 @@ from llama_stack.log import get_logger
|
|||
logger = get_logger(__name__, category="testing")
|
||||
|
||||
# Global state for the recording system
|
||||
# Note: Using module globals instead of ContextVars because the session-scoped
|
||||
# client initialization happens in one async context, but tests run in different
|
||||
# contexts, and we need the mode/storage to persist across all contexts.
|
||||
_current_mode: str | None = None
|
||||
_current_storage: ResponseStorage | None = None
|
||||
_original_methods: dict[str, Any] = {}
|
||||
|
||||
# Test context uses ContextVar since it changes per-test and needs async isolation
|
||||
from contextvars import ContextVar
|
||||
|
||||
_test_context: ContextVar[str | None] = ContextVar("_test_context", default=None)
|
||||
|
||||
from openai.types.completion_choice import CompletionChoice
|
||||
|
||||
# update the "finish_reason" field, since its type definition is wrong (no None is accepted)
|
||||
|
@ -33,22 +41,38 @@ CompletionChoice.model_fields["finish_reason"].annotation = Literal["stop", "len
|
|||
CompletionChoice.model_rebuild()
|
||||
|
||||
REPO_ROOT = Path(__file__).parent.parent.parent
|
||||
DEFAULT_STORAGE_DIR = REPO_ROOT / "tests/integration/recordings"
|
||||
DEFAULT_STORAGE_DIR = REPO_ROOT / "tests/integration/common"
|
||||
|
||||
|
||||
class InferenceMode(StrEnum):
|
||||
LIVE = "live"
|
||||
RECORD = "record"
|
||||
REPLAY = "replay"
|
||||
RECORD_IF_MISSING = "record-if-missing"
|
||||
|
||||
|
||||
def normalize_request(method: str, url: str, headers: dict[str, Any], body: dict[str, Any]) -> str:
|
||||
"""Create a normalized hash of the request for consistent matching."""
|
||||
"""Create a normalized hash of the request for consistent matching.
|
||||
|
||||
Includes test_id from context to ensure test isolation - identical requests
|
||||
from different tests will have different hashes.
|
||||
|
||||
Exception: Model list endpoints (/v1/models, /api/tags) exclude test_id since
|
||||
they are infrastructure/shared and need to work across session setup and tests.
|
||||
"""
|
||||
# Extract just the endpoint path
|
||||
from urllib.parse import urlparse
|
||||
|
||||
parsed = urlparse(url)
|
||||
normalized = {"method": method.upper(), "endpoint": parsed.path, "body": body}
|
||||
normalized: dict[str, Any] = {
|
||||
"method": method.upper(),
|
||||
"endpoint": parsed.path,
|
||||
"body": body,
|
||||
}
|
||||
|
||||
# Include test_id for isolation, except for shared infrastructure endpoints
|
||||
if parsed.path not in ("/api/tags", "/v1/models"):
|
||||
normalized["test_id"] = _test_context.get()
|
||||
|
||||
# Create hash - sort_keys=True ensures deterministic ordering
|
||||
normalized_json = json.dumps(normalized, sort_keys=True)
|
||||
|
@ -67,7 +91,11 @@ def setup_inference_recording():
|
|||
Currently, this is only supported for OpenAI and Ollama clients. These should cover the vast majority of use cases.
|
||||
|
||||
Two environment variables are supported:
|
||||
- LLAMA_STACK_TEST_INFERENCE_MODE: The mode to run in. Must be 'live', 'record', or 'replay'. Default is 'replay'.
|
||||
- LLAMA_STACK_TEST_INFERENCE_MODE: The mode to run in. Must be 'live', 'record', 'replay', or 'record-if-missing'. Default is 'replay'.
|
||||
- 'live': Make all requests live without recording
|
||||
- 'record': Record all requests (overwrites existing recordings)
|
||||
- 'replay': Use only recorded responses (fails if recording not found)
|
||||
- 'record-if-missing': Use recorded responses when available, record new ones when not found
|
||||
- LLAMA_STACK_TEST_RECORDING_DIR: The directory to store the recordings in. Default is 'tests/integration/recordings'.
|
||||
|
||||
The recordings are stored as JSON files.
|
||||
|
@ -154,21 +182,43 @@ def _deserialize_response(data: dict[str, Any]) -> Any:
|
|||
class ResponseStorage:
|
||||
"""Handles SQLite index + JSON file storage/retrieval for inference recordings."""
|
||||
|
||||
def __init__(self, test_dir: Path):
|
||||
self.test_dir = test_dir
|
||||
self.responses_dir = self.test_dir / "responses"
|
||||
def __init__(self, base_dir: Path):
|
||||
self.base_dir = base_dir
|
||||
# Don't create responses_dir here - determine it per-test at runtime
|
||||
|
||||
self._ensure_directories()
|
||||
def _get_test_dir(self) -> Path:
|
||||
"""Get the recordings directory in the test file's parent directory.
|
||||
|
||||
For test at "tests/integration/inference/test_foo.py::test_bar",
|
||||
returns "tests/integration/inference/recordings/".
|
||||
"""
|
||||
test_id = _test_context.get()
|
||||
if test_id:
|
||||
# Extract the directory path from the test nodeid
|
||||
# e.g., "tests/integration/inference/test_basic.py::test_foo[params]"
|
||||
# -> get "tests/integration/inference"
|
||||
test_file = test_id.split("::")[0] # Remove test function part
|
||||
test_dir = Path(test_file).parent # Get parent directory
|
||||
|
||||
# Put recordings in a "recordings" subdirectory of the test's parent dir
|
||||
# e.g., "tests/integration/inference" -> "tests/integration/inference/recordings"
|
||||
return test_dir / "recordings"
|
||||
else:
|
||||
# Fallback for non-test contexts
|
||||
return self.base_dir / "recordings"
|
||||
|
||||
def _ensure_directories(self):
|
||||
self.test_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.responses_dir.mkdir(exist_ok=True)
|
||||
"""Ensure test-specific directories exist."""
|
||||
test_dir = self._get_test_dir()
|
||||
test_dir.mkdir(parents=True, exist_ok=True)
|
||||
return test_dir
|
||||
|
||||
def store_recording(self, request_hash: str, request: dict[str, Any], response: dict[str, Any]):
|
||||
"""Store a request/response pair."""
|
||||
# Generate unique response filename
|
||||
short_hash = request_hash[:12]
|
||||
response_file = f"{short_hash}.json"
|
||||
responses_dir = self._ensure_directories()
|
||||
|
||||
# Use FULL hash (not truncated)
|
||||
response_file = f"{request_hash}.json"
|
||||
|
||||
# Serialize response body if needed
|
||||
serialized_response = dict(response)
|
||||
|
@ -182,35 +232,71 @@ class ResponseStorage:
|
|||
# Handle single response
|
||||
serialized_response["body"] = _serialize_response(serialized_response["body"], request_hash)
|
||||
|
||||
# If this is an Ollama /api/tags recording, include models digest in filename to distinguish variants
|
||||
# For model-list endpoints, include digest in filename to distinguish different model sets
|
||||
endpoint = request.get("endpoint")
|
||||
if endpoint in ("/api/tags", "/v1/models"):
|
||||
digest = _model_identifiers_digest(endpoint, response)
|
||||
response_file = f"models-{short_hash}-{digest}.json"
|
||||
response_file = f"models-{request_hash}-{digest}.json"
|
||||
|
||||
response_path = self.responses_dir / response_file
|
||||
response_path = responses_dir / response_file
|
||||
|
||||
# Save response to JSON file
|
||||
# Save response to JSON file with metadata
|
||||
with open(response_path, "w") as f:
|
||||
json.dump({"request": request, "response": serialized_response}, f, indent=2)
|
||||
json.dump(
|
||||
{
|
||||
"test_id": _test_context.get(), # Include for debugging
|
||||
"request": request,
|
||||
"response": serialized_response,
|
||||
},
|
||||
f,
|
||||
indent=2,
|
||||
)
|
||||
f.write("\n")
|
||||
f.flush()
|
||||
|
||||
def find_recording(self, request_hash: str) -> dict[str, Any] | None:
|
||||
"""Find a recorded response by request hash."""
|
||||
response_file = f"{request_hash[:12]}.json"
|
||||
response_path = self.responses_dir / response_file
|
||||
"""Find a recorded response by request hash.
|
||||
|
||||
if not response_path.exists():
|
||||
return None
|
||||
Uses fallback: first checks test-specific dir, then falls back to base recordings dir.
|
||||
This handles cases where recordings happen during session setup (no test context) but
|
||||
are requested during tests (with test context).
|
||||
"""
|
||||
response_file = f"{request_hash}.json"
|
||||
|
||||
return _recording_from_file(response_path)
|
||||
# Try test-specific directory first
|
||||
test_dir = self._get_test_dir()
|
||||
response_path = test_dir / response_file
|
||||
|
||||
def _model_list_responses(self, short_hash: str) -> list[dict[str, Any]]:
|
||||
if response_path.exists():
|
||||
return _recording_from_file(response_path)
|
||||
|
||||
# Fallback to base recordings directory (for session-level recordings)
|
||||
fallback_dir = self.base_dir / "recordings"
|
||||
fallback_path = fallback_dir / response_file
|
||||
|
||||
if fallback_path.exists():
|
||||
return _recording_from_file(fallback_path)
|
||||
|
||||
return None
|
||||
|
||||
def _model_list_responses(self, request_hash: str) -> list[dict[str, Any]]:
|
||||
"""Find all model-list recordings with the given hash (different digests)."""
|
||||
results: list[dict[str, Any]] = []
|
||||
for path in self.responses_dir.glob(f"models-{short_hash}-*.json"):
|
||||
data = _recording_from_file(path)
|
||||
results.append(data)
|
||||
|
||||
# Check test-specific directory first
|
||||
test_dir = self._get_test_dir()
|
||||
if test_dir.exists():
|
||||
for path in test_dir.glob(f"models-{request_hash}-*.json"):
|
||||
data = _recording_from_file(path)
|
||||
results.append(data)
|
||||
|
||||
# Also check fallback directory
|
||||
fallback_dir = self.base_dir / "recordings"
|
||||
if fallback_dir.exists():
|
||||
for path in fallback_dir.glob(f"models-{request_hash}-*.json"):
|
||||
data = _recording_from_file(path)
|
||||
results.append(data)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
|
@ -231,6 +317,8 @@ def _recording_from_file(response_path) -> dict[str, Any]:
|
|||
|
||||
|
||||
def _model_identifiers_digest(endpoint: str, response: dict[str, Any]) -> str:
|
||||
"""Generate a digest from model identifiers for distinguishing different model sets."""
|
||||
|
||||
def _extract_model_identifiers():
|
||||
"""Extract a stable set of identifiers for model-list endpoints.
|
||||
|
||||
|
@ -253,7 +341,14 @@ def _model_identifiers_digest(endpoint: str, response: dict[str, Any]) -> str:
|
|||
|
||||
|
||||
def _combine_model_list_responses(endpoint: str, records: list[dict[str, Any]]) -> dict[str, Any] | None:
|
||||
"""Return a single, unioned recording for supported model-list endpoints."""
|
||||
"""Return a single, unioned recording for supported model-list endpoints.
|
||||
|
||||
Merges multiple recordings with different model sets (from different servers) into
|
||||
a single response containing all models.
|
||||
"""
|
||||
if not records:
|
||||
return None
|
||||
|
||||
seen: dict[str, dict[str, Any]] = {}
|
||||
for rec in records:
|
||||
body = rec["response"]["body"]
|
||||
|
@ -282,7 +377,10 @@ def _combine_model_list_responses(endpoint: str, records: list[dict[str, Any]])
|
|||
async def _patched_inference_method(original_method, self, client_type, endpoint, *args, **kwargs):
|
||||
global _current_mode, _current_storage
|
||||
|
||||
if _current_mode == InferenceMode.LIVE or _current_storage is None:
|
||||
mode = _current_mode
|
||||
storage = _current_storage
|
||||
|
||||
if mode == InferenceMode.LIVE or storage is None:
|
||||
if endpoint == "/v1/models":
|
||||
return original_method(self, *args, **kwargs)
|
||||
else:
|
||||
|
@ -313,13 +411,16 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
|
|||
|
||||
request_hash = normalize_request(method, url, headers, body)
|
||||
|
||||
if _current_mode == InferenceMode.REPLAY:
|
||||
# Special handling for model-list endpoints: return union of all responses
|
||||
# Try to find existing recording for REPLAY or RECORD_IF_MISSING modes
|
||||
recording = None
|
||||
if mode == InferenceMode.REPLAY or mode == InferenceMode.RECORD_IF_MISSING:
|
||||
# Special handling for model-list endpoints: merge all recordings with this hash
|
||||
if endpoint in ("/api/tags", "/v1/models"):
|
||||
records = _current_storage._model_list_responses(request_hash[:12])
|
||||
records = storage._model_list_responses(request_hash)
|
||||
recording = _combine_model_list_responses(endpoint, records)
|
||||
else:
|
||||
recording = _current_storage.find_recording(request_hash)
|
||||
recording = storage.find_recording(request_hash)
|
||||
|
||||
if recording:
|
||||
response_body = recording["response"]["body"]
|
||||
|
||||
|
@ -332,7 +433,8 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
|
|||
return replay_stream()
|
||||
else:
|
||||
return response_body
|
||||
else:
|
||||
elif mode == InferenceMode.REPLAY:
|
||||
# REPLAY mode requires recording to exist
|
||||
raise RuntimeError(
|
||||
f"No recorded response found for request hash: {request_hash}\n"
|
||||
f"Request: {method} {url} {body}\n"
|
||||
|
@ -340,7 +442,7 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
|
|||
f"To record this response, run with LLAMA_STACK_TEST_INFERENCE_MODE=record"
|
||||
)
|
||||
|
||||
elif _current_mode == InferenceMode.RECORD:
|
||||
if mode == InferenceMode.RECORD or (mode == InferenceMode.RECORD_IF_MISSING and not recording):
|
||||
if endpoint == "/v1/models":
|
||||
response = original_method(self, *args, **kwargs)
|
||||
else:
|
||||
|
@ -371,7 +473,7 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
|
|||
|
||||
# Store the recording immediately
|
||||
response_data = {"body": chunks, "is_streaming": True}
|
||||
_current_storage.store_recording(request_hash, request_data, response_data)
|
||||
storage.store_recording(request_hash, request_data, response_data)
|
||||
|
||||
# Return a generator that replays the stored chunks
|
||||
async def replay_recorded_stream():
|
||||
|
@ -381,11 +483,11 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
|
|||
return replay_recorded_stream()
|
||||
else:
|
||||
response_data = {"body": response, "is_streaming": False}
|
||||
_current_storage.store_recording(request_hash, request_data, response_data)
|
||||
storage.store_recording(request_hash, request_data, response_data)
|
||||
return response
|
||||
|
||||
else:
|
||||
raise AssertionError(f"Invalid mode: {_current_mode}")
|
||||
raise AssertionError(f"Invalid mode: {mode}")
|
||||
|
||||
|
||||
def patch_inference_clients():
|
||||
|
@ -526,9 +628,9 @@ def inference_recording(mode: str, storage_dir: str | Path | None = None) -> Gen
|
|||
try:
|
||||
_current_mode = mode
|
||||
|
||||
if mode in ["record", "replay"]:
|
||||
if mode in ["record", "replay", "record-if-missing"]:
|
||||
if storage_dir is None:
|
||||
raise ValueError("storage_dir is required for record and replay modes")
|
||||
raise ValueError("storage_dir is required for record, replay, and record-if-missing modes")
|
||||
_current_storage = ResponseStorage(Path(storage_dir))
|
||||
patch_inference_clients()
|
||||
|
||||
|
@ -536,7 +638,7 @@ def inference_recording(mode: str, storage_dir: str | Path | None = None) -> Gen
|
|||
|
||||
finally:
|
||||
# Restore previous state
|
||||
if mode in ["record", "replay"]:
|
||||
if mode in ["record", "replay", "record-if-missing"]:
|
||||
unpatch_inference_clients()
|
||||
|
||||
_current_mode = prev_mode
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue