mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-13 19:32:37 +00:00
feat(tests): make inference_recorder into api_recorder (include tool_invoke)
This commit is contained in:
parent
b96640eca3
commit
9205731cd6
19 changed files with 849 additions and 666 deletions
|
|
@ -159,7 +159,6 @@ def make_mcp_server(required_auth_token: str | None = None, tools: dict[str, Cal
|
|||
import threading
|
||||
import time
|
||||
|
||||
import httpx
|
||||
import uvicorn
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
from mcp.server.sse import SseServerTransport
|
||||
|
|
@ -171,6 +170,11 @@ def make_mcp_server(required_auth_token: str | None = None, tools: dict[str, Cal
|
|||
|
||||
server = FastMCP("FastMCP Test Server", log_level="WARNING")
|
||||
|
||||
# Silence verbose MCP server logs
|
||||
import logging # allow-direct-logging
|
||||
|
||||
logging.getLogger("mcp.server.lowlevel.server").setLevel(logging.WARNING)
|
||||
|
||||
tools = tools or default_tools()
|
||||
|
||||
# Register all tools with the server
|
||||
|
|
@ -234,29 +238,25 @@ def make_mcp_server(required_auth_token: str | None = None, tools: dict[str, Cal
|
|||
logger.debug(f"Starting MCP server thread on port {port}")
|
||||
server_thread.start()
|
||||
|
||||
# Polling until the server is ready
|
||||
timeout = 10
|
||||
# Wait for the server thread to be running
|
||||
# Note: We can't use a simple HTTP GET health check on /sse because it's an SSE endpoint
|
||||
# that expects a long-lived connection, not a simple request/response
|
||||
timeout = 2
|
||||
start_time = time.time()
|
||||
|
||||
server_url = f"http://localhost:{port}/sse"
|
||||
logger.debug(f"Waiting for MCP server to be ready at {server_url}")
|
||||
logger.debug(f"Waiting for MCP server thread to start on port {port}")
|
||||
|
||||
while time.time() - start_time < timeout:
|
||||
try:
|
||||
response = httpx.get(server_url)
|
||||
if response.status_code in [200, 401]:
|
||||
logger.debug(f"MCP server is ready on port {port} (status: {response.status_code})")
|
||||
break
|
||||
except httpx.RequestError as e:
|
||||
logger.debug(f"Server not ready yet, retrying... ({e})")
|
||||
pass
|
||||
time.sleep(0.1)
|
||||
if server_thread.is_alive():
|
||||
# Give the server a moment to bind to the port
|
||||
time.sleep(0.1)
|
||||
logger.debug(f"MCP server is ready on port {port}")
|
||||
break
|
||||
time.sleep(0.05)
|
||||
else:
|
||||
# If we exit the loop due to timeout
|
||||
logger.error(f"MCP server failed to start within {timeout} seconds on port {port}")
|
||||
logger.error(f"Thread alive: {server_thread.is_alive()}")
|
||||
if server_thread.is_alive():
|
||||
logger.error("Server thread is still running but not responding to HTTP requests")
|
||||
logger.error(f"MCP server thread failed to start within {timeout} seconds on port {port}")
|
||||
|
||||
try:
|
||||
yield {"server_url": server_url}
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@
|
|||
import inspect
|
||||
import itertools
|
||||
import os
|
||||
import tempfile
|
||||
import textwrap
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
|
@ -14,6 +15,7 @@ import pytest
|
|||
from dotenv import load_dotenv
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.testing.api_recorder import patch_httpx_for_test_id
|
||||
|
||||
from .suites import SETUP_DEFINITIONS, SUITE_DEFINITIONS
|
||||
|
||||
|
|
@ -35,6 +37,10 @@ def pytest_sessionstart(session):
|
|||
if "LLAMA_STACK_TEST_INFERENCE_MODE" not in os.environ:
|
||||
os.environ["LLAMA_STACK_TEST_INFERENCE_MODE"] = "replay"
|
||||
|
||||
if "SQLITE_STORE_DIR" not in os.environ:
|
||||
os.environ["SQLITE_STORE_DIR"] = tempfile.mkdtemp()
|
||||
|
||||
# Set test stack config type for api_recorder test isolation
|
||||
stack_config = session.config.getoption("--stack-config", default=None)
|
||||
if stack_config and stack_config.startswith("server:"):
|
||||
os.environ["LLAMA_STACK_TEST_STACK_CONFIG_TYPE"] = "server"
|
||||
|
|
@ -43,8 +49,6 @@ def pytest_sessionstart(session):
|
|||
os.environ["LLAMA_STACK_TEST_STACK_CONFIG_TYPE"] = "library_client"
|
||||
logger.info(f"Test stack config type: library_client (stack_config={stack_config})")
|
||||
|
||||
from llama_stack.testing.inference_recorder import patch_httpx_for_test_id
|
||||
|
||||
patch_httpx_for_test_id()
|
||||
|
||||
|
||||
|
|
@ -55,7 +59,7 @@ def _track_test_context(request):
|
|||
This fixture runs for every test and stores the test's nodeid in a contextvar
|
||||
that the recording system can access to determine which subdirectory to use.
|
||||
"""
|
||||
from llama_stack.testing.inference_recorder import _test_context
|
||||
from llama_stack.testing.api_recorder import _test_context
|
||||
|
||||
# Store the test nodeid (e.g., "tests/integration/responses/test_basic.py::test_foo[params]")
|
||||
token = _test_context.set(request.node.nodeid)
|
||||
|
|
@ -121,9 +125,13 @@ def pytest_configure(config):
|
|||
# Apply defaults if not provided explicitly
|
||||
for dest, value in setup_obj.defaults.items():
|
||||
current = getattr(config.option, dest, None)
|
||||
if not current:
|
||||
if current is None:
|
||||
setattr(config.option, dest, value)
|
||||
|
||||
# Apply global fallback for embedding_dimension if still not set
|
||||
if getattr(config.option, "embedding_dimension", None) is None:
|
||||
config.option.embedding_dimension = 384
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
parser.addoption(
|
||||
|
|
@ -161,8 +169,8 @@ def pytest_addoption(parser):
|
|||
parser.addoption(
|
||||
"--embedding-dimension",
|
||||
type=int,
|
||||
default=384,
|
||||
help="Output dimensionality of the embedding model to use for testing. Default: 384",
|
||||
default=None,
|
||||
help="Output dimensionality of the embedding model to use for testing. Default: 384 (or setup-specific)",
|
||||
)
|
||||
|
||||
parser.addoption(
|
||||
|
|
@ -236,7 +244,9 @@ def pytest_generate_tests(metafunc):
|
|||
continue
|
||||
|
||||
params.append(fixture_name)
|
||||
val = metafunc.config.getoption(option)
|
||||
# Use getattr on config.option to see values set by pytest_configure fallbacks
|
||||
dest = option.lstrip("-").replace("-", "_")
|
||||
val = getattr(metafunc.config.option, dest, None)
|
||||
|
||||
values = [v.strip() for v in str(val).split(",")] if val else [None]
|
||||
param_values[fixture_name] = values
|
||||
|
|
|
|||
|
|
@ -183,6 +183,12 @@ def llama_stack_client(request):
|
|||
# would be forced to use llama_stack_client, which is not what we want.
|
||||
print("\ninstantiating llama_stack_client")
|
||||
start_time = time.time()
|
||||
|
||||
# Patch httpx to inject test ID for server-mode test isolation
|
||||
from llama_stack.testing.api_recorder import patch_httpx_for_test_id
|
||||
|
||||
patch_httpx_for_test_id()
|
||||
|
||||
client = instantiate_llama_stack_client(request.session)
|
||||
print(f"llama_stack_client instantiated in {time.time() - start_time:.3f}s")
|
||||
return client
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@
|
|||
import time
|
||||
|
||||
|
||||
def new_vector_store(openai_client, name):
|
||||
def new_vector_store(openai_client, name, embedding_model, embedding_dimension):
|
||||
"""Create a new vector store, cleaning up any existing one with the same name."""
|
||||
# Ensure we don't reuse an existing vector store
|
||||
vector_stores = openai_client.vector_stores.list()
|
||||
|
|
@ -16,7 +16,21 @@ def new_vector_store(openai_client, name):
|
|||
openai_client.vector_stores.delete(vector_store_id=vector_store.id)
|
||||
|
||||
# Create a new vector store
|
||||
vector_store = openai_client.vector_stores.create(name=name)
|
||||
# OpenAI SDK client uses extra_body for non-standard parameters
|
||||
from openai import OpenAI
|
||||
|
||||
if isinstance(openai_client, OpenAI):
|
||||
# OpenAI SDK client - use extra_body
|
||||
vector_store = openai_client.vector_stores.create(
|
||||
name=name,
|
||||
extra_body={"embedding_model": embedding_model, "embedding_dimension": embedding_dimension},
|
||||
)
|
||||
else:
|
||||
# LlamaStack client - direct parameter
|
||||
vector_store = openai_client.vector_stores.create(
|
||||
name=name, embedding_model=embedding_model, embedding_dimension=embedding_dimension
|
||||
)
|
||||
|
||||
return vector_store
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ import pytest
|
|||
from llama_stack_client import APIStatusError
|
||||
|
||||
|
||||
@pytest.mark.xfail(reason="Shields are not yet implemented inside responses")
|
||||
def test_shields_via_extra_body(compat_client, text_model_id):
|
||||
"""Test that shields parameter is received by the server and raises NotImplementedError."""
|
||||
|
||||
|
|
|
|||
|
|
@ -47,12 +47,14 @@ def test_response_text_format(compat_client, text_model_id, text_format):
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def vector_store_with_filtered_files(compat_client, text_model_id, tmp_path_factory):
|
||||
"""Create a vector store with multiple files that have different attributes for filtering tests."""
|
||||
def vector_store_with_filtered_files(compat_client, embedding_model_id, embedding_dimension, tmp_path_factory):
|
||||
# """Create a vector store with multiple files that have different attributes for filtering tests."""
|
||||
if isinstance(compat_client, LlamaStackAsLibraryClient):
|
||||
pytest.skip("Responses API file search is not yet supported in library client.")
|
||||
pytest.skip("upload_file() is not yet supported in library client somehow?")
|
||||
|
||||
vector_store = new_vector_store(compat_client, "test_vector_store_with_filters")
|
||||
vector_store = new_vector_store(
|
||||
compat_client, "test_vector_store_with_filters", embedding_model_id, embedding_dimension
|
||||
)
|
||||
tmp_path = tmp_path_factory.mktemp("filter_test_files")
|
||||
|
||||
# Create multiple files with different attributes
|
||||
|
|
|
|||
|
|
@ -46,11 +46,13 @@ def test_response_non_streaming_web_search(compat_client, text_model_id, case):
|
|||
|
||||
|
||||
@pytest.mark.parametrize("case", file_search_test_cases)
|
||||
def test_response_non_streaming_file_search(compat_client, text_model_id, tmp_path, case):
|
||||
def test_response_non_streaming_file_search(
|
||||
compat_client, text_model_id, embedding_model_id, embedding_dimension, tmp_path, case
|
||||
):
|
||||
if isinstance(compat_client, LlamaStackAsLibraryClient):
|
||||
pytest.skip("Responses API file search is not yet supported in library client.")
|
||||
|
||||
vector_store = new_vector_store(compat_client, "test_vector_store")
|
||||
vector_store = new_vector_store(compat_client, "test_vector_store", embedding_model_id, embedding_dimension)
|
||||
|
||||
if case.file_content:
|
||||
file_name = "test_response_non_streaming_file_search.txt"
|
||||
|
|
@ -101,11 +103,13 @@ def test_response_non_streaming_file_search(compat_client, text_model_id, tmp_pa
|
|||
assert case.expected.lower() in response.output_text.lower().strip()
|
||||
|
||||
|
||||
def test_response_non_streaming_file_search_empty_vector_store(compat_client, text_model_id):
|
||||
def test_response_non_streaming_file_search_empty_vector_store(
|
||||
compat_client, text_model_id, embedding_model_id, embedding_dimension
|
||||
):
|
||||
if isinstance(compat_client, LlamaStackAsLibraryClient):
|
||||
pytest.skip("Responses API file search is not yet supported in library client.")
|
||||
|
||||
vector_store = new_vector_store(compat_client, "test_vector_store")
|
||||
vector_store = new_vector_store(compat_client, "test_vector_store", embedding_model_id, embedding_dimension)
|
||||
|
||||
# Create the response request, which should query our vector store
|
||||
response = compat_client.responses.create(
|
||||
|
|
@ -127,12 +131,14 @@ def test_response_non_streaming_file_search_empty_vector_store(compat_client, te
|
|||
assert response.output_text
|
||||
|
||||
|
||||
def test_response_sequential_file_search(compat_client, text_model_id, tmp_path):
|
||||
def test_response_sequential_file_search(
|
||||
compat_client, text_model_id, embedding_model_id, embedding_dimension, tmp_path
|
||||
):
|
||||
"""Test file search with sequential responses using previous_response_id."""
|
||||
if isinstance(compat_client, LlamaStackAsLibraryClient):
|
||||
pytest.skip("Responses API file search is not yet supported in library client.")
|
||||
|
||||
vector_store = new_vector_store(compat_client, "test_vector_store")
|
||||
vector_store = new_vector_store(compat_client, "test_vector_store", embedding_model_id, embedding_dimension)
|
||||
|
||||
# Create a test file with content
|
||||
file_content = "The Llama 4 Maverick model has 128 experts in its mixture of experts architecture."
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ class Setup(BaseModel):
|
|||
|
||||
name: str
|
||||
description: str
|
||||
defaults: dict[str, str] = Field(default_factory=dict)
|
||||
defaults: dict[str, str | int] = Field(default_factory=dict)
|
||||
env: dict[str, str] = Field(default_factory=dict)
|
||||
|
||||
|
||||
|
|
@ -88,6 +88,7 @@ SETUP_DEFINITIONS: dict[str, Setup] = {
|
|||
defaults={
|
||||
"text_model": "openai/gpt-4o",
|
||||
"embedding_model": "openai/text-embedding-3-small",
|
||||
"embedding_dimension": 1536,
|
||||
},
|
||||
),
|
||||
"tgi": Setup(
|
||||
|
|
|
|||
273
tests/unit/distribution/test_api_recordings.py
Normal file
273
tests/unit/distribution/test_api_recordings.py
Normal file
|
|
@ -0,0 +1,273 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
# Import the real Pydantic response types instead of using Mocks
|
||||
from llama_stack.apis.inference import (
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChoice,
|
||||
OpenAIEmbeddingData,
|
||||
OpenAIEmbeddingsResponse,
|
||||
OpenAIEmbeddingUsage,
|
||||
)
|
||||
from llama_stack.testing.api_recorder import (
|
||||
APIRecordingMode,
|
||||
ResponseStorage,
|
||||
api_recording,
|
||||
normalize_request,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_storage_dir():
|
||||
"""Create a temporary directory for test recordings."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
yield Path(temp_dir)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def real_openai_chat_response():
|
||||
"""Real OpenAI chat completion response using proper Pydantic objects."""
|
||||
return OpenAIChatCompletion(
|
||||
id="chatcmpl-test123",
|
||||
choices=[
|
||||
OpenAIChoice(
|
||||
index=0,
|
||||
message=OpenAIAssistantMessageParam(
|
||||
role="assistant", content="Hello! I'm doing well, thank you for asking."
|
||||
),
|
||||
finish_reason="stop",
|
||||
)
|
||||
],
|
||||
created=1234567890,
|
||||
model="llama3.2:3b",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def real_embeddings_response():
|
||||
"""Real OpenAI embeddings response using proper Pydantic objects."""
|
||||
return OpenAIEmbeddingsResponse(
|
||||
object="list",
|
||||
data=[
|
||||
OpenAIEmbeddingData(object="embedding", embedding=[0.1, 0.2, 0.3], index=0),
|
||||
OpenAIEmbeddingData(object="embedding", embedding=[0.4, 0.5, 0.6], index=1),
|
||||
],
|
||||
model="nomic-embed-text",
|
||||
usage=OpenAIEmbeddingUsage(prompt_tokens=6, total_tokens=6),
|
||||
)
|
||||
|
||||
|
||||
class TestInferenceRecording:
|
||||
"""Test the inference recording system."""
|
||||
|
||||
def test_request_normalization(self):
|
||||
"""Test that request normalization produces consistent hashes."""
|
||||
# Test basic normalization
|
||||
hash1 = normalize_request(
|
||||
"POST",
|
||||
"http://localhost:11434/v1/chat/completions",
|
||||
{},
|
||||
{"model": "llama3.2:3b", "messages": [{"role": "user", "content": "Hello world"}], "temperature": 0.7},
|
||||
)
|
||||
|
||||
# Same request should produce same hash
|
||||
hash2 = normalize_request(
|
||||
"POST",
|
||||
"http://localhost:11434/v1/chat/completions",
|
||||
{},
|
||||
{"model": "llama3.2:3b", "messages": [{"role": "user", "content": "Hello world"}], "temperature": 0.7},
|
||||
)
|
||||
|
||||
assert hash1 == hash2
|
||||
|
||||
# Different content should produce different hash
|
||||
hash3 = normalize_request(
|
||||
"POST",
|
||||
"http://localhost:11434/v1/chat/completions",
|
||||
{},
|
||||
{
|
||||
"model": "llama3.2:3b",
|
||||
"messages": [{"role": "user", "content": "Different message"}],
|
||||
"temperature": 0.7,
|
||||
},
|
||||
)
|
||||
|
||||
assert hash1 != hash3
|
||||
|
||||
def test_request_normalization_edge_cases(self):
|
||||
"""Test request normalization is precise about request content."""
|
||||
# Test that different whitespace produces different hashes (no normalization)
|
||||
hash1 = normalize_request(
|
||||
"POST",
|
||||
"http://test/v1/chat/completions",
|
||||
{},
|
||||
{"messages": [{"role": "user", "content": "Hello world\n\n"}]},
|
||||
)
|
||||
hash2 = normalize_request(
|
||||
"POST", "http://test/v1/chat/completions", {}, {"messages": [{"role": "user", "content": "Hello world"}]}
|
||||
)
|
||||
assert hash1 != hash2 # Different whitespace should produce different hashes
|
||||
|
||||
# Test that different float precision produces different hashes (no rounding)
|
||||
hash3 = normalize_request("POST", "http://test/v1/chat/completions", {}, {"temperature": 0.7000001})
|
||||
hash4 = normalize_request("POST", "http://test/v1/chat/completions", {}, {"temperature": 0.7})
|
||||
assert hash3 != hash4 # Different precision should produce different hashes
|
||||
|
||||
def test_response_storage(self, temp_storage_dir):
|
||||
"""Test the ResponseStorage class."""
|
||||
temp_storage_dir = temp_storage_dir / "test_response_storage"
|
||||
storage = ResponseStorage(temp_storage_dir)
|
||||
|
||||
# Test storing and retrieving a recording
|
||||
request_hash = "test_hash_123"
|
||||
request_data = {
|
||||
"method": "POST",
|
||||
"url": "http://localhost:11434/v1/chat/completions",
|
||||
"endpoint": "/v1/chat/completions",
|
||||
"model": "llama3.2:3b",
|
||||
}
|
||||
response_data = {"body": {"content": "test response"}, "is_streaming": False}
|
||||
|
||||
storage.store_recording(request_hash, request_data, response_data)
|
||||
|
||||
# Verify file storage and retrieval
|
||||
retrieved = storage.find_recording(request_hash)
|
||||
assert retrieved is not None
|
||||
assert retrieved["request"]["model"] == "llama3.2:3b"
|
||||
assert retrieved["response"]["body"]["content"] == "test response"
|
||||
|
||||
async def test_recording_mode(self, temp_storage_dir, real_openai_chat_response):
|
||||
"""Test that recording mode captures and stores responses."""
|
||||
|
||||
async def mock_create(*args, **kwargs):
|
||||
return real_openai_chat_response
|
||||
|
||||
temp_storage_dir = temp_storage_dir / "test_recording_mode"
|
||||
with patch("openai.resources.chat.completions.AsyncCompletions.create", side_effect=mock_create):
|
||||
with api_recording(mode=APIRecordingMode.RECORD, storage_dir=str(temp_storage_dir)):
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
|
||||
response = await client.chat.completions.create(
|
||||
model="llama3.2:3b",
|
||||
messages=[{"role": "user", "content": "Hello, how are you?"}],
|
||||
temperature=0.7,
|
||||
max_tokens=50,
|
||||
)
|
||||
|
||||
# Verify the response was returned correctly
|
||||
assert response.choices[0].message.content == "Hello! I'm doing well, thank you for asking."
|
||||
|
||||
# Verify recording was stored
|
||||
storage = ResponseStorage(temp_storage_dir)
|
||||
assert storage._get_test_dir().exists()
|
||||
|
||||
async def test_replay_mode(self, temp_storage_dir, real_openai_chat_response):
|
||||
"""Test that replay mode returns stored responses without making real calls."""
|
||||
|
||||
async def mock_create(*args, **kwargs):
|
||||
return real_openai_chat_response
|
||||
|
||||
temp_storage_dir = temp_storage_dir / "test_replay_mode"
|
||||
# First, record a response
|
||||
with patch("openai.resources.chat.completions.AsyncCompletions.create", side_effect=mock_create):
|
||||
with api_recording(mode=APIRecordingMode.RECORD, storage_dir=str(temp_storage_dir)):
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
|
||||
response = await client.chat.completions.create(
|
||||
model="llama3.2:3b",
|
||||
messages=[{"role": "user", "content": "Hello, how are you?"}],
|
||||
temperature=0.7,
|
||||
max_tokens=50,
|
||||
)
|
||||
|
||||
# Now test replay mode - should not call the original method
|
||||
with patch("openai.resources.chat.completions.AsyncCompletions.create") as mock_create_patch:
|
||||
with api_recording(mode=APIRecordingMode.REPLAY, storage_dir=str(temp_storage_dir)):
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
|
||||
response = await client.chat.completions.create(
|
||||
model="llama3.2:3b",
|
||||
messages=[{"role": "user", "content": "Hello, how are you?"}],
|
||||
temperature=0.7,
|
||||
max_tokens=50,
|
||||
)
|
||||
|
||||
# Verify we got the recorded response
|
||||
assert response.choices[0].message.content == "Hello! I'm doing well, thank you for asking."
|
||||
|
||||
# Verify the original method was NOT called
|
||||
mock_create_patch.assert_not_called()
|
||||
|
||||
async def test_replay_missing_recording(self, temp_storage_dir):
|
||||
"""Test that replay mode fails when no recording is found."""
|
||||
temp_storage_dir = temp_storage_dir / "test_replay_missing_recording"
|
||||
with patch("openai.resources.chat.completions.AsyncCompletions.create"):
|
||||
with api_recording(mode=APIRecordingMode.REPLAY, storage_dir=str(temp_storage_dir)):
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
|
||||
with pytest.raises(RuntimeError, match="No recorded response found"):
|
||||
await client.chat.completions.create(
|
||||
model="llama3.2:3b", messages=[{"role": "user", "content": "This was never recorded"}]
|
||||
)
|
||||
|
||||
async def test_embeddings_recording(self, temp_storage_dir, real_embeddings_response):
|
||||
"""Test recording and replay of embeddings calls."""
|
||||
|
||||
async def mock_create(*args, **kwargs):
|
||||
return real_embeddings_response
|
||||
|
||||
temp_storage_dir = temp_storage_dir / "test_embeddings_recording"
|
||||
# Record
|
||||
with patch("openai.resources.embeddings.AsyncEmbeddings.create", side_effect=mock_create):
|
||||
with api_recording(mode=APIRecordingMode.RECORD, storage_dir=str(temp_storage_dir)):
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
|
||||
response = await client.embeddings.create(
|
||||
model="nomic-embed-text", input=["Hello world", "Test embedding"]
|
||||
)
|
||||
|
||||
assert len(response.data) == 2
|
||||
|
||||
# Replay
|
||||
with patch("openai.resources.embeddings.AsyncEmbeddings.create") as mock_create_patch:
|
||||
with api_recording(mode=APIRecordingMode.REPLAY, storage_dir=str(temp_storage_dir)):
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
|
||||
response = await client.embeddings.create(
|
||||
model="nomic-embed-text", input=["Hello world", "Test embedding"]
|
||||
)
|
||||
|
||||
# Verify we got the recorded response
|
||||
assert len(response.data) == 2
|
||||
assert response.data[0].embedding == [0.1, 0.2, 0.3]
|
||||
|
||||
# Verify original method was not called
|
||||
mock_create_patch.assert_not_called()
|
||||
|
||||
async def test_live_mode(self, real_openai_chat_response):
|
||||
"""Test that live mode passes through to original methods."""
|
||||
|
||||
async def mock_create(*args, **kwargs):
|
||||
return real_openai_chat_response
|
||||
|
||||
with patch("openai.resources.chat.completions.AsyncCompletions.create", side_effect=mock_create):
|
||||
with api_recording(mode=APIRecordingMode.LIVE, storage_dir="foo"):
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
|
||||
response = await client.chat.completions.create(
|
||||
model="llama3.2:3b", messages=[{"role": "user", "content": "Hello"}]
|
||||
)
|
||||
|
||||
# Verify the response was returned
|
||||
assert response.choices[0].message.content == "Hello! I'm doing well, thank you for asking."
|
||||
|
|
@ -1,382 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from openai import NOT_GIVEN, AsyncOpenAI
|
||||
from openai.types.model import Model as OpenAIModel
|
||||
|
||||
# Import the real Pydantic response types instead of using Mocks
|
||||
from llama_stack.apis.inference import (
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChoice,
|
||||
OpenAICompletion,
|
||||
OpenAIEmbeddingData,
|
||||
OpenAIEmbeddingsResponse,
|
||||
OpenAIEmbeddingUsage,
|
||||
)
|
||||
from llama_stack.testing.inference_recorder import (
|
||||
InferenceMode,
|
||||
ResponseStorage,
|
||||
inference_recording,
|
||||
normalize_request,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_storage_dir():
|
||||
"""Create a temporary directory for test recordings."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
yield Path(temp_dir)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def real_openai_chat_response():
|
||||
"""Real OpenAI chat completion response using proper Pydantic objects."""
|
||||
return OpenAIChatCompletion(
|
||||
id="chatcmpl-test123",
|
||||
choices=[
|
||||
OpenAIChoice(
|
||||
index=0,
|
||||
message=OpenAIAssistantMessageParam(
|
||||
role="assistant", content="Hello! I'm doing well, thank you for asking."
|
||||
),
|
||||
finish_reason="stop",
|
||||
)
|
||||
],
|
||||
created=1234567890,
|
||||
model="llama3.2:3b",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def real_embeddings_response():
|
||||
"""Real OpenAI embeddings response using proper Pydantic objects."""
|
||||
return OpenAIEmbeddingsResponse(
|
||||
object="list",
|
||||
data=[
|
||||
OpenAIEmbeddingData(object="embedding", embedding=[0.1, 0.2, 0.3], index=0),
|
||||
OpenAIEmbeddingData(object="embedding", embedding=[0.4, 0.5, 0.6], index=1),
|
||||
],
|
||||
model="nomic-embed-text",
|
||||
usage=OpenAIEmbeddingUsage(prompt_tokens=6, total_tokens=6),
|
||||
)
|
||||
|
||||
|
||||
class TestInferenceRecording:
|
||||
"""Test the inference recording system."""
|
||||
|
||||
def test_request_normalization(self):
|
||||
"""Test that request normalization produces consistent hashes."""
|
||||
# Test basic normalization
|
||||
hash1 = normalize_request(
|
||||
"POST",
|
||||
"http://localhost:11434/v1/chat/completions",
|
||||
{},
|
||||
{"model": "llama3.2:3b", "messages": [{"role": "user", "content": "Hello world"}], "temperature": 0.7},
|
||||
)
|
||||
|
||||
# Same request should produce same hash
|
||||
hash2 = normalize_request(
|
||||
"POST",
|
||||
"http://localhost:11434/v1/chat/completions",
|
||||
{},
|
||||
{"model": "llama3.2:3b", "messages": [{"role": "user", "content": "Hello world"}], "temperature": 0.7},
|
||||
)
|
||||
|
||||
assert hash1 == hash2
|
||||
|
||||
# Different content should produce different hash
|
||||
hash3 = normalize_request(
|
||||
"POST",
|
||||
"http://localhost:11434/v1/chat/completions",
|
||||
{},
|
||||
{
|
||||
"model": "llama3.2:3b",
|
||||
"messages": [{"role": "user", "content": "Different message"}],
|
||||
"temperature": 0.7,
|
||||
},
|
||||
)
|
||||
|
||||
assert hash1 != hash3
|
||||
|
||||
def test_request_normalization_edge_cases(self):
|
||||
"""Test request normalization is precise about request content."""
|
||||
# Test that different whitespace produces different hashes (no normalization)
|
||||
hash1 = normalize_request(
|
||||
"POST",
|
||||
"http://test/v1/chat/completions",
|
||||
{},
|
||||
{"messages": [{"role": "user", "content": "Hello world\n\n"}]},
|
||||
)
|
||||
hash2 = normalize_request(
|
||||
"POST", "http://test/v1/chat/completions", {}, {"messages": [{"role": "user", "content": "Hello world"}]}
|
||||
)
|
||||
assert hash1 != hash2 # Different whitespace should produce different hashes
|
||||
|
||||
# Test that different float precision produces different hashes (no rounding)
|
||||
hash3 = normalize_request("POST", "http://test/v1/chat/completions", {}, {"temperature": 0.7000001})
|
||||
hash4 = normalize_request("POST", "http://test/v1/chat/completions", {}, {"temperature": 0.7})
|
||||
assert hash3 != hash4 # Different precision should produce different hashes
|
||||
|
||||
def test_response_storage(self, temp_storage_dir):
|
||||
"""Test the ResponseStorage class."""
|
||||
temp_storage_dir = temp_storage_dir / "test_response_storage"
|
||||
storage = ResponseStorage(temp_storage_dir)
|
||||
|
||||
# Test storing and retrieving a recording
|
||||
request_hash = "test_hash_123"
|
||||
request_data = {
|
||||
"method": "POST",
|
||||
"url": "http://localhost:11434/v1/chat/completions",
|
||||
"endpoint": "/v1/chat/completions",
|
||||
"model": "llama3.2:3b",
|
||||
}
|
||||
response_data = {"body": {"content": "test response"}, "is_streaming": False}
|
||||
|
||||
storage.store_recording(request_hash, request_data, response_data)
|
||||
|
||||
# Verify file storage and retrieval
|
||||
retrieved = storage.find_recording(request_hash)
|
||||
assert retrieved is not None
|
||||
assert retrieved["request"]["model"] == "llama3.2:3b"
|
||||
assert retrieved["response"]["body"]["content"] == "test response"
|
||||
|
||||
async def test_recording_mode(self, temp_storage_dir, real_openai_chat_response):
|
||||
"""Test that recording mode captures and stores responses."""
|
||||
temp_storage_dir = temp_storage_dir / "test_recording_mode"
|
||||
with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)):
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
client.chat.completions._post = AsyncMock(return_value=real_openai_chat_response)
|
||||
|
||||
response = await client.chat.completions.create(
|
||||
model="llama3.2:3b",
|
||||
messages=[{"role": "user", "content": "Hello, how are you?"}],
|
||||
temperature=0.7,
|
||||
max_tokens=50,
|
||||
user=NOT_GIVEN,
|
||||
)
|
||||
|
||||
# Verify the response was returned correctly
|
||||
assert response.choices[0].message.content == "Hello! I'm doing well, thank you for asking."
|
||||
client.chat.completions._post.assert_called_once()
|
||||
|
||||
# Verify recording was stored
|
||||
storage = ResponseStorage(temp_storage_dir)
|
||||
dir = storage._get_test_dir()
|
||||
assert dir.exists()
|
||||
|
||||
async def test_replay_mode(self, temp_storage_dir, real_openai_chat_response):
|
||||
"""Test that replay mode returns stored responses without making real calls."""
|
||||
temp_storage_dir = temp_storage_dir / "test_replay_mode"
|
||||
# First, record a response
|
||||
with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)):
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
client.chat.completions._post = AsyncMock(return_value=real_openai_chat_response)
|
||||
|
||||
response = await client.chat.completions.create(
|
||||
model="llama3.2:3b",
|
||||
messages=[{"role": "user", "content": "Hello, how are you?"}],
|
||||
temperature=0.7,
|
||||
max_tokens=50,
|
||||
user=NOT_GIVEN,
|
||||
)
|
||||
client.chat.completions._post.assert_called_once()
|
||||
|
||||
# Now test replay mode - should not call the original method
|
||||
with inference_recording(mode=InferenceMode.REPLAY, storage_dir=str(temp_storage_dir)):
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
client.chat.completions._post = AsyncMock(return_value=real_openai_chat_response)
|
||||
|
||||
response = await client.chat.completions.create(
|
||||
model="llama3.2:3b",
|
||||
messages=[{"role": "user", "content": "Hello, how are you?"}],
|
||||
temperature=0.7,
|
||||
max_tokens=50,
|
||||
)
|
||||
|
||||
# Verify we got the recorded response
|
||||
assert response.choices[0].message.content == "Hello! I'm doing well, thank you for asking."
|
||||
|
||||
# Verify the original method was NOT called
|
||||
client.chat.completions._post.assert_not_called()
|
||||
|
||||
async def test_replay_mode_models(self, temp_storage_dir):
|
||||
"""Test that replay mode returns stored responses without making real model listing calls."""
|
||||
|
||||
async def _async_iterator(models):
|
||||
for model in models:
|
||||
yield model
|
||||
|
||||
models = [
|
||||
OpenAIModel(id="foo", created=1, object="model", owned_by="test"),
|
||||
OpenAIModel(id="bar", created=2, object="model", owned_by="test"),
|
||||
]
|
||||
|
||||
expected_ids = {m.id for m in models}
|
||||
|
||||
temp_storage_dir = temp_storage_dir / "test_replay_mode_models"
|
||||
|
||||
# baseline - mock works without recording
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
client.models._get_api_list = Mock(return_value=_async_iterator(models))
|
||||
assert {m.id async for m in client.models.list()} == expected_ids
|
||||
client.models._get_api_list.assert_called_once()
|
||||
|
||||
# record the call
|
||||
with inference_recording(mode=InferenceMode.RECORD, storage_dir=temp_storage_dir):
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
client.models._get_api_list = Mock(return_value=_async_iterator(models))
|
||||
assert {m.id async for m in client.models.list()} == expected_ids
|
||||
client.models._get_api_list.assert_called_once()
|
||||
|
||||
# replay the call
|
||||
with inference_recording(mode=InferenceMode.REPLAY, storage_dir=temp_storage_dir):
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
client.models._get_api_list = Mock(return_value=_async_iterator(models))
|
||||
assert {m.id async for m in client.models.list()} == expected_ids
|
||||
client.models._get_api_list.assert_not_called()
|
||||
|
||||
async def test_replay_missing_recording(self, temp_storage_dir):
|
||||
"""Test that replay mode fails when no recording is found."""
|
||||
temp_storage_dir = temp_storage_dir / "test_replay_missing_recording"
|
||||
with patch("openai.resources.chat.completions.AsyncCompletions.create"):
|
||||
with inference_recording(mode=InferenceMode.REPLAY, storage_dir=str(temp_storage_dir)):
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
|
||||
with pytest.raises(RuntimeError, match="No recorded response found"):
|
||||
await client.chat.completions.create(
|
||||
model="llama3.2:3b", messages=[{"role": "user", "content": "This was never recorded"}]
|
||||
)
|
||||
|
||||
async def test_embeddings_recording(self, temp_storage_dir, real_embeddings_response):
|
||||
"""Test recording and replay of embeddings calls."""
|
||||
|
||||
# baseline - mock works without recording
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
client.embeddings._post = AsyncMock(return_value=real_embeddings_response)
|
||||
response = await client.embeddings.create(
|
||||
model=real_embeddings_response.model,
|
||||
input=["Hello world", "Test embedding"],
|
||||
encoding_format=NOT_GIVEN,
|
||||
)
|
||||
assert len(response.data) == 2
|
||||
assert response.data[0].embedding == [0.1, 0.2, 0.3]
|
||||
client.embeddings._post.assert_called_once()
|
||||
|
||||
temp_storage_dir = temp_storage_dir / "test_embeddings_recording"
|
||||
# Record
|
||||
with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)):
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
client.embeddings._post = AsyncMock(return_value=real_embeddings_response)
|
||||
|
||||
response = await client.embeddings.create(
|
||||
model=real_embeddings_response.model,
|
||||
input=["Hello world", "Test embedding"],
|
||||
encoding_format=NOT_GIVEN,
|
||||
dimensions=NOT_GIVEN,
|
||||
user=NOT_GIVEN,
|
||||
)
|
||||
|
||||
assert len(response.data) == 2
|
||||
|
||||
# Replay
|
||||
with inference_recording(mode=InferenceMode.REPLAY, storage_dir=str(temp_storage_dir)):
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
client.embeddings._post = AsyncMock(return_value=real_embeddings_response)
|
||||
|
||||
response = await client.embeddings.create(
|
||||
model=real_embeddings_response.model,
|
||||
input=["Hello world", "Test embedding"],
|
||||
)
|
||||
|
||||
# Verify we got the recorded response
|
||||
assert len(response.data) == 2
|
||||
assert response.data[0].embedding == [0.1, 0.2, 0.3]
|
||||
|
||||
# Verify original method was not called
|
||||
client.embeddings._post.assert_not_called()
|
||||
|
||||
async def test_completions_recording(self, temp_storage_dir):
|
||||
real_completions_response = OpenAICompletion(
|
||||
id="test_completion",
|
||||
object="text_completion",
|
||||
created=1234567890,
|
||||
model="llama3.2:3b",
|
||||
choices=[
|
||||
{
|
||||
"text": "Hello! I'm doing well, thank you for asking.",
|
||||
"index": 0,
|
||||
"logprobs": None,
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
temp_storage_dir = temp_storage_dir / "test_completions_recording"
|
||||
|
||||
# baseline - mock works without recording
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
client.completions._post = AsyncMock(return_value=real_completions_response)
|
||||
response = await client.completions.create(
|
||||
model=real_completions_response.model,
|
||||
prompt="Hello, how are you?",
|
||||
temperature=0.7,
|
||||
max_tokens=50,
|
||||
user=NOT_GIVEN,
|
||||
)
|
||||
assert response.choices[0].text == real_completions_response.choices[0].text
|
||||
client.completions._post.assert_called_once()
|
||||
|
||||
# Record
|
||||
with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)):
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
client.completions._post = AsyncMock(return_value=real_completions_response)
|
||||
|
||||
response = await client.completions.create(
|
||||
model=real_completions_response.model,
|
||||
prompt="Hello, how are you?",
|
||||
temperature=0.7,
|
||||
max_tokens=50,
|
||||
user=NOT_GIVEN,
|
||||
)
|
||||
|
||||
assert response.choices[0].text == real_completions_response.choices[0].text
|
||||
client.completions._post.assert_called_once()
|
||||
|
||||
# Replay
|
||||
with inference_recording(mode=InferenceMode.REPLAY, storage_dir=str(temp_storage_dir)):
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
client.completions._post = AsyncMock(return_value=real_completions_response)
|
||||
response = await client.completions.create(
|
||||
model=real_completions_response.model,
|
||||
prompt="Hello, how are you?",
|
||||
temperature=0.7,
|
||||
max_tokens=50,
|
||||
)
|
||||
assert response.choices[0].text == real_completions_response.choices[0].text
|
||||
client.completions._post.assert_not_called()
|
||||
|
||||
async def test_live_mode(self, real_openai_chat_response):
|
||||
"""Test that live mode passes through to original methods."""
|
||||
|
||||
async def mock_create(*args, **kwargs):
|
||||
return real_openai_chat_response
|
||||
|
||||
with patch("openai.resources.chat.completions.AsyncCompletions.create", side_effect=mock_create):
|
||||
with inference_recording(mode=InferenceMode.LIVE, storage_dir="foo"):
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
|
||||
response = await client.chat.completions.create(
|
||||
model="llama3.2:3b", messages=[{"role": "user", "content": "Hello"}]
|
||||
)
|
||||
|
||||
# Verify the response was returned
|
||||
assert response.choices[0].message.content == "Hello! I'm doing well, thank you for asking."
|
||||
Loading…
Add table
Add a link
Reference in a new issue