mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-11 13:44:38 +00:00
feat(tests): make inference_recorder into api_recorder (include tool_invoke) (#3403)
Renames `inference_recorder.py` to `api_recorder.py` and extends it to support recording/replaying tool invocations in addition to inference calls. This allows us to record web-search, etc. tool calls and thereafter apply recordings for `tests/integration/responses` ## Test Plan ``` export OPENAI_API_KEY=... export TAVILY_SEARCH_API_KEY=... ./scripts/integration-tests.sh --stack-config ci-tests \ --suite responses --inference-mode record-if-missing ```
This commit is contained in:
parent
26fd5dbd34
commit
f50ce11a3b
284 changed files with 296191 additions and 631 deletions
42
llama_stack/core/id_generation.py
Normal file
42
llama_stack/core/id_generation.py
Normal file
|
@ -0,0 +1,42 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
IdFactory = Callable[[], str]
|
||||
IdOverride = Callable[[str, IdFactory], str]
|
||||
|
||||
_id_override: IdOverride | None = None
|
||||
|
||||
|
||||
def generate_object_id(kind: str, factory: IdFactory) -> str:
|
||||
"""Generate an identifier for the given kind using the provided factory.
|
||||
|
||||
Allows tests to override ID generation deterministically by installing an
|
||||
override callback via :func:`set_id_override`.
|
||||
"""
|
||||
|
||||
override = _id_override
|
||||
if override is not None:
|
||||
return override(kind, factory)
|
||||
return factory()
|
||||
|
||||
|
||||
def set_id_override(override: IdOverride) -> IdOverride | None:
|
||||
"""Install an override used to generate deterministic identifiers."""
|
||||
|
||||
global _id_override
|
||||
|
||||
previous = _id_override
|
||||
_id_override = override
|
||||
return previous
|
||||
|
||||
|
||||
def reset_id_override(previous: IdOverride | None) -> None:
|
||||
"""Restore the previous override returned by :func:`set_id_override`."""
|
||||
|
||||
global _id_override
|
||||
_id_override = previous
|
|
@ -232,14 +232,25 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
|
|||
|
||||
await log_request_pre_validation(request)
|
||||
|
||||
test_context_token = None
|
||||
|
||||
# Use context manager with both provider data and auth attributes
|
||||
with request_provider_data_context(request.headers, user):
|
||||
if os.environ.get("LLAMA_STACK_TEST_INFERENCE_MODE"):
|
||||
from llama_stack.core.testing_context import (
|
||||
TEST_CONTEXT,
|
||||
reset_test_context,
|
||||
sync_test_context_from_provider_data,
|
||||
)
|
||||
|
||||
test_context_token = sync_test_context_from_provider_data()
|
||||
|
||||
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
|
||||
|
||||
try:
|
||||
if is_streaming:
|
||||
gen = preserve_contexts_async_generator(
|
||||
sse_generator(func(**kwargs)), [CURRENT_TRACE_CONTEXT, PROVIDER_DATA_VAR]
|
||||
sse_generator(func(**kwargs)), [CURRENT_TRACE_CONTEXT, PROVIDER_DATA_VAR, TEST_CONTEXT]
|
||||
)
|
||||
return StreamingResponse(gen, media_type="text/event-stream")
|
||||
else:
|
||||
|
@ -258,6 +269,9 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
|
|||
else:
|
||||
logger.error(f"Error executing endpoint {route=} {method=}: {str(e)}")
|
||||
raise translate_exception(e) from e
|
||||
finally:
|
||||
if test_context_token is not None:
|
||||
reset_test_context(test_context_token)
|
||||
|
||||
sig = inspect.signature(func)
|
||||
|
||||
|
|
|
@ -316,13 +316,13 @@ class Stack:
|
|||
# asked for in the run config.
|
||||
async def initialize(self):
|
||||
if "LLAMA_STACK_TEST_INFERENCE_MODE" in os.environ:
|
||||
from llama_stack.testing.inference_recorder import setup_inference_recording
|
||||
from llama_stack.testing.api_recorder import setup_api_recording
|
||||
|
||||
global TEST_RECORDING_CONTEXT
|
||||
TEST_RECORDING_CONTEXT = setup_inference_recording()
|
||||
TEST_RECORDING_CONTEXT = setup_api_recording()
|
||||
if TEST_RECORDING_CONTEXT:
|
||||
TEST_RECORDING_CONTEXT.__enter__()
|
||||
logger.info(f"Inference recording enabled: mode={os.environ.get('LLAMA_STACK_TEST_INFERENCE_MODE')}")
|
||||
logger.info(f"API recording enabled: mode={os.environ.get('LLAMA_STACK_TEST_INFERENCE_MODE')}")
|
||||
|
||||
dist_registry, _ = await create_dist_registry(self.run_config.metadata_store, self.run_config.image_name)
|
||||
policy = self.run_config.server.auth.access_policy if self.run_config.server.auth else []
|
||||
|
@ -381,7 +381,7 @@ class Stack:
|
|||
try:
|
||||
TEST_RECORDING_CONTEXT.__exit__(None, None, None)
|
||||
except Exception as e:
|
||||
logger.error(f"Error during inference recording cleanup: {e}")
|
||||
logger.error(f"Error during API recording cleanup: {e}")
|
||||
|
||||
global REGISTRY_REFRESH_TASK
|
||||
if REGISTRY_REFRESH_TASK:
|
||||
|
|
44
llama_stack/core/testing_context.py
Normal file
44
llama_stack/core/testing_context.py
Normal file
|
@ -0,0 +1,44 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
from contextvars import ContextVar
|
||||
|
||||
from llama_stack.core.request_headers import PROVIDER_DATA_VAR
|
||||
|
||||
TEST_CONTEXT: ContextVar[str | None] = ContextVar("llama_stack_test_context", default=None)
|
||||
|
||||
|
||||
def get_test_context() -> str | None:
|
||||
return TEST_CONTEXT.get()
|
||||
|
||||
|
||||
def set_test_context(value: str | None):
|
||||
return TEST_CONTEXT.set(value)
|
||||
|
||||
|
||||
def reset_test_context(token) -> None:
|
||||
TEST_CONTEXT.reset(token)
|
||||
|
||||
|
||||
def sync_test_context_from_provider_data():
|
||||
"""Sync test context from provider data when running in server test mode."""
|
||||
if "LLAMA_STACK_TEST_INFERENCE_MODE" not in os.environ:
|
||||
return None
|
||||
|
||||
stack_config_type = os.environ.get("LLAMA_STACK_TEST_STACK_CONFIG_TYPE", "library_client")
|
||||
if stack_config_type != "server":
|
||||
return None
|
||||
|
||||
try:
|
||||
provider_data = PROVIDER_DATA_VAR.get()
|
||||
except LookupError:
|
||||
provider_data = None
|
||||
|
||||
if provider_data and "__test_id" in provider_data:
|
||||
return TEST_CONTEXT.set(provider_data["__test_id"])
|
||||
|
||||
return None
|
|
@ -108,7 +108,7 @@ class OpenAIResponsesImpl:
|
|||
# Use stored messages directly and convert only new input
|
||||
message_adapter = TypeAdapter(list[OpenAIMessageParam])
|
||||
messages = message_adapter.validate_python(previous_response.messages)
|
||||
new_messages = await convert_response_input_to_chat_messages(input)
|
||||
new_messages = await convert_response_input_to_chat_messages(input, previous_messages=messages)
|
||||
messages.extend(new_messages)
|
||||
else:
|
||||
# Backward compatibility: reconstruct from inputs
|
||||
|
|
|
@ -103,9 +103,13 @@ async def convert_response_content_to_chat_content(
|
|||
|
||||
async def convert_response_input_to_chat_messages(
|
||||
input: str | list[OpenAIResponseInput],
|
||||
previous_messages: list[OpenAIMessageParam] | None = None,
|
||||
) -> list[OpenAIMessageParam]:
|
||||
"""
|
||||
Convert the input from an OpenAI Response API request into OpenAI Chat Completion messages.
|
||||
|
||||
:param input: The input to convert
|
||||
:param previous_messages: Optional previous messages to check for function_call references
|
||||
"""
|
||||
messages: list[OpenAIMessageParam] = []
|
||||
if isinstance(input, list):
|
||||
|
@ -169,16 +173,53 @@ async def convert_response_input_to_chat_messages(
|
|||
raise ValueError(
|
||||
f"Llama Stack OpenAI Responses does not yet support message role '{input_item.role}' in this context"
|
||||
)
|
||||
# Skip user messages that duplicate the last user message in previous_messages
|
||||
# This handles cases where input includes context for function_call_outputs
|
||||
if previous_messages and input_item.role == "user":
|
||||
last_user_msg = None
|
||||
for msg in reversed(previous_messages):
|
||||
if isinstance(msg, OpenAIUserMessageParam):
|
||||
last_user_msg = msg
|
||||
break
|
||||
if last_user_msg:
|
||||
last_user_content = getattr(last_user_msg, "content", None)
|
||||
if last_user_content == content:
|
||||
continue # Skip duplicate user message
|
||||
messages.append(message_type(content=content))
|
||||
if len(tool_call_results):
|
||||
raise ValueError(
|
||||
f"Received function_call_output(s) with call_id(s) {tool_call_results.keys()}, but no corresponding function_call"
|
||||
)
|
||||
# Check if unpaired function_call_outputs reference function_calls from previous messages
|
||||
if previous_messages:
|
||||
previous_call_ids = _extract_tool_call_ids(previous_messages)
|
||||
for call_id in list(tool_call_results.keys()):
|
||||
if call_id in previous_call_ids:
|
||||
# Valid: this output references a call from previous messages
|
||||
# Add the tool message
|
||||
messages.append(tool_call_results[call_id])
|
||||
del tool_call_results[call_id]
|
||||
|
||||
# If still have unpaired outputs, error
|
||||
if len(tool_call_results):
|
||||
raise ValueError(
|
||||
f"Received function_call_output(s) with call_id(s) {tool_call_results.keys()}, but no corresponding function_call"
|
||||
)
|
||||
else:
|
||||
messages.append(OpenAIUserMessageParam(content=input))
|
||||
return messages
|
||||
|
||||
|
||||
def _extract_tool_call_ids(messages: list[OpenAIMessageParam]) -> set[str]:
|
||||
"""Extract all tool_call IDs from messages."""
|
||||
call_ids = set()
|
||||
for msg in messages:
|
||||
if isinstance(msg, OpenAIAssistantMessageParam):
|
||||
tool_calls = getattr(msg, "tool_calls", None)
|
||||
if tool_calls:
|
||||
for tool_call in tool_calls:
|
||||
# tool_call is a Pydantic model, use attribute access
|
||||
call_ids.add(tool_call.id)
|
||||
return call_ids
|
||||
|
||||
|
||||
async def convert_response_text_to_chat_response_format(
|
||||
text: OpenAIResponseText,
|
||||
) -> OpenAIResponseFormatParam:
|
||||
|
|
|
@ -22,6 +22,7 @@ from llama_stack.apis.files import (
|
|||
OpenAIFilePurpose,
|
||||
)
|
||||
from llama_stack.core.datatypes import AccessRule
|
||||
from llama_stack.core.id_generation import generate_object_id
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.files.form_data import parse_expires_after
|
||||
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
|
||||
|
@ -65,7 +66,7 @@ class LocalfsFilesImpl(Files):
|
|||
|
||||
def _generate_file_id(self) -> str:
|
||||
"""Generate a unique file ID for OpenAI API."""
|
||||
return f"file-{uuid.uuid4().hex}"
|
||||
return generate_object_id("file", lambda: f"file-{uuid.uuid4().hex}")
|
||||
|
||||
def _get_file_path(self, file_id: str) -> Path:
|
||||
"""Get the filesystem path for a file ID."""
|
||||
|
@ -95,7 +96,9 @@ class LocalfsFilesImpl(Files):
|
|||
raise RuntimeError("Files provider not initialized")
|
||||
|
||||
if expires_after is not None:
|
||||
raise NotImplementedError("File expiration is not supported by this provider")
|
||||
logger.warning(
|
||||
f"File expiration is not supported by this provider, ignoring expires_after: {expires_after}"
|
||||
)
|
||||
|
||||
file_id = self._generate_file_id()
|
||||
file_path = self._get_file_path(file_id)
|
||||
|
|
|
@ -23,6 +23,7 @@ from llama_stack.apis.files import (
|
|||
OpenAIFilePurpose,
|
||||
)
|
||||
from llama_stack.core.datatypes import AccessRule
|
||||
from llama_stack.core.id_generation import generate_object_id
|
||||
from llama_stack.providers.utils.files.form_data import parse_expires_after
|
||||
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
|
||||
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
||||
|
@ -198,7 +199,7 @@ class S3FilesImpl(Files):
|
|||
purpose: Annotated[OpenAIFilePurpose, Form()],
|
||||
expires_after: Annotated[ExpiresAfter | None, Depends(parse_expires_after)] = None,
|
||||
) -> OpenAIFileObject:
|
||||
file_id = f"file-{uuid.uuid4().hex}"
|
||||
file_id = generate_object_id("file", lambda: f"file-{uuid.uuid4().hex}")
|
||||
|
||||
filename = getattr(file, "filename", None) or "uploaded_file"
|
||||
|
||||
|
|
|
@ -40,6 +40,7 @@ from llama_stack.apis.vector_io import (
|
|||
VectorStoreSearchResponse,
|
||||
VectorStoreSearchResponsePage,
|
||||
)
|
||||
from llama_stack.core.id_generation import generate_object_id
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
|
@ -352,7 +353,7 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
"""Creates a vector store."""
|
||||
created_at = int(time.time())
|
||||
# Derive the canonical vector_db_id (allow override, else generate)
|
||||
vector_db_id = provider_vector_db_id or f"vs_{uuid.uuid4()}"
|
||||
vector_db_id = provider_vector_db_id or generate_object_id("vector_store", lambda: f"vs_{uuid.uuid4()}")
|
||||
|
||||
if provider_id is None:
|
||||
raise ValueError("Provider ID is required")
|
||||
|
@ -986,7 +987,7 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
chunking_strategy = chunking_strategy or VectorStoreChunkingStrategyAuto()
|
||||
|
||||
created_at = int(time.time())
|
||||
batch_id = f"batch_{uuid.uuid4()}"
|
||||
batch_id = generate_object_id("vector_store_file_batch", lambda: f"batch_{uuid.uuid4()}")
|
||||
# File batches expire after 7 days
|
||||
expires_at = created_at + (7 * 24 * 60 * 60)
|
||||
|
||||
|
|
|
@ -9,7 +9,8 @@ from __future__ import annotations # for forward references
|
|||
import hashlib
|
||||
import json
|
||||
import os
|
||||
from collections.abc import Generator
|
||||
import re
|
||||
from collections.abc import Callable, Generator
|
||||
from contextlib import contextmanager
|
||||
from enum import StrEnum
|
||||
from pathlib import Path
|
||||
|
@ -17,6 +18,7 @@ from typing import Any, Literal, cast
|
|||
|
||||
from openai import NOT_GIVEN, OpenAI
|
||||
|
||||
from llama_stack.core.id_generation import reset_id_override, set_id_override
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
logger = get_logger(__name__, category="testing")
|
||||
|
@ -29,13 +31,14 @@ _current_mode: str | None = None
|
|||
_current_storage: ResponseStorage | None = None
|
||||
_original_methods: dict[str, Any] = {}
|
||||
|
||||
# Per-test deterministic ID counters (test_id -> id_kind -> counter)
|
||||
_id_counters: dict[str, dict[str, int]] = {}
|
||||
|
||||
# Test context uses ContextVar since it changes per-test and needs async isolation
|
||||
from contextvars import ContextVar
|
||||
|
||||
_test_context: ContextVar[str | None] = ContextVar("_test_context", default=None)
|
||||
|
||||
from openai.types.completion_choice import CompletionChoice
|
||||
|
||||
from llama_stack.core.testing_context import get_test_context
|
||||
|
||||
# update the "finish_reason" field, since its type definition is wrong (no None is accepted)
|
||||
CompletionChoice.model_fields["finish_reason"].annotation = Literal["stop", "length", "content_filter"] | None
|
||||
CompletionChoice.model_rebuild()
|
||||
|
@ -44,14 +47,89 @@ REPO_ROOT = Path(__file__).parent.parent.parent
|
|||
DEFAULT_STORAGE_DIR = REPO_ROOT / "tests/integration/common"
|
||||
|
||||
|
||||
class InferenceMode(StrEnum):
|
||||
class APIRecordingMode(StrEnum):
|
||||
LIVE = "live"
|
||||
RECORD = "record"
|
||||
REPLAY = "replay"
|
||||
RECORD_IF_MISSING = "record-if-missing"
|
||||
|
||||
|
||||
def normalize_request(method: str, url: str, headers: dict[str, Any], body: dict[str, Any]) -> str:
|
||||
_ID_KIND_PREFIXES: dict[str, str] = {
|
||||
"file": "file-",
|
||||
"vector_store": "vs_",
|
||||
"vector_store_file_batch": "batch_",
|
||||
"tool_call": "call_",
|
||||
}
|
||||
|
||||
|
||||
_FLOAT_IN_STRING_PATTERN = re.compile(r"(-?\d+\.\d{4,})")
|
||||
|
||||
|
||||
def _normalize_numeric_literal_strings(value: str) -> str:
|
||||
"""Round any long decimal literals embedded in strings for stable hashing."""
|
||||
|
||||
def _replace(match: re.Match[str]) -> str:
|
||||
number = float(match.group(0))
|
||||
return f"{number:.5f}"
|
||||
|
||||
return _FLOAT_IN_STRING_PATTERN.sub(_replace, value)
|
||||
|
||||
|
||||
def _normalize_body_for_hash(value: Any) -> Any:
|
||||
"""Recursively normalize a JSON-like value to improve hash stability."""
|
||||
|
||||
if isinstance(value, dict):
|
||||
return {key: _normalize_body_for_hash(item) for key, item in value.items()}
|
||||
if isinstance(value, list):
|
||||
return [_normalize_body_for_hash(item) for item in value]
|
||||
if isinstance(value, tuple):
|
||||
return tuple(_normalize_body_for_hash(item) for item in value)
|
||||
if isinstance(value, float):
|
||||
return round(value, 5)
|
||||
if isinstance(value, str):
|
||||
return _normalize_numeric_literal_strings(value)
|
||||
return value
|
||||
|
||||
|
||||
def _allocate_test_scoped_id(kind: str) -> str | None:
|
||||
"""Return the next deterministic ID for the given kind within the current test."""
|
||||
|
||||
global _id_counters
|
||||
|
||||
test_id = get_test_context()
|
||||
prefix = _ID_KIND_PREFIXES.get(kind)
|
||||
|
||||
if prefix is None:
|
||||
return None
|
||||
|
||||
if not test_id:
|
||||
raise ValueError(f"Test ID is required for {kind} ID allocation")
|
||||
|
||||
key = test_id
|
||||
if key not in _id_counters:
|
||||
_id_counters[key] = {}
|
||||
|
||||
# each test should get a contiguous block of IDs otherwise we will get
|
||||
# collisions between tests inside other systems (like file storage) which
|
||||
# expect IDs to be unique
|
||||
test_hash = hashlib.sha256(test_id.encode()).hexdigest()
|
||||
test_hash_int = int(test_hash, 16)
|
||||
counter = test_hash_int % 1000000000000
|
||||
|
||||
counter = _id_counters[key].get(kind, counter) + 1
|
||||
_id_counters[key][kind] = counter
|
||||
|
||||
return f"{prefix}{counter}"
|
||||
|
||||
|
||||
def _deterministic_id_override(kind: str, factory: Callable[[], str]) -> str:
|
||||
deterministic_id = _allocate_test_scoped_id(kind)
|
||||
if deterministic_id is not None:
|
||||
return deterministic_id
|
||||
return factory()
|
||||
|
||||
|
||||
def normalize_inference_request(method: str, url: str, headers: dict[str, Any], body: dict[str, Any]) -> str:
|
||||
"""Create a normalized hash of the request for consistent matching.
|
||||
|
||||
Includes test_id from context to ensure test isolation - identical requests
|
||||
|
@ -60,50 +138,39 @@ def normalize_request(method: str, url: str, headers: dict[str, Any], body: dict
|
|||
Exception: Model list endpoints (/v1/models, /api/tags) exclude test_id since
|
||||
they are infrastructure/shared and need to work across session setup and tests.
|
||||
"""
|
||||
|
||||
# Extract just the endpoint path
|
||||
from urllib.parse import urlparse
|
||||
|
||||
parsed = urlparse(url)
|
||||
|
||||
body_for_hash = _normalize_body_for_hash(body)
|
||||
|
||||
normalized: dict[str, Any] = {
|
||||
"method": method.upper(),
|
||||
"endpoint": parsed.path,
|
||||
"body": body,
|
||||
"body": body_for_hash,
|
||||
}
|
||||
|
||||
# Include test_id for isolation, except for shared infrastructure endpoints
|
||||
if parsed.path not in ("/api/tags", "/v1/models"):
|
||||
normalized["test_id"] = _test_context.get()
|
||||
normalized["test_id"] = get_test_context()
|
||||
|
||||
# Create hash - sort_keys=True ensures deterministic ordering
|
||||
normalized_json = json.dumps(normalized, sort_keys=True)
|
||||
return hashlib.sha256(normalized_json.encode()).hexdigest()
|
||||
|
||||
|
||||
def _sync_test_context_from_provider_data():
|
||||
"""In server mode, sync test ID from provider_data to _test_context.
|
||||
def normalize_tool_request(provider_name: str, tool_name: str, kwargs: dict[str, Any]) -> str:
|
||||
"""Create a normalized hash of the tool request for consistent matching."""
|
||||
normalized = {
|
||||
"provider": provider_name,
|
||||
"tool_name": tool_name,
|
||||
"kwargs": kwargs,
|
||||
}
|
||||
|
||||
This ensures that storage operations (which read from _test_context) work correctly
|
||||
in server mode where the test ID arrives via HTTP header → provider_data.
|
||||
|
||||
Returns a token to reset _test_context, or None if no sync was needed.
|
||||
"""
|
||||
stack_config_type = os.environ.get("LLAMA_STACK_TEST_STACK_CONFIG_TYPE", "library_client")
|
||||
|
||||
if stack_config_type != "server":
|
||||
return None
|
||||
|
||||
try:
|
||||
from llama_stack.core.request_headers import PROVIDER_DATA_VAR
|
||||
|
||||
provider_data = PROVIDER_DATA_VAR.get()
|
||||
|
||||
if provider_data and "__test_id" in provider_data:
|
||||
test_id = provider_data["__test_id"]
|
||||
return _test_context.set(test_id)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
return None
|
||||
# Create hash - sort_keys=True ensures deterministic ordering
|
||||
normalized_json = json.dumps(normalized, sort_keys=True)
|
||||
return hashlib.sha256(normalized_json.encode()).hexdigest()
|
||||
|
||||
|
||||
def patch_httpx_for_test_id():
|
||||
|
@ -127,13 +194,12 @@ def patch_httpx_for_test_id():
|
|||
def patched_prepare_request(self, request):
|
||||
# Call original first (it's a sync method that returns None)
|
||||
# Determine which original to call based on client type
|
||||
if "llama_stack_client" in self.__class__.__module__:
|
||||
_original_methods["llama_stack_client_prepare_request"](self, request)
|
||||
_original_methods["openai_prepare_request"](self, request)
|
||||
_original_methods["llama_stack_client_prepare_request"](self, request)
|
||||
_original_methods["openai_prepare_request"](self, request)
|
||||
|
||||
# Only inject test ID in server mode
|
||||
stack_config_type = os.environ.get("LLAMA_STACK_TEST_STACK_CONFIG_TYPE", "library_client")
|
||||
test_id = _test_context.get()
|
||||
test_id = get_test_context()
|
||||
|
||||
if stack_config_type == "server" and test_id:
|
||||
provider_data_header = request.headers.get("X-LlamaStack-Provider-Data")
|
||||
|
@ -162,23 +228,22 @@ def unpatch_httpx_for_test_id():
|
|||
|
||||
LlamaStackClient._prepare_request = _original_methods["llama_stack_client_prepare_request"]
|
||||
del _original_methods["llama_stack_client_prepare_request"]
|
||||
|
||||
# Also restore OpenAI client if it was patched
|
||||
if "openai_prepare_request" in _original_methods:
|
||||
OpenAI._prepare_request = _original_methods["openai_prepare_request"]
|
||||
del _original_methods["openai_prepare_request"]
|
||||
OpenAI._prepare_request = _original_methods["openai_prepare_request"]
|
||||
del _original_methods["openai_prepare_request"]
|
||||
|
||||
|
||||
def get_inference_mode() -> InferenceMode:
|
||||
return InferenceMode(os.environ.get("LLAMA_STACK_TEST_INFERENCE_MODE", "replay").lower())
|
||||
def get_api_recording_mode() -> APIRecordingMode:
|
||||
return APIRecordingMode(os.environ.get("LLAMA_STACK_TEST_INFERENCE_MODE", "replay").lower())
|
||||
|
||||
|
||||
def setup_inference_recording():
|
||||
def setup_api_recording():
|
||||
"""
|
||||
Returns a context manager that can be used to record or replay inference requests. This is to be used in tests
|
||||
to increase their reliability and reduce reliance on expensive, external services.
|
||||
Returns a context manager that can be used to record or replay API requests (inference and tools).
|
||||
This is to be used in tests to increase their reliability and reduce reliance on expensive, external services.
|
||||
|
||||
Currently, this is only supported for OpenAI and Ollama clients. These should cover the vast majority of use cases.
|
||||
Currently supports:
|
||||
- Inference: OpenAI and Ollama clients
|
||||
- Tools: Search providers (Tavily)
|
||||
|
||||
Two environment variables are supported:
|
||||
- LLAMA_STACK_TEST_INFERENCE_MODE: The mode to run in. Must be 'live', 'record', 'replay', or 'record-if-missing'. Default is 'replay'.
|
||||
|
@ -190,15 +255,15 @@ def setup_inference_recording():
|
|||
|
||||
The recordings are stored as JSON files.
|
||||
"""
|
||||
mode = get_inference_mode()
|
||||
if mode == InferenceMode.LIVE:
|
||||
mode = get_api_recording_mode()
|
||||
if mode == APIRecordingMode.LIVE:
|
||||
return None
|
||||
|
||||
storage_dir = os.environ.get("LLAMA_STACK_TEST_RECORDING_DIR", DEFAULT_STORAGE_DIR)
|
||||
return inference_recording(mode=mode, storage_dir=storage_dir)
|
||||
return api_recording(mode=mode, storage_dir=storage_dir)
|
||||
|
||||
|
||||
def _normalize_response_data(data: dict[str, Any], request_hash: str) -> dict[str, Any]:
|
||||
def _normalize_response(data: dict[str, Any], request_hash: str) -> dict[str, Any]:
|
||||
"""Normalize fields that change between recordings but don't affect functionality.
|
||||
|
||||
This reduces noise in git diffs by making IDs deterministic and timestamps constant.
|
||||
|
@ -234,7 +299,7 @@ def _serialize_response(response: Any, request_hash: str = "") -> Any:
|
|||
if hasattr(response, "model_dump"):
|
||||
data = response.model_dump(mode="json")
|
||||
# Normalize fields to reduce noise
|
||||
data = _normalize_response_data(data, request_hash)
|
||||
data = _normalize_response(data, request_hash)
|
||||
return {
|
||||
"__type__": f"{response.__class__.__module__}.{response.__class__.__qualname__}",
|
||||
"__data__": data,
|
||||
|
@ -282,7 +347,7 @@ class ResponseStorage:
|
|||
For test at "tests/integration/inference/test_foo.py::test_bar",
|
||||
returns "tests/integration/inference/recordings/".
|
||||
"""
|
||||
test_id = _test_context.get()
|
||||
test_id = get_test_context()
|
||||
if test_id:
|
||||
# Extract the directory path from the test nodeid
|
||||
# e.g., "tests/integration/inference/test_basic.py::test_foo[params]"
|
||||
|
@ -297,7 +362,7 @@ class ResponseStorage:
|
|||
# Fallback for non-test contexts
|
||||
return self.base_dir / "recordings"
|
||||
|
||||
def _ensure_directories(self):
|
||||
def _ensure_directory(self):
|
||||
"""Ensure test-specific directories exist."""
|
||||
test_dir = self._get_test_dir()
|
||||
test_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
@ -305,7 +370,7 @@ class ResponseStorage:
|
|||
|
||||
def store_recording(self, request_hash: str, request: dict[str, Any], response: dict[str, Any]):
|
||||
"""Store a request/response pair."""
|
||||
responses_dir = self._ensure_directories()
|
||||
responses_dir = self._ensure_directory()
|
||||
|
||||
# Use FULL hash (not truncated)
|
||||
response_file = f"{request_hash}.json"
|
||||
|
@ -334,9 +399,10 @@ class ResponseStorage:
|
|||
with open(response_path, "w") as f:
|
||||
json.dump(
|
||||
{
|
||||
"test_id": _test_context.get(),
|
||||
"test_id": get_test_context(),
|
||||
"request": request,
|
||||
"response": serialized_response,
|
||||
"id_normalization_mapping": {},
|
||||
},
|
||||
f,
|
||||
indent=2,
|
||||
|
@ -394,6 +460,14 @@ def _recording_from_file(response_path) -> dict[str, Any]:
|
|||
with open(response_path) as f:
|
||||
data = json.load(f)
|
||||
|
||||
mapping = data.get("id_normalization_mapping") or {}
|
||||
if mapping:
|
||||
serialized = json.dumps(data)
|
||||
for normalized, original in mapping.items():
|
||||
serialized = serialized.replace(original, normalized)
|
||||
data = json.loads(serialized)
|
||||
data["id_normalization_mapping"] = {}
|
||||
|
||||
# Deserialize response body if needed
|
||||
if "response" in data and "body" in data["response"]:
|
||||
if isinstance(data["response"]["body"], list):
|
||||
|
@ -464,131 +538,168 @@ def _combine_model_list_responses(endpoint: str, records: list[dict[str, Any]])
|
|||
return {"request": canonical_req, "response": {"body": body, "is_streaming": False}}
|
||||
|
||||
|
||||
async def _patched_tool_invoke_method(
|
||||
original_method, provider_name: str, self, tool_name: str, kwargs: dict[str, Any]
|
||||
):
|
||||
"""Patched version of tool runtime invoke_tool method for recording/replay."""
|
||||
global _current_mode, _current_storage
|
||||
|
||||
if _current_mode == APIRecordingMode.LIVE or _current_storage is None:
|
||||
# Normal operation
|
||||
return await original_method(self, tool_name, kwargs)
|
||||
|
||||
request_hash = normalize_tool_request(provider_name, tool_name, kwargs)
|
||||
|
||||
if _current_mode in (APIRecordingMode.REPLAY, APIRecordingMode.RECORD_IF_MISSING):
|
||||
recording = _current_storage.find_recording(request_hash)
|
||||
if recording:
|
||||
return recording["response"]["body"]
|
||||
elif _current_mode == APIRecordingMode.REPLAY:
|
||||
raise RuntimeError(
|
||||
f"No recorded tool result found for {provider_name}.{tool_name}\n"
|
||||
f"Request: {kwargs}\n"
|
||||
f"To record this response, run with LLAMA_STACK_TEST_INFERENCE_MODE=record"
|
||||
)
|
||||
# If RECORD_IF_MISSING and no recording found, fall through to record
|
||||
|
||||
if _current_mode in (APIRecordingMode.RECORD, APIRecordingMode.RECORD_IF_MISSING):
|
||||
# Make the tool call and record it
|
||||
result = await original_method(self, tool_name, kwargs)
|
||||
|
||||
request_data = {
|
||||
"test_id": get_test_context(),
|
||||
"provider": provider_name,
|
||||
"tool_name": tool_name,
|
||||
"kwargs": kwargs,
|
||||
}
|
||||
response_data = {"body": result, "is_streaming": False}
|
||||
|
||||
# Store the recording
|
||||
_current_storage.store_recording(request_hash, request_data, response_data)
|
||||
return result
|
||||
|
||||
else:
|
||||
raise AssertionError(f"Invalid mode: {_current_mode}")
|
||||
|
||||
|
||||
async def _patched_inference_method(original_method, self, client_type, endpoint, *args, **kwargs):
|
||||
global _current_mode, _current_storage
|
||||
|
||||
mode = _current_mode
|
||||
storage = _current_storage
|
||||
|
||||
if mode == InferenceMode.LIVE or storage is None:
|
||||
if mode == APIRecordingMode.LIVE or storage is None:
|
||||
if endpoint == "/v1/models":
|
||||
return original_method(self, *args, **kwargs)
|
||||
else:
|
||||
return await original_method(self, *args, **kwargs)
|
||||
|
||||
# In server mode, sync test ID from provider_data to _test_context for storage operations
|
||||
test_context_token = _sync_test_context_from_provider_data()
|
||||
# Get base URL based on client type
|
||||
if client_type == "openai":
|
||||
base_url = str(self._client.base_url)
|
||||
|
||||
try:
|
||||
# Get base URL based on client type
|
||||
if client_type == "openai":
|
||||
base_url = str(self._client.base_url)
|
||||
# the OpenAI client methods may pass NOT_GIVEN for unset parameters; filter these out
|
||||
kwargs = {k: v for k, v in kwargs.items() if v is not NOT_GIVEN}
|
||||
elif client_type == "ollama":
|
||||
# Get base URL from the client (Ollama client uses host attribute)
|
||||
base_url = getattr(self, "host", "http://localhost:11434")
|
||||
if not base_url.startswith("http"):
|
||||
base_url = f"http://{base_url}"
|
||||
else:
|
||||
raise ValueError(f"Unknown client type: {client_type}")
|
||||
|
||||
# the OpenAI client methods may pass NOT_GIVEN for unset parameters; filter these out
|
||||
kwargs = {k: v for k, v in kwargs.items() if v is not NOT_GIVEN}
|
||||
elif client_type == "ollama":
|
||||
# Get base URL from the client (Ollama client uses host attribute)
|
||||
base_url = getattr(self, "host", "http://localhost:11434")
|
||||
if not base_url.startswith("http"):
|
||||
base_url = f"http://{base_url}"
|
||||
url = base_url.rstrip("/") + endpoint
|
||||
# Special handling for Databricks URLs to avoid leaking workspace info
|
||||
# e.g. https://adb-1234567890123456.7.cloud.databricks.com -> https://...cloud.databricks.com
|
||||
if "cloud.databricks.com" in url:
|
||||
url = "__databricks__" + url.split("cloud.databricks.com")[-1]
|
||||
method = "POST"
|
||||
headers = {}
|
||||
body = kwargs
|
||||
|
||||
request_hash = normalize_inference_request(method, url, headers, body)
|
||||
|
||||
# Try to find existing recording for REPLAY or RECORD_IF_MISSING modes
|
||||
recording = None
|
||||
if mode == APIRecordingMode.REPLAY or mode == APIRecordingMode.RECORD_IF_MISSING:
|
||||
# Special handling for model-list endpoints: merge all recordings with this hash
|
||||
if endpoint in ("/api/tags", "/v1/models"):
|
||||
records = storage._model_list_responses(request_hash)
|
||||
recording = _combine_model_list_responses(endpoint, records)
|
||||
else:
|
||||
raise ValueError(f"Unknown client type: {client_type}")
|
||||
recording = storage.find_recording(request_hash)
|
||||
|
||||
url = base_url.rstrip("/") + endpoint
|
||||
# Special handling for Databricks URLs to avoid leaking workspace info
|
||||
# e.g. https://adb-1234567890123456.7.cloud.databricks.com -> https://...cloud.databricks.com
|
||||
if "cloud.databricks.com" in url:
|
||||
url = "__databricks__" + url.split("cloud.databricks.com")[-1]
|
||||
method = "POST"
|
||||
headers = {}
|
||||
body = kwargs
|
||||
if recording:
|
||||
response_body = recording["response"]["body"]
|
||||
|
||||
request_hash = normalize_request(method, url, headers, body)
|
||||
if recording["response"].get("is_streaming", False):
|
||||
|
||||
# Try to find existing recording for REPLAY or RECORD_IF_MISSING modes
|
||||
recording = None
|
||||
if mode == InferenceMode.REPLAY or mode == InferenceMode.RECORD_IF_MISSING:
|
||||
# Special handling for model-list endpoints: merge all recordings with this hash
|
||||
if endpoint in ("/api/tags", "/v1/models"):
|
||||
records = storage._model_list_responses(request_hash)
|
||||
recording = _combine_model_list_responses(endpoint, records)
|
||||
else:
|
||||
recording = storage.find_recording(request_hash)
|
||||
|
||||
if recording:
|
||||
response_body = recording["response"]["body"]
|
||||
|
||||
if recording["response"].get("is_streaming", False):
|
||||
|
||||
async def replay_stream():
|
||||
for chunk in response_body:
|
||||
yield chunk
|
||||
|
||||
return replay_stream()
|
||||
else:
|
||||
return response_body
|
||||
elif mode == InferenceMode.REPLAY:
|
||||
# REPLAY mode requires recording to exist
|
||||
raise RuntimeError(
|
||||
f"No recorded response found for request hash: {request_hash}\n"
|
||||
f"Request: {method} {url} {body}\n"
|
||||
f"Model: {body.get('model', 'unknown')}\n"
|
||||
f"To record this response, run with LLAMA_STACK_TEST_INFERENCE_MODE=record"
|
||||
)
|
||||
|
||||
if mode == InferenceMode.RECORD or (mode == InferenceMode.RECORD_IF_MISSING and not recording):
|
||||
if endpoint == "/v1/models":
|
||||
response = original_method(self, *args, **kwargs)
|
||||
else:
|
||||
response = await original_method(self, *args, **kwargs)
|
||||
|
||||
# we want to store the result of the iterator, not the iterator itself
|
||||
if endpoint == "/v1/models":
|
||||
response = [m async for m in response]
|
||||
|
||||
request_data = {
|
||||
"method": method,
|
||||
"url": url,
|
||||
"headers": headers,
|
||||
"body": body,
|
||||
"endpoint": endpoint,
|
||||
"model": body.get("model", ""),
|
||||
}
|
||||
|
||||
# Determine if this is a streaming request based on request parameters
|
||||
is_streaming = body.get("stream", False)
|
||||
|
||||
if is_streaming:
|
||||
# For streaming responses, we need to collect all chunks immediately before yielding
|
||||
# This ensures the recording is saved even if the generator isn't fully consumed
|
||||
chunks = []
|
||||
async for chunk in response:
|
||||
chunks.append(chunk)
|
||||
|
||||
# Store the recording immediately
|
||||
response_data = {"body": chunks, "is_streaming": True}
|
||||
storage.store_recording(request_hash, request_data, response_data)
|
||||
|
||||
# Return a generator that replays the stored chunks
|
||||
async def replay_recorded_stream():
|
||||
for chunk in chunks:
|
||||
async def replay_stream():
|
||||
for chunk in response_body:
|
||||
yield chunk
|
||||
|
||||
return replay_recorded_stream()
|
||||
return replay_stream()
|
||||
else:
|
||||
response_data = {"body": response, "is_streaming": False}
|
||||
storage.store_recording(request_hash, request_data, response_data)
|
||||
return response
|
||||
return response_body
|
||||
elif mode == APIRecordingMode.REPLAY:
|
||||
# REPLAY mode requires recording to exist
|
||||
raise RuntimeError(
|
||||
f"No recorded response found for request hash: {request_hash}\n"
|
||||
f"Request: {method} {url} {body}\n"
|
||||
f"Model: {body.get('model', 'unknown')}\n"
|
||||
f"To record this response, run with LLAMA_STACK_TEST_INFERENCE_MODE=record"
|
||||
)
|
||||
|
||||
if mode == APIRecordingMode.RECORD or (mode == APIRecordingMode.RECORD_IF_MISSING and not recording):
|
||||
if endpoint == "/v1/models":
|
||||
response = original_method(self, *args, **kwargs)
|
||||
else:
|
||||
raise AssertionError(f"Invalid mode: {mode}")
|
||||
finally:
|
||||
if test_context_token:
|
||||
_test_context.reset(test_context_token)
|
||||
response = await original_method(self, *args, **kwargs)
|
||||
|
||||
# we want to store the result of the iterator, not the iterator itself
|
||||
if endpoint == "/v1/models":
|
||||
response = [m async for m in response]
|
||||
|
||||
request_data = {
|
||||
"method": method,
|
||||
"url": url,
|
||||
"headers": headers,
|
||||
"body": body,
|
||||
"endpoint": endpoint,
|
||||
"model": body.get("model", ""),
|
||||
}
|
||||
|
||||
# Determine if this is a streaming request based on request parameters
|
||||
is_streaming = body.get("stream", False)
|
||||
|
||||
if is_streaming:
|
||||
# For streaming responses, we need to collect all chunks immediately before yielding
|
||||
# This ensures the recording is saved even if the generator isn't fully consumed
|
||||
chunks: list[Any] = []
|
||||
async for chunk in response:
|
||||
chunks.append(chunk)
|
||||
|
||||
# Store the recording immediately
|
||||
response_data = {"body": chunks, "is_streaming": True}
|
||||
storage.store_recording(request_hash, request_data, response_data)
|
||||
|
||||
# Return a generator that replays the stored chunks
|
||||
async def replay_recorded_stream():
|
||||
for chunk in chunks:
|
||||
yield chunk
|
||||
|
||||
return replay_recorded_stream()
|
||||
else:
|
||||
response_data = {"body": response, "is_streaming": False}
|
||||
storage.store_recording(request_hash, request_data, response_data)
|
||||
return response
|
||||
|
||||
else:
|
||||
raise AssertionError(f"Invalid mode: {mode}")
|
||||
|
||||
|
||||
def patch_inference_clients():
|
||||
"""Install monkey patches for OpenAI client methods and Ollama AsyncClient methods."""
|
||||
"""Install monkey patches for OpenAI client methods, Ollama AsyncClient methods, and tool runtime methods."""
|
||||
global _original_methods
|
||||
|
||||
from ollama import AsyncClient as OllamaAsyncClient
|
||||
|
@ -597,7 +708,9 @@ def patch_inference_clients():
|
|||
from openai.resources.embeddings import AsyncEmbeddings
|
||||
from openai.resources.models import AsyncModels
|
||||
|
||||
# Store original methods for both OpenAI and Ollama clients
|
||||
from llama_stack.providers.remote.tool_runtime.tavily_search.tavily_search import TavilySearchToolRuntimeImpl
|
||||
|
||||
# Store original methods for OpenAI, Ollama clients, and tool runtimes
|
||||
_original_methods = {
|
||||
"chat_completions_create": AsyncChatCompletions.create,
|
||||
"completions_create": AsyncCompletions.create,
|
||||
|
@ -609,6 +722,7 @@ def patch_inference_clients():
|
|||
"ollama_ps": OllamaAsyncClient.ps,
|
||||
"ollama_pull": OllamaAsyncClient.pull,
|
||||
"ollama_list": OllamaAsyncClient.list,
|
||||
"tavily_invoke_tool": TavilySearchToolRuntimeImpl.invoke_tool,
|
||||
}
|
||||
|
||||
# Create patched methods for OpenAI client
|
||||
|
@ -681,9 +795,18 @@ def patch_inference_clients():
|
|||
OllamaAsyncClient.pull = patched_ollama_pull
|
||||
OllamaAsyncClient.list = patched_ollama_list
|
||||
|
||||
# Create patched methods for tool runtimes
|
||||
async def patched_tavily_invoke_tool(self, tool_name: str, kwargs: dict[str, Any]):
|
||||
return await _patched_tool_invoke_method(
|
||||
_original_methods["tavily_invoke_tool"], "tavily", self, tool_name, kwargs
|
||||
)
|
||||
|
||||
# Apply tool runtime patches
|
||||
TavilySearchToolRuntimeImpl.invoke_tool = patched_tavily_invoke_tool
|
||||
|
||||
|
||||
def unpatch_inference_clients():
|
||||
"""Remove monkey patches and restore original OpenAI and Ollama client methods."""
|
||||
"""Remove monkey patches and restore original OpenAI, Ollama client, and tool runtime methods."""
|
||||
global _original_methods
|
||||
|
||||
if not _original_methods:
|
||||
|
@ -696,6 +819,8 @@ def unpatch_inference_clients():
|
|||
from openai.resources.embeddings import AsyncEmbeddings
|
||||
from openai.resources.models import AsyncModels
|
||||
|
||||
from llama_stack.providers.remote.tool_runtime.tavily_search.tavily_search import TavilySearchToolRuntimeImpl
|
||||
|
||||
# Restore OpenAI client methods
|
||||
AsyncChatCompletions.create = _original_methods["chat_completions_create"]
|
||||
AsyncCompletions.create = _original_methods["completions_create"]
|
||||
|
@ -710,17 +835,21 @@ def unpatch_inference_clients():
|
|||
OllamaAsyncClient.pull = _original_methods["ollama_pull"]
|
||||
OllamaAsyncClient.list = _original_methods["ollama_list"]
|
||||
|
||||
# Restore tool runtime methods
|
||||
TavilySearchToolRuntimeImpl.invoke_tool = _original_methods["tavily_invoke_tool"]
|
||||
|
||||
_original_methods.clear()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def inference_recording(mode: str, storage_dir: str | Path | None = None) -> Generator[None, None, None]:
|
||||
"""Context manager for inference recording/replaying."""
|
||||
def api_recording(mode: str, storage_dir: str | Path | None = None) -> Generator[None, None, None]:
|
||||
"""Context manager for API recording/replaying (inference and tools)."""
|
||||
global _current_mode, _current_storage
|
||||
|
||||
# Store previous state
|
||||
prev_mode = _current_mode
|
||||
prev_storage = _current_storage
|
||||
previous_override = None
|
||||
|
||||
try:
|
||||
_current_mode = mode
|
||||
|
@ -729,7 +858,9 @@ def inference_recording(mode: str, storage_dir: str | Path | None = None) -> Gen
|
|||
if storage_dir is None:
|
||||
raise ValueError("storage_dir is required for record, replay, and record-if-missing modes")
|
||||
_current_storage = ResponseStorage(Path(storage_dir))
|
||||
_id_counters.clear()
|
||||
patch_inference_clients()
|
||||
previous_override = set_id_override(_deterministic_id_override)
|
||||
|
||||
yield
|
||||
|
||||
|
@ -737,6 +868,7 @@ def inference_recording(mode: str, storage_dir: str | Path | None = None) -> Gen
|
|||
# Restore previous state
|
||||
if mode in ["record", "replay", "record-if-missing"]:
|
||||
unpatch_inference_clients()
|
||||
reset_id_override(previous_override)
|
||||
|
||||
_current_mode = prev_mode
|
||||
_current_storage = prev_storage
|
Loading…
Add table
Add a link
Reference in a new issue