feat(tests): make inference_recorder into api_recorder (include tool_invoke)

This commit is contained in:
Ashwin Bharambe 2025-10-04 11:53:44 -07:00
parent b96640eca3
commit 9205731cd6
19 changed files with 849 additions and 666 deletions

View file

@ -54,14 +54,14 @@ jobs:
# Define (setup, suite) pairs - they are always matched and cannot be independent
# Weekly schedule (Sun 1 AM): vllm+base
# Input test-setup=ollama-vision: ollama-vision+vision
# Default (including test-setup=ollama): both ollama+base and ollama-vision+vision
# Default (including test-setup=ollama): ollama+base, ollama-vision+vision, gpt+responses
config: >-
${{
github.event.schedule == '1 0 * * 0'
&& fromJSON('[{"setup": "vllm", "suite": "base"}]')
|| github.event.inputs.test-setup == 'ollama-vision'
&& fromJSON('[{"setup": "ollama-vision", "suite": "vision"}]')
|| fromJSON('[{"setup": "ollama", "suite": "base"}, {"setup": "ollama-vision", "suite": "vision"}]')
|| fromJSON('[{"setup": "ollama", "suite": "base"}, {"setup": "ollama-vision", "suite": "vision"}, {"setup": "gpt", "suite": "responses"}]')
}}
steps:

View file

@ -61,6 +61,9 @@ jobs:
- name: Run and record tests
uses: ./.github/actions/run-and-record-tests
env:
# Set OPENAI_API_KEY if using gpt setup
OPENAI_API_KEY: ${{ inputs.test-setup == 'gpt' && secrets.OPENAI_API_KEY || '' }}
with:
stack-config: 'server:ci-tests' # recording must be done with server since more tests are run
setup: ${{ inputs.test-setup || 'ollama' }}

View file

@ -68,7 +68,9 @@ recordings/
Direct API calls with no recording or replay:
```python
with inference_recording(mode=InferenceMode.LIVE):
from llama_stack.testing.api_recorder import api_recording, APIRecordingMode
with api_recording(mode=APIRecordingMode.LIVE):
response = await client.chat.completions.create(...)
```
@ -79,7 +81,7 @@ Use for initial development and debugging against real APIs.
Captures API interactions while passing through real responses:
```python
with inference_recording(mode=InferenceMode.RECORD, storage_dir="./recordings"):
with api_recording(mode=APIRecordingMode.RECORD, storage_dir="./recordings"):
response = await client.chat.completions.create(...)
# Real API call made, response captured AND returned
```
@ -96,7 +98,7 @@ The recording process:
Returns stored responses instead of making API calls:
```python
with inference_recording(mode=InferenceMode.REPLAY, storage_dir="./recordings"):
with api_recording(mode=APIRecordingMode.REPLAY, storage_dir="./recordings"):
response = await client.chat.completions.create(...)
# No API call made, cached response returned instantly
```

View file

@ -316,13 +316,13 @@ class Stack:
# asked for in the run config.
async def initialize(self):
if "LLAMA_STACK_TEST_INFERENCE_MODE" in os.environ:
from llama_stack.testing.inference_recorder import setup_inference_recording
from llama_stack.testing.api_recorder import setup_api_recording
global TEST_RECORDING_CONTEXT
TEST_RECORDING_CONTEXT = setup_inference_recording()
TEST_RECORDING_CONTEXT = setup_api_recording()
if TEST_RECORDING_CONTEXT:
TEST_RECORDING_CONTEXT.__enter__()
logger.info(f"Inference recording enabled: mode={os.environ.get('LLAMA_STACK_TEST_INFERENCE_MODE')}")
logger.info(f"API recording enabled: mode={os.environ.get('LLAMA_STACK_TEST_INFERENCE_MODE')}")
dist_registry, _ = await create_dist_registry(self.run_config.metadata_store, self.run_config.image_name)
policy = self.run_config.server.auth.access_policy if self.run_config.server.auth else []
@ -381,7 +381,7 @@ class Stack:
try:
TEST_RECORDING_CONTEXT.__exit__(None, None, None)
except Exception as e:
logger.error(f"Error during inference recording cleanup: {e}")
logger.error(f"Error during API recording cleanup: {e}")
global REGISTRY_REFRESH_TASK
if REGISTRY_REFRESH_TASK:

View file

@ -108,7 +108,7 @@ class OpenAIResponsesImpl:
# Use stored messages directly and convert only new input
message_adapter = TypeAdapter(list[OpenAIMessageParam])
messages = message_adapter.validate_python(previous_response.messages)
new_messages = await convert_response_input_to_chat_messages(input)
new_messages = await convert_response_input_to_chat_messages(input, previous_messages=messages)
messages.extend(new_messages)
else:
# Backward compatibility: reconstruct from inputs

View file

@ -103,9 +103,13 @@ async def convert_response_content_to_chat_content(
async def convert_response_input_to_chat_messages(
input: str | list[OpenAIResponseInput],
previous_messages: list[OpenAIMessageParam] | None = None,
) -> list[OpenAIMessageParam]:
"""
Convert the input from an OpenAI Response API request into OpenAI Chat Completion messages.
:param input: The input to convert
:param previous_messages: Optional previous messages to check for function_call references
"""
messages: list[OpenAIMessageParam] = []
if isinstance(input, list):
@ -169,16 +173,53 @@ async def convert_response_input_to_chat_messages(
raise ValueError(
f"Llama Stack OpenAI Responses does not yet support message role '{input_item.role}' in this context"
)
# Skip user messages that duplicate the last user message in previous_messages
# This handles cases where input includes context for function_call_outputs
if previous_messages and input_item.role == "user":
last_user_msg = None
for msg in reversed(previous_messages):
if isinstance(msg, OpenAIUserMessageParam):
last_user_msg = msg
break
if last_user_msg:
last_user_content = getattr(last_user_msg, "content", None)
if last_user_content == content:
continue # Skip duplicate user message
messages.append(message_type(content=content))
if len(tool_call_results):
raise ValueError(
f"Received function_call_output(s) with call_id(s) {tool_call_results.keys()}, but no corresponding function_call"
)
# Check if unpaired function_call_outputs reference function_calls from previous messages
if previous_messages:
previous_call_ids = _extract_tool_call_ids(previous_messages)
for call_id in list(tool_call_results.keys()):
if call_id in previous_call_ids:
# Valid: this output references a call from previous messages
# Add the tool message
messages.append(tool_call_results[call_id])
del tool_call_results[call_id]
# If still have unpaired outputs, error
if len(tool_call_results):
raise ValueError(
f"Received function_call_output(s) with call_id(s) {tool_call_results.keys()}, but no corresponding function_call"
)
else:
messages.append(OpenAIUserMessageParam(content=input))
return messages
def _extract_tool_call_ids(messages: list[OpenAIMessageParam]) -> set[str]:
"""Extract all tool_call IDs from messages."""
call_ids = set()
for msg in messages:
if isinstance(msg, OpenAIAssistantMessageParam):
tool_calls = getattr(msg, "tool_calls", None)
if tool_calls:
for tool_call in tool_calls:
# tool_call is a Pydantic model, use attribute access
call_ids.add(tool_call.id)
return call_ids
async def convert_response_text_to_chat_response_format(
text: OpenAIResponseText,
) -> OpenAIResponseFormatParam:

View file

@ -95,7 +95,9 @@ class LocalfsFilesImpl(Files):
raise RuntimeError("Files provider not initialized")
if expires_after is not None:
raise NotImplementedError("File expiration is not supported by this provider")
logger.warning(
f"File expiration is not supported by this provider, ignoring expires_after: {expires_after}"
)
file_id = self._generate_file_id()
file_path = self._get_file_path(file_id)

View file

@ -15,19 +15,20 @@ from enum import StrEnum
from pathlib import Path
from typing import Any, Literal, cast
from openai import NOT_GIVEN, OpenAI
from openai import NOT_GIVEN
from llama_stack.log import get_logger
logger = get_logger(__name__, category="testing")
# Global state for the recording system
# Global state for the API recording system
# Note: Using module globals instead of ContextVars because the session-scoped
# client initialization happens in one async context, but tests run in different
# contexts, and we need the mode/storage to persist across all contexts.
_current_mode: str | None = None
_current_storage: ResponseStorage | None = None
_original_methods: dict[str, Any] = {}
_memory_cache: dict[str, dict[str, Any]] = {}
# Test context uses ContextVar since it changes per-test and needs async isolation
from contextvars import ContextVar
@ -44,158 +45,33 @@ REPO_ROOT = Path(__file__).parent.parent.parent
DEFAULT_STORAGE_DIR = REPO_ROOT / "tests/integration/common"
class InferenceMode(StrEnum):
class APIRecordingMode(StrEnum):
LIVE = "live"
RECORD = "record"
REPLAY = "replay"
RECORD_IF_MISSING = "record-if-missing"
def normalize_request(method: str, url: str, headers: dict[str, Any], body: dict[str, Any]) -> str:
"""Create a normalized hash of the request for consistent matching.
def _normalize_file_ids(obj: Any) -> Any:
"""Recursively replace file IDs with a canonical placeholder for consistent hashing."""
import re
Includes test_id from context to ensure test isolation - identical requests
from different tests will have different hashes.
Exception: Model list endpoints (/v1/models, /api/tags) exclude test_id since
they are infrastructure/shared and need to work across session setup and tests.
"""
# Extract just the endpoint path
from urllib.parse import urlparse
parsed = urlparse(url)
normalized: dict[str, Any] = {
"method": method.upper(),
"endpoint": parsed.path,
"body": body,
}
# Include test_id for isolation, except for shared infrastructure endpoints
if parsed.path not in ("/api/tags", "/v1/models"):
normalized["test_id"] = _test_context.get()
# Create hash - sort_keys=True ensures deterministic ordering
normalized_json = json.dumps(normalized, sort_keys=True)
return hashlib.sha256(normalized_json.encode()).hexdigest()
def _sync_test_context_from_provider_data():
"""In server mode, sync test ID from provider_data to _test_context.
This ensures that storage operations (which read from _test_context) work correctly
in server mode where the test ID arrives via HTTP header provider_data.
Returns a token to reset _test_context, or None if no sync was needed.
"""
stack_config_type = os.environ.get("LLAMA_STACK_TEST_STACK_CONFIG_TYPE", "library_client")
if stack_config_type != "server":
return None
try:
from llama_stack.core.request_headers import PROVIDER_DATA_VAR
provider_data = PROVIDER_DATA_VAR.get()
if provider_data and "__test_id" in provider_data:
test_id = provider_data["__test_id"]
return _test_context.set(test_id)
except ImportError:
pass
return None
def patch_httpx_for_test_id():
"""Patch client _prepare_request methods to inject test ID into provider data header.
This is needed for server mode where the test ID must be transported from
client to server via HTTP headers. In library_client mode, this patch is a no-op
since everything runs in the same process.
We use the _prepare_request hook that Stainless clients provide for mutating
requests after construction but before sending.
"""
from llama_stack_client import LlamaStackClient
if "llama_stack_client_prepare_request" in _original_methods:
return
_original_methods["llama_stack_client_prepare_request"] = LlamaStackClient._prepare_request
_original_methods["openai_prepare_request"] = OpenAI._prepare_request
def patched_prepare_request(self, request):
# Call original first (it's a sync method that returns None)
# Determine which original to call based on client type
if "llama_stack_client" in self.__class__.__module__:
_original_methods["llama_stack_client_prepare_request"](self, request)
_original_methods["openai_prepare_request"](self, request)
# Only inject test ID in server mode
stack_config_type = os.environ.get("LLAMA_STACK_TEST_STACK_CONFIG_TYPE", "library_client")
test_id = _test_context.get()
if stack_config_type == "server" and test_id:
provider_data_header = request.headers.get("X-LlamaStack-Provider-Data")
if provider_data_header:
provider_data = json.loads(provider_data_header)
if isinstance(obj, dict):
result = {}
for k, v in obj.items():
# Normalize file IDs in attribute dictionaries
if k == "document_id" and isinstance(v, str) and v.startswith("file-"):
result[k] = "file-NORMALIZED"
else:
provider_data = {}
provider_data["__test_id"] = test_id
request.headers["X-LlamaStack-Provider-Data"] = json.dumps(provider_data)
return None
LlamaStackClient._prepare_request = patched_prepare_request
OpenAI._prepare_request = patched_prepare_request
# currently, unpatch is never called
def unpatch_httpx_for_test_id():
"""Remove client _prepare_request patches for test ID injection."""
if "llama_stack_client_prepare_request" not in _original_methods:
return
from llama_stack_client import LlamaStackClient
LlamaStackClient._prepare_request = _original_methods["llama_stack_client_prepare_request"]
del _original_methods["llama_stack_client_prepare_request"]
# Also restore OpenAI client if it was patched
if "openai_prepare_request" in _original_methods:
OpenAI._prepare_request = _original_methods["openai_prepare_request"]
del _original_methods["openai_prepare_request"]
def get_inference_mode() -> InferenceMode:
return InferenceMode(os.environ.get("LLAMA_STACK_TEST_INFERENCE_MODE", "replay").lower())
def setup_inference_recording():
"""
Returns a context manager that can be used to record or replay inference requests. This is to be used in tests
to increase their reliability and reduce reliance on expensive, external services.
Currently, this is only supported for OpenAI and Ollama clients. These should cover the vast majority of use cases.
Two environment variables are supported:
- LLAMA_STACK_TEST_INFERENCE_MODE: The mode to run in. Must be 'live', 'record', 'replay', or 'record-if-missing'. Default is 'replay'.
- 'live': Make all requests live without recording
- 'record': Record all requests (overwrites existing recordings)
- 'replay': Use only recorded responses (fails if recording not found)
- 'record-if-missing': Use recorded responses when available, record new ones when not found
- LLAMA_STACK_TEST_RECORDING_DIR: The directory to store the recordings in. Default is 'tests/integration/recordings'.
The recordings are stored as JSON files.
"""
mode = get_inference_mode()
if mode == InferenceMode.LIVE:
return None
storage_dir = os.environ.get("LLAMA_STACK_TEST_RECORDING_DIR", DEFAULT_STORAGE_DIR)
return inference_recording(mode=mode, storage_dir=storage_dir)
result[k] = _normalize_file_ids(v)
return result
elif isinstance(obj, list):
return [_normalize_file_ids(item) for item in obj]
elif isinstance(obj, str):
# Replace file-<uuid> patterns in strings (like in text content)
return re.sub(r"file-[a-f0-9]{32}", "file-NORMALIZED", obj)
else:
return obj
def _normalize_response_data(data: dict[str, Any], request_hash: str) -> dict[str, Any]:
@ -230,11 +106,184 @@ def _normalize_response_data(data: dict[str, Any], request_hash: str) -> dict[st
return data
def _serialize_response(response: Any, request_hash: str = "") -> Any:
def normalize_request(method: str, url: str, headers: dict[str, Any], body: dict[str, Any]) -> str:
"""Create a normalized hash of the request for consistent matching.
Includes test_id from context to ensure test isolation - identical requests
from different tests will have different hashes.
"""
# Extract just the endpoint path
from urllib.parse import urlparse
parsed = urlparse(url)
# Normalize file IDs in the body to ensure consistent hashing across test runs
normalized_body = _normalize_file_ids(body)
normalized: dict[str, Any] = {"method": method.upper(), "endpoint": parsed.path, "body": normalized_body}
# Include test_id for isolation, except for shared infrastructure endpoints
if parsed.path not in ("/api/tags", "/v1/models"):
# Server mode: test ID was synced from provider_data to _test_context
# by _sync_test_context_from_provider_data() at the start of the request.
# We read from _test_context because it's available in all contexts (including
# when making outgoing API calls), whereas PROVIDER_DATA_VAR is only set
# for incoming HTTP requests.
#
# Library client mode: test ID in same-process ContextVar
test_id = _test_context.get()
normalized["test_id"] = test_id
# Create hash - sort_keys=True ensures deterministic ordering
normalized_json = json.dumps(normalized, sort_keys=True)
request_hash = hashlib.sha256(normalized_json.encode()).hexdigest()
return request_hash
def normalize_tool_request(provider_name: str, tool_name: str, kwargs: dict[str, Any]) -> str:
"""Create a normalized hash of the tool request for consistent matching."""
normalized = {"provider": provider_name, "tool_name": tool_name, "kwargs": kwargs}
# Create hash - sort_keys=True ensures deterministic ordering
normalized_json = json.dumps(normalized, sort_keys=True)
return hashlib.sha256(normalized_json.encode()).hexdigest()
@contextmanager
def set_test_context(test_id: str) -> Generator[None, None, None]:
"""Set the test context for recording isolation.
Usage:
with set_test_context("test_basic_completion"):
# Make API calls that will be recorded with this test_id
response = client.chat.completions.create(...)
"""
token = _test_context.set(test_id)
try:
yield
finally:
_test_context.reset(token)
def patch_httpx_for_test_id():
"""Patch client _prepare_request methods to inject test ID into provider data header.
Patches both LlamaStackClient and OpenAI client to ensure test ID is transported
from client to server via HTTP headers in server mode.
This is needed for server mode where the test ID must be transported from
client to server via HTTP headers. In library_client mode, this patch is a no-op
since everything runs in the same process.
We use the _prepare_request hook that Stainless clients provide for mutating
requests after construction but before sending.
"""
from llama_stack_client import LlamaStackClient
if "llama_stack_client_prepare_request" in _original_methods:
# Already patched
return
# Save original methods
_original_methods["llama_stack_client_prepare_request"] = LlamaStackClient._prepare_request
# Also patch OpenAI client if available (used in compat tests)
try:
from openai import OpenAI
_original_methods["openai_prepare_request"] = OpenAI._prepare_request
except ImportError:
pass
def patched_prepare_request(self, request):
# Call original first (it's a sync method that returns None)
# Determine which original to call based on client type
if "llama_stack_client" in self.__class__.__module__:
_original_methods["llama_stack_client_prepare_request"](self, request)
elif "openai_prepare_request" in _original_methods:
_original_methods["openai_prepare_request"](self, request)
# Only inject test ID in server mode
stack_config_type = os.environ.get("LLAMA_STACK_TEST_STACK_CONFIG_TYPE", "library_client")
test_id = _test_context.get()
if stack_config_type == "server" and test_id:
# Get existing provider data header or create new dict
provider_data_header = request.headers.get("X-LlamaStack-Provider-Data")
if provider_data_header:
provider_data = json.loads(provider_data_header)
else:
provider_data = {}
# Inject test ID
provider_data["__test_id"] = test_id
request.headers["X-LlamaStack-Provider-Data"] = json.dumps(provider_data)
# Sync version returns None
return None
# Apply patches
LlamaStackClient._prepare_request = patched_prepare_request
if "openai_prepare_request" in _original_methods:
from openai import OpenAI
OpenAI._prepare_request = patched_prepare_request
def unpatch_httpx_for_test_id():
"""Remove client _prepare_request patches for test ID injection."""
if "llama_stack_client_prepare_request" not in _original_methods:
return
from llama_stack_client import LlamaStackClient
LlamaStackClient._prepare_request = _original_methods["llama_stack_client_prepare_request"]
del _original_methods["llama_stack_client_prepare_request"]
# Also restore OpenAI client if it was patched
if "openai_prepare_request" in _original_methods:
from openai import OpenAI
OpenAI._prepare_request = _original_methods["openai_prepare_request"]
del _original_methods["openai_prepare_request"]
def get_api_recording_mode() -> APIRecordingMode:
return APIRecordingMode(os.environ.get("LLAMA_STACK_TEST_INFERENCE_MODE", "replay").lower())
def setup_api_recording():
"""
Returns a context manager that can be used to record or replay API requests (inference and tools).
This is to be used in tests to increase their reliability and reduce reliance on expensive, external services.
Currently supports:
- Inference: OpenAI, Ollama, and LiteLLM clients
- Tools: Search providers (Tavily for now)
Two environment variables are supported:
- LLAMA_STACK_TEST_INFERENCE_MODE: The mode to run in. Options:
- 'live': Make real API calls, no recording
- 'record': Record all API interactions (overwrites existing)
- 'replay': Use recorded responses only (default)
- 'record-if-missing': Replay when possible, record when recording doesn't exist
- LLAMA_STACK_TEST_RECORDING_DIR: The directory to store the recordings in. Default is 'tests/integration/recordings'.
The recordings are stored as JSON files.
"""
mode = get_api_recording_mode()
if mode == APIRecordingMode.LIVE:
return None
storage_dir = os.environ.get("LLAMA_STACK_TEST_RECORDING_DIR", DEFAULT_STORAGE_DIR)
return api_recording(mode=mode, storage_dir=storage_dir)
def _serialize_response(response: Any) -> Any:
if hasattr(response, "model_dump"):
data = response.model_dump(mode="json")
# Normalize fields to reduce noise
data = _normalize_response_data(data, request_hash)
return {
"__type__": f"{response.__class__.__module__}.{response.__class__.__qualname__}",
"__data__": data,
@ -259,22 +308,17 @@ def _deserialize_response(data: dict[str, Any]) -> Any:
return cls.model_validate(data["__data__"])
except (ImportError, AttributeError, TypeError, ValueError) as e:
logger.warning(f"Failed to deserialize object of type {data['__type__']} with model_validate: {e}")
try:
return cls.model_construct(**data["__data__"])
except Exception as e:
logger.warning(f"Failed to deserialize object of type {data['__type__']} with model_construct: {e}")
return data["__data__"]
logger.warning(f"Failed to deserialize object of type {data['__type__']}: {e}")
return data["__data__"]
return data
class ResponseStorage:
"""Handles SQLite index + JSON file storage/retrieval for inference recordings."""
"""Handles storage/retrieval for API recordings (inference and tools)."""
def __init__(self, base_dir: Path):
self.base_dir = base_dir
# Don't create responses_dir here - determine it per-test at runtime
def _get_test_dir(self) -> Path:
"""Get the recordings directory in the test file's parent directory.
@ -283,6 +327,7 @@ class ResponseStorage:
returns "tests/integration/inference/recordings/".
"""
test_id = _test_context.get()
if test_id:
# Extract the directory path from the test nodeid
# e.g., "tests/integration/inference/test_basic.py::test_foo[params]"
@ -297,17 +342,21 @@ class ResponseStorage:
# Fallback for non-test contexts
return self.base_dir / "recordings"
def _ensure_directories(self):
"""Ensure test-specific directories exist."""
def _ensure_directories(self) -> Path:
test_dir = self._get_test_dir()
test_dir.mkdir(parents=True, exist_ok=True)
return test_dir
def store_recording(self, request_hash: str, request: dict[str, Any], response: dict[str, Any]):
"""Store a request/response pair."""
"""Store a request/response pair both in memory cache and on disk."""
global _memory_cache
# Store in memory cache first
_memory_cache[request_hash] = {"request": request, "response": response}
responses_dir = self._ensure_directories()
# Use FULL hash (not truncated)
# Generate unique response filename using full hash
response_file = f"{request_hash}.json"
# Serialize response body if needed
@ -315,45 +364,32 @@ class ResponseStorage:
if "body" in serialized_response:
if isinstance(serialized_response["body"], list):
# Handle streaming responses (list of chunks)
serialized_response["body"] = [
_serialize_response(chunk, request_hash) for chunk in serialized_response["body"]
]
serialized_response["body"] = [_serialize_response(chunk) for chunk in serialized_response["body"]]
else:
# Handle single response
serialized_response["body"] = _serialize_response(serialized_response["body"], request_hash)
serialized_response["body"] = _serialize_response(serialized_response["body"])
# For model-list endpoints, include digest in filename to distinguish different model sets
# If this is a model-list endpoint recording, include models digest in filename to distinguish variants
endpoint = request.get("endpoint")
test_id = _test_context.get()
if endpoint in ("/api/tags", "/v1/models"):
test_id = None
digest = _model_identifiers_digest(endpoint, response)
response_file = f"models-{request_hash}-{digest}.json"
response_path = responses_dir / response_file
# Save response to JSON file with metadata
# Save response to JSON file
with open(response_path, "w") as f:
json.dump(
{
"test_id": _test_context.get(),
"request": request,
"response": serialized_response,
},
f,
indent=2,
)
json.dump({"test_id": test_id, "request": request, "response": serialized_response}, f, indent=2)
f.write("\n")
f.flush()
def find_recording(self, request_hash: str) -> dict[str, Any] | None:
"""Find a recorded response by request hash.
Uses fallback: first checks test-specific dir, then falls back to base recordings dir.
This handles cases where recordings happen during session setup (no test context) but
are requested during tests (with test context).
"""
"""Find a recorded response by request hash."""
response_file = f"{request_hash}.json"
# Try test-specific directory first
# Check test-specific directory first
test_dir = self._get_test_dir()
response_path = test_dir / response_file
@ -464,15 +500,97 @@ def _combine_model_list_responses(endpoint: str, records: list[dict[str, Any]])
return {"request": canonical_req, "response": {"body": body, "is_streaming": False}}
async def _patched_tool_invoke_method(
original_method, provider_name: str, self, tool_name: str, kwargs: dict[str, Any]
):
"""Patched version of tool runtime invoke_tool method for recording/replay."""
global _current_mode, _current_storage
if _current_mode == APIRecordingMode.LIVE or _current_storage is None:
# Normal operation
return await original_method(self, tool_name, kwargs)
# In server mode, sync test ID from provider_data to _test_context for storage operations
test_context_token = _sync_test_context_from_provider_data()
try:
request_hash = normalize_tool_request(provider_name, tool_name, kwargs)
if _current_mode in (APIRecordingMode.REPLAY, APIRecordingMode.RECORD_IF_MISSING):
recording = _current_storage.find_recording(request_hash)
if recording:
return recording["response"]["body"]
elif _current_mode == APIRecordingMode.REPLAY:
raise RuntimeError(
f"No recorded tool result found for {provider_name}.{tool_name}\n"
f"Request: {kwargs}\n"
f"To record this response, run with LLAMA_STACK_TEST_INFERENCE_MODE=record"
)
# If RECORD_IF_MISSING and no recording found, fall through to record
if _current_mode in (APIRecordingMode.RECORD, APIRecordingMode.RECORD_IF_MISSING):
# Check in-memory cache first (collision detection)
global _memory_cache
if request_hash in _memory_cache:
# Return the cached response instead of making a new tool call
return _memory_cache[request_hash]["response"]["body"]
# No cached response, make the tool call and record it
result = await original_method(self, tool_name, kwargs)
request_data = {
"provider": provider_name,
"tool_name": tool_name,
"kwargs": kwargs,
}
response_data = {"body": result, "is_streaming": False}
# Store the recording (both in memory and on disk)
_current_storage.store_recording(request_hash, request_data, response_data)
return result
else:
raise AssertionError(f"Invalid mode: {_current_mode}")
finally:
# Reset test context if we set it in server mode
if test_context_token is not None:
_test_context.reset(test_context_token)
def _sync_test_context_from_provider_data():
"""In server mode, sync test ID from provider_data to _test_context.
This ensures that storage operations (which read from _test_context) work correctly
in server mode where the test ID arrives via HTTP header provider_data.
Returns a token to reset _test_context, or None if no sync was needed.
"""
stack_config_type = os.environ.get("LLAMA_STACK_TEST_STACK_CONFIG_TYPE", "library_client")
if stack_config_type != "server":
return None
try:
from llama_stack.core.request_headers import PROVIDER_DATA_VAR
provider_data = PROVIDER_DATA_VAR.get()
if provider_data and "__test_id" in provider_data:
test_id = provider_data["__test_id"]
return _test_context.set(test_id)
except ImportError:
pass
return None
async def _patched_inference_method(original_method, self, client_type, endpoint, *args, **kwargs):
global _current_mode, _current_storage
mode = _current_mode
storage = _current_storage
if mode == InferenceMode.LIVE or storage is None:
if endpoint == "/v1/models":
return original_method(self, *args, **kwargs)
if _current_mode == APIRecordingMode.LIVE or _current_storage is None:
# Normal operation
if client_type == "litellm":
return await original_method(*args, **kwargs)
else:
return await original_method(self, *args, **kwargs)
@ -491,34 +609,30 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
base_url = getattr(self, "host", "http://localhost:11434")
if not base_url.startswith("http"):
base_url = f"http://{base_url}"
elif client_type == "litellm":
# For LiteLLM, extract base URL from kwargs if available
base_url = kwargs.get("api_base", "https://api.openai.com")
else:
raise ValueError(f"Unknown client type: {client_type}")
url = base_url.rstrip("/") + endpoint
# Special handling for Databricks URLs to avoid leaking workspace info
# e.g. https://adb-1234567890123456.7.cloud.databricks.com -> https://...cloud.databricks.com
if "cloud.databricks.com" in url:
url = "__databricks__" + url.split("cloud.databricks.com")[-1]
method = "POST"
headers = {}
body = kwargs
request_hash = normalize_request(method, url, headers, body)
# Try to find existing recording for REPLAY or RECORD_IF_MISSING modes
recording = None
if mode == InferenceMode.REPLAY or mode == InferenceMode.RECORD_IF_MISSING:
# Special handling for model-list endpoints: merge all recordings with this hash
if _current_mode in (APIRecordingMode.REPLAY, APIRecordingMode.RECORD_IF_MISSING):
# Special handling for model-list endpoints: return union of all responses
if endpoint in ("/api/tags", "/v1/models"):
records = storage._model_list_responses(request_hash)
records = _current_storage._model_list_responses(request_hash)
recording = _combine_model_list_responses(endpoint, records)
else:
recording = storage.find_recording(request_hash)
recording = _current_storage.find_recording(request_hash)
if recording:
response_body = recording["response"]["body"]
if recording["response"].get("is_streaming", False):
if recording["response"].get("is_streaming", False) or recording["response"].get("is_paginated", False):
async def replay_stream():
for chunk in response_body:
@ -527,25 +641,41 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
return replay_stream()
else:
return response_body
elif mode == InferenceMode.REPLAY:
# REPLAY mode requires recording to exist
elif _current_mode == APIRecordingMode.REPLAY:
raise RuntimeError(
f"No recorded response found for request hash: {request_hash}\n"
f"Request: {method} {url} {body}\n"
f"Model: {body.get('model', 'unknown')}\n"
f"To record this response, run with LLAMA_STACK_TEST_INFERENCE_MODE=record"
)
# If RECORD_IF_MISSING and no recording found, fall through to record
if mode == InferenceMode.RECORD or (mode == InferenceMode.RECORD_IF_MISSING and not recording):
if endpoint == "/v1/models":
response = original_method(self, *args, **kwargs)
if _current_mode in (APIRecordingMode.RECORD, APIRecordingMode.RECORD_IF_MISSING):
# Check in-memory cache first (collision detection)
global _memory_cache
if request_hash in _memory_cache:
# Return the cached response instead of making a new API call
cached_recording = _memory_cache[request_hash]
response_body = cached_recording["response"]["body"]
if cached_recording["response"].get("is_streaming", False) or cached_recording["response"].get(
"is_paginated", False
):
async def replay_cached_stream():
for chunk in response_body:
yield chunk
return replay_cached_stream()
else:
return response_body
# No cached response, make the API call and record it
if client_type == "litellm":
response = await original_method(*args, **kwargs)
else:
response = await original_method(self, *args, **kwargs)
# we want to store the result of the iterator, not the iterator itself
if endpoint == "/v1/models":
response = [m async for m in response]
request_data = {
"method": method,
"url": url,
@ -558,16 +688,20 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
# Determine if this is a streaming request based on request parameters
is_streaming = body.get("stream", False)
if is_streaming:
# For streaming responses, we need to collect all chunks immediately before yielding
# Special case: /v1/models is a paginated endpoint that returns an async iterator
is_paginated = endpoint == "/v1/models"
if is_streaming or is_paginated:
# For streaming/paginated responses, we need to collect all chunks immediately before yielding
# This ensures the recording is saved even if the generator isn't fully consumed
chunks = []
async for chunk in response:
chunks.append(chunk)
# Store the recording immediately
response_data = {"body": chunks, "is_streaming": True}
storage.store_recording(request_hash, request_data, response_data)
# Store the recording immediately (both in memory and on disk)
# For paginated endpoints, mark as paginated rather than streaming
response_data = {"body": chunks, "is_streaming": is_streaming, "is_paginated": is_paginated}
_current_storage.store_recording(request_hash, request_data, response_data)
# Return a generator that replays the stored chunks
async def replay_recorded_stream():
@ -577,27 +711,34 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
return replay_recorded_stream()
else:
response_data = {"body": response, "is_streaming": False}
storage.store_recording(request_hash, request_data, response_data)
# Store the response (both in memory and on disk)
_current_storage.store_recording(request_hash, request_data, response_data)
return response
else:
raise AssertionError(f"Invalid mode: {mode}")
raise AssertionError(f"Invalid mode: {_current_mode}")
finally:
if test_context_token:
# Reset test context if we set it in server mode
if test_context_token is not None:
_test_context.reset(test_context_token)
def patch_inference_clients():
"""Install monkey patches for OpenAI client methods and Ollama AsyncClient methods."""
def patch_api_clients():
"""Install monkey patches for inference clients and tool runtime methods."""
global _original_methods
import litellm
from ollama import AsyncClient as OllamaAsyncClient
from openai.resources.chat.completions import AsyncCompletions as AsyncChatCompletions
from openai.resources.completions import AsyncCompletions
from openai.resources.embeddings import AsyncEmbeddings
from openai.resources.models import AsyncModels
# Store original methods for both OpenAI and Ollama clients
from llama_stack.providers.remote.tool_runtime.tavily_search.tavily_search import TavilySearchToolRuntimeImpl
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
# Store original methods for OpenAI, Ollama, LiteLLM clients, and tool runtimes
_original_methods = {
"chat_completions_create": AsyncChatCompletions.create,
"completions_create": AsyncCompletions.create,
@ -609,6 +750,10 @@ def patch_inference_clients():
"ollama_ps": OllamaAsyncClient.ps,
"ollama_pull": OllamaAsyncClient.pull,
"ollama_list": OllamaAsyncClient.list,
"litellm_acompletion": litellm.acompletion,
"litellm_atext_completion": litellm.atext_completion,
"litellm_openai_mixin_get_api_key": LiteLLMOpenAIMixin.get_api_key,
"tavily_invoke_tool": TavilySearchToolRuntimeImpl.invoke_tool,
}
# Create patched methods for OpenAI client
@ -629,10 +774,18 @@ def patch_inference_clients():
def patched_models_list(self, *args, **kwargs):
async def _iter():
for item in await _patched_inference_method(
result = await _patched_inference_method(
_original_methods["models_list"], self, "openai", "/v1/models", *args, **kwargs
):
yield item
)
# The result is either an async generator (streaming/paginated) or a list
# If it's an async generator, iterate through it
if hasattr(result, "__aiter__"):
async for item in result:
yield item
else:
# It's a list, yield each item
for item in result:
yield item
return _iter()
@ -681,21 +834,61 @@ def patch_inference_clients():
OllamaAsyncClient.pull = patched_ollama_pull
OllamaAsyncClient.list = patched_ollama_list
# Create patched methods for LiteLLM
async def patched_litellm_acompletion(*args, **kwargs):
return await _patched_inference_method(
_original_methods["litellm_acompletion"], None, "litellm", "/chat/completions", *args, **kwargs
)
def unpatch_inference_clients():
"""Remove monkey patches and restore original OpenAI and Ollama client methods."""
global _original_methods
async def patched_litellm_atext_completion(*args, **kwargs):
return await _patched_inference_method(
_original_methods["litellm_atext_completion"], None, "litellm", "/completions", *args, **kwargs
)
# Apply LiteLLM patches
litellm.acompletion = patched_litellm_acompletion
litellm.atext_completion = patched_litellm_atext_completion
# Create patched method for LiteLLMOpenAIMixin.get_api_key
def patched_litellm_get_api_key(self):
global _current_mode
if _current_mode != APIRecordingMode.REPLAY:
return _original_methods["litellm_openai_mixin_get_api_key"](self)
else:
# For record/replay modes, return a fake API key to avoid exposing real credentials
return "fake-api-key-for-testing"
# Apply LiteLLMOpenAIMixin patch
LiteLLMOpenAIMixin.get_api_key = patched_litellm_get_api_key
# Create patched methods for tool runtimes
async def patched_tavily_invoke_tool(self, tool_name: str, kwargs: dict[str, Any]):
return await _patched_tool_invoke_method(
_original_methods["tavily_invoke_tool"], "tavily", self, tool_name, kwargs
)
# Apply tool runtime patches
TavilySearchToolRuntimeImpl.invoke_tool = patched_tavily_invoke_tool
def unpatch_api_clients():
"""Remove monkey patches and restore original client methods and tool runtimes."""
global _original_methods, _memory_cache
if not _original_methods:
return
# Import here to avoid circular imports
import litellm
from ollama import AsyncClient as OllamaAsyncClient
from openai.resources.chat.completions import AsyncCompletions as AsyncChatCompletions
from openai.resources.completions import AsyncCompletions
from openai.resources.embeddings import AsyncEmbeddings
from openai.resources.models import AsyncModels
from llama_stack.providers.remote.tool_runtime.tavily_search.tavily_search import TavilySearchToolRuntimeImpl
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
# Restore OpenAI client methods
AsyncChatCompletions.create = _original_methods["chat_completions_create"]
AsyncCompletions.create = _original_methods["completions_create"]
@ -710,12 +903,23 @@ def unpatch_inference_clients():
OllamaAsyncClient.pull = _original_methods["ollama_pull"]
OllamaAsyncClient.list = _original_methods["ollama_list"]
# Restore LiteLLM methods
litellm.acompletion = _original_methods["litellm_acompletion"]
litellm.atext_completion = _original_methods["litellm_atext_completion"]
LiteLLMOpenAIMixin.get_api_key = _original_methods["litellm_openai_mixin_get_api_key"]
# Restore tool runtime methods
TavilySearchToolRuntimeImpl.invoke_tool = _original_methods["tavily_invoke_tool"]
_original_methods.clear()
# Clear memory cache to prevent memory leaks
_memory_cache.clear()
@contextmanager
def inference_recording(mode: str, storage_dir: str | Path | None = None) -> Generator[None, None, None]:
"""Context manager for inference recording/replaying."""
def api_recording(mode: str, storage_dir: str | Path | None = None) -> Generator[None, None, None]:
"""Context manager for API recording/replaying (inference and tools)."""
global _current_mode, _current_storage
# Store previous state
@ -729,14 +933,14 @@ def inference_recording(mode: str, storage_dir: str | Path | None = None) -> Gen
if storage_dir is None:
raise ValueError("storage_dir is required for record, replay, and record-if-missing modes")
_current_storage = ResponseStorage(Path(storage_dir))
patch_inference_clients()
patch_api_clients()
yield
finally:
# Restore previous state
if mode in ["record", "replay", "record-if-missing"]:
unpatch_inference_clients()
unpatch_api_clients()
_current_mode = prev_mode
_current_storage = prev_storage

View file

@ -29,7 +29,7 @@ Options:
--stack-config STRING Stack configuration to use (required)
--suite STRING Test suite to run (default: 'base')
--setup STRING Test setup (models, env) to use (e.g., 'ollama', 'ollama-vision', 'gpt', 'vllm')
--inference-mode STRING Inference mode: record or replay (default: replay)
--inference-mode STRING Inference mode: replay, record-if-missing or record (default: replay)
--subdirs STRING Comma-separated list of test subdirectories to run (overrides suite)
--pattern STRING Regex pattern to pass to pytest -k
--help Show this help message
@ -102,7 +102,7 @@ if [[ -z "$STACK_CONFIG" ]]; then
fi
if [[ -z "$TEST_SETUP" && -n "$TEST_SUBDIRS" ]]; then
echo "Error: --test-setup is required when --test-subdirs is provided"
echo "Error: --test-setup is required when --test-subdirs is not provided"
usage
exit 1
fi

View file

@ -159,7 +159,6 @@ def make_mcp_server(required_auth_token: str | None = None, tools: dict[str, Cal
import threading
import time
import httpx
import uvicorn
from mcp.server.fastmcp import FastMCP
from mcp.server.sse import SseServerTransport
@ -171,6 +170,11 @@ def make_mcp_server(required_auth_token: str | None = None, tools: dict[str, Cal
server = FastMCP("FastMCP Test Server", log_level="WARNING")
# Silence verbose MCP server logs
import logging # allow-direct-logging
logging.getLogger("mcp.server.lowlevel.server").setLevel(logging.WARNING)
tools = tools or default_tools()
# Register all tools with the server
@ -234,29 +238,25 @@ def make_mcp_server(required_auth_token: str | None = None, tools: dict[str, Cal
logger.debug(f"Starting MCP server thread on port {port}")
server_thread.start()
# Polling until the server is ready
timeout = 10
# Wait for the server thread to be running
# Note: We can't use a simple HTTP GET health check on /sse because it's an SSE endpoint
# that expects a long-lived connection, not a simple request/response
timeout = 2
start_time = time.time()
server_url = f"http://localhost:{port}/sse"
logger.debug(f"Waiting for MCP server to be ready at {server_url}")
logger.debug(f"Waiting for MCP server thread to start on port {port}")
while time.time() - start_time < timeout:
try:
response = httpx.get(server_url)
if response.status_code in [200, 401]:
logger.debug(f"MCP server is ready on port {port} (status: {response.status_code})")
break
except httpx.RequestError as e:
logger.debug(f"Server not ready yet, retrying... ({e})")
pass
time.sleep(0.1)
if server_thread.is_alive():
# Give the server a moment to bind to the port
time.sleep(0.1)
logger.debug(f"MCP server is ready on port {port}")
break
time.sleep(0.05)
else:
# If we exit the loop due to timeout
logger.error(f"MCP server failed to start within {timeout} seconds on port {port}")
logger.error(f"Thread alive: {server_thread.is_alive()}")
if server_thread.is_alive():
logger.error("Server thread is still running but not responding to HTTP requests")
logger.error(f"MCP server thread failed to start within {timeout} seconds on port {port}")
try:
yield {"server_url": server_url}

View file

@ -6,6 +6,7 @@
import inspect
import itertools
import os
import tempfile
import textwrap
import time
from pathlib import Path
@ -14,6 +15,7 @@ import pytest
from dotenv import load_dotenv
from llama_stack.log import get_logger
from llama_stack.testing.api_recorder import patch_httpx_for_test_id
from .suites import SETUP_DEFINITIONS, SUITE_DEFINITIONS
@ -35,6 +37,10 @@ def pytest_sessionstart(session):
if "LLAMA_STACK_TEST_INFERENCE_MODE" not in os.environ:
os.environ["LLAMA_STACK_TEST_INFERENCE_MODE"] = "replay"
if "SQLITE_STORE_DIR" not in os.environ:
os.environ["SQLITE_STORE_DIR"] = tempfile.mkdtemp()
# Set test stack config type for api_recorder test isolation
stack_config = session.config.getoption("--stack-config", default=None)
if stack_config and stack_config.startswith("server:"):
os.environ["LLAMA_STACK_TEST_STACK_CONFIG_TYPE"] = "server"
@ -43,8 +49,6 @@ def pytest_sessionstart(session):
os.environ["LLAMA_STACK_TEST_STACK_CONFIG_TYPE"] = "library_client"
logger.info(f"Test stack config type: library_client (stack_config={stack_config})")
from llama_stack.testing.inference_recorder import patch_httpx_for_test_id
patch_httpx_for_test_id()
@ -55,7 +59,7 @@ def _track_test_context(request):
This fixture runs for every test and stores the test's nodeid in a contextvar
that the recording system can access to determine which subdirectory to use.
"""
from llama_stack.testing.inference_recorder import _test_context
from llama_stack.testing.api_recorder import _test_context
# Store the test nodeid (e.g., "tests/integration/responses/test_basic.py::test_foo[params]")
token = _test_context.set(request.node.nodeid)
@ -121,9 +125,13 @@ def pytest_configure(config):
# Apply defaults if not provided explicitly
for dest, value in setup_obj.defaults.items():
current = getattr(config.option, dest, None)
if not current:
if current is None:
setattr(config.option, dest, value)
# Apply global fallback for embedding_dimension if still not set
if getattr(config.option, "embedding_dimension", None) is None:
config.option.embedding_dimension = 384
def pytest_addoption(parser):
parser.addoption(
@ -161,8 +169,8 @@ def pytest_addoption(parser):
parser.addoption(
"--embedding-dimension",
type=int,
default=384,
help="Output dimensionality of the embedding model to use for testing. Default: 384",
default=None,
help="Output dimensionality of the embedding model to use for testing. Default: 384 (or setup-specific)",
)
parser.addoption(
@ -236,7 +244,9 @@ def pytest_generate_tests(metafunc):
continue
params.append(fixture_name)
val = metafunc.config.getoption(option)
# Use getattr on config.option to see values set by pytest_configure fallbacks
dest = option.lstrip("-").replace("-", "_")
val = getattr(metafunc.config.option, dest, None)
values = [v.strip() for v in str(val).split(",")] if val else [None]
param_values[fixture_name] = values

View file

@ -183,6 +183,12 @@ def llama_stack_client(request):
# would be forced to use llama_stack_client, which is not what we want.
print("\ninstantiating llama_stack_client")
start_time = time.time()
# Patch httpx to inject test ID for server-mode test isolation
from llama_stack.testing.api_recorder import patch_httpx_for_test_id
patch_httpx_for_test_id()
client = instantiate_llama_stack_client(request.session)
print(f"llama_stack_client instantiated in {time.time() - start_time:.3f}s")
return client

View file

@ -7,7 +7,7 @@
import time
def new_vector_store(openai_client, name):
def new_vector_store(openai_client, name, embedding_model, embedding_dimension):
"""Create a new vector store, cleaning up any existing one with the same name."""
# Ensure we don't reuse an existing vector store
vector_stores = openai_client.vector_stores.list()
@ -16,7 +16,21 @@ def new_vector_store(openai_client, name):
openai_client.vector_stores.delete(vector_store_id=vector_store.id)
# Create a new vector store
vector_store = openai_client.vector_stores.create(name=name)
# OpenAI SDK client uses extra_body for non-standard parameters
from openai import OpenAI
if isinstance(openai_client, OpenAI):
# OpenAI SDK client - use extra_body
vector_store = openai_client.vector_stores.create(
name=name,
extra_body={"embedding_model": embedding_model, "embedding_dimension": embedding_dimension},
)
else:
# LlamaStack client - direct parameter
vector_store = openai_client.vector_stores.create(
name=name, embedding_model=embedding_model, embedding_dimension=embedding_dimension
)
return vector_store

View file

@ -16,6 +16,7 @@ import pytest
from llama_stack_client import APIStatusError
@pytest.mark.xfail(reason="Shields are not yet implemented inside responses")
def test_shields_via_extra_body(compat_client, text_model_id):
"""Test that shields parameter is received by the server and raises NotImplementedError."""

View file

@ -47,12 +47,14 @@ def test_response_text_format(compat_client, text_model_id, text_format):
@pytest.fixture
def vector_store_with_filtered_files(compat_client, text_model_id, tmp_path_factory):
"""Create a vector store with multiple files that have different attributes for filtering tests."""
def vector_store_with_filtered_files(compat_client, embedding_model_id, embedding_dimension, tmp_path_factory):
# """Create a vector store with multiple files that have different attributes for filtering tests."""
if isinstance(compat_client, LlamaStackAsLibraryClient):
pytest.skip("Responses API file search is not yet supported in library client.")
pytest.skip("upload_file() is not yet supported in library client somehow?")
vector_store = new_vector_store(compat_client, "test_vector_store_with_filters")
vector_store = new_vector_store(
compat_client, "test_vector_store_with_filters", embedding_model_id, embedding_dimension
)
tmp_path = tmp_path_factory.mktemp("filter_test_files")
# Create multiple files with different attributes

View file

@ -46,11 +46,13 @@ def test_response_non_streaming_web_search(compat_client, text_model_id, case):
@pytest.mark.parametrize("case", file_search_test_cases)
def test_response_non_streaming_file_search(compat_client, text_model_id, tmp_path, case):
def test_response_non_streaming_file_search(
compat_client, text_model_id, embedding_model_id, embedding_dimension, tmp_path, case
):
if isinstance(compat_client, LlamaStackAsLibraryClient):
pytest.skip("Responses API file search is not yet supported in library client.")
vector_store = new_vector_store(compat_client, "test_vector_store")
vector_store = new_vector_store(compat_client, "test_vector_store", embedding_model_id, embedding_dimension)
if case.file_content:
file_name = "test_response_non_streaming_file_search.txt"
@ -101,11 +103,13 @@ def test_response_non_streaming_file_search(compat_client, text_model_id, tmp_pa
assert case.expected.lower() in response.output_text.lower().strip()
def test_response_non_streaming_file_search_empty_vector_store(compat_client, text_model_id):
def test_response_non_streaming_file_search_empty_vector_store(
compat_client, text_model_id, embedding_model_id, embedding_dimension
):
if isinstance(compat_client, LlamaStackAsLibraryClient):
pytest.skip("Responses API file search is not yet supported in library client.")
vector_store = new_vector_store(compat_client, "test_vector_store")
vector_store = new_vector_store(compat_client, "test_vector_store", embedding_model_id, embedding_dimension)
# Create the response request, which should query our vector store
response = compat_client.responses.create(
@ -127,12 +131,14 @@ def test_response_non_streaming_file_search_empty_vector_store(compat_client, te
assert response.output_text
def test_response_sequential_file_search(compat_client, text_model_id, tmp_path):
def test_response_sequential_file_search(
compat_client, text_model_id, embedding_model_id, embedding_dimension, tmp_path
):
"""Test file search with sequential responses using previous_response_id."""
if isinstance(compat_client, LlamaStackAsLibraryClient):
pytest.skip("Responses API file search is not yet supported in library client.")
vector_store = new_vector_store(compat_client, "test_vector_store")
vector_store = new_vector_store(compat_client, "test_vector_store", embedding_model_id, embedding_dimension)
# Create a test file with content
file_content = "The Llama 4 Maverick model has 128 experts in its mixture of experts architecture."

View file

@ -39,7 +39,7 @@ class Setup(BaseModel):
name: str
description: str
defaults: dict[str, str] = Field(default_factory=dict)
defaults: dict[str, str | int] = Field(default_factory=dict)
env: dict[str, str] = Field(default_factory=dict)
@ -88,6 +88,7 @@ SETUP_DEFINITIONS: dict[str, Setup] = {
defaults={
"text_model": "openai/gpt-4o",
"embedding_model": "openai/text-embedding-3-small",
"embedding_dimension": 1536,
},
),
"tgi": Setup(

View file

@ -0,0 +1,273 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import tempfile
from pathlib import Path
from unittest.mock import patch
import pytest
from openai import AsyncOpenAI
# Import the real Pydantic response types instead of using Mocks
from llama_stack.apis.inference import (
OpenAIAssistantMessageParam,
OpenAIChatCompletion,
OpenAIChoice,
OpenAIEmbeddingData,
OpenAIEmbeddingsResponse,
OpenAIEmbeddingUsage,
)
from llama_stack.testing.api_recorder import (
APIRecordingMode,
ResponseStorage,
api_recording,
normalize_request,
)
@pytest.fixture
def temp_storage_dir():
"""Create a temporary directory for test recordings."""
with tempfile.TemporaryDirectory() as temp_dir:
yield Path(temp_dir)
@pytest.fixture
def real_openai_chat_response():
"""Real OpenAI chat completion response using proper Pydantic objects."""
return OpenAIChatCompletion(
id="chatcmpl-test123",
choices=[
OpenAIChoice(
index=0,
message=OpenAIAssistantMessageParam(
role="assistant", content="Hello! I'm doing well, thank you for asking."
),
finish_reason="stop",
)
],
created=1234567890,
model="llama3.2:3b",
)
@pytest.fixture
def real_embeddings_response():
"""Real OpenAI embeddings response using proper Pydantic objects."""
return OpenAIEmbeddingsResponse(
object="list",
data=[
OpenAIEmbeddingData(object="embedding", embedding=[0.1, 0.2, 0.3], index=0),
OpenAIEmbeddingData(object="embedding", embedding=[0.4, 0.5, 0.6], index=1),
],
model="nomic-embed-text",
usage=OpenAIEmbeddingUsage(prompt_tokens=6, total_tokens=6),
)
class TestInferenceRecording:
"""Test the inference recording system."""
def test_request_normalization(self):
"""Test that request normalization produces consistent hashes."""
# Test basic normalization
hash1 = normalize_request(
"POST",
"http://localhost:11434/v1/chat/completions",
{},
{"model": "llama3.2:3b", "messages": [{"role": "user", "content": "Hello world"}], "temperature": 0.7},
)
# Same request should produce same hash
hash2 = normalize_request(
"POST",
"http://localhost:11434/v1/chat/completions",
{},
{"model": "llama3.2:3b", "messages": [{"role": "user", "content": "Hello world"}], "temperature": 0.7},
)
assert hash1 == hash2
# Different content should produce different hash
hash3 = normalize_request(
"POST",
"http://localhost:11434/v1/chat/completions",
{},
{
"model": "llama3.2:3b",
"messages": [{"role": "user", "content": "Different message"}],
"temperature": 0.7,
},
)
assert hash1 != hash3
def test_request_normalization_edge_cases(self):
"""Test request normalization is precise about request content."""
# Test that different whitespace produces different hashes (no normalization)
hash1 = normalize_request(
"POST",
"http://test/v1/chat/completions",
{},
{"messages": [{"role": "user", "content": "Hello world\n\n"}]},
)
hash2 = normalize_request(
"POST", "http://test/v1/chat/completions", {}, {"messages": [{"role": "user", "content": "Hello world"}]}
)
assert hash1 != hash2 # Different whitespace should produce different hashes
# Test that different float precision produces different hashes (no rounding)
hash3 = normalize_request("POST", "http://test/v1/chat/completions", {}, {"temperature": 0.7000001})
hash4 = normalize_request("POST", "http://test/v1/chat/completions", {}, {"temperature": 0.7})
assert hash3 != hash4 # Different precision should produce different hashes
def test_response_storage(self, temp_storage_dir):
"""Test the ResponseStorage class."""
temp_storage_dir = temp_storage_dir / "test_response_storage"
storage = ResponseStorage(temp_storage_dir)
# Test storing and retrieving a recording
request_hash = "test_hash_123"
request_data = {
"method": "POST",
"url": "http://localhost:11434/v1/chat/completions",
"endpoint": "/v1/chat/completions",
"model": "llama3.2:3b",
}
response_data = {"body": {"content": "test response"}, "is_streaming": False}
storage.store_recording(request_hash, request_data, response_data)
# Verify file storage and retrieval
retrieved = storage.find_recording(request_hash)
assert retrieved is not None
assert retrieved["request"]["model"] == "llama3.2:3b"
assert retrieved["response"]["body"]["content"] == "test response"
async def test_recording_mode(self, temp_storage_dir, real_openai_chat_response):
"""Test that recording mode captures and stores responses."""
async def mock_create(*args, **kwargs):
return real_openai_chat_response
temp_storage_dir = temp_storage_dir / "test_recording_mode"
with patch("openai.resources.chat.completions.AsyncCompletions.create", side_effect=mock_create):
with api_recording(mode=APIRecordingMode.RECORD, storage_dir=str(temp_storage_dir)):
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
response = await client.chat.completions.create(
model="llama3.2:3b",
messages=[{"role": "user", "content": "Hello, how are you?"}],
temperature=0.7,
max_tokens=50,
)
# Verify the response was returned correctly
assert response.choices[0].message.content == "Hello! I'm doing well, thank you for asking."
# Verify recording was stored
storage = ResponseStorage(temp_storage_dir)
assert storage._get_test_dir().exists()
async def test_replay_mode(self, temp_storage_dir, real_openai_chat_response):
"""Test that replay mode returns stored responses without making real calls."""
async def mock_create(*args, **kwargs):
return real_openai_chat_response
temp_storage_dir = temp_storage_dir / "test_replay_mode"
# First, record a response
with patch("openai.resources.chat.completions.AsyncCompletions.create", side_effect=mock_create):
with api_recording(mode=APIRecordingMode.RECORD, storage_dir=str(temp_storage_dir)):
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
response = await client.chat.completions.create(
model="llama3.2:3b",
messages=[{"role": "user", "content": "Hello, how are you?"}],
temperature=0.7,
max_tokens=50,
)
# Now test replay mode - should not call the original method
with patch("openai.resources.chat.completions.AsyncCompletions.create") as mock_create_patch:
with api_recording(mode=APIRecordingMode.REPLAY, storage_dir=str(temp_storage_dir)):
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
response = await client.chat.completions.create(
model="llama3.2:3b",
messages=[{"role": "user", "content": "Hello, how are you?"}],
temperature=0.7,
max_tokens=50,
)
# Verify we got the recorded response
assert response.choices[0].message.content == "Hello! I'm doing well, thank you for asking."
# Verify the original method was NOT called
mock_create_patch.assert_not_called()
async def test_replay_missing_recording(self, temp_storage_dir):
"""Test that replay mode fails when no recording is found."""
temp_storage_dir = temp_storage_dir / "test_replay_missing_recording"
with patch("openai.resources.chat.completions.AsyncCompletions.create"):
with api_recording(mode=APIRecordingMode.REPLAY, storage_dir=str(temp_storage_dir)):
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
with pytest.raises(RuntimeError, match="No recorded response found"):
await client.chat.completions.create(
model="llama3.2:3b", messages=[{"role": "user", "content": "This was never recorded"}]
)
async def test_embeddings_recording(self, temp_storage_dir, real_embeddings_response):
"""Test recording and replay of embeddings calls."""
async def mock_create(*args, **kwargs):
return real_embeddings_response
temp_storage_dir = temp_storage_dir / "test_embeddings_recording"
# Record
with patch("openai.resources.embeddings.AsyncEmbeddings.create", side_effect=mock_create):
with api_recording(mode=APIRecordingMode.RECORD, storage_dir=str(temp_storage_dir)):
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
response = await client.embeddings.create(
model="nomic-embed-text", input=["Hello world", "Test embedding"]
)
assert len(response.data) == 2
# Replay
with patch("openai.resources.embeddings.AsyncEmbeddings.create") as mock_create_patch:
with api_recording(mode=APIRecordingMode.REPLAY, storage_dir=str(temp_storage_dir)):
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
response = await client.embeddings.create(
model="nomic-embed-text", input=["Hello world", "Test embedding"]
)
# Verify we got the recorded response
assert len(response.data) == 2
assert response.data[0].embedding == [0.1, 0.2, 0.3]
# Verify original method was not called
mock_create_patch.assert_not_called()
async def test_live_mode(self, real_openai_chat_response):
"""Test that live mode passes through to original methods."""
async def mock_create(*args, **kwargs):
return real_openai_chat_response
with patch("openai.resources.chat.completions.AsyncCompletions.create", side_effect=mock_create):
with api_recording(mode=APIRecordingMode.LIVE, storage_dir="foo"):
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
response = await client.chat.completions.create(
model="llama3.2:3b", messages=[{"role": "user", "content": "Hello"}]
)
# Verify the response was returned
assert response.choices[0].message.content == "Hello! I'm doing well, thank you for asking."

View file

@ -1,382 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import tempfile
from pathlib import Path
from unittest.mock import AsyncMock, Mock, patch
import pytest
from openai import NOT_GIVEN, AsyncOpenAI
from openai.types.model import Model as OpenAIModel
# Import the real Pydantic response types instead of using Mocks
from llama_stack.apis.inference import (
OpenAIAssistantMessageParam,
OpenAIChatCompletion,
OpenAIChoice,
OpenAICompletion,
OpenAIEmbeddingData,
OpenAIEmbeddingsResponse,
OpenAIEmbeddingUsage,
)
from llama_stack.testing.inference_recorder import (
InferenceMode,
ResponseStorage,
inference_recording,
normalize_request,
)
@pytest.fixture
def temp_storage_dir():
"""Create a temporary directory for test recordings."""
with tempfile.TemporaryDirectory() as temp_dir:
yield Path(temp_dir)
@pytest.fixture
def real_openai_chat_response():
"""Real OpenAI chat completion response using proper Pydantic objects."""
return OpenAIChatCompletion(
id="chatcmpl-test123",
choices=[
OpenAIChoice(
index=0,
message=OpenAIAssistantMessageParam(
role="assistant", content="Hello! I'm doing well, thank you for asking."
),
finish_reason="stop",
)
],
created=1234567890,
model="llama3.2:3b",
)
@pytest.fixture
def real_embeddings_response():
"""Real OpenAI embeddings response using proper Pydantic objects."""
return OpenAIEmbeddingsResponse(
object="list",
data=[
OpenAIEmbeddingData(object="embedding", embedding=[0.1, 0.2, 0.3], index=0),
OpenAIEmbeddingData(object="embedding", embedding=[0.4, 0.5, 0.6], index=1),
],
model="nomic-embed-text",
usage=OpenAIEmbeddingUsage(prompt_tokens=6, total_tokens=6),
)
class TestInferenceRecording:
"""Test the inference recording system."""
def test_request_normalization(self):
"""Test that request normalization produces consistent hashes."""
# Test basic normalization
hash1 = normalize_request(
"POST",
"http://localhost:11434/v1/chat/completions",
{},
{"model": "llama3.2:3b", "messages": [{"role": "user", "content": "Hello world"}], "temperature": 0.7},
)
# Same request should produce same hash
hash2 = normalize_request(
"POST",
"http://localhost:11434/v1/chat/completions",
{},
{"model": "llama3.2:3b", "messages": [{"role": "user", "content": "Hello world"}], "temperature": 0.7},
)
assert hash1 == hash2
# Different content should produce different hash
hash3 = normalize_request(
"POST",
"http://localhost:11434/v1/chat/completions",
{},
{
"model": "llama3.2:3b",
"messages": [{"role": "user", "content": "Different message"}],
"temperature": 0.7,
},
)
assert hash1 != hash3
def test_request_normalization_edge_cases(self):
"""Test request normalization is precise about request content."""
# Test that different whitespace produces different hashes (no normalization)
hash1 = normalize_request(
"POST",
"http://test/v1/chat/completions",
{},
{"messages": [{"role": "user", "content": "Hello world\n\n"}]},
)
hash2 = normalize_request(
"POST", "http://test/v1/chat/completions", {}, {"messages": [{"role": "user", "content": "Hello world"}]}
)
assert hash1 != hash2 # Different whitespace should produce different hashes
# Test that different float precision produces different hashes (no rounding)
hash3 = normalize_request("POST", "http://test/v1/chat/completions", {}, {"temperature": 0.7000001})
hash4 = normalize_request("POST", "http://test/v1/chat/completions", {}, {"temperature": 0.7})
assert hash3 != hash4 # Different precision should produce different hashes
def test_response_storage(self, temp_storage_dir):
"""Test the ResponseStorage class."""
temp_storage_dir = temp_storage_dir / "test_response_storage"
storage = ResponseStorage(temp_storage_dir)
# Test storing and retrieving a recording
request_hash = "test_hash_123"
request_data = {
"method": "POST",
"url": "http://localhost:11434/v1/chat/completions",
"endpoint": "/v1/chat/completions",
"model": "llama3.2:3b",
}
response_data = {"body": {"content": "test response"}, "is_streaming": False}
storage.store_recording(request_hash, request_data, response_data)
# Verify file storage and retrieval
retrieved = storage.find_recording(request_hash)
assert retrieved is not None
assert retrieved["request"]["model"] == "llama3.2:3b"
assert retrieved["response"]["body"]["content"] == "test response"
async def test_recording_mode(self, temp_storage_dir, real_openai_chat_response):
"""Test that recording mode captures and stores responses."""
temp_storage_dir = temp_storage_dir / "test_recording_mode"
with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)):
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
client.chat.completions._post = AsyncMock(return_value=real_openai_chat_response)
response = await client.chat.completions.create(
model="llama3.2:3b",
messages=[{"role": "user", "content": "Hello, how are you?"}],
temperature=0.7,
max_tokens=50,
user=NOT_GIVEN,
)
# Verify the response was returned correctly
assert response.choices[0].message.content == "Hello! I'm doing well, thank you for asking."
client.chat.completions._post.assert_called_once()
# Verify recording was stored
storage = ResponseStorage(temp_storage_dir)
dir = storage._get_test_dir()
assert dir.exists()
async def test_replay_mode(self, temp_storage_dir, real_openai_chat_response):
"""Test that replay mode returns stored responses without making real calls."""
temp_storage_dir = temp_storage_dir / "test_replay_mode"
# First, record a response
with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)):
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
client.chat.completions._post = AsyncMock(return_value=real_openai_chat_response)
response = await client.chat.completions.create(
model="llama3.2:3b",
messages=[{"role": "user", "content": "Hello, how are you?"}],
temperature=0.7,
max_tokens=50,
user=NOT_GIVEN,
)
client.chat.completions._post.assert_called_once()
# Now test replay mode - should not call the original method
with inference_recording(mode=InferenceMode.REPLAY, storage_dir=str(temp_storage_dir)):
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
client.chat.completions._post = AsyncMock(return_value=real_openai_chat_response)
response = await client.chat.completions.create(
model="llama3.2:3b",
messages=[{"role": "user", "content": "Hello, how are you?"}],
temperature=0.7,
max_tokens=50,
)
# Verify we got the recorded response
assert response.choices[0].message.content == "Hello! I'm doing well, thank you for asking."
# Verify the original method was NOT called
client.chat.completions._post.assert_not_called()
async def test_replay_mode_models(self, temp_storage_dir):
"""Test that replay mode returns stored responses without making real model listing calls."""
async def _async_iterator(models):
for model in models:
yield model
models = [
OpenAIModel(id="foo", created=1, object="model", owned_by="test"),
OpenAIModel(id="bar", created=2, object="model", owned_by="test"),
]
expected_ids = {m.id for m in models}
temp_storage_dir = temp_storage_dir / "test_replay_mode_models"
# baseline - mock works without recording
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
client.models._get_api_list = Mock(return_value=_async_iterator(models))
assert {m.id async for m in client.models.list()} == expected_ids
client.models._get_api_list.assert_called_once()
# record the call
with inference_recording(mode=InferenceMode.RECORD, storage_dir=temp_storage_dir):
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
client.models._get_api_list = Mock(return_value=_async_iterator(models))
assert {m.id async for m in client.models.list()} == expected_ids
client.models._get_api_list.assert_called_once()
# replay the call
with inference_recording(mode=InferenceMode.REPLAY, storage_dir=temp_storage_dir):
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
client.models._get_api_list = Mock(return_value=_async_iterator(models))
assert {m.id async for m in client.models.list()} == expected_ids
client.models._get_api_list.assert_not_called()
async def test_replay_missing_recording(self, temp_storage_dir):
"""Test that replay mode fails when no recording is found."""
temp_storage_dir = temp_storage_dir / "test_replay_missing_recording"
with patch("openai.resources.chat.completions.AsyncCompletions.create"):
with inference_recording(mode=InferenceMode.REPLAY, storage_dir=str(temp_storage_dir)):
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
with pytest.raises(RuntimeError, match="No recorded response found"):
await client.chat.completions.create(
model="llama3.2:3b", messages=[{"role": "user", "content": "This was never recorded"}]
)
async def test_embeddings_recording(self, temp_storage_dir, real_embeddings_response):
"""Test recording and replay of embeddings calls."""
# baseline - mock works without recording
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
client.embeddings._post = AsyncMock(return_value=real_embeddings_response)
response = await client.embeddings.create(
model=real_embeddings_response.model,
input=["Hello world", "Test embedding"],
encoding_format=NOT_GIVEN,
)
assert len(response.data) == 2
assert response.data[0].embedding == [0.1, 0.2, 0.3]
client.embeddings._post.assert_called_once()
temp_storage_dir = temp_storage_dir / "test_embeddings_recording"
# Record
with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)):
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
client.embeddings._post = AsyncMock(return_value=real_embeddings_response)
response = await client.embeddings.create(
model=real_embeddings_response.model,
input=["Hello world", "Test embedding"],
encoding_format=NOT_GIVEN,
dimensions=NOT_GIVEN,
user=NOT_GIVEN,
)
assert len(response.data) == 2
# Replay
with inference_recording(mode=InferenceMode.REPLAY, storage_dir=str(temp_storage_dir)):
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
client.embeddings._post = AsyncMock(return_value=real_embeddings_response)
response = await client.embeddings.create(
model=real_embeddings_response.model,
input=["Hello world", "Test embedding"],
)
# Verify we got the recorded response
assert len(response.data) == 2
assert response.data[0].embedding == [0.1, 0.2, 0.3]
# Verify original method was not called
client.embeddings._post.assert_not_called()
async def test_completions_recording(self, temp_storage_dir):
real_completions_response = OpenAICompletion(
id="test_completion",
object="text_completion",
created=1234567890,
model="llama3.2:3b",
choices=[
{
"text": "Hello! I'm doing well, thank you for asking.",
"index": 0,
"logprobs": None,
"finish_reason": "stop",
}
],
)
temp_storage_dir = temp_storage_dir / "test_completions_recording"
# baseline - mock works without recording
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
client.completions._post = AsyncMock(return_value=real_completions_response)
response = await client.completions.create(
model=real_completions_response.model,
prompt="Hello, how are you?",
temperature=0.7,
max_tokens=50,
user=NOT_GIVEN,
)
assert response.choices[0].text == real_completions_response.choices[0].text
client.completions._post.assert_called_once()
# Record
with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)):
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
client.completions._post = AsyncMock(return_value=real_completions_response)
response = await client.completions.create(
model=real_completions_response.model,
prompt="Hello, how are you?",
temperature=0.7,
max_tokens=50,
user=NOT_GIVEN,
)
assert response.choices[0].text == real_completions_response.choices[0].text
client.completions._post.assert_called_once()
# Replay
with inference_recording(mode=InferenceMode.REPLAY, storage_dir=str(temp_storage_dir)):
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
client.completions._post = AsyncMock(return_value=real_completions_response)
response = await client.completions.create(
model=real_completions_response.model,
prompt="Hello, how are you?",
temperature=0.7,
max_tokens=50,
)
assert response.choices[0].text == real_completions_response.choices[0].text
client.completions._post.assert_not_called()
async def test_live_mode(self, real_openai_chat_response):
"""Test that live mode passes through to original methods."""
async def mock_create(*args, **kwargs):
return real_openai_chat_response
with patch("openai.resources.chat.completions.AsyncCompletions.create", side_effect=mock_create):
with inference_recording(mode=InferenceMode.LIVE, storage_dir="foo"):
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
response = await client.chat.completions.create(
model="llama3.2:3b", messages=[{"role": "user", "content": "Hello"}]
)
# Verify the response was returned
assert response.choices[0].message.content == "Hello! I'm doing well, thank you for asking."