feat(tests): introduce inference record/replay to increase test reliability

Implements a comprehensive recording and replay system for inference API calls that eliminates dependency on online inference providers during testing. The system treats inference as deterministic by recording real API responses and replaying them in subsequent test runs. Applies to OpenAI clients (which should cover many inference requests) as well as OpenAI AsyncClient.

For storing, we use a hybrid system: Sqlite for fast lookups and JSON
files for easy greppability / debuggability.

As expected, tests become much much faster (more than 3x in just
inference testing.)

```bash
LLAMA_STACK_INFERENCE_MODE=record uv run pytest -s -v tests/integration/inference \
  --stack-config=starter \
  -k "not( builtin_tool or safety_with_image or code_interpreter or test_rag )" \
  --text-model="ollama/llama3.2:3b-instruct-fp16" \
  --embedding-model=sentence-transformers/all-MiniLM-L6-v2
```

```bash
LLAMA_STACK_INFERENCE_MODE=replay LLAMA_STACK_TEST_ID=<test_id> uv run pytest -s -v tests/integration/inference \
  --stack-config=starter \
  -k "not( builtin_tool or safety_with_image or code_interpreter or test_rag )" \
  --text-model="ollama/llama3.2:3b-instruct-fp16" \
  --embedding-model=sentence-transformers/all-MiniLM-L6-v2
```

- `LLAMA_STACK_INFERENCE_MODE`: `live` (default), `record`, or `replay`
- `LLAMA_STACK_TEST_ID`: Specify recording directory (auto-generated if not set)
- `LLAMA_STACK_RECORDING_DIR`: Storage location (default: ~/.llama/recordings/)
This commit is contained in:
Ashwin Bharambe 2025-07-28 17:03:39 -07:00
parent fee365b71e
commit e59c13f2b8
4 changed files with 997 additions and 58 deletions

View file

@ -185,70 +185,95 @@ def llama_stack_client(request, provider_data):
if not config:
raise ValueError("You must specify either --stack-config or LLAMA_STACK_CONFIG")
# Handle server:<config_name> format or server:<config_name>:<port>
if config.startswith("server:"):
parts = config.split(":")
config_name = parts[1]
port = int(parts[2]) if len(parts) > 2 else int(os.environ.get("LLAMA_STACK_PORT", DEFAULT_PORT))
base_url = f"http://localhost:{port}"
# Set up inference recording if enabled
inference_mode = os.environ.get("LLAMA_STACK_INFERENCE_MODE", "live").lower()
recording_context = None
# Check if port is available
if is_port_available(port):
print(f"Starting llama stack server with config '{config_name}' on port {port}...")
if inference_mode in ["record", "replay"]:
from llama_stack.testing.inference_recorder import setup_inference_recording
# Start server
server_process = start_llama_stack_server(config_name)
recording_context = setup_inference_recording()
recording_context.__enter__()
print(f"Inference recording enabled: mode={inference_mode}")
# Wait for server to be ready
if not wait_for_server_ready(base_url, timeout=120, process=server_process):
print("Server failed to start within timeout")
server_process.terminate()
raise RuntimeError(
f"Server failed to start within timeout. Check that config '{config_name}' exists and is valid. "
f"See server.log for details."
)
print(f"Server is ready at {base_url}")
# Store process for potential cleanup (pytest will handle termination at session end)
request.session._llama_stack_server_process = server_process
else:
print(f"Port {port} is already in use, assuming server is already running...")
return LlamaStackClient(
base_url=base_url,
provider_data=provider_data,
timeout=int(os.environ.get("LLAMA_STACK_CLIENT_TIMEOUT", "30")),
)
# check if this looks like a URL using proper URL parsing
try:
parsed_url = urlparse(config)
if parsed_url.scheme and parsed_url.netloc:
return LlamaStackClient(
base_url=config,
# Handle server:<config_name> format or server:<config_name>:<port>
if config.startswith("server:"):
parts = config.split(":")
config_name = parts[1]
port = int(parts[2]) if len(parts) > 2 else int(os.environ.get("LLAMA_STACK_PORT", DEFAULT_PORT))
base_url = f"http://localhost:{port}"
# Check if port is available
if is_port_available(port):
print(f"Starting llama stack server with config '{config_name}' on port {port}...")
# Start server
server_process = start_llama_stack_server(config_name)
# Wait for server to be ready
if not wait_for_server_ready(base_url, timeout=120, process=server_process):
print("Server failed to start within timeout")
server_process.terminate()
raise RuntimeError(
f"Server failed to start within timeout. Check that config '{config_name}' exists and is valid. "
f"See server.log for details."
)
print(f"Server is ready at {base_url}")
# Store process for potential cleanup (pytest will handle termination at session end)
request.session._llama_stack_server_process = server_process
else:
print(f"Port {port} is already in use, assuming server is already running...")
client = LlamaStackClient(
base_url=base_url,
provider_data=provider_data,
timeout=int(os.environ.get("LLAMA_STACK_CLIENT_TIMEOUT", "30")),
)
else:
# check if this looks like a URL using proper URL parsing
try:
parsed_url = urlparse(config)
if parsed_url.scheme and parsed_url.netloc:
client = LlamaStackClient(
base_url=config,
provider_data=provider_data,
)
else:
raise ValueError("Not a URL")
except Exception as e:
# If URL parsing fails, treat as library config
if "=" in config:
run_config = run_config_from_adhoc_config_spec(config)
run_config_file = tempfile.NamedTemporaryFile(delete=False, suffix=".yaml")
with open(run_config_file.name, "w") as f:
yaml.dump(run_config.model_dump(), f)
config = run_config_file.name
client = LlamaStackAsLibraryClient(
config,
provider_data=provider_data,
skip_logger_removal=True,
)
if not client.initialize():
raise RuntimeError("Initialization failed") from e
# Store recording context for cleanup
if recording_context:
request.session._inference_recording_context = recording_context
return client
except Exception:
# If URL parsing fails, treat as non-URL config
pass
if "=" in config:
run_config = run_config_from_adhoc_config_spec(config)
run_config_file = tempfile.NamedTemporaryFile(delete=False, suffix=".yaml")
with open(run_config_file.name, "w") as f:
yaml.dump(run_config.model_dump(), f)
config = run_config_file.name
client = LlamaStackAsLibraryClient(
config,
provider_data=provider_data,
skip_logger_removal=True,
)
if not client.initialize():
raise RuntimeError("Initialization failed")
return client
# Clean up recording context on error
if recording_context:
try:
recording_context.__exit__(None, None, None)
except Exception as cleanup_error:
print(f"Warning: Error cleaning up recording context: {cleanup_error}")
raise
@pytest.fixture(scope="session")
@ -264,9 +289,20 @@ def compat_client(request):
@pytest.fixture(scope="session", autouse=True)
def cleanup_server_process(request):
"""Cleanup server process at the end of the test session."""
"""Cleanup server process and inference recording at the end of the test session."""
yield # Run tests
# Clean up inference recording context
if hasattr(request.session, "_inference_recording_context"):
recording_context = request.session._inference_recording_context
if recording_context:
try:
print("Cleaning up inference recording context...")
recording_context.__exit__(None, None, None)
except Exception as e:
print(f"Error during inference recording cleanup: {e}")
# Clean up server process
if hasattr(request.session, "_llama_stack_server_process"):
server_process = request.session._llama_stack_server_process
if server_process: