feat(tests): introduce inference record/replay to increase test reliability (#2941)

Implements a comprehensive recording and replay system for inference API
calls that eliminates dependency on online inference providers during
testing. The system treats inference as deterministic by recording real
API responses and replaying them in subsequent test runs. Applies to
OpenAI clients (which should cover many inference requests) as well as
Ollama AsyncClient.

For storing, we use a hybrid system: Sqlite for fast lookups and JSON
files for easy greppability / debuggability.

As expected, tests become much much faster (more than 3x in just
inference testing.)

```bash
LLAMA_STACK_TEST_INFERENCE_MODE=record LLAMA_STACK_TEST_RECORDING_DIR=<...> \
  uv run pytest -s -v tests/integration/inference \
  --stack-config=starter \
  -k "not( builtin_tool or safety_with_image or code_interpreter or test_rag )" \
  --text-model="ollama/llama3.2:3b-instruct-fp16" \
  --embedding-model=sentence-transformers/all-MiniLM-L6-v2
```

```bash
LLAMA_STACK_TEST_INFERENCE_MODE=replay LLAMA_STACK_TEST_RECORDING_DIR=<...> \
  uv run pytest -s -v tests/integration/inference \
  --stack-config=starter \
  -k "not( builtin_tool or safety_with_image or code_interpreter or test_rag )" \
  --text-model="ollama/llama3.2:3b-instruct-fp16" \
  --embedding-model=sentence-transformers/all-MiniLM-L6-v2
```

- `LLAMA_STACK_TEST_INFERENCE_MODE`: `live` (default), `record`, or
`replay`
- `LLAMA_STACK_TEST_RECORDING_DIR`: Storage location (must be specified
for record or replay modes)
This commit is contained in:
Ashwin Bharambe 2025-07-29 12:41:31 -07:00 committed by GitHub
parent abf1d6a703
commit 08b4a1deb3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
33 changed files with 9880 additions and 2 deletions

View file

@ -0,0 +1,291 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import sqlite3
import tempfile
from pathlib import Path
from unittest.mock import patch
import pytest
from openai import AsyncOpenAI
# Import the real Pydantic response types instead of using Mocks
from llama_stack.apis.inference import (
OpenAIAssistantMessageParam,
OpenAIChatCompletion,
OpenAIChoice,
OpenAIEmbeddingData,
OpenAIEmbeddingsResponse,
OpenAIEmbeddingUsage,
)
from llama_stack.testing.inference_recorder import (
InferenceMode,
ResponseStorage,
inference_recording,
normalize_request,
)
@pytest.fixture
def temp_storage_dir():
"""Create a temporary directory for test recordings."""
with tempfile.TemporaryDirectory() as temp_dir:
yield Path(temp_dir)
@pytest.fixture
def real_openai_chat_response():
"""Real OpenAI chat completion response using proper Pydantic objects."""
return OpenAIChatCompletion(
id="chatcmpl-test123",
choices=[
OpenAIChoice(
index=0,
message=OpenAIAssistantMessageParam(
role="assistant", content="Hello! I'm doing well, thank you for asking."
),
finish_reason="stop",
)
],
created=1234567890,
model="llama3.2:3b",
)
@pytest.fixture
def real_embeddings_response():
"""Real OpenAI embeddings response using proper Pydantic objects."""
return OpenAIEmbeddingsResponse(
object="list",
data=[
OpenAIEmbeddingData(object="embedding", embedding=[0.1, 0.2, 0.3], index=0),
OpenAIEmbeddingData(object="embedding", embedding=[0.4, 0.5, 0.6], index=1),
],
model="nomic-embed-text",
usage=OpenAIEmbeddingUsage(prompt_tokens=6, total_tokens=6),
)
class TestInferenceRecording:
"""Test the inference recording system."""
def test_request_normalization(self):
"""Test that request normalization produces consistent hashes."""
# Test basic normalization
hash1 = normalize_request(
"POST",
"http://localhost:11434/v1/chat/completions",
{},
{"model": "llama3.2:3b", "messages": [{"role": "user", "content": "Hello world"}], "temperature": 0.7},
)
# Same request should produce same hash
hash2 = normalize_request(
"POST",
"http://localhost:11434/v1/chat/completions",
{},
{"model": "llama3.2:3b", "messages": [{"role": "user", "content": "Hello world"}], "temperature": 0.7},
)
assert hash1 == hash2
# Different content should produce different hash
hash3 = normalize_request(
"POST",
"http://localhost:11434/v1/chat/completions",
{},
{
"model": "llama3.2:3b",
"messages": [{"role": "user", "content": "Different message"}],
"temperature": 0.7,
},
)
assert hash1 != hash3
def test_request_normalization_edge_cases(self):
"""Test request normalization is precise about request content."""
# Test that different whitespace produces different hashes (no normalization)
hash1 = normalize_request(
"POST",
"http://test/v1/chat/completions",
{},
{"messages": [{"role": "user", "content": "Hello world\n\n"}]},
)
hash2 = normalize_request(
"POST", "http://test/v1/chat/completions", {}, {"messages": [{"role": "user", "content": "Hello world"}]}
)
assert hash1 != hash2 # Different whitespace should produce different hashes
# Test that different float precision produces different hashes (no rounding)
hash3 = normalize_request("POST", "http://test/v1/chat/completions", {}, {"temperature": 0.7000001})
hash4 = normalize_request("POST", "http://test/v1/chat/completions", {}, {"temperature": 0.7})
assert hash3 != hash4 # Different precision should produce different hashes
def test_response_storage(self, temp_storage_dir):
"""Test the ResponseStorage class."""
temp_storage_dir = temp_storage_dir / "test_response_storage"
storage = ResponseStorage(temp_storage_dir)
# Test directory creation
assert storage.test_dir.exists()
assert storage.responses_dir.exists()
assert storage.db_path.exists()
# Test storing and retrieving a recording
request_hash = "test_hash_123"
request_data = {
"method": "POST",
"url": "http://localhost:11434/v1/chat/completions",
"endpoint": "/v1/chat/completions",
"model": "llama3.2:3b",
}
response_data = {"body": {"content": "test response"}, "is_streaming": False}
storage.store_recording(request_hash, request_data, response_data)
# Verify SQLite record
with sqlite3.connect(storage.db_path) as conn:
result = conn.execute("SELECT * FROM recordings WHERE request_hash = ?", (request_hash,)).fetchone()
assert result is not None
assert result[0] == request_hash # request_hash
assert result[2] == "/v1/chat/completions" # endpoint
assert result[3] == "llama3.2:3b" # model
# Verify file storage and retrieval
retrieved = storage.find_recording(request_hash)
assert retrieved is not None
assert retrieved["request"]["model"] == "llama3.2:3b"
assert retrieved["response"]["body"]["content"] == "test response"
async def test_recording_mode(self, temp_storage_dir, real_openai_chat_response):
"""Test that recording mode captures and stores responses."""
async def mock_create(*args, **kwargs):
return real_openai_chat_response
temp_storage_dir = temp_storage_dir / "test_recording_mode"
with patch("openai.resources.chat.completions.AsyncCompletions.create", side_effect=mock_create):
with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)):
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
response = await client.chat.completions.create(
model="llama3.2:3b",
messages=[{"role": "user", "content": "Hello, how are you?"}],
temperature=0.7,
max_tokens=50,
)
# Verify the response was returned correctly
assert response.choices[0].message.content == "Hello! I'm doing well, thank you for asking."
# Verify recording was stored
storage = ResponseStorage(temp_storage_dir)
with sqlite3.connect(storage.db_path) as conn:
recordings = conn.execute("SELECT COUNT(*) FROM recordings").fetchone()[0]
assert recordings == 1
async def test_replay_mode(self, temp_storage_dir, real_openai_chat_response):
"""Test that replay mode returns stored responses without making real calls."""
async def mock_create(*args, **kwargs):
return real_openai_chat_response
temp_storage_dir = temp_storage_dir / "test_replay_mode"
# First, record a response
with patch("openai.resources.chat.completions.AsyncCompletions.create", side_effect=mock_create):
with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)):
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
response = await client.chat.completions.create(
model="llama3.2:3b",
messages=[{"role": "user", "content": "Hello, how are you?"}],
temperature=0.7,
max_tokens=50,
)
# Now test replay mode - should not call the original method
with patch("openai.resources.chat.completions.AsyncCompletions.create") as mock_create_patch:
with inference_recording(mode=InferenceMode.REPLAY, storage_dir=str(temp_storage_dir)):
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
response = await client.chat.completions.create(
model="llama3.2:3b",
messages=[{"role": "user", "content": "Hello, how are you?"}],
temperature=0.7,
max_tokens=50,
)
# Verify we got the recorded response
assert response.choices[0].message.content == "Hello! I'm doing well, thank you for asking."
# Verify the original method was NOT called
mock_create_patch.assert_not_called()
async def test_replay_missing_recording(self, temp_storage_dir):
"""Test that replay mode fails when no recording is found."""
temp_storage_dir = temp_storage_dir / "test_replay_missing_recording"
with patch("openai.resources.chat.completions.AsyncCompletions.create"):
with inference_recording(mode=InferenceMode.REPLAY, storage_dir=str(temp_storage_dir)):
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
with pytest.raises(RuntimeError, match="No recorded response found"):
await client.chat.completions.create(
model="llama3.2:3b", messages=[{"role": "user", "content": "This was never recorded"}]
)
async def test_embeddings_recording(self, temp_storage_dir, real_embeddings_response):
"""Test recording and replay of embeddings calls."""
async def mock_create(*args, **kwargs):
return real_embeddings_response
temp_storage_dir = temp_storage_dir / "test_embeddings_recording"
# Record
with patch("openai.resources.embeddings.AsyncEmbeddings.create", side_effect=mock_create):
with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)):
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
response = await client.embeddings.create(
model="nomic-embed-text", input=["Hello world", "Test embedding"]
)
assert len(response.data) == 2
# Replay
with patch("openai.resources.embeddings.AsyncEmbeddings.create") as mock_create_patch:
with inference_recording(mode=InferenceMode.REPLAY, storage_dir=str(temp_storage_dir)):
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
response = await client.embeddings.create(
model="nomic-embed-text", input=["Hello world", "Test embedding"]
)
# Verify we got the recorded response
assert len(response.data) == 2
assert response.data[0].embedding == [0.1, 0.2, 0.3]
# Verify original method was not called
mock_create_patch.assert_not_called()
async def test_live_mode(self, real_openai_chat_response):
"""Test that live mode passes through to original methods."""
async def mock_create(*args, **kwargs):
return real_openai_chat_response
with patch("openai.resources.chat.completions.AsyncCompletions.create", side_effect=mock_create):
with inference_recording(mode=InferenceMode.LIVE):
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
response = await client.chat.completions.create(
model="llama3.2:3b", messages=[{"role": "user", "content": "Hello"}]
)
# Verify the response was returned
assert response.choices[0].message.content == "Hello! I'm doing well, thank you for asking."