feat(tests): make inference_recorder into api_recorder (include tool_invoke) (#3403)

Renames `inference_recorder.py` to `api_recorder.py` and extends it to
support recording/replaying tool invocations in addition to inference
calls.

This allows us to record web-search, etc. tool calls and thereafter
apply recordings for `tests/integration/responses`

## Test Plan

```
export OPENAI_API_KEY=...
export TAVILY_SEARCH_API_KEY=...

./scripts/integration-tests.sh --stack-config ci-tests \
   --suite responses --inference-mode record-if-missing
```
This commit is contained in:
Ashwin Bharambe 2025-10-09 14:27:51 -07:00 committed by GitHub
parent 26fd5dbd34
commit f50ce11a3b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
284 changed files with 296191 additions and 631 deletions

View file

@ -0,0 +1,42 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import Callable
IdFactory = Callable[[], str]
IdOverride = Callable[[str, IdFactory], str]
_id_override: IdOverride | None = None
def generate_object_id(kind: str, factory: IdFactory) -> str:
"""Generate an identifier for the given kind using the provided factory.
Allows tests to override ID generation deterministically by installing an
override callback via :func:`set_id_override`.
"""
override = _id_override
if override is not None:
return override(kind, factory)
return factory()
def set_id_override(override: IdOverride) -> IdOverride | None:
"""Install an override used to generate deterministic identifiers."""
global _id_override
previous = _id_override
_id_override = override
return previous
def reset_id_override(previous: IdOverride | None) -> None:
"""Restore the previous override returned by :func:`set_id_override`."""
global _id_override
_id_override = previous

View file

@ -232,14 +232,25 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
await log_request_pre_validation(request)
test_context_token = None
# Use context manager with both provider data and auth attributes
with request_provider_data_context(request.headers, user):
if os.environ.get("LLAMA_STACK_TEST_INFERENCE_MODE"):
from llama_stack.core.testing_context import (
TEST_CONTEXT,
reset_test_context,
sync_test_context_from_provider_data,
)
test_context_token = sync_test_context_from_provider_data()
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
try:
if is_streaming:
gen = preserve_contexts_async_generator(
sse_generator(func(**kwargs)), [CURRENT_TRACE_CONTEXT, PROVIDER_DATA_VAR]
sse_generator(func(**kwargs)), [CURRENT_TRACE_CONTEXT, PROVIDER_DATA_VAR, TEST_CONTEXT]
)
return StreamingResponse(gen, media_type="text/event-stream")
else:
@ -258,6 +269,9 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
else:
logger.error(f"Error executing endpoint {route=} {method=}: {str(e)}")
raise translate_exception(e) from e
finally:
if test_context_token is not None:
reset_test_context(test_context_token)
sig = inspect.signature(func)

View file

@ -316,13 +316,13 @@ class Stack:
# asked for in the run config.
async def initialize(self):
if "LLAMA_STACK_TEST_INFERENCE_MODE" in os.environ:
from llama_stack.testing.inference_recorder import setup_inference_recording
from llama_stack.testing.api_recorder import setup_api_recording
global TEST_RECORDING_CONTEXT
TEST_RECORDING_CONTEXT = setup_inference_recording()
TEST_RECORDING_CONTEXT = setup_api_recording()
if TEST_RECORDING_CONTEXT:
TEST_RECORDING_CONTEXT.__enter__()
logger.info(f"Inference recording enabled: mode={os.environ.get('LLAMA_STACK_TEST_INFERENCE_MODE')}")
logger.info(f"API recording enabled: mode={os.environ.get('LLAMA_STACK_TEST_INFERENCE_MODE')}")
dist_registry, _ = await create_dist_registry(self.run_config.metadata_store, self.run_config.image_name)
policy = self.run_config.server.auth.access_policy if self.run_config.server.auth else []
@ -381,7 +381,7 @@ class Stack:
try:
TEST_RECORDING_CONTEXT.__exit__(None, None, None)
except Exception as e:
logger.error(f"Error during inference recording cleanup: {e}")
logger.error(f"Error during API recording cleanup: {e}")
global REGISTRY_REFRESH_TASK
if REGISTRY_REFRESH_TASK:

View file

@ -0,0 +1,44 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
from contextvars import ContextVar
from llama_stack.core.request_headers import PROVIDER_DATA_VAR
TEST_CONTEXT: ContextVar[str | None] = ContextVar("llama_stack_test_context", default=None)
def get_test_context() -> str | None:
return TEST_CONTEXT.get()
def set_test_context(value: str | None):
return TEST_CONTEXT.set(value)
def reset_test_context(token) -> None:
TEST_CONTEXT.reset(token)
def sync_test_context_from_provider_data():
"""Sync test context from provider data when running in server test mode."""
if "LLAMA_STACK_TEST_INFERENCE_MODE" not in os.environ:
return None
stack_config_type = os.environ.get("LLAMA_STACK_TEST_STACK_CONFIG_TYPE", "library_client")
if stack_config_type != "server":
return None
try:
provider_data = PROVIDER_DATA_VAR.get()
except LookupError:
provider_data = None
if provider_data and "__test_id" in provider_data:
return TEST_CONTEXT.set(provider_data["__test_id"])
return None