mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-22 22:49:43 +00:00
feat(tests): introduce inference record/replay to increase test reliability
Implements a comprehensive recording and replay system for inference API calls that eliminates dependency on online inference providers during testing. The system treats inference as deterministic by recording real API responses and replaying them in subsequent test runs. Applies to OpenAI clients (which should cover many inference requests) as well as OpenAI AsyncClient. For storing, we use a hybrid system: Sqlite for fast lookups and JSON files for easy greppability / debuggability. As expected, tests become much much faster (more than 3x in just inference testing.) ```bash LLAMA_STACK_INFERENCE_MODE=record uv run pytest -s -v tests/integration/inference \ --stack-config=starter \ -k "not( builtin_tool or safety_with_image or code_interpreter or test_rag )" \ --text-model="ollama/llama3.2:3b-instruct-fp16" \ --embedding-model=sentence-transformers/all-MiniLM-L6-v2 ``` ```bash LLAMA_STACK_INFERENCE_MODE=replay LLAMA_STACK_TEST_ID=<test_id> uv run pytest -s -v tests/integration/inference \ --stack-config=starter \ -k "not( builtin_tool or safety_with_image or code_interpreter or test_rag )" \ --text-model="ollama/llama3.2:3b-instruct-fp16" \ --embedding-model=sentence-transformers/all-MiniLM-L6-v2 ``` - `LLAMA_STACK_INFERENCE_MODE`: `live` (default), `record`, or `replay` - `LLAMA_STACK_TEST_ID`: Specify recording directory (auto-generated if not set) - `LLAMA_STACK_RECORDING_DIR`: Storage location (default: ~/.llama/recordings/)
This commit is contained in:
parent
fee365b71e
commit
e59c13f2b8
4 changed files with 997 additions and 58 deletions
|
|
@ -185,70 +185,95 @@ def llama_stack_client(request, provider_data):
|
|||
if not config:
|
||||
raise ValueError("You must specify either --stack-config or LLAMA_STACK_CONFIG")
|
||||
|
||||
# Handle server:<config_name> format or server:<config_name>:<port>
|
||||
if config.startswith("server:"):
|
||||
parts = config.split(":")
|
||||
config_name = parts[1]
|
||||
port = int(parts[2]) if len(parts) > 2 else int(os.environ.get("LLAMA_STACK_PORT", DEFAULT_PORT))
|
||||
base_url = f"http://localhost:{port}"
|
||||
# Set up inference recording if enabled
|
||||
inference_mode = os.environ.get("LLAMA_STACK_INFERENCE_MODE", "live").lower()
|
||||
recording_context = None
|
||||
|
||||
# Check if port is available
|
||||
if is_port_available(port):
|
||||
print(f"Starting llama stack server with config '{config_name}' on port {port}...")
|
||||
if inference_mode in ["record", "replay"]:
|
||||
from llama_stack.testing.inference_recorder import setup_inference_recording
|
||||
|
||||
# Start server
|
||||
server_process = start_llama_stack_server(config_name)
|
||||
recording_context = setup_inference_recording()
|
||||
recording_context.__enter__()
|
||||
print(f"Inference recording enabled: mode={inference_mode}")
|
||||
|
||||
# Wait for server to be ready
|
||||
if not wait_for_server_ready(base_url, timeout=120, process=server_process):
|
||||
print("Server failed to start within timeout")
|
||||
server_process.terminate()
|
||||
raise RuntimeError(
|
||||
f"Server failed to start within timeout. Check that config '{config_name}' exists and is valid. "
|
||||
f"See server.log for details."
|
||||
)
|
||||
|
||||
print(f"Server is ready at {base_url}")
|
||||
|
||||
# Store process for potential cleanup (pytest will handle termination at session end)
|
||||
request.session._llama_stack_server_process = server_process
|
||||
else:
|
||||
print(f"Port {port} is already in use, assuming server is already running...")
|
||||
|
||||
return LlamaStackClient(
|
||||
base_url=base_url,
|
||||
provider_data=provider_data,
|
||||
timeout=int(os.environ.get("LLAMA_STACK_CLIENT_TIMEOUT", "30")),
|
||||
)
|
||||
|
||||
# check if this looks like a URL using proper URL parsing
|
||||
try:
|
||||
parsed_url = urlparse(config)
|
||||
if parsed_url.scheme and parsed_url.netloc:
|
||||
return LlamaStackClient(
|
||||
base_url=config,
|
||||
# Handle server:<config_name> format or server:<config_name>:<port>
|
||||
if config.startswith("server:"):
|
||||
parts = config.split(":")
|
||||
config_name = parts[1]
|
||||
port = int(parts[2]) if len(parts) > 2 else int(os.environ.get("LLAMA_STACK_PORT", DEFAULT_PORT))
|
||||
base_url = f"http://localhost:{port}"
|
||||
|
||||
# Check if port is available
|
||||
if is_port_available(port):
|
||||
print(f"Starting llama stack server with config '{config_name}' on port {port}...")
|
||||
|
||||
# Start server
|
||||
server_process = start_llama_stack_server(config_name)
|
||||
|
||||
# Wait for server to be ready
|
||||
if not wait_for_server_ready(base_url, timeout=120, process=server_process):
|
||||
print("Server failed to start within timeout")
|
||||
server_process.terminate()
|
||||
raise RuntimeError(
|
||||
f"Server failed to start within timeout. Check that config '{config_name}' exists and is valid. "
|
||||
f"See server.log for details."
|
||||
)
|
||||
|
||||
print(f"Server is ready at {base_url}")
|
||||
|
||||
# Store process for potential cleanup (pytest will handle termination at session end)
|
||||
request.session._llama_stack_server_process = server_process
|
||||
else:
|
||||
print(f"Port {port} is already in use, assuming server is already running...")
|
||||
|
||||
client = LlamaStackClient(
|
||||
base_url=base_url,
|
||||
provider_data=provider_data,
|
||||
timeout=int(os.environ.get("LLAMA_STACK_CLIENT_TIMEOUT", "30")),
|
||||
)
|
||||
else:
|
||||
# check if this looks like a URL using proper URL parsing
|
||||
try:
|
||||
parsed_url = urlparse(config)
|
||||
if parsed_url.scheme and parsed_url.netloc:
|
||||
client = LlamaStackClient(
|
||||
base_url=config,
|
||||
provider_data=provider_data,
|
||||
)
|
||||
else:
|
||||
raise ValueError("Not a URL")
|
||||
except Exception as e:
|
||||
# If URL parsing fails, treat as library config
|
||||
if "=" in config:
|
||||
run_config = run_config_from_adhoc_config_spec(config)
|
||||
run_config_file = tempfile.NamedTemporaryFile(delete=False, suffix=".yaml")
|
||||
with open(run_config_file.name, "w") as f:
|
||||
yaml.dump(run_config.model_dump(), f)
|
||||
config = run_config_file.name
|
||||
|
||||
client = LlamaStackAsLibraryClient(
|
||||
config,
|
||||
provider_data=provider_data,
|
||||
skip_logger_removal=True,
|
||||
)
|
||||
if not client.initialize():
|
||||
raise RuntimeError("Initialization failed") from e
|
||||
|
||||
# Store recording context for cleanup
|
||||
if recording_context:
|
||||
request.session._inference_recording_context = recording_context
|
||||
|
||||
return client
|
||||
|
||||
except Exception:
|
||||
# If URL parsing fails, treat as non-URL config
|
||||
pass
|
||||
|
||||
if "=" in config:
|
||||
run_config = run_config_from_adhoc_config_spec(config)
|
||||
run_config_file = tempfile.NamedTemporaryFile(delete=False, suffix=".yaml")
|
||||
with open(run_config_file.name, "w") as f:
|
||||
yaml.dump(run_config.model_dump(), f)
|
||||
config = run_config_file.name
|
||||
|
||||
client = LlamaStackAsLibraryClient(
|
||||
config,
|
||||
provider_data=provider_data,
|
||||
skip_logger_removal=True,
|
||||
)
|
||||
if not client.initialize():
|
||||
raise RuntimeError("Initialization failed")
|
||||
|
||||
return client
|
||||
# Clean up recording context on error
|
||||
if recording_context:
|
||||
try:
|
||||
recording_context.__exit__(None, None, None)
|
||||
except Exception as cleanup_error:
|
||||
print(f"Warning: Error cleaning up recording context: {cleanup_error}")
|
||||
raise
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
|
|
@ -264,9 +289,20 @@ def compat_client(request):
|
|||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def cleanup_server_process(request):
|
||||
"""Cleanup server process at the end of the test session."""
|
||||
"""Cleanup server process and inference recording at the end of the test session."""
|
||||
yield # Run tests
|
||||
|
||||
# Clean up inference recording context
|
||||
if hasattr(request.session, "_inference_recording_context"):
|
||||
recording_context = request.session._inference_recording_context
|
||||
if recording_context:
|
||||
try:
|
||||
print("Cleaning up inference recording context...")
|
||||
recording_context.__exit__(None, None, None)
|
||||
except Exception as e:
|
||||
print(f"Error during inference recording cleanup: {e}")
|
||||
|
||||
# Clean up server process
|
||||
if hasattr(request.session, "_llama_stack_server_process"):
|
||||
server_process = request.session._llama_stack_server_process
|
||||
if server_process:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue