feat(tests): introduce inference record/replay to increase test reliability (#2941)

Implements a comprehensive recording and replay system for inference API
calls that eliminates dependency on online inference providers during
testing. The system treats inference as deterministic by recording real
API responses and replaying them in subsequent test runs. Applies to
OpenAI clients (which should cover many inference requests) as well as
Ollama AsyncClient.

For storing, we use a hybrid system: Sqlite for fast lookups and JSON
files for easy greppability / debuggability.

As expected, tests become much much faster (more than 3x in just
inference testing.)

```bash
LLAMA_STACK_TEST_INFERENCE_MODE=record LLAMA_STACK_TEST_RECORDING_DIR=<...> \
  uv run pytest -s -v tests/integration/inference \
  --stack-config=starter \
  -k "not( builtin_tool or safety_with_image or code_interpreter or test_rag )" \
  --text-model="ollama/llama3.2:3b-instruct-fp16" \
  --embedding-model=sentence-transformers/all-MiniLM-L6-v2
```

```bash
LLAMA_STACK_TEST_INFERENCE_MODE=replay LLAMA_STACK_TEST_RECORDING_DIR=<...> \
  uv run pytest -s -v tests/integration/inference \
  --stack-config=starter \
  -k "not( builtin_tool or safety_with_image or code_interpreter or test_rag )" \
  --text-model="ollama/llama3.2:3b-instruct-fp16" \
  --embedding-model=sentence-transformers/all-MiniLM-L6-v2
```

- `LLAMA_STACK_TEST_INFERENCE_MODE`: `live` (default), `record`, or
`replay`
- `LLAMA_STACK_TEST_RECORDING_DIR`: Storage location (must be specified
for record or replay modes)
This commit is contained in:
Ashwin Bharambe 2025-07-29 12:41:31 -07:00 committed by GitHub
parent abf1d6a703
commit 08b4a1deb3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
33 changed files with 9880 additions and 2 deletions

View file

@ -79,11 +79,9 @@ class InferenceRouter(Inference):
async def initialize(self) -> None:
logger.debug("InferenceRouter.initialize")
pass
async def shutdown(self) -> None:
logger.debug("InferenceRouter.shutdown")
pass
async def register_model(
self,

View file

@ -94,6 +94,7 @@ RESOURCES = [
REGISTRY_REFRESH_INTERVAL_SECONDS = 300
REGISTRY_REFRESH_TASK = None
TEST_RECORDING_CONTEXT = None
async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
@ -307,6 +308,15 @@ def add_internal_implementations(impls: dict[Api, Any], run_config: StackRunConf
async def construct_stack(
run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None
) -> dict[Api, Any]:
if "LLAMA_STACK_TEST_INFERENCE_MODE" in os.environ:
from llama_stack.testing.inference_recorder import setup_inference_recording
global TEST_RECORDING_CONTEXT
TEST_RECORDING_CONTEXT = setup_inference_recording()
if TEST_RECORDING_CONTEXT:
TEST_RECORDING_CONTEXT.__enter__()
logger.info(f"Inference recording enabled: mode={os.environ.get('LLAMA_STACK_TEST_INFERENCE_MODE')}")
dist_registry, _ = await create_dist_registry(run_config.metadata_store, run_config.image_name)
policy = run_config.server.auth.access_policy if run_config.server.auth else []
impls = await resolve_impls(
@ -352,6 +362,13 @@ async def shutdown_stack(impls: dict[Api, Any]):
except (Exception, asyncio.CancelledError) as e:
logger.exception(f"Failed to shutdown {impl_name}: {e}")
global TEST_RECORDING_CONTEXT
if TEST_RECORDING_CONTEXT:
try:
TEST_RECORDING_CONTEXT.__exit__(None, None, None)
except Exception as e:
logger.error(f"Error during inference recording cleanup: {e}")
global REGISTRY_REFRESH_TASK
if REGISTRY_REFRESH_TASK:
REGISTRY_REFRESH_TASK.cancel()