feat(tests): make inference_recorder into api_recorder (include tool_invoke) (#3403)

Renames `inference_recorder.py` to `api_recorder.py` and extends it to
support recording/replaying tool invocations in addition to inference
calls.

This allows us to record web-search, etc. tool calls and thereafter
apply recordings for `tests/integration/responses`

## Test Plan

```
export OPENAI_API_KEY=...
export TAVILY_SEARCH_API_KEY=...

./scripts/integration-tests.sh --stack-config ci-tests \
   --suite responses --inference-mode record-if-missing
```
This commit is contained in:
Ashwin Bharambe 2025-10-09 14:27:51 -07:00 committed by GitHub
parent 26fd5dbd34
commit f50ce11a3b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
284 changed files with 296191 additions and 631 deletions

View file

@ -0,0 +1,42 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import Callable
IdFactory = Callable[[], str]
IdOverride = Callable[[str, IdFactory], str]
_id_override: IdOverride | None = None
def generate_object_id(kind: str, factory: IdFactory) -> str:
"""Generate an identifier for the given kind using the provided factory.
Allows tests to override ID generation deterministically by installing an
override callback via :func:`set_id_override`.
"""
override = _id_override
if override is not None:
return override(kind, factory)
return factory()
def set_id_override(override: IdOverride) -> IdOverride | None:
"""Install an override used to generate deterministic identifiers."""
global _id_override
previous = _id_override
_id_override = override
return previous
def reset_id_override(previous: IdOverride | None) -> None:
"""Restore the previous override returned by :func:`set_id_override`."""
global _id_override
_id_override = previous

View file

@ -232,14 +232,25 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
await log_request_pre_validation(request)
test_context_token = None
# Use context manager with both provider data and auth attributes
with request_provider_data_context(request.headers, user):
if os.environ.get("LLAMA_STACK_TEST_INFERENCE_MODE"):
from llama_stack.core.testing_context import (
TEST_CONTEXT,
reset_test_context,
sync_test_context_from_provider_data,
)
test_context_token = sync_test_context_from_provider_data()
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
try:
if is_streaming:
gen = preserve_contexts_async_generator(
sse_generator(func(**kwargs)), [CURRENT_TRACE_CONTEXT, PROVIDER_DATA_VAR]
sse_generator(func(**kwargs)), [CURRENT_TRACE_CONTEXT, PROVIDER_DATA_VAR, TEST_CONTEXT]
)
return StreamingResponse(gen, media_type="text/event-stream")
else:
@ -258,6 +269,9 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
else:
logger.error(f"Error executing endpoint {route=} {method=}: {str(e)}")
raise translate_exception(e) from e
finally:
if test_context_token is not None:
reset_test_context(test_context_token)
sig = inspect.signature(func)

View file

@ -316,13 +316,13 @@ class Stack:
# asked for in the run config.
async def initialize(self):
if "LLAMA_STACK_TEST_INFERENCE_MODE" in os.environ:
from llama_stack.testing.inference_recorder import setup_inference_recording
from llama_stack.testing.api_recorder import setup_api_recording
global TEST_RECORDING_CONTEXT
TEST_RECORDING_CONTEXT = setup_inference_recording()
TEST_RECORDING_CONTEXT = setup_api_recording()
if TEST_RECORDING_CONTEXT:
TEST_RECORDING_CONTEXT.__enter__()
logger.info(f"Inference recording enabled: mode={os.environ.get('LLAMA_STACK_TEST_INFERENCE_MODE')}")
logger.info(f"API recording enabled: mode={os.environ.get('LLAMA_STACK_TEST_INFERENCE_MODE')}")
dist_registry, _ = await create_dist_registry(self.run_config.metadata_store, self.run_config.image_name)
policy = self.run_config.server.auth.access_policy if self.run_config.server.auth else []
@ -381,7 +381,7 @@ class Stack:
try:
TEST_RECORDING_CONTEXT.__exit__(None, None, None)
except Exception as e:
logger.error(f"Error during inference recording cleanup: {e}")
logger.error(f"Error during API recording cleanup: {e}")
global REGISTRY_REFRESH_TASK
if REGISTRY_REFRESH_TASK:

View file

@ -0,0 +1,44 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
from contextvars import ContextVar
from llama_stack.core.request_headers import PROVIDER_DATA_VAR
TEST_CONTEXT: ContextVar[str | None] = ContextVar("llama_stack_test_context", default=None)
def get_test_context() -> str | None:
return TEST_CONTEXT.get()
def set_test_context(value: str | None):
return TEST_CONTEXT.set(value)
def reset_test_context(token) -> None:
TEST_CONTEXT.reset(token)
def sync_test_context_from_provider_data():
"""Sync test context from provider data when running in server test mode."""
if "LLAMA_STACK_TEST_INFERENCE_MODE" not in os.environ:
return None
stack_config_type = os.environ.get("LLAMA_STACK_TEST_STACK_CONFIG_TYPE", "library_client")
if stack_config_type != "server":
return None
try:
provider_data = PROVIDER_DATA_VAR.get()
except LookupError:
provider_data = None
if provider_data and "__test_id" in provider_data:
return TEST_CONTEXT.set(provider_data["__test_id"])
return None

View file

@ -108,7 +108,7 @@ class OpenAIResponsesImpl:
# Use stored messages directly and convert only new input
message_adapter = TypeAdapter(list[OpenAIMessageParam])
messages = message_adapter.validate_python(previous_response.messages)
new_messages = await convert_response_input_to_chat_messages(input)
new_messages = await convert_response_input_to_chat_messages(input, previous_messages=messages)
messages.extend(new_messages)
else:
# Backward compatibility: reconstruct from inputs

View file

@ -103,9 +103,13 @@ async def convert_response_content_to_chat_content(
async def convert_response_input_to_chat_messages(
input: str | list[OpenAIResponseInput],
previous_messages: list[OpenAIMessageParam] | None = None,
) -> list[OpenAIMessageParam]:
"""
Convert the input from an OpenAI Response API request into OpenAI Chat Completion messages.
:param input: The input to convert
:param previous_messages: Optional previous messages to check for function_call references
"""
messages: list[OpenAIMessageParam] = []
if isinstance(input, list):
@ -169,16 +173,53 @@ async def convert_response_input_to_chat_messages(
raise ValueError(
f"Llama Stack OpenAI Responses does not yet support message role '{input_item.role}' in this context"
)
# Skip user messages that duplicate the last user message in previous_messages
# This handles cases where input includes context for function_call_outputs
if previous_messages and input_item.role == "user":
last_user_msg = None
for msg in reversed(previous_messages):
if isinstance(msg, OpenAIUserMessageParam):
last_user_msg = msg
break
if last_user_msg:
last_user_content = getattr(last_user_msg, "content", None)
if last_user_content == content:
continue # Skip duplicate user message
messages.append(message_type(content=content))
if len(tool_call_results):
raise ValueError(
f"Received function_call_output(s) with call_id(s) {tool_call_results.keys()}, but no corresponding function_call"
)
# Check if unpaired function_call_outputs reference function_calls from previous messages
if previous_messages:
previous_call_ids = _extract_tool_call_ids(previous_messages)
for call_id in list(tool_call_results.keys()):
if call_id in previous_call_ids:
# Valid: this output references a call from previous messages
# Add the tool message
messages.append(tool_call_results[call_id])
del tool_call_results[call_id]
# If still have unpaired outputs, error
if len(tool_call_results):
raise ValueError(
f"Received function_call_output(s) with call_id(s) {tool_call_results.keys()}, but no corresponding function_call"
)
else:
messages.append(OpenAIUserMessageParam(content=input))
return messages
def _extract_tool_call_ids(messages: list[OpenAIMessageParam]) -> set[str]:
"""Extract all tool_call IDs from messages."""
call_ids = set()
for msg in messages:
if isinstance(msg, OpenAIAssistantMessageParam):
tool_calls = getattr(msg, "tool_calls", None)
if tool_calls:
for tool_call in tool_calls:
# tool_call is a Pydantic model, use attribute access
call_ids.add(tool_call.id)
return call_ids
async def convert_response_text_to_chat_response_format(
text: OpenAIResponseText,
) -> OpenAIResponseFormatParam:

View file

@ -22,6 +22,7 @@ from llama_stack.apis.files import (
OpenAIFilePurpose,
)
from llama_stack.core.datatypes import AccessRule
from llama_stack.core.id_generation import generate_object_id
from llama_stack.log import get_logger
from llama_stack.providers.utils.files.form_data import parse_expires_after
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
@ -65,7 +66,7 @@ class LocalfsFilesImpl(Files):
def _generate_file_id(self) -> str:
"""Generate a unique file ID for OpenAI API."""
return f"file-{uuid.uuid4().hex}"
return generate_object_id("file", lambda: f"file-{uuid.uuid4().hex}")
def _get_file_path(self, file_id: str) -> Path:
"""Get the filesystem path for a file ID."""
@ -95,7 +96,9 @@ class LocalfsFilesImpl(Files):
raise RuntimeError("Files provider not initialized")
if expires_after is not None:
raise NotImplementedError("File expiration is not supported by this provider")
logger.warning(
f"File expiration is not supported by this provider, ignoring expires_after: {expires_after}"
)
file_id = self._generate_file_id()
file_path = self._get_file_path(file_id)

View file

@ -23,6 +23,7 @@ from llama_stack.apis.files import (
OpenAIFilePurpose,
)
from llama_stack.core.datatypes import AccessRule
from llama_stack.core.id_generation import generate_object_id
from llama_stack.providers.utils.files.form_data import parse_expires_after
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
@ -198,7 +199,7 @@ class S3FilesImpl(Files):
purpose: Annotated[OpenAIFilePurpose, Form()],
expires_after: Annotated[ExpiresAfter | None, Depends(parse_expires_after)] = None,
) -> OpenAIFileObject:
file_id = f"file-{uuid.uuid4().hex}"
file_id = generate_object_id("file", lambda: f"file-{uuid.uuid4().hex}")
filename = getattr(file, "filename", None) or "uploaded_file"

View file

@ -40,6 +40,7 @@ from llama_stack.apis.vector_io import (
VectorStoreSearchResponse,
VectorStoreSearchResponsePage,
)
from llama_stack.core.id_generation import generate_object_id
from llama_stack.log import get_logger
from llama_stack.providers.utils.kvstore.api import KVStore
from llama_stack.providers.utils.memory.vector_store import (
@ -352,7 +353,7 @@ class OpenAIVectorStoreMixin(ABC):
"""Creates a vector store."""
created_at = int(time.time())
# Derive the canonical vector_db_id (allow override, else generate)
vector_db_id = provider_vector_db_id or f"vs_{uuid.uuid4()}"
vector_db_id = provider_vector_db_id or generate_object_id("vector_store", lambda: f"vs_{uuid.uuid4()}")
if provider_id is None:
raise ValueError("Provider ID is required")
@ -986,7 +987,7 @@ class OpenAIVectorStoreMixin(ABC):
chunking_strategy = chunking_strategy or VectorStoreChunkingStrategyAuto()
created_at = int(time.time())
batch_id = f"batch_{uuid.uuid4()}"
batch_id = generate_object_id("vector_store_file_batch", lambda: f"batch_{uuid.uuid4()}")
# File batches expire after 7 days
expires_at = created_at + (7 * 24 * 60 * 60)

View file

@ -9,7 +9,8 @@ from __future__ import annotations # for forward references
import hashlib
import json
import os
from collections.abc import Generator
import re
from collections.abc import Callable, Generator
from contextlib import contextmanager
from enum import StrEnum
from pathlib import Path
@ -17,6 +18,7 @@ from typing import Any, Literal, cast
from openai import NOT_GIVEN, OpenAI
from llama_stack.core.id_generation import reset_id_override, set_id_override
from llama_stack.log import get_logger
logger = get_logger(__name__, category="testing")
@ -29,13 +31,14 @@ _current_mode: str | None = None
_current_storage: ResponseStorage | None = None
_original_methods: dict[str, Any] = {}
# Per-test deterministic ID counters (test_id -> id_kind -> counter)
_id_counters: dict[str, dict[str, int]] = {}
# Test context uses ContextVar since it changes per-test and needs async isolation
from contextvars import ContextVar
_test_context: ContextVar[str | None] = ContextVar("_test_context", default=None)
from openai.types.completion_choice import CompletionChoice
from llama_stack.core.testing_context import get_test_context
# update the "finish_reason" field, since its type definition is wrong (no None is accepted)
CompletionChoice.model_fields["finish_reason"].annotation = Literal["stop", "length", "content_filter"] | None
CompletionChoice.model_rebuild()
@ -44,14 +47,89 @@ REPO_ROOT = Path(__file__).parent.parent.parent
DEFAULT_STORAGE_DIR = REPO_ROOT / "tests/integration/common"
class InferenceMode(StrEnum):
class APIRecordingMode(StrEnum):
LIVE = "live"
RECORD = "record"
REPLAY = "replay"
RECORD_IF_MISSING = "record-if-missing"
def normalize_request(method: str, url: str, headers: dict[str, Any], body: dict[str, Any]) -> str:
_ID_KIND_PREFIXES: dict[str, str] = {
"file": "file-",
"vector_store": "vs_",
"vector_store_file_batch": "batch_",
"tool_call": "call_",
}
_FLOAT_IN_STRING_PATTERN = re.compile(r"(-?\d+\.\d{4,})")
def _normalize_numeric_literal_strings(value: str) -> str:
"""Round any long decimal literals embedded in strings for stable hashing."""
def _replace(match: re.Match[str]) -> str:
number = float(match.group(0))
return f"{number:.5f}"
return _FLOAT_IN_STRING_PATTERN.sub(_replace, value)
def _normalize_body_for_hash(value: Any) -> Any:
"""Recursively normalize a JSON-like value to improve hash stability."""
if isinstance(value, dict):
return {key: _normalize_body_for_hash(item) for key, item in value.items()}
if isinstance(value, list):
return [_normalize_body_for_hash(item) for item in value]
if isinstance(value, tuple):
return tuple(_normalize_body_for_hash(item) for item in value)
if isinstance(value, float):
return round(value, 5)
if isinstance(value, str):
return _normalize_numeric_literal_strings(value)
return value
def _allocate_test_scoped_id(kind: str) -> str | None:
"""Return the next deterministic ID for the given kind within the current test."""
global _id_counters
test_id = get_test_context()
prefix = _ID_KIND_PREFIXES.get(kind)
if prefix is None:
return None
if not test_id:
raise ValueError(f"Test ID is required for {kind} ID allocation")
key = test_id
if key not in _id_counters:
_id_counters[key] = {}
# each test should get a contiguous block of IDs otherwise we will get
# collisions between tests inside other systems (like file storage) which
# expect IDs to be unique
test_hash = hashlib.sha256(test_id.encode()).hexdigest()
test_hash_int = int(test_hash, 16)
counter = test_hash_int % 1000000000000
counter = _id_counters[key].get(kind, counter) + 1
_id_counters[key][kind] = counter
return f"{prefix}{counter}"
def _deterministic_id_override(kind: str, factory: Callable[[], str]) -> str:
deterministic_id = _allocate_test_scoped_id(kind)
if deterministic_id is not None:
return deterministic_id
return factory()
def normalize_inference_request(method: str, url: str, headers: dict[str, Any], body: dict[str, Any]) -> str:
"""Create a normalized hash of the request for consistent matching.
Includes test_id from context to ensure test isolation - identical requests
@ -60,50 +138,39 @@ def normalize_request(method: str, url: str, headers: dict[str, Any], body: dict
Exception: Model list endpoints (/v1/models, /api/tags) exclude test_id since
they are infrastructure/shared and need to work across session setup and tests.
"""
# Extract just the endpoint path
from urllib.parse import urlparse
parsed = urlparse(url)
body_for_hash = _normalize_body_for_hash(body)
normalized: dict[str, Any] = {
"method": method.upper(),
"endpoint": parsed.path,
"body": body,
"body": body_for_hash,
}
# Include test_id for isolation, except for shared infrastructure endpoints
if parsed.path not in ("/api/tags", "/v1/models"):
normalized["test_id"] = _test_context.get()
normalized["test_id"] = get_test_context()
# Create hash - sort_keys=True ensures deterministic ordering
normalized_json = json.dumps(normalized, sort_keys=True)
return hashlib.sha256(normalized_json.encode()).hexdigest()
def _sync_test_context_from_provider_data():
"""In server mode, sync test ID from provider_data to _test_context.
def normalize_tool_request(provider_name: str, tool_name: str, kwargs: dict[str, Any]) -> str:
"""Create a normalized hash of the tool request for consistent matching."""
normalized = {
"provider": provider_name,
"tool_name": tool_name,
"kwargs": kwargs,
}
This ensures that storage operations (which read from _test_context) work correctly
in server mode where the test ID arrives via HTTP header provider_data.
Returns a token to reset _test_context, or None if no sync was needed.
"""
stack_config_type = os.environ.get("LLAMA_STACK_TEST_STACK_CONFIG_TYPE", "library_client")
if stack_config_type != "server":
return None
try:
from llama_stack.core.request_headers import PROVIDER_DATA_VAR
provider_data = PROVIDER_DATA_VAR.get()
if provider_data and "__test_id" in provider_data:
test_id = provider_data["__test_id"]
return _test_context.set(test_id)
except ImportError:
pass
return None
# Create hash - sort_keys=True ensures deterministic ordering
normalized_json = json.dumps(normalized, sort_keys=True)
return hashlib.sha256(normalized_json.encode()).hexdigest()
def patch_httpx_for_test_id():
@ -127,13 +194,12 @@ def patch_httpx_for_test_id():
def patched_prepare_request(self, request):
# Call original first (it's a sync method that returns None)
# Determine which original to call based on client type
if "llama_stack_client" in self.__class__.__module__:
_original_methods["llama_stack_client_prepare_request"](self, request)
_original_methods["openai_prepare_request"](self, request)
_original_methods["llama_stack_client_prepare_request"](self, request)
_original_methods["openai_prepare_request"](self, request)
# Only inject test ID in server mode
stack_config_type = os.environ.get("LLAMA_STACK_TEST_STACK_CONFIG_TYPE", "library_client")
test_id = _test_context.get()
test_id = get_test_context()
if stack_config_type == "server" and test_id:
provider_data_header = request.headers.get("X-LlamaStack-Provider-Data")
@ -162,23 +228,22 @@ def unpatch_httpx_for_test_id():
LlamaStackClient._prepare_request = _original_methods["llama_stack_client_prepare_request"]
del _original_methods["llama_stack_client_prepare_request"]
# Also restore OpenAI client if it was patched
if "openai_prepare_request" in _original_methods:
OpenAI._prepare_request = _original_methods["openai_prepare_request"]
del _original_methods["openai_prepare_request"]
OpenAI._prepare_request = _original_methods["openai_prepare_request"]
del _original_methods["openai_prepare_request"]
def get_inference_mode() -> InferenceMode:
return InferenceMode(os.environ.get("LLAMA_STACK_TEST_INFERENCE_MODE", "replay").lower())
def get_api_recording_mode() -> APIRecordingMode:
return APIRecordingMode(os.environ.get("LLAMA_STACK_TEST_INFERENCE_MODE", "replay").lower())
def setup_inference_recording():
def setup_api_recording():
"""
Returns a context manager that can be used to record or replay inference requests. This is to be used in tests
to increase their reliability and reduce reliance on expensive, external services.
Returns a context manager that can be used to record or replay API requests (inference and tools).
This is to be used in tests to increase their reliability and reduce reliance on expensive, external services.
Currently, this is only supported for OpenAI and Ollama clients. These should cover the vast majority of use cases.
Currently supports:
- Inference: OpenAI and Ollama clients
- Tools: Search providers (Tavily)
Two environment variables are supported:
- LLAMA_STACK_TEST_INFERENCE_MODE: The mode to run in. Must be 'live', 'record', 'replay', or 'record-if-missing'. Default is 'replay'.
@ -190,15 +255,15 @@ def setup_inference_recording():
The recordings are stored as JSON files.
"""
mode = get_inference_mode()
if mode == InferenceMode.LIVE:
mode = get_api_recording_mode()
if mode == APIRecordingMode.LIVE:
return None
storage_dir = os.environ.get("LLAMA_STACK_TEST_RECORDING_DIR", DEFAULT_STORAGE_DIR)
return inference_recording(mode=mode, storage_dir=storage_dir)
return api_recording(mode=mode, storage_dir=storage_dir)
def _normalize_response_data(data: dict[str, Any], request_hash: str) -> dict[str, Any]:
def _normalize_response(data: dict[str, Any], request_hash: str) -> dict[str, Any]:
"""Normalize fields that change between recordings but don't affect functionality.
This reduces noise in git diffs by making IDs deterministic and timestamps constant.
@ -234,7 +299,7 @@ def _serialize_response(response: Any, request_hash: str = "") -> Any:
if hasattr(response, "model_dump"):
data = response.model_dump(mode="json")
# Normalize fields to reduce noise
data = _normalize_response_data(data, request_hash)
data = _normalize_response(data, request_hash)
return {
"__type__": f"{response.__class__.__module__}.{response.__class__.__qualname__}",
"__data__": data,
@ -282,7 +347,7 @@ class ResponseStorage:
For test at "tests/integration/inference/test_foo.py::test_bar",
returns "tests/integration/inference/recordings/".
"""
test_id = _test_context.get()
test_id = get_test_context()
if test_id:
# Extract the directory path from the test nodeid
# e.g., "tests/integration/inference/test_basic.py::test_foo[params]"
@ -297,7 +362,7 @@ class ResponseStorage:
# Fallback for non-test contexts
return self.base_dir / "recordings"
def _ensure_directories(self):
def _ensure_directory(self):
"""Ensure test-specific directories exist."""
test_dir = self._get_test_dir()
test_dir.mkdir(parents=True, exist_ok=True)
@ -305,7 +370,7 @@ class ResponseStorage:
def store_recording(self, request_hash: str, request: dict[str, Any], response: dict[str, Any]):
"""Store a request/response pair."""
responses_dir = self._ensure_directories()
responses_dir = self._ensure_directory()
# Use FULL hash (not truncated)
response_file = f"{request_hash}.json"
@ -334,9 +399,10 @@ class ResponseStorage:
with open(response_path, "w") as f:
json.dump(
{
"test_id": _test_context.get(),
"test_id": get_test_context(),
"request": request,
"response": serialized_response,
"id_normalization_mapping": {},
},
f,
indent=2,
@ -394,6 +460,14 @@ def _recording_from_file(response_path) -> dict[str, Any]:
with open(response_path) as f:
data = json.load(f)
mapping = data.get("id_normalization_mapping") or {}
if mapping:
serialized = json.dumps(data)
for normalized, original in mapping.items():
serialized = serialized.replace(original, normalized)
data = json.loads(serialized)
data["id_normalization_mapping"] = {}
# Deserialize response body if needed
if "response" in data and "body" in data["response"]:
if isinstance(data["response"]["body"], list):
@ -464,131 +538,168 @@ def _combine_model_list_responses(endpoint: str, records: list[dict[str, Any]])
return {"request": canonical_req, "response": {"body": body, "is_streaming": False}}
async def _patched_tool_invoke_method(
original_method, provider_name: str, self, tool_name: str, kwargs: dict[str, Any]
):
"""Patched version of tool runtime invoke_tool method for recording/replay."""
global _current_mode, _current_storage
if _current_mode == APIRecordingMode.LIVE or _current_storage is None:
# Normal operation
return await original_method(self, tool_name, kwargs)
request_hash = normalize_tool_request(provider_name, tool_name, kwargs)
if _current_mode in (APIRecordingMode.REPLAY, APIRecordingMode.RECORD_IF_MISSING):
recording = _current_storage.find_recording(request_hash)
if recording:
return recording["response"]["body"]
elif _current_mode == APIRecordingMode.REPLAY:
raise RuntimeError(
f"No recorded tool result found for {provider_name}.{tool_name}\n"
f"Request: {kwargs}\n"
f"To record this response, run with LLAMA_STACK_TEST_INFERENCE_MODE=record"
)
# If RECORD_IF_MISSING and no recording found, fall through to record
if _current_mode in (APIRecordingMode.RECORD, APIRecordingMode.RECORD_IF_MISSING):
# Make the tool call and record it
result = await original_method(self, tool_name, kwargs)
request_data = {
"test_id": get_test_context(),
"provider": provider_name,
"tool_name": tool_name,
"kwargs": kwargs,
}
response_data = {"body": result, "is_streaming": False}
# Store the recording
_current_storage.store_recording(request_hash, request_data, response_data)
return result
else:
raise AssertionError(f"Invalid mode: {_current_mode}")
async def _patched_inference_method(original_method, self, client_type, endpoint, *args, **kwargs):
global _current_mode, _current_storage
mode = _current_mode
storage = _current_storage
if mode == InferenceMode.LIVE or storage is None:
if mode == APIRecordingMode.LIVE or storage is None:
if endpoint == "/v1/models":
return original_method(self, *args, **kwargs)
else:
return await original_method(self, *args, **kwargs)
# In server mode, sync test ID from provider_data to _test_context for storage operations
test_context_token = _sync_test_context_from_provider_data()
# Get base URL based on client type
if client_type == "openai":
base_url = str(self._client.base_url)
try:
# Get base URL based on client type
if client_type == "openai":
base_url = str(self._client.base_url)
# the OpenAI client methods may pass NOT_GIVEN for unset parameters; filter these out
kwargs = {k: v for k, v in kwargs.items() if v is not NOT_GIVEN}
elif client_type == "ollama":
# Get base URL from the client (Ollama client uses host attribute)
base_url = getattr(self, "host", "http://localhost:11434")
if not base_url.startswith("http"):
base_url = f"http://{base_url}"
else:
raise ValueError(f"Unknown client type: {client_type}")
# the OpenAI client methods may pass NOT_GIVEN for unset parameters; filter these out
kwargs = {k: v for k, v in kwargs.items() if v is not NOT_GIVEN}
elif client_type == "ollama":
# Get base URL from the client (Ollama client uses host attribute)
base_url = getattr(self, "host", "http://localhost:11434")
if not base_url.startswith("http"):
base_url = f"http://{base_url}"
url = base_url.rstrip("/") + endpoint
# Special handling for Databricks URLs to avoid leaking workspace info
# e.g. https://adb-1234567890123456.7.cloud.databricks.com -> https://...cloud.databricks.com
if "cloud.databricks.com" in url:
url = "__databricks__" + url.split("cloud.databricks.com")[-1]
method = "POST"
headers = {}
body = kwargs
request_hash = normalize_inference_request(method, url, headers, body)
# Try to find existing recording for REPLAY or RECORD_IF_MISSING modes
recording = None
if mode == APIRecordingMode.REPLAY or mode == APIRecordingMode.RECORD_IF_MISSING:
# Special handling for model-list endpoints: merge all recordings with this hash
if endpoint in ("/api/tags", "/v1/models"):
records = storage._model_list_responses(request_hash)
recording = _combine_model_list_responses(endpoint, records)
else:
raise ValueError(f"Unknown client type: {client_type}")
recording = storage.find_recording(request_hash)
url = base_url.rstrip("/") + endpoint
# Special handling for Databricks URLs to avoid leaking workspace info
# e.g. https://adb-1234567890123456.7.cloud.databricks.com -> https://...cloud.databricks.com
if "cloud.databricks.com" in url:
url = "__databricks__" + url.split("cloud.databricks.com")[-1]
method = "POST"
headers = {}
body = kwargs
if recording:
response_body = recording["response"]["body"]
request_hash = normalize_request(method, url, headers, body)
if recording["response"].get("is_streaming", False):
# Try to find existing recording for REPLAY or RECORD_IF_MISSING modes
recording = None
if mode == InferenceMode.REPLAY or mode == InferenceMode.RECORD_IF_MISSING:
# Special handling for model-list endpoints: merge all recordings with this hash
if endpoint in ("/api/tags", "/v1/models"):
records = storage._model_list_responses(request_hash)
recording = _combine_model_list_responses(endpoint, records)
else:
recording = storage.find_recording(request_hash)
if recording:
response_body = recording["response"]["body"]
if recording["response"].get("is_streaming", False):
async def replay_stream():
for chunk in response_body:
yield chunk
return replay_stream()
else:
return response_body
elif mode == InferenceMode.REPLAY:
# REPLAY mode requires recording to exist
raise RuntimeError(
f"No recorded response found for request hash: {request_hash}\n"
f"Request: {method} {url} {body}\n"
f"Model: {body.get('model', 'unknown')}\n"
f"To record this response, run with LLAMA_STACK_TEST_INFERENCE_MODE=record"
)
if mode == InferenceMode.RECORD or (mode == InferenceMode.RECORD_IF_MISSING and not recording):
if endpoint == "/v1/models":
response = original_method(self, *args, **kwargs)
else:
response = await original_method(self, *args, **kwargs)
# we want to store the result of the iterator, not the iterator itself
if endpoint == "/v1/models":
response = [m async for m in response]
request_data = {
"method": method,
"url": url,
"headers": headers,
"body": body,
"endpoint": endpoint,
"model": body.get("model", ""),
}
# Determine if this is a streaming request based on request parameters
is_streaming = body.get("stream", False)
if is_streaming:
# For streaming responses, we need to collect all chunks immediately before yielding
# This ensures the recording is saved even if the generator isn't fully consumed
chunks = []
async for chunk in response:
chunks.append(chunk)
# Store the recording immediately
response_data = {"body": chunks, "is_streaming": True}
storage.store_recording(request_hash, request_data, response_data)
# Return a generator that replays the stored chunks
async def replay_recorded_stream():
for chunk in chunks:
async def replay_stream():
for chunk in response_body:
yield chunk
return replay_recorded_stream()
return replay_stream()
else:
response_data = {"body": response, "is_streaming": False}
storage.store_recording(request_hash, request_data, response_data)
return response
return response_body
elif mode == APIRecordingMode.REPLAY:
# REPLAY mode requires recording to exist
raise RuntimeError(
f"No recorded response found for request hash: {request_hash}\n"
f"Request: {method} {url} {body}\n"
f"Model: {body.get('model', 'unknown')}\n"
f"To record this response, run with LLAMA_STACK_TEST_INFERENCE_MODE=record"
)
if mode == APIRecordingMode.RECORD or (mode == APIRecordingMode.RECORD_IF_MISSING and not recording):
if endpoint == "/v1/models":
response = original_method(self, *args, **kwargs)
else:
raise AssertionError(f"Invalid mode: {mode}")
finally:
if test_context_token:
_test_context.reset(test_context_token)
response = await original_method(self, *args, **kwargs)
# we want to store the result of the iterator, not the iterator itself
if endpoint == "/v1/models":
response = [m async for m in response]
request_data = {
"method": method,
"url": url,
"headers": headers,
"body": body,
"endpoint": endpoint,
"model": body.get("model", ""),
}
# Determine if this is a streaming request based on request parameters
is_streaming = body.get("stream", False)
if is_streaming:
# For streaming responses, we need to collect all chunks immediately before yielding
# This ensures the recording is saved even if the generator isn't fully consumed
chunks: list[Any] = []
async for chunk in response:
chunks.append(chunk)
# Store the recording immediately
response_data = {"body": chunks, "is_streaming": True}
storage.store_recording(request_hash, request_data, response_data)
# Return a generator that replays the stored chunks
async def replay_recorded_stream():
for chunk in chunks:
yield chunk
return replay_recorded_stream()
else:
response_data = {"body": response, "is_streaming": False}
storage.store_recording(request_hash, request_data, response_data)
return response
else:
raise AssertionError(f"Invalid mode: {mode}")
def patch_inference_clients():
"""Install monkey patches for OpenAI client methods and Ollama AsyncClient methods."""
"""Install monkey patches for OpenAI client methods, Ollama AsyncClient methods, and tool runtime methods."""
global _original_methods
from ollama import AsyncClient as OllamaAsyncClient
@ -597,7 +708,9 @@ def patch_inference_clients():
from openai.resources.embeddings import AsyncEmbeddings
from openai.resources.models import AsyncModels
# Store original methods for both OpenAI and Ollama clients
from llama_stack.providers.remote.tool_runtime.tavily_search.tavily_search import TavilySearchToolRuntimeImpl
# Store original methods for OpenAI, Ollama clients, and tool runtimes
_original_methods = {
"chat_completions_create": AsyncChatCompletions.create,
"completions_create": AsyncCompletions.create,
@ -609,6 +722,7 @@ def patch_inference_clients():
"ollama_ps": OllamaAsyncClient.ps,
"ollama_pull": OllamaAsyncClient.pull,
"ollama_list": OllamaAsyncClient.list,
"tavily_invoke_tool": TavilySearchToolRuntimeImpl.invoke_tool,
}
# Create patched methods for OpenAI client
@ -681,9 +795,18 @@ def patch_inference_clients():
OllamaAsyncClient.pull = patched_ollama_pull
OllamaAsyncClient.list = patched_ollama_list
# Create patched methods for tool runtimes
async def patched_tavily_invoke_tool(self, tool_name: str, kwargs: dict[str, Any]):
return await _patched_tool_invoke_method(
_original_methods["tavily_invoke_tool"], "tavily", self, tool_name, kwargs
)
# Apply tool runtime patches
TavilySearchToolRuntimeImpl.invoke_tool = patched_tavily_invoke_tool
def unpatch_inference_clients():
"""Remove monkey patches and restore original OpenAI and Ollama client methods."""
"""Remove monkey patches and restore original OpenAI, Ollama client, and tool runtime methods."""
global _original_methods
if not _original_methods:
@ -696,6 +819,8 @@ def unpatch_inference_clients():
from openai.resources.embeddings import AsyncEmbeddings
from openai.resources.models import AsyncModels
from llama_stack.providers.remote.tool_runtime.tavily_search.tavily_search import TavilySearchToolRuntimeImpl
# Restore OpenAI client methods
AsyncChatCompletions.create = _original_methods["chat_completions_create"]
AsyncCompletions.create = _original_methods["completions_create"]
@ -710,17 +835,21 @@ def unpatch_inference_clients():
OllamaAsyncClient.pull = _original_methods["ollama_pull"]
OllamaAsyncClient.list = _original_methods["ollama_list"]
# Restore tool runtime methods
TavilySearchToolRuntimeImpl.invoke_tool = _original_methods["tavily_invoke_tool"]
_original_methods.clear()
@contextmanager
def inference_recording(mode: str, storage_dir: str | Path | None = None) -> Generator[None, None, None]:
"""Context manager for inference recording/replaying."""
def api_recording(mode: str, storage_dir: str | Path | None = None) -> Generator[None, None, None]:
"""Context manager for API recording/replaying (inference and tools)."""
global _current_mode, _current_storage
# Store previous state
prev_mode = _current_mode
prev_storage = _current_storage
previous_override = None
try:
_current_mode = mode
@ -729,7 +858,9 @@ def inference_recording(mode: str, storage_dir: str | Path | None = None) -> Gen
if storage_dir is None:
raise ValueError("storage_dir is required for record, replay, and record-if-missing modes")
_current_storage = ResponseStorage(Path(storage_dir))
_id_counters.clear()
patch_inference_clients()
previous_override = set_id_override(_deterministic_id_override)
yield
@ -737,6 +868,7 @@ def inference_recording(mode: str, storage_dir: str | Path | None = None) -> Gen
# Restore previous state
if mode in ["record", "replay", "record-if-missing"]:
unpatch_inference_clients()
reset_id_override(previous_override)
_current_mode = prev_mode
_current_storage = prev_storage