From 9205731cd6f41cace21a38fa1d78aefcacfa2e1a Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Sat, 4 Oct 2025 11:53:44 -0700 Subject: [PATCH] feat(tests): make inference_recorder into api_recorder (include tool_invoke) --- .github/workflows/integration-tests.yml | 4 +- .../workflows/record-integration-tests.yml | 3 + .../contributing/testing/record-replay.mdx | 8 +- llama_stack/core/stack.py | 8 +- .../responses/openai_responses.py | 2 +- .../agents/meta_reference/responses/utils.py | 47 +- .../providers/inline/files/localfs/files.py | 4 +- ...{inference_recorder.py => api_recorder.py} | 666 ++++++++++++------ scripts/integration-tests.sh | 4 +- tests/common/mcp.py | 34 +- tests/integration/conftest.py | 24 +- tests/integration/fixtures/common.py | 6 + tests/integration/responses/helpers.py | 18 +- .../responses/test_extra_body_shields.py | 1 + .../integration/responses/test_file_search.py | 10 +- .../responses/test_tool_responses.py | 18 +- tests/integration/suites.py | 3 +- .../unit/distribution/test_api_recordings.py | 273 +++++++ .../distribution/test_inference_recordings.py | 382 ---------- 19 files changed, 849 insertions(+), 666 deletions(-) rename llama_stack/testing/{inference_recorder.py => api_recorder.py} (62%) create mode 100644 tests/unit/distribution/test_api_recordings.py delete mode 100644 tests/unit/distribution/test_inference_recordings.py diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index ace1f4edc..5fb25f047 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -54,14 +54,14 @@ jobs: # Define (setup, suite) pairs - they are always matched and cannot be independent # Weekly schedule (Sun 1 AM): vllm+base # Input test-setup=ollama-vision: ollama-vision+vision - # Default (including test-setup=ollama): both ollama+base and ollama-vision+vision + # Default (including test-setup=ollama): ollama+base, ollama-vision+vision, gpt+responses config: >- ${{ github.event.schedule == '1 0 * * 0' && fromJSON('[{"setup": "vllm", "suite": "base"}]') || github.event.inputs.test-setup == 'ollama-vision' && fromJSON('[{"setup": "ollama-vision", "suite": "vision"}]') - || fromJSON('[{"setup": "ollama", "suite": "base"}, {"setup": "ollama-vision", "suite": "vision"}]') + || fromJSON('[{"setup": "ollama", "suite": "base"}, {"setup": "ollama-vision", "suite": "vision"}, {"setup": "gpt", "suite": "responses"}]') }} steps: diff --git a/.github/workflows/record-integration-tests.yml b/.github/workflows/record-integration-tests.yml index 65a04f125..57f95580e 100644 --- a/.github/workflows/record-integration-tests.yml +++ b/.github/workflows/record-integration-tests.yml @@ -61,6 +61,9 @@ jobs: - name: Run and record tests uses: ./.github/actions/run-and-record-tests + env: + # Set OPENAI_API_KEY if using gpt setup + OPENAI_API_KEY: ${{ inputs.test-setup == 'gpt' && secrets.OPENAI_API_KEY || '' }} with: stack-config: 'server:ci-tests' # recording must be done with server since more tests are run setup: ${{ inputs.test-setup || 'ollama' }} diff --git a/docs/docs/contributing/testing/record-replay.mdx b/docs/docs/contributing/testing/record-replay.mdx index 47803c150..cc3eb2b9d 100644 --- a/docs/docs/contributing/testing/record-replay.mdx +++ b/docs/docs/contributing/testing/record-replay.mdx @@ -68,7 +68,9 @@ recordings/ Direct API calls with no recording or replay: ```python -with inference_recording(mode=InferenceMode.LIVE): +from llama_stack.testing.api_recorder import api_recording, APIRecordingMode + +with api_recording(mode=APIRecordingMode.LIVE): response = await client.chat.completions.create(...) ``` @@ -79,7 +81,7 @@ Use for initial development and debugging against real APIs. Captures API interactions while passing through real responses: ```python -with inference_recording(mode=InferenceMode.RECORD, storage_dir="./recordings"): +with api_recording(mode=APIRecordingMode.RECORD, storage_dir="./recordings"): response = await client.chat.completions.create(...) # Real API call made, response captured AND returned ``` @@ -96,7 +98,7 @@ The recording process: Returns stored responses instead of making API calls: ```python -with inference_recording(mode=InferenceMode.REPLAY, storage_dir="./recordings"): +with api_recording(mode=APIRecordingMode.REPLAY, storage_dir="./recordings"): response = await client.chat.completions.create(...) # No API call made, cached response returned instantly ``` diff --git a/llama_stack/core/stack.py b/llama_stack/core/stack.py index acc02eeff..49f6b9cc9 100644 --- a/llama_stack/core/stack.py +++ b/llama_stack/core/stack.py @@ -316,13 +316,13 @@ class Stack: # asked for in the run config. async def initialize(self): if "LLAMA_STACK_TEST_INFERENCE_MODE" in os.environ: - from llama_stack.testing.inference_recorder import setup_inference_recording + from llama_stack.testing.api_recorder import setup_api_recording global TEST_RECORDING_CONTEXT - TEST_RECORDING_CONTEXT = setup_inference_recording() + TEST_RECORDING_CONTEXT = setup_api_recording() if TEST_RECORDING_CONTEXT: TEST_RECORDING_CONTEXT.__enter__() - logger.info(f"Inference recording enabled: mode={os.environ.get('LLAMA_STACK_TEST_INFERENCE_MODE')}") + logger.info(f"API recording enabled: mode={os.environ.get('LLAMA_STACK_TEST_INFERENCE_MODE')}") dist_registry, _ = await create_dist_registry(self.run_config.metadata_store, self.run_config.image_name) policy = self.run_config.server.auth.access_policy if self.run_config.server.auth else [] @@ -381,7 +381,7 @@ class Stack: try: TEST_RECORDING_CONTEXT.__exit__(None, None, None) except Exception as e: - logger.error(f"Error during inference recording cleanup: {e}") + logger.error(f"Error during API recording cleanup: {e}") global REGISTRY_REFRESH_TASK if REGISTRY_REFRESH_TASK: diff --git a/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py b/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py index 245203f10..da8b01f40 100644 --- a/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py +++ b/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py @@ -108,7 +108,7 @@ class OpenAIResponsesImpl: # Use stored messages directly and convert only new input message_adapter = TypeAdapter(list[OpenAIMessageParam]) messages = message_adapter.validate_python(previous_response.messages) - new_messages = await convert_response_input_to_chat_messages(input) + new_messages = await convert_response_input_to_chat_messages(input, previous_messages=messages) messages.extend(new_messages) else: # Backward compatibility: reconstruct from inputs diff --git a/llama_stack/providers/inline/agents/meta_reference/responses/utils.py b/llama_stack/providers/inline/agents/meta_reference/responses/utils.py index 5b013b9c4..a3316a635 100644 --- a/llama_stack/providers/inline/agents/meta_reference/responses/utils.py +++ b/llama_stack/providers/inline/agents/meta_reference/responses/utils.py @@ -103,9 +103,13 @@ async def convert_response_content_to_chat_content( async def convert_response_input_to_chat_messages( input: str | list[OpenAIResponseInput], + previous_messages: list[OpenAIMessageParam] | None = None, ) -> list[OpenAIMessageParam]: """ Convert the input from an OpenAI Response API request into OpenAI Chat Completion messages. + + :param input: The input to convert + :param previous_messages: Optional previous messages to check for function_call references """ messages: list[OpenAIMessageParam] = [] if isinstance(input, list): @@ -169,16 +173,53 @@ async def convert_response_input_to_chat_messages( raise ValueError( f"Llama Stack OpenAI Responses does not yet support message role '{input_item.role}' in this context" ) + # Skip user messages that duplicate the last user message in previous_messages + # This handles cases where input includes context for function_call_outputs + if previous_messages and input_item.role == "user": + last_user_msg = None + for msg in reversed(previous_messages): + if isinstance(msg, OpenAIUserMessageParam): + last_user_msg = msg + break + if last_user_msg: + last_user_content = getattr(last_user_msg, "content", None) + if last_user_content == content: + continue # Skip duplicate user message messages.append(message_type(content=content)) if len(tool_call_results): - raise ValueError( - f"Received function_call_output(s) with call_id(s) {tool_call_results.keys()}, but no corresponding function_call" - ) + # Check if unpaired function_call_outputs reference function_calls from previous messages + if previous_messages: + previous_call_ids = _extract_tool_call_ids(previous_messages) + for call_id in list(tool_call_results.keys()): + if call_id in previous_call_ids: + # Valid: this output references a call from previous messages + # Add the tool message + messages.append(tool_call_results[call_id]) + del tool_call_results[call_id] + + # If still have unpaired outputs, error + if len(tool_call_results): + raise ValueError( + f"Received function_call_output(s) with call_id(s) {tool_call_results.keys()}, but no corresponding function_call" + ) else: messages.append(OpenAIUserMessageParam(content=input)) return messages +def _extract_tool_call_ids(messages: list[OpenAIMessageParam]) -> set[str]: + """Extract all tool_call IDs from messages.""" + call_ids = set() + for msg in messages: + if isinstance(msg, OpenAIAssistantMessageParam): + tool_calls = getattr(msg, "tool_calls", None) + if tool_calls: + for tool_call in tool_calls: + # tool_call is a Pydantic model, use attribute access + call_ids.add(tool_call.id) + return call_ids + + async def convert_response_text_to_chat_response_format( text: OpenAIResponseText, ) -> OpenAIResponseFormatParam: diff --git a/llama_stack/providers/inline/files/localfs/files.py b/llama_stack/providers/inline/files/localfs/files.py index be1da291a..77af94681 100644 --- a/llama_stack/providers/inline/files/localfs/files.py +++ b/llama_stack/providers/inline/files/localfs/files.py @@ -95,7 +95,9 @@ class LocalfsFilesImpl(Files): raise RuntimeError("Files provider not initialized") if expires_after is not None: - raise NotImplementedError("File expiration is not supported by this provider") + logger.warning( + f"File expiration is not supported by this provider, ignoring expires_after: {expires_after}" + ) file_id = self._generate_file_id() file_path = self._get_file_path(file_id) diff --git a/llama_stack/testing/inference_recorder.py b/llama_stack/testing/api_recorder.py similarity index 62% rename from llama_stack/testing/inference_recorder.py rename to llama_stack/testing/api_recorder.py index 16071f80f..cc7c65465 100644 --- a/llama_stack/testing/inference_recorder.py +++ b/llama_stack/testing/api_recorder.py @@ -15,19 +15,20 @@ from enum import StrEnum from pathlib import Path from typing import Any, Literal, cast -from openai import NOT_GIVEN, OpenAI +from openai import NOT_GIVEN from llama_stack.log import get_logger logger = get_logger(__name__, category="testing") -# Global state for the recording system +# Global state for the API recording system # Note: Using module globals instead of ContextVars because the session-scoped # client initialization happens in one async context, but tests run in different # contexts, and we need the mode/storage to persist across all contexts. _current_mode: str | None = None _current_storage: ResponseStorage | None = None _original_methods: dict[str, Any] = {} +_memory_cache: dict[str, dict[str, Any]] = {} # Test context uses ContextVar since it changes per-test and needs async isolation from contextvars import ContextVar @@ -44,158 +45,33 @@ REPO_ROOT = Path(__file__).parent.parent.parent DEFAULT_STORAGE_DIR = REPO_ROOT / "tests/integration/common" -class InferenceMode(StrEnum): +class APIRecordingMode(StrEnum): LIVE = "live" RECORD = "record" REPLAY = "replay" RECORD_IF_MISSING = "record-if-missing" -def normalize_request(method: str, url: str, headers: dict[str, Any], body: dict[str, Any]) -> str: - """Create a normalized hash of the request for consistent matching. +def _normalize_file_ids(obj: Any) -> Any: + """Recursively replace file IDs with a canonical placeholder for consistent hashing.""" + import re - Includes test_id from context to ensure test isolation - identical requests - from different tests will have different hashes. - - Exception: Model list endpoints (/v1/models, /api/tags) exclude test_id since - they are infrastructure/shared and need to work across session setup and tests. - """ - # Extract just the endpoint path - from urllib.parse import urlparse - - parsed = urlparse(url) - normalized: dict[str, Any] = { - "method": method.upper(), - "endpoint": parsed.path, - "body": body, - } - - # Include test_id for isolation, except for shared infrastructure endpoints - if parsed.path not in ("/api/tags", "/v1/models"): - normalized["test_id"] = _test_context.get() - - # Create hash - sort_keys=True ensures deterministic ordering - normalized_json = json.dumps(normalized, sort_keys=True) - return hashlib.sha256(normalized_json.encode()).hexdigest() - - -def _sync_test_context_from_provider_data(): - """In server mode, sync test ID from provider_data to _test_context. - - This ensures that storage operations (which read from _test_context) work correctly - in server mode where the test ID arrives via HTTP header → provider_data. - - Returns a token to reset _test_context, or None if no sync was needed. - """ - stack_config_type = os.environ.get("LLAMA_STACK_TEST_STACK_CONFIG_TYPE", "library_client") - - if stack_config_type != "server": - return None - - try: - from llama_stack.core.request_headers import PROVIDER_DATA_VAR - - provider_data = PROVIDER_DATA_VAR.get() - - if provider_data and "__test_id" in provider_data: - test_id = provider_data["__test_id"] - return _test_context.set(test_id) - except ImportError: - pass - - return None - - -def patch_httpx_for_test_id(): - """Patch client _prepare_request methods to inject test ID into provider data header. - - This is needed for server mode where the test ID must be transported from - client to server via HTTP headers. In library_client mode, this patch is a no-op - since everything runs in the same process. - - We use the _prepare_request hook that Stainless clients provide for mutating - requests after construction but before sending. - """ - from llama_stack_client import LlamaStackClient - - if "llama_stack_client_prepare_request" in _original_methods: - return - - _original_methods["llama_stack_client_prepare_request"] = LlamaStackClient._prepare_request - _original_methods["openai_prepare_request"] = OpenAI._prepare_request - - def patched_prepare_request(self, request): - # Call original first (it's a sync method that returns None) - # Determine which original to call based on client type - if "llama_stack_client" in self.__class__.__module__: - _original_methods["llama_stack_client_prepare_request"](self, request) - _original_methods["openai_prepare_request"](self, request) - - # Only inject test ID in server mode - stack_config_type = os.environ.get("LLAMA_STACK_TEST_STACK_CONFIG_TYPE", "library_client") - test_id = _test_context.get() - - if stack_config_type == "server" and test_id: - provider_data_header = request.headers.get("X-LlamaStack-Provider-Data") - - if provider_data_header: - provider_data = json.loads(provider_data_header) + if isinstance(obj, dict): + result = {} + for k, v in obj.items(): + # Normalize file IDs in attribute dictionaries + if k == "document_id" and isinstance(v, str) and v.startswith("file-"): + result[k] = "file-NORMALIZED" else: - provider_data = {} - - provider_data["__test_id"] = test_id - request.headers["X-LlamaStack-Provider-Data"] = json.dumps(provider_data) - - return None - - LlamaStackClient._prepare_request = patched_prepare_request - OpenAI._prepare_request = patched_prepare_request - - -# currently, unpatch is never called -def unpatch_httpx_for_test_id(): - """Remove client _prepare_request patches for test ID injection.""" - if "llama_stack_client_prepare_request" not in _original_methods: - return - - from llama_stack_client import LlamaStackClient - - LlamaStackClient._prepare_request = _original_methods["llama_stack_client_prepare_request"] - del _original_methods["llama_stack_client_prepare_request"] - - # Also restore OpenAI client if it was patched - if "openai_prepare_request" in _original_methods: - OpenAI._prepare_request = _original_methods["openai_prepare_request"] - del _original_methods["openai_prepare_request"] - - -def get_inference_mode() -> InferenceMode: - return InferenceMode(os.environ.get("LLAMA_STACK_TEST_INFERENCE_MODE", "replay").lower()) - - -def setup_inference_recording(): - """ - Returns a context manager that can be used to record or replay inference requests. This is to be used in tests - to increase their reliability and reduce reliance on expensive, external services. - - Currently, this is only supported for OpenAI and Ollama clients. These should cover the vast majority of use cases. - - Two environment variables are supported: - - LLAMA_STACK_TEST_INFERENCE_MODE: The mode to run in. Must be 'live', 'record', 'replay', or 'record-if-missing'. Default is 'replay'. - - 'live': Make all requests live without recording - - 'record': Record all requests (overwrites existing recordings) - - 'replay': Use only recorded responses (fails if recording not found) - - 'record-if-missing': Use recorded responses when available, record new ones when not found - - LLAMA_STACK_TEST_RECORDING_DIR: The directory to store the recordings in. Default is 'tests/integration/recordings'. - - The recordings are stored as JSON files. - """ - mode = get_inference_mode() - if mode == InferenceMode.LIVE: - return None - - storage_dir = os.environ.get("LLAMA_STACK_TEST_RECORDING_DIR", DEFAULT_STORAGE_DIR) - return inference_recording(mode=mode, storage_dir=storage_dir) + result[k] = _normalize_file_ids(v) + return result + elif isinstance(obj, list): + return [_normalize_file_ids(item) for item in obj] + elif isinstance(obj, str): + # Replace file- patterns in strings (like in text content) + return re.sub(r"file-[a-f0-9]{32}", "file-NORMALIZED", obj) + else: + return obj def _normalize_response_data(data: dict[str, Any], request_hash: str) -> dict[str, Any]: @@ -230,11 +106,184 @@ def _normalize_response_data(data: dict[str, Any], request_hash: str) -> dict[st return data -def _serialize_response(response: Any, request_hash: str = "") -> Any: +def normalize_request(method: str, url: str, headers: dict[str, Any], body: dict[str, Any]) -> str: + """Create a normalized hash of the request for consistent matching. + + Includes test_id from context to ensure test isolation - identical requests + from different tests will have different hashes. + """ + # Extract just the endpoint path + from urllib.parse import urlparse + + parsed = urlparse(url) + + # Normalize file IDs in the body to ensure consistent hashing across test runs + normalized_body = _normalize_file_ids(body) + + normalized: dict[str, Any] = {"method": method.upper(), "endpoint": parsed.path, "body": normalized_body} + + # Include test_id for isolation, except for shared infrastructure endpoints + if parsed.path not in ("/api/tags", "/v1/models"): + # Server mode: test ID was synced from provider_data to _test_context + # by _sync_test_context_from_provider_data() at the start of the request. + # We read from _test_context because it's available in all contexts (including + # when making outgoing API calls), whereas PROVIDER_DATA_VAR is only set + # for incoming HTTP requests. + # + # Library client mode: test ID in same-process ContextVar + test_id = _test_context.get() + normalized["test_id"] = test_id + + # Create hash - sort_keys=True ensures deterministic ordering + normalized_json = json.dumps(normalized, sort_keys=True) + request_hash = hashlib.sha256(normalized_json.encode()).hexdigest() + + return request_hash + + +def normalize_tool_request(provider_name: str, tool_name: str, kwargs: dict[str, Any]) -> str: + """Create a normalized hash of the tool request for consistent matching.""" + normalized = {"provider": provider_name, "tool_name": tool_name, "kwargs": kwargs} + + # Create hash - sort_keys=True ensures deterministic ordering + normalized_json = json.dumps(normalized, sort_keys=True) + return hashlib.sha256(normalized_json.encode()).hexdigest() + + +@contextmanager +def set_test_context(test_id: str) -> Generator[None, None, None]: + """Set the test context for recording isolation. + + Usage: + with set_test_context("test_basic_completion"): + # Make API calls that will be recorded with this test_id + response = client.chat.completions.create(...) + """ + token = _test_context.set(test_id) + try: + yield + finally: + _test_context.reset(token) + + +def patch_httpx_for_test_id(): + """Patch client _prepare_request methods to inject test ID into provider data header. + + Patches both LlamaStackClient and OpenAI client to ensure test ID is transported + from client to server via HTTP headers in server mode. + + This is needed for server mode where the test ID must be transported from + client to server via HTTP headers. In library_client mode, this patch is a no-op + since everything runs in the same process. + + We use the _prepare_request hook that Stainless clients provide for mutating + requests after construction but before sending. + """ + from llama_stack_client import LlamaStackClient + + if "llama_stack_client_prepare_request" in _original_methods: + # Already patched + return + + # Save original methods + _original_methods["llama_stack_client_prepare_request"] = LlamaStackClient._prepare_request + + # Also patch OpenAI client if available (used in compat tests) + try: + from openai import OpenAI + + _original_methods["openai_prepare_request"] = OpenAI._prepare_request + except ImportError: + pass + + def patched_prepare_request(self, request): + # Call original first (it's a sync method that returns None) + # Determine which original to call based on client type + if "llama_stack_client" in self.__class__.__module__: + _original_methods["llama_stack_client_prepare_request"](self, request) + elif "openai_prepare_request" in _original_methods: + _original_methods["openai_prepare_request"](self, request) + + # Only inject test ID in server mode + stack_config_type = os.environ.get("LLAMA_STACK_TEST_STACK_CONFIG_TYPE", "library_client") + test_id = _test_context.get() + + if stack_config_type == "server" and test_id: + # Get existing provider data header or create new dict + provider_data_header = request.headers.get("X-LlamaStack-Provider-Data") + + if provider_data_header: + provider_data = json.loads(provider_data_header) + else: + provider_data = {} + + # Inject test ID + provider_data["__test_id"] = test_id + request.headers["X-LlamaStack-Provider-Data"] = json.dumps(provider_data) + + # Sync version returns None + return None + + # Apply patches + LlamaStackClient._prepare_request = patched_prepare_request + if "openai_prepare_request" in _original_methods: + from openai import OpenAI + + OpenAI._prepare_request = patched_prepare_request + + +def unpatch_httpx_for_test_id(): + """Remove client _prepare_request patches for test ID injection.""" + if "llama_stack_client_prepare_request" not in _original_methods: + return + + from llama_stack_client import LlamaStackClient + + LlamaStackClient._prepare_request = _original_methods["llama_stack_client_prepare_request"] + del _original_methods["llama_stack_client_prepare_request"] + + # Also restore OpenAI client if it was patched + if "openai_prepare_request" in _original_methods: + from openai import OpenAI + + OpenAI._prepare_request = _original_methods["openai_prepare_request"] + del _original_methods["openai_prepare_request"] + + +def get_api_recording_mode() -> APIRecordingMode: + return APIRecordingMode(os.environ.get("LLAMA_STACK_TEST_INFERENCE_MODE", "replay").lower()) + + +def setup_api_recording(): + """ + Returns a context manager that can be used to record or replay API requests (inference and tools). + This is to be used in tests to increase their reliability and reduce reliance on expensive, external services. + + Currently supports: + - Inference: OpenAI, Ollama, and LiteLLM clients + - Tools: Search providers (Tavily for now) + + Two environment variables are supported: + - LLAMA_STACK_TEST_INFERENCE_MODE: The mode to run in. Options: + - 'live': Make real API calls, no recording + - 'record': Record all API interactions (overwrites existing) + - 'replay': Use recorded responses only (default) + - 'record-if-missing': Replay when possible, record when recording doesn't exist + - LLAMA_STACK_TEST_RECORDING_DIR: The directory to store the recordings in. Default is 'tests/integration/recordings'. + + The recordings are stored as JSON files. + """ + mode = get_api_recording_mode() + if mode == APIRecordingMode.LIVE: + return None + + storage_dir = os.environ.get("LLAMA_STACK_TEST_RECORDING_DIR", DEFAULT_STORAGE_DIR) + return api_recording(mode=mode, storage_dir=storage_dir) + + +def _serialize_response(response: Any) -> Any: if hasattr(response, "model_dump"): data = response.model_dump(mode="json") - # Normalize fields to reduce noise - data = _normalize_response_data(data, request_hash) return { "__type__": f"{response.__class__.__module__}.{response.__class__.__qualname__}", "__data__": data, @@ -259,22 +308,17 @@ def _deserialize_response(data: dict[str, Any]) -> Any: return cls.model_validate(data["__data__"]) except (ImportError, AttributeError, TypeError, ValueError) as e: - logger.warning(f"Failed to deserialize object of type {data['__type__']} with model_validate: {e}") - try: - return cls.model_construct(**data["__data__"]) - except Exception as e: - logger.warning(f"Failed to deserialize object of type {data['__type__']} with model_construct: {e}") - return data["__data__"] + logger.warning(f"Failed to deserialize object of type {data['__type__']}: {e}") + return data["__data__"] return data class ResponseStorage: - """Handles SQLite index + JSON file storage/retrieval for inference recordings.""" + """Handles storage/retrieval for API recordings (inference and tools).""" def __init__(self, base_dir: Path): self.base_dir = base_dir - # Don't create responses_dir here - determine it per-test at runtime def _get_test_dir(self) -> Path: """Get the recordings directory in the test file's parent directory. @@ -283,6 +327,7 @@ class ResponseStorage: returns "tests/integration/inference/recordings/". """ test_id = _test_context.get() + if test_id: # Extract the directory path from the test nodeid # e.g., "tests/integration/inference/test_basic.py::test_foo[params]" @@ -297,17 +342,21 @@ class ResponseStorage: # Fallback for non-test contexts return self.base_dir / "recordings" - def _ensure_directories(self): - """Ensure test-specific directories exist.""" + def _ensure_directories(self) -> Path: test_dir = self._get_test_dir() test_dir.mkdir(parents=True, exist_ok=True) return test_dir def store_recording(self, request_hash: str, request: dict[str, Any], response: dict[str, Any]): - """Store a request/response pair.""" + """Store a request/response pair both in memory cache and on disk.""" + global _memory_cache + + # Store in memory cache first + _memory_cache[request_hash] = {"request": request, "response": response} + responses_dir = self._ensure_directories() - # Use FULL hash (not truncated) + # Generate unique response filename using full hash response_file = f"{request_hash}.json" # Serialize response body if needed @@ -315,45 +364,32 @@ class ResponseStorage: if "body" in serialized_response: if isinstance(serialized_response["body"], list): # Handle streaming responses (list of chunks) - serialized_response["body"] = [ - _serialize_response(chunk, request_hash) for chunk in serialized_response["body"] - ] + serialized_response["body"] = [_serialize_response(chunk) for chunk in serialized_response["body"]] else: # Handle single response - serialized_response["body"] = _serialize_response(serialized_response["body"], request_hash) + serialized_response["body"] = _serialize_response(serialized_response["body"]) - # For model-list endpoints, include digest in filename to distinguish different model sets + # If this is a model-list endpoint recording, include models digest in filename to distinguish variants endpoint = request.get("endpoint") + test_id = _test_context.get() if endpoint in ("/api/tags", "/v1/models"): + test_id = None digest = _model_identifiers_digest(endpoint, response) response_file = f"models-{request_hash}-{digest}.json" response_path = responses_dir / response_file - # Save response to JSON file with metadata + # Save response to JSON file with open(response_path, "w") as f: - json.dump( - { - "test_id": _test_context.get(), - "request": request, - "response": serialized_response, - }, - f, - indent=2, - ) + json.dump({"test_id": test_id, "request": request, "response": serialized_response}, f, indent=2) f.write("\n") f.flush() def find_recording(self, request_hash: str) -> dict[str, Any] | None: - """Find a recorded response by request hash. - - Uses fallback: first checks test-specific dir, then falls back to base recordings dir. - This handles cases where recordings happen during session setup (no test context) but - are requested during tests (with test context). - """ + """Find a recorded response by request hash.""" response_file = f"{request_hash}.json" - # Try test-specific directory first + # Check test-specific directory first test_dir = self._get_test_dir() response_path = test_dir / response_file @@ -464,15 +500,97 @@ def _combine_model_list_responses(endpoint: str, records: list[dict[str, Any]]) return {"request": canonical_req, "response": {"body": body, "is_streaming": False}} +async def _patched_tool_invoke_method( + original_method, provider_name: str, self, tool_name: str, kwargs: dict[str, Any] +): + """Patched version of tool runtime invoke_tool method for recording/replay.""" + global _current_mode, _current_storage + + if _current_mode == APIRecordingMode.LIVE or _current_storage is None: + # Normal operation + return await original_method(self, tool_name, kwargs) + + # In server mode, sync test ID from provider_data to _test_context for storage operations + test_context_token = _sync_test_context_from_provider_data() + + try: + request_hash = normalize_tool_request(provider_name, tool_name, kwargs) + + if _current_mode in (APIRecordingMode.REPLAY, APIRecordingMode.RECORD_IF_MISSING): + recording = _current_storage.find_recording(request_hash) + if recording: + return recording["response"]["body"] + elif _current_mode == APIRecordingMode.REPLAY: + raise RuntimeError( + f"No recorded tool result found for {provider_name}.{tool_name}\n" + f"Request: {kwargs}\n" + f"To record this response, run with LLAMA_STACK_TEST_INFERENCE_MODE=record" + ) + # If RECORD_IF_MISSING and no recording found, fall through to record + + if _current_mode in (APIRecordingMode.RECORD, APIRecordingMode.RECORD_IF_MISSING): + # Check in-memory cache first (collision detection) + global _memory_cache + if request_hash in _memory_cache: + # Return the cached response instead of making a new tool call + return _memory_cache[request_hash]["response"]["body"] + + # No cached response, make the tool call and record it + result = await original_method(self, tool_name, kwargs) + + request_data = { + "provider": provider_name, + "tool_name": tool_name, + "kwargs": kwargs, + } + response_data = {"body": result, "is_streaming": False} + + # Store the recording (both in memory and on disk) + _current_storage.store_recording(request_hash, request_data, response_data) + return result + + else: + raise AssertionError(f"Invalid mode: {_current_mode}") + finally: + # Reset test context if we set it in server mode + if test_context_token is not None: + _test_context.reset(test_context_token) + + +def _sync_test_context_from_provider_data(): + """In server mode, sync test ID from provider_data to _test_context. + + This ensures that storage operations (which read from _test_context) work correctly + in server mode where the test ID arrives via HTTP header → provider_data. + + Returns a token to reset _test_context, or None if no sync was needed. + """ + stack_config_type = os.environ.get("LLAMA_STACK_TEST_STACK_CONFIG_TYPE", "library_client") + + if stack_config_type != "server": + return None + + try: + from llama_stack.core.request_headers import PROVIDER_DATA_VAR + + provider_data = PROVIDER_DATA_VAR.get() + + if provider_data and "__test_id" in provider_data: + test_id = provider_data["__test_id"] + return _test_context.set(test_id) + except ImportError: + pass + + return None + + async def _patched_inference_method(original_method, self, client_type, endpoint, *args, **kwargs): global _current_mode, _current_storage - mode = _current_mode - storage = _current_storage - - if mode == InferenceMode.LIVE or storage is None: - if endpoint == "/v1/models": - return original_method(self, *args, **kwargs) + if _current_mode == APIRecordingMode.LIVE or _current_storage is None: + # Normal operation + if client_type == "litellm": + return await original_method(*args, **kwargs) else: return await original_method(self, *args, **kwargs) @@ -491,34 +609,30 @@ async def _patched_inference_method(original_method, self, client_type, endpoint base_url = getattr(self, "host", "http://localhost:11434") if not base_url.startswith("http"): base_url = f"http://{base_url}" + elif client_type == "litellm": + # For LiteLLM, extract base URL from kwargs if available + base_url = kwargs.get("api_base", "https://api.openai.com") else: raise ValueError(f"Unknown client type: {client_type}") url = base_url.rstrip("/") + endpoint - # Special handling for Databricks URLs to avoid leaking workspace info - # e.g. https://adb-1234567890123456.7.cloud.databricks.com -> https://...cloud.databricks.com - if "cloud.databricks.com" in url: - url = "__databricks__" + url.split("cloud.databricks.com")[-1] method = "POST" headers = {} body = kwargs request_hash = normalize_request(method, url, headers, body) - # Try to find existing recording for REPLAY or RECORD_IF_MISSING modes - recording = None - if mode == InferenceMode.REPLAY or mode == InferenceMode.RECORD_IF_MISSING: - # Special handling for model-list endpoints: merge all recordings with this hash + if _current_mode in (APIRecordingMode.REPLAY, APIRecordingMode.RECORD_IF_MISSING): + # Special handling for model-list endpoints: return union of all responses if endpoint in ("/api/tags", "/v1/models"): - records = storage._model_list_responses(request_hash) + records = _current_storage._model_list_responses(request_hash) recording = _combine_model_list_responses(endpoint, records) else: - recording = storage.find_recording(request_hash) - + recording = _current_storage.find_recording(request_hash) if recording: response_body = recording["response"]["body"] - if recording["response"].get("is_streaming", False): + if recording["response"].get("is_streaming", False) or recording["response"].get("is_paginated", False): async def replay_stream(): for chunk in response_body: @@ -527,25 +641,41 @@ async def _patched_inference_method(original_method, self, client_type, endpoint return replay_stream() else: return response_body - elif mode == InferenceMode.REPLAY: - # REPLAY mode requires recording to exist + elif _current_mode == APIRecordingMode.REPLAY: raise RuntimeError( f"No recorded response found for request hash: {request_hash}\n" f"Request: {method} {url} {body}\n" f"Model: {body.get('model', 'unknown')}\n" f"To record this response, run with LLAMA_STACK_TEST_INFERENCE_MODE=record" ) + # If RECORD_IF_MISSING and no recording found, fall through to record - if mode == InferenceMode.RECORD or (mode == InferenceMode.RECORD_IF_MISSING and not recording): - if endpoint == "/v1/models": - response = original_method(self, *args, **kwargs) + if _current_mode in (APIRecordingMode.RECORD, APIRecordingMode.RECORD_IF_MISSING): + # Check in-memory cache first (collision detection) + global _memory_cache + if request_hash in _memory_cache: + # Return the cached response instead of making a new API call + cached_recording = _memory_cache[request_hash] + response_body = cached_recording["response"]["body"] + + if cached_recording["response"].get("is_streaming", False) or cached_recording["response"].get( + "is_paginated", False + ): + + async def replay_cached_stream(): + for chunk in response_body: + yield chunk + + return replay_cached_stream() + else: + return response_body + + # No cached response, make the API call and record it + if client_type == "litellm": + response = await original_method(*args, **kwargs) else: response = await original_method(self, *args, **kwargs) - # we want to store the result of the iterator, not the iterator itself - if endpoint == "/v1/models": - response = [m async for m in response] - request_data = { "method": method, "url": url, @@ -558,16 +688,20 @@ async def _patched_inference_method(original_method, self, client_type, endpoint # Determine if this is a streaming request based on request parameters is_streaming = body.get("stream", False) - if is_streaming: - # For streaming responses, we need to collect all chunks immediately before yielding + # Special case: /v1/models is a paginated endpoint that returns an async iterator + is_paginated = endpoint == "/v1/models" + + if is_streaming or is_paginated: + # For streaming/paginated responses, we need to collect all chunks immediately before yielding # This ensures the recording is saved even if the generator isn't fully consumed chunks = [] async for chunk in response: chunks.append(chunk) - # Store the recording immediately - response_data = {"body": chunks, "is_streaming": True} - storage.store_recording(request_hash, request_data, response_data) + # Store the recording immediately (both in memory and on disk) + # For paginated endpoints, mark as paginated rather than streaming + response_data = {"body": chunks, "is_streaming": is_streaming, "is_paginated": is_paginated} + _current_storage.store_recording(request_hash, request_data, response_data) # Return a generator that replays the stored chunks async def replay_recorded_stream(): @@ -577,27 +711,34 @@ async def _patched_inference_method(original_method, self, client_type, endpoint return replay_recorded_stream() else: response_data = {"body": response, "is_streaming": False} - storage.store_recording(request_hash, request_data, response_data) + # Store the response (both in memory and on disk) + _current_storage.store_recording(request_hash, request_data, response_data) return response else: - raise AssertionError(f"Invalid mode: {mode}") + raise AssertionError(f"Invalid mode: {_current_mode}") + finally: - if test_context_token: + # Reset test context if we set it in server mode + if test_context_token is not None: _test_context.reset(test_context_token) -def patch_inference_clients(): - """Install monkey patches for OpenAI client methods and Ollama AsyncClient methods.""" +def patch_api_clients(): + """Install monkey patches for inference clients and tool runtime methods.""" global _original_methods + import litellm from ollama import AsyncClient as OllamaAsyncClient from openai.resources.chat.completions import AsyncCompletions as AsyncChatCompletions from openai.resources.completions import AsyncCompletions from openai.resources.embeddings import AsyncEmbeddings from openai.resources.models import AsyncModels - # Store original methods for both OpenAI and Ollama clients + from llama_stack.providers.remote.tool_runtime.tavily_search.tavily_search import TavilySearchToolRuntimeImpl + from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin + + # Store original methods for OpenAI, Ollama, LiteLLM clients, and tool runtimes _original_methods = { "chat_completions_create": AsyncChatCompletions.create, "completions_create": AsyncCompletions.create, @@ -609,6 +750,10 @@ def patch_inference_clients(): "ollama_ps": OllamaAsyncClient.ps, "ollama_pull": OllamaAsyncClient.pull, "ollama_list": OllamaAsyncClient.list, + "litellm_acompletion": litellm.acompletion, + "litellm_atext_completion": litellm.atext_completion, + "litellm_openai_mixin_get_api_key": LiteLLMOpenAIMixin.get_api_key, + "tavily_invoke_tool": TavilySearchToolRuntimeImpl.invoke_tool, } # Create patched methods for OpenAI client @@ -629,10 +774,18 @@ def patch_inference_clients(): def patched_models_list(self, *args, **kwargs): async def _iter(): - for item in await _patched_inference_method( + result = await _patched_inference_method( _original_methods["models_list"], self, "openai", "/v1/models", *args, **kwargs - ): - yield item + ) + # The result is either an async generator (streaming/paginated) or a list + # If it's an async generator, iterate through it + if hasattr(result, "__aiter__"): + async for item in result: + yield item + else: + # It's a list, yield each item + for item in result: + yield item return _iter() @@ -681,21 +834,61 @@ def patch_inference_clients(): OllamaAsyncClient.pull = patched_ollama_pull OllamaAsyncClient.list = patched_ollama_list + # Create patched methods for LiteLLM + async def patched_litellm_acompletion(*args, **kwargs): + return await _patched_inference_method( + _original_methods["litellm_acompletion"], None, "litellm", "/chat/completions", *args, **kwargs + ) -def unpatch_inference_clients(): - """Remove monkey patches and restore original OpenAI and Ollama client methods.""" - global _original_methods + async def patched_litellm_atext_completion(*args, **kwargs): + return await _patched_inference_method( + _original_methods["litellm_atext_completion"], None, "litellm", "/completions", *args, **kwargs + ) + + # Apply LiteLLM patches + litellm.acompletion = patched_litellm_acompletion + litellm.atext_completion = patched_litellm_atext_completion + + # Create patched method for LiteLLMOpenAIMixin.get_api_key + def patched_litellm_get_api_key(self): + global _current_mode + if _current_mode != APIRecordingMode.REPLAY: + return _original_methods["litellm_openai_mixin_get_api_key"](self) + else: + # For record/replay modes, return a fake API key to avoid exposing real credentials + return "fake-api-key-for-testing" + + # Apply LiteLLMOpenAIMixin patch + LiteLLMOpenAIMixin.get_api_key = patched_litellm_get_api_key + + # Create patched methods for tool runtimes + async def patched_tavily_invoke_tool(self, tool_name: str, kwargs: dict[str, Any]): + return await _patched_tool_invoke_method( + _original_methods["tavily_invoke_tool"], "tavily", self, tool_name, kwargs + ) + + # Apply tool runtime patches + TavilySearchToolRuntimeImpl.invoke_tool = patched_tavily_invoke_tool + + +def unpatch_api_clients(): + """Remove monkey patches and restore original client methods and tool runtimes.""" + global _original_methods, _memory_cache if not _original_methods: return # Import here to avoid circular imports + import litellm from ollama import AsyncClient as OllamaAsyncClient from openai.resources.chat.completions import AsyncCompletions as AsyncChatCompletions from openai.resources.completions import AsyncCompletions from openai.resources.embeddings import AsyncEmbeddings from openai.resources.models import AsyncModels + from llama_stack.providers.remote.tool_runtime.tavily_search.tavily_search import TavilySearchToolRuntimeImpl + from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin + # Restore OpenAI client methods AsyncChatCompletions.create = _original_methods["chat_completions_create"] AsyncCompletions.create = _original_methods["completions_create"] @@ -710,12 +903,23 @@ def unpatch_inference_clients(): OllamaAsyncClient.pull = _original_methods["ollama_pull"] OllamaAsyncClient.list = _original_methods["ollama_list"] + # Restore LiteLLM methods + litellm.acompletion = _original_methods["litellm_acompletion"] + litellm.atext_completion = _original_methods["litellm_atext_completion"] + LiteLLMOpenAIMixin.get_api_key = _original_methods["litellm_openai_mixin_get_api_key"] + + # Restore tool runtime methods + TavilySearchToolRuntimeImpl.invoke_tool = _original_methods["tavily_invoke_tool"] + _original_methods.clear() + # Clear memory cache to prevent memory leaks + _memory_cache.clear() + @contextmanager -def inference_recording(mode: str, storage_dir: str | Path | None = None) -> Generator[None, None, None]: - """Context manager for inference recording/replaying.""" +def api_recording(mode: str, storage_dir: str | Path | None = None) -> Generator[None, None, None]: + """Context manager for API recording/replaying (inference and tools).""" global _current_mode, _current_storage # Store previous state @@ -729,14 +933,14 @@ def inference_recording(mode: str, storage_dir: str | Path | None = None) -> Gen if storage_dir is None: raise ValueError("storage_dir is required for record, replay, and record-if-missing modes") _current_storage = ResponseStorage(Path(storage_dir)) - patch_inference_clients() + patch_api_clients() yield finally: # Restore previous state if mode in ["record", "replay", "record-if-missing"]: - unpatch_inference_clients() + unpatch_api_clients() _current_mode = prev_mode _current_storage = prev_storage diff --git a/scripts/integration-tests.sh b/scripts/integration-tests.sh index 4ae73f170..9a85e3257 100755 --- a/scripts/integration-tests.sh +++ b/scripts/integration-tests.sh @@ -29,7 +29,7 @@ Options: --stack-config STRING Stack configuration to use (required) --suite STRING Test suite to run (default: 'base') --setup STRING Test setup (models, env) to use (e.g., 'ollama', 'ollama-vision', 'gpt', 'vllm') - --inference-mode STRING Inference mode: record or replay (default: replay) + --inference-mode STRING Inference mode: replay, record-if-missing or record (default: replay) --subdirs STRING Comma-separated list of test subdirectories to run (overrides suite) --pattern STRING Regex pattern to pass to pytest -k --help Show this help message @@ -102,7 +102,7 @@ if [[ -z "$STACK_CONFIG" ]]; then fi if [[ -z "$TEST_SETUP" && -n "$TEST_SUBDIRS" ]]; then - echo "Error: --test-setup is required when --test-subdirs is provided" + echo "Error: --test-setup is required when --test-subdirs is not provided" usage exit 1 fi diff --git a/tests/common/mcp.py b/tests/common/mcp.py index 357ea4d41..644becd2d 100644 --- a/tests/common/mcp.py +++ b/tests/common/mcp.py @@ -159,7 +159,6 @@ def make_mcp_server(required_auth_token: str | None = None, tools: dict[str, Cal import threading import time - import httpx import uvicorn from mcp.server.fastmcp import FastMCP from mcp.server.sse import SseServerTransport @@ -171,6 +170,11 @@ def make_mcp_server(required_auth_token: str | None = None, tools: dict[str, Cal server = FastMCP("FastMCP Test Server", log_level="WARNING") + # Silence verbose MCP server logs + import logging # allow-direct-logging + + logging.getLogger("mcp.server.lowlevel.server").setLevel(logging.WARNING) + tools = tools or default_tools() # Register all tools with the server @@ -234,29 +238,25 @@ def make_mcp_server(required_auth_token: str | None = None, tools: dict[str, Cal logger.debug(f"Starting MCP server thread on port {port}") server_thread.start() - # Polling until the server is ready - timeout = 10 + # Wait for the server thread to be running + # Note: We can't use a simple HTTP GET health check on /sse because it's an SSE endpoint + # that expects a long-lived connection, not a simple request/response + timeout = 2 start_time = time.time() server_url = f"http://localhost:{port}/sse" - logger.debug(f"Waiting for MCP server to be ready at {server_url}") + logger.debug(f"Waiting for MCP server thread to start on port {port}") while time.time() - start_time < timeout: - try: - response = httpx.get(server_url) - if response.status_code in [200, 401]: - logger.debug(f"MCP server is ready on port {port} (status: {response.status_code})") - break - except httpx.RequestError as e: - logger.debug(f"Server not ready yet, retrying... ({e})") - pass - time.sleep(0.1) + if server_thread.is_alive(): + # Give the server a moment to bind to the port + time.sleep(0.1) + logger.debug(f"MCP server is ready on port {port}") + break + time.sleep(0.05) else: # If we exit the loop due to timeout - logger.error(f"MCP server failed to start within {timeout} seconds on port {port}") - logger.error(f"Thread alive: {server_thread.is_alive()}") - if server_thread.is_alive(): - logger.error("Server thread is still running but not responding to HTTP requests") + logger.error(f"MCP server thread failed to start within {timeout} seconds on port {port}") try: yield {"server_url": server_url} diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index c0eb27b98..4896741c1 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -6,6 +6,7 @@ import inspect import itertools import os +import tempfile import textwrap import time from pathlib import Path @@ -14,6 +15,7 @@ import pytest from dotenv import load_dotenv from llama_stack.log import get_logger +from llama_stack.testing.api_recorder import patch_httpx_for_test_id from .suites import SETUP_DEFINITIONS, SUITE_DEFINITIONS @@ -35,6 +37,10 @@ def pytest_sessionstart(session): if "LLAMA_STACK_TEST_INFERENCE_MODE" not in os.environ: os.environ["LLAMA_STACK_TEST_INFERENCE_MODE"] = "replay" + if "SQLITE_STORE_DIR" not in os.environ: + os.environ["SQLITE_STORE_DIR"] = tempfile.mkdtemp() + + # Set test stack config type for api_recorder test isolation stack_config = session.config.getoption("--stack-config", default=None) if stack_config and stack_config.startswith("server:"): os.environ["LLAMA_STACK_TEST_STACK_CONFIG_TYPE"] = "server" @@ -43,8 +49,6 @@ def pytest_sessionstart(session): os.environ["LLAMA_STACK_TEST_STACK_CONFIG_TYPE"] = "library_client" logger.info(f"Test stack config type: library_client (stack_config={stack_config})") - from llama_stack.testing.inference_recorder import patch_httpx_for_test_id - patch_httpx_for_test_id() @@ -55,7 +59,7 @@ def _track_test_context(request): This fixture runs for every test and stores the test's nodeid in a contextvar that the recording system can access to determine which subdirectory to use. """ - from llama_stack.testing.inference_recorder import _test_context + from llama_stack.testing.api_recorder import _test_context # Store the test nodeid (e.g., "tests/integration/responses/test_basic.py::test_foo[params]") token = _test_context.set(request.node.nodeid) @@ -121,9 +125,13 @@ def pytest_configure(config): # Apply defaults if not provided explicitly for dest, value in setup_obj.defaults.items(): current = getattr(config.option, dest, None) - if not current: + if current is None: setattr(config.option, dest, value) + # Apply global fallback for embedding_dimension if still not set + if getattr(config.option, "embedding_dimension", None) is None: + config.option.embedding_dimension = 384 + def pytest_addoption(parser): parser.addoption( @@ -161,8 +169,8 @@ def pytest_addoption(parser): parser.addoption( "--embedding-dimension", type=int, - default=384, - help="Output dimensionality of the embedding model to use for testing. Default: 384", + default=None, + help="Output dimensionality of the embedding model to use for testing. Default: 384 (or setup-specific)", ) parser.addoption( @@ -236,7 +244,9 @@ def pytest_generate_tests(metafunc): continue params.append(fixture_name) - val = metafunc.config.getoption(option) + # Use getattr on config.option to see values set by pytest_configure fallbacks + dest = option.lstrip("-").replace("-", "_") + val = getattr(metafunc.config.option, dest, None) values = [v.strip() for v in str(val).split(",")] if val else [None] param_values[fixture_name] = values diff --git a/tests/integration/fixtures/common.py b/tests/integration/fixtures/common.py index 68aa2b60b..2de6b1ccb 100644 --- a/tests/integration/fixtures/common.py +++ b/tests/integration/fixtures/common.py @@ -183,6 +183,12 @@ def llama_stack_client(request): # would be forced to use llama_stack_client, which is not what we want. print("\ninstantiating llama_stack_client") start_time = time.time() + + # Patch httpx to inject test ID for server-mode test isolation + from llama_stack.testing.api_recorder import patch_httpx_for_test_id + + patch_httpx_for_test_id() + client = instantiate_llama_stack_client(request.session) print(f"llama_stack_client instantiated in {time.time() - start_time:.3f}s") return client diff --git a/tests/integration/responses/helpers.py b/tests/integration/responses/helpers.py index 7c988402f..605b64b3c 100644 --- a/tests/integration/responses/helpers.py +++ b/tests/integration/responses/helpers.py @@ -7,7 +7,7 @@ import time -def new_vector_store(openai_client, name): +def new_vector_store(openai_client, name, embedding_model, embedding_dimension): """Create a new vector store, cleaning up any existing one with the same name.""" # Ensure we don't reuse an existing vector store vector_stores = openai_client.vector_stores.list() @@ -16,7 +16,21 @@ def new_vector_store(openai_client, name): openai_client.vector_stores.delete(vector_store_id=vector_store.id) # Create a new vector store - vector_store = openai_client.vector_stores.create(name=name) + # OpenAI SDK client uses extra_body for non-standard parameters + from openai import OpenAI + + if isinstance(openai_client, OpenAI): + # OpenAI SDK client - use extra_body + vector_store = openai_client.vector_stores.create( + name=name, + extra_body={"embedding_model": embedding_model, "embedding_dimension": embedding_dimension}, + ) + else: + # LlamaStack client - direct parameter + vector_store = openai_client.vector_stores.create( + name=name, embedding_model=embedding_model, embedding_dimension=embedding_dimension + ) + return vector_store diff --git a/tests/integration/responses/test_extra_body_shields.py b/tests/integration/responses/test_extra_body_shields.py index 3dedb287a..eb41cc150 100644 --- a/tests/integration/responses/test_extra_body_shields.py +++ b/tests/integration/responses/test_extra_body_shields.py @@ -16,6 +16,7 @@ import pytest from llama_stack_client import APIStatusError +@pytest.mark.xfail(reason="Shields are not yet implemented inside responses") def test_shields_via_extra_body(compat_client, text_model_id): """Test that shields parameter is received by the server and raises NotImplementedError.""" diff --git a/tests/integration/responses/test_file_search.py b/tests/integration/responses/test_file_search.py index ba7775a0b..3fc0f001e 100644 --- a/tests/integration/responses/test_file_search.py +++ b/tests/integration/responses/test_file_search.py @@ -47,12 +47,14 @@ def test_response_text_format(compat_client, text_model_id, text_format): @pytest.fixture -def vector_store_with_filtered_files(compat_client, text_model_id, tmp_path_factory): - """Create a vector store with multiple files that have different attributes for filtering tests.""" +def vector_store_with_filtered_files(compat_client, embedding_model_id, embedding_dimension, tmp_path_factory): + # """Create a vector store with multiple files that have different attributes for filtering tests.""" if isinstance(compat_client, LlamaStackAsLibraryClient): - pytest.skip("Responses API file search is not yet supported in library client.") + pytest.skip("upload_file() is not yet supported in library client somehow?") - vector_store = new_vector_store(compat_client, "test_vector_store_with_filters") + vector_store = new_vector_store( + compat_client, "test_vector_store_with_filters", embedding_model_id, embedding_dimension + ) tmp_path = tmp_path_factory.mktemp("filter_test_files") # Create multiple files with different attributes diff --git a/tests/integration/responses/test_tool_responses.py b/tests/integration/responses/test_tool_responses.py index 5d6899fa6..2cff4d27d 100644 --- a/tests/integration/responses/test_tool_responses.py +++ b/tests/integration/responses/test_tool_responses.py @@ -46,11 +46,13 @@ def test_response_non_streaming_web_search(compat_client, text_model_id, case): @pytest.mark.parametrize("case", file_search_test_cases) -def test_response_non_streaming_file_search(compat_client, text_model_id, tmp_path, case): +def test_response_non_streaming_file_search( + compat_client, text_model_id, embedding_model_id, embedding_dimension, tmp_path, case +): if isinstance(compat_client, LlamaStackAsLibraryClient): pytest.skip("Responses API file search is not yet supported in library client.") - vector_store = new_vector_store(compat_client, "test_vector_store") + vector_store = new_vector_store(compat_client, "test_vector_store", embedding_model_id, embedding_dimension) if case.file_content: file_name = "test_response_non_streaming_file_search.txt" @@ -101,11 +103,13 @@ def test_response_non_streaming_file_search(compat_client, text_model_id, tmp_pa assert case.expected.lower() in response.output_text.lower().strip() -def test_response_non_streaming_file_search_empty_vector_store(compat_client, text_model_id): +def test_response_non_streaming_file_search_empty_vector_store( + compat_client, text_model_id, embedding_model_id, embedding_dimension +): if isinstance(compat_client, LlamaStackAsLibraryClient): pytest.skip("Responses API file search is not yet supported in library client.") - vector_store = new_vector_store(compat_client, "test_vector_store") + vector_store = new_vector_store(compat_client, "test_vector_store", embedding_model_id, embedding_dimension) # Create the response request, which should query our vector store response = compat_client.responses.create( @@ -127,12 +131,14 @@ def test_response_non_streaming_file_search_empty_vector_store(compat_client, te assert response.output_text -def test_response_sequential_file_search(compat_client, text_model_id, tmp_path): +def test_response_sequential_file_search( + compat_client, text_model_id, embedding_model_id, embedding_dimension, tmp_path +): """Test file search with sequential responses using previous_response_id.""" if isinstance(compat_client, LlamaStackAsLibraryClient): pytest.skip("Responses API file search is not yet supported in library client.") - vector_store = new_vector_store(compat_client, "test_vector_store") + vector_store = new_vector_store(compat_client, "test_vector_store", embedding_model_id, embedding_dimension) # Create a test file with content file_content = "The Llama 4 Maverick model has 128 experts in its mixture of experts architecture." diff --git a/tests/integration/suites.py b/tests/integration/suites.py index e82e766e3..bc252bb08 100644 --- a/tests/integration/suites.py +++ b/tests/integration/suites.py @@ -39,7 +39,7 @@ class Setup(BaseModel): name: str description: str - defaults: dict[str, str] = Field(default_factory=dict) + defaults: dict[str, str | int] = Field(default_factory=dict) env: dict[str, str] = Field(default_factory=dict) @@ -88,6 +88,7 @@ SETUP_DEFINITIONS: dict[str, Setup] = { defaults={ "text_model": "openai/gpt-4o", "embedding_model": "openai/text-embedding-3-small", + "embedding_dimension": 1536, }, ), "tgi": Setup( diff --git a/tests/unit/distribution/test_api_recordings.py b/tests/unit/distribution/test_api_recordings.py new file mode 100644 index 000000000..dbcf92757 --- /dev/null +++ b/tests/unit/distribution/test_api_recordings.py @@ -0,0 +1,273 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import tempfile +from pathlib import Path +from unittest.mock import patch + +import pytest +from openai import AsyncOpenAI + +# Import the real Pydantic response types instead of using Mocks +from llama_stack.apis.inference import ( + OpenAIAssistantMessageParam, + OpenAIChatCompletion, + OpenAIChoice, + OpenAIEmbeddingData, + OpenAIEmbeddingsResponse, + OpenAIEmbeddingUsage, +) +from llama_stack.testing.api_recorder import ( + APIRecordingMode, + ResponseStorage, + api_recording, + normalize_request, +) + + +@pytest.fixture +def temp_storage_dir(): + """Create a temporary directory for test recordings.""" + with tempfile.TemporaryDirectory() as temp_dir: + yield Path(temp_dir) + + +@pytest.fixture +def real_openai_chat_response(): + """Real OpenAI chat completion response using proper Pydantic objects.""" + return OpenAIChatCompletion( + id="chatcmpl-test123", + choices=[ + OpenAIChoice( + index=0, + message=OpenAIAssistantMessageParam( + role="assistant", content="Hello! I'm doing well, thank you for asking." + ), + finish_reason="stop", + ) + ], + created=1234567890, + model="llama3.2:3b", + ) + + +@pytest.fixture +def real_embeddings_response(): + """Real OpenAI embeddings response using proper Pydantic objects.""" + return OpenAIEmbeddingsResponse( + object="list", + data=[ + OpenAIEmbeddingData(object="embedding", embedding=[0.1, 0.2, 0.3], index=0), + OpenAIEmbeddingData(object="embedding", embedding=[0.4, 0.5, 0.6], index=1), + ], + model="nomic-embed-text", + usage=OpenAIEmbeddingUsage(prompt_tokens=6, total_tokens=6), + ) + + +class TestInferenceRecording: + """Test the inference recording system.""" + + def test_request_normalization(self): + """Test that request normalization produces consistent hashes.""" + # Test basic normalization + hash1 = normalize_request( + "POST", + "http://localhost:11434/v1/chat/completions", + {}, + {"model": "llama3.2:3b", "messages": [{"role": "user", "content": "Hello world"}], "temperature": 0.7}, + ) + + # Same request should produce same hash + hash2 = normalize_request( + "POST", + "http://localhost:11434/v1/chat/completions", + {}, + {"model": "llama3.2:3b", "messages": [{"role": "user", "content": "Hello world"}], "temperature": 0.7}, + ) + + assert hash1 == hash2 + + # Different content should produce different hash + hash3 = normalize_request( + "POST", + "http://localhost:11434/v1/chat/completions", + {}, + { + "model": "llama3.2:3b", + "messages": [{"role": "user", "content": "Different message"}], + "temperature": 0.7, + }, + ) + + assert hash1 != hash3 + + def test_request_normalization_edge_cases(self): + """Test request normalization is precise about request content.""" + # Test that different whitespace produces different hashes (no normalization) + hash1 = normalize_request( + "POST", + "http://test/v1/chat/completions", + {}, + {"messages": [{"role": "user", "content": "Hello world\n\n"}]}, + ) + hash2 = normalize_request( + "POST", "http://test/v1/chat/completions", {}, {"messages": [{"role": "user", "content": "Hello world"}]} + ) + assert hash1 != hash2 # Different whitespace should produce different hashes + + # Test that different float precision produces different hashes (no rounding) + hash3 = normalize_request("POST", "http://test/v1/chat/completions", {}, {"temperature": 0.7000001}) + hash4 = normalize_request("POST", "http://test/v1/chat/completions", {}, {"temperature": 0.7}) + assert hash3 != hash4 # Different precision should produce different hashes + + def test_response_storage(self, temp_storage_dir): + """Test the ResponseStorage class.""" + temp_storage_dir = temp_storage_dir / "test_response_storage" + storage = ResponseStorage(temp_storage_dir) + + # Test storing and retrieving a recording + request_hash = "test_hash_123" + request_data = { + "method": "POST", + "url": "http://localhost:11434/v1/chat/completions", + "endpoint": "/v1/chat/completions", + "model": "llama3.2:3b", + } + response_data = {"body": {"content": "test response"}, "is_streaming": False} + + storage.store_recording(request_hash, request_data, response_data) + + # Verify file storage and retrieval + retrieved = storage.find_recording(request_hash) + assert retrieved is not None + assert retrieved["request"]["model"] == "llama3.2:3b" + assert retrieved["response"]["body"]["content"] == "test response" + + async def test_recording_mode(self, temp_storage_dir, real_openai_chat_response): + """Test that recording mode captures and stores responses.""" + + async def mock_create(*args, **kwargs): + return real_openai_chat_response + + temp_storage_dir = temp_storage_dir / "test_recording_mode" + with patch("openai.resources.chat.completions.AsyncCompletions.create", side_effect=mock_create): + with api_recording(mode=APIRecordingMode.RECORD, storage_dir=str(temp_storage_dir)): + client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test") + + response = await client.chat.completions.create( + model="llama3.2:3b", + messages=[{"role": "user", "content": "Hello, how are you?"}], + temperature=0.7, + max_tokens=50, + ) + + # Verify the response was returned correctly + assert response.choices[0].message.content == "Hello! I'm doing well, thank you for asking." + + # Verify recording was stored + storage = ResponseStorage(temp_storage_dir) + assert storage._get_test_dir().exists() + + async def test_replay_mode(self, temp_storage_dir, real_openai_chat_response): + """Test that replay mode returns stored responses without making real calls.""" + + async def mock_create(*args, **kwargs): + return real_openai_chat_response + + temp_storage_dir = temp_storage_dir / "test_replay_mode" + # First, record a response + with patch("openai.resources.chat.completions.AsyncCompletions.create", side_effect=mock_create): + with api_recording(mode=APIRecordingMode.RECORD, storage_dir=str(temp_storage_dir)): + client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test") + + response = await client.chat.completions.create( + model="llama3.2:3b", + messages=[{"role": "user", "content": "Hello, how are you?"}], + temperature=0.7, + max_tokens=50, + ) + + # Now test replay mode - should not call the original method + with patch("openai.resources.chat.completions.AsyncCompletions.create") as mock_create_patch: + with api_recording(mode=APIRecordingMode.REPLAY, storage_dir=str(temp_storage_dir)): + client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test") + + response = await client.chat.completions.create( + model="llama3.2:3b", + messages=[{"role": "user", "content": "Hello, how are you?"}], + temperature=0.7, + max_tokens=50, + ) + + # Verify we got the recorded response + assert response.choices[0].message.content == "Hello! I'm doing well, thank you for asking." + + # Verify the original method was NOT called + mock_create_patch.assert_not_called() + + async def test_replay_missing_recording(self, temp_storage_dir): + """Test that replay mode fails when no recording is found.""" + temp_storage_dir = temp_storage_dir / "test_replay_missing_recording" + with patch("openai.resources.chat.completions.AsyncCompletions.create"): + with api_recording(mode=APIRecordingMode.REPLAY, storage_dir=str(temp_storage_dir)): + client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test") + + with pytest.raises(RuntimeError, match="No recorded response found"): + await client.chat.completions.create( + model="llama3.2:3b", messages=[{"role": "user", "content": "This was never recorded"}] + ) + + async def test_embeddings_recording(self, temp_storage_dir, real_embeddings_response): + """Test recording and replay of embeddings calls.""" + + async def mock_create(*args, **kwargs): + return real_embeddings_response + + temp_storage_dir = temp_storage_dir / "test_embeddings_recording" + # Record + with patch("openai.resources.embeddings.AsyncEmbeddings.create", side_effect=mock_create): + with api_recording(mode=APIRecordingMode.RECORD, storage_dir=str(temp_storage_dir)): + client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test") + + response = await client.embeddings.create( + model="nomic-embed-text", input=["Hello world", "Test embedding"] + ) + + assert len(response.data) == 2 + + # Replay + with patch("openai.resources.embeddings.AsyncEmbeddings.create") as mock_create_patch: + with api_recording(mode=APIRecordingMode.REPLAY, storage_dir=str(temp_storage_dir)): + client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test") + + response = await client.embeddings.create( + model="nomic-embed-text", input=["Hello world", "Test embedding"] + ) + + # Verify we got the recorded response + assert len(response.data) == 2 + assert response.data[0].embedding == [0.1, 0.2, 0.3] + + # Verify original method was not called + mock_create_patch.assert_not_called() + + async def test_live_mode(self, real_openai_chat_response): + """Test that live mode passes through to original methods.""" + + async def mock_create(*args, **kwargs): + return real_openai_chat_response + + with patch("openai.resources.chat.completions.AsyncCompletions.create", side_effect=mock_create): + with api_recording(mode=APIRecordingMode.LIVE, storage_dir="foo"): + client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test") + + response = await client.chat.completions.create( + model="llama3.2:3b", messages=[{"role": "user", "content": "Hello"}] + ) + + # Verify the response was returned + assert response.choices[0].message.content == "Hello! I'm doing well, thank you for asking." diff --git a/tests/unit/distribution/test_inference_recordings.py b/tests/unit/distribution/test_inference_recordings.py deleted file mode 100644 index cb6b92837..000000000 --- a/tests/unit/distribution/test_inference_recordings.py +++ /dev/null @@ -1,382 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import tempfile -from pathlib import Path -from unittest.mock import AsyncMock, Mock, patch - -import pytest -from openai import NOT_GIVEN, AsyncOpenAI -from openai.types.model import Model as OpenAIModel - -# Import the real Pydantic response types instead of using Mocks -from llama_stack.apis.inference import ( - OpenAIAssistantMessageParam, - OpenAIChatCompletion, - OpenAIChoice, - OpenAICompletion, - OpenAIEmbeddingData, - OpenAIEmbeddingsResponse, - OpenAIEmbeddingUsage, -) -from llama_stack.testing.inference_recorder import ( - InferenceMode, - ResponseStorage, - inference_recording, - normalize_request, -) - - -@pytest.fixture -def temp_storage_dir(): - """Create a temporary directory for test recordings.""" - with tempfile.TemporaryDirectory() as temp_dir: - yield Path(temp_dir) - - -@pytest.fixture -def real_openai_chat_response(): - """Real OpenAI chat completion response using proper Pydantic objects.""" - return OpenAIChatCompletion( - id="chatcmpl-test123", - choices=[ - OpenAIChoice( - index=0, - message=OpenAIAssistantMessageParam( - role="assistant", content="Hello! I'm doing well, thank you for asking." - ), - finish_reason="stop", - ) - ], - created=1234567890, - model="llama3.2:3b", - ) - - -@pytest.fixture -def real_embeddings_response(): - """Real OpenAI embeddings response using proper Pydantic objects.""" - return OpenAIEmbeddingsResponse( - object="list", - data=[ - OpenAIEmbeddingData(object="embedding", embedding=[0.1, 0.2, 0.3], index=0), - OpenAIEmbeddingData(object="embedding", embedding=[0.4, 0.5, 0.6], index=1), - ], - model="nomic-embed-text", - usage=OpenAIEmbeddingUsage(prompt_tokens=6, total_tokens=6), - ) - - -class TestInferenceRecording: - """Test the inference recording system.""" - - def test_request_normalization(self): - """Test that request normalization produces consistent hashes.""" - # Test basic normalization - hash1 = normalize_request( - "POST", - "http://localhost:11434/v1/chat/completions", - {}, - {"model": "llama3.2:3b", "messages": [{"role": "user", "content": "Hello world"}], "temperature": 0.7}, - ) - - # Same request should produce same hash - hash2 = normalize_request( - "POST", - "http://localhost:11434/v1/chat/completions", - {}, - {"model": "llama3.2:3b", "messages": [{"role": "user", "content": "Hello world"}], "temperature": 0.7}, - ) - - assert hash1 == hash2 - - # Different content should produce different hash - hash3 = normalize_request( - "POST", - "http://localhost:11434/v1/chat/completions", - {}, - { - "model": "llama3.2:3b", - "messages": [{"role": "user", "content": "Different message"}], - "temperature": 0.7, - }, - ) - - assert hash1 != hash3 - - def test_request_normalization_edge_cases(self): - """Test request normalization is precise about request content.""" - # Test that different whitespace produces different hashes (no normalization) - hash1 = normalize_request( - "POST", - "http://test/v1/chat/completions", - {}, - {"messages": [{"role": "user", "content": "Hello world\n\n"}]}, - ) - hash2 = normalize_request( - "POST", "http://test/v1/chat/completions", {}, {"messages": [{"role": "user", "content": "Hello world"}]} - ) - assert hash1 != hash2 # Different whitespace should produce different hashes - - # Test that different float precision produces different hashes (no rounding) - hash3 = normalize_request("POST", "http://test/v1/chat/completions", {}, {"temperature": 0.7000001}) - hash4 = normalize_request("POST", "http://test/v1/chat/completions", {}, {"temperature": 0.7}) - assert hash3 != hash4 # Different precision should produce different hashes - - def test_response_storage(self, temp_storage_dir): - """Test the ResponseStorage class.""" - temp_storage_dir = temp_storage_dir / "test_response_storage" - storage = ResponseStorage(temp_storage_dir) - - # Test storing and retrieving a recording - request_hash = "test_hash_123" - request_data = { - "method": "POST", - "url": "http://localhost:11434/v1/chat/completions", - "endpoint": "/v1/chat/completions", - "model": "llama3.2:3b", - } - response_data = {"body": {"content": "test response"}, "is_streaming": False} - - storage.store_recording(request_hash, request_data, response_data) - - # Verify file storage and retrieval - retrieved = storage.find_recording(request_hash) - assert retrieved is not None - assert retrieved["request"]["model"] == "llama3.2:3b" - assert retrieved["response"]["body"]["content"] == "test response" - - async def test_recording_mode(self, temp_storage_dir, real_openai_chat_response): - """Test that recording mode captures and stores responses.""" - temp_storage_dir = temp_storage_dir / "test_recording_mode" - with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)): - client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test") - client.chat.completions._post = AsyncMock(return_value=real_openai_chat_response) - - response = await client.chat.completions.create( - model="llama3.2:3b", - messages=[{"role": "user", "content": "Hello, how are you?"}], - temperature=0.7, - max_tokens=50, - user=NOT_GIVEN, - ) - - # Verify the response was returned correctly - assert response.choices[0].message.content == "Hello! I'm doing well, thank you for asking." - client.chat.completions._post.assert_called_once() - - # Verify recording was stored - storage = ResponseStorage(temp_storage_dir) - dir = storage._get_test_dir() - assert dir.exists() - - async def test_replay_mode(self, temp_storage_dir, real_openai_chat_response): - """Test that replay mode returns stored responses without making real calls.""" - temp_storage_dir = temp_storage_dir / "test_replay_mode" - # First, record a response - with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)): - client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test") - client.chat.completions._post = AsyncMock(return_value=real_openai_chat_response) - - response = await client.chat.completions.create( - model="llama3.2:3b", - messages=[{"role": "user", "content": "Hello, how are you?"}], - temperature=0.7, - max_tokens=50, - user=NOT_GIVEN, - ) - client.chat.completions._post.assert_called_once() - - # Now test replay mode - should not call the original method - with inference_recording(mode=InferenceMode.REPLAY, storage_dir=str(temp_storage_dir)): - client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test") - client.chat.completions._post = AsyncMock(return_value=real_openai_chat_response) - - response = await client.chat.completions.create( - model="llama3.2:3b", - messages=[{"role": "user", "content": "Hello, how are you?"}], - temperature=0.7, - max_tokens=50, - ) - - # Verify we got the recorded response - assert response.choices[0].message.content == "Hello! I'm doing well, thank you for asking." - - # Verify the original method was NOT called - client.chat.completions._post.assert_not_called() - - async def test_replay_mode_models(self, temp_storage_dir): - """Test that replay mode returns stored responses without making real model listing calls.""" - - async def _async_iterator(models): - for model in models: - yield model - - models = [ - OpenAIModel(id="foo", created=1, object="model", owned_by="test"), - OpenAIModel(id="bar", created=2, object="model", owned_by="test"), - ] - - expected_ids = {m.id for m in models} - - temp_storage_dir = temp_storage_dir / "test_replay_mode_models" - - # baseline - mock works without recording - client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test") - client.models._get_api_list = Mock(return_value=_async_iterator(models)) - assert {m.id async for m in client.models.list()} == expected_ids - client.models._get_api_list.assert_called_once() - - # record the call - with inference_recording(mode=InferenceMode.RECORD, storage_dir=temp_storage_dir): - client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test") - client.models._get_api_list = Mock(return_value=_async_iterator(models)) - assert {m.id async for m in client.models.list()} == expected_ids - client.models._get_api_list.assert_called_once() - - # replay the call - with inference_recording(mode=InferenceMode.REPLAY, storage_dir=temp_storage_dir): - client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test") - client.models._get_api_list = Mock(return_value=_async_iterator(models)) - assert {m.id async for m in client.models.list()} == expected_ids - client.models._get_api_list.assert_not_called() - - async def test_replay_missing_recording(self, temp_storage_dir): - """Test that replay mode fails when no recording is found.""" - temp_storage_dir = temp_storage_dir / "test_replay_missing_recording" - with patch("openai.resources.chat.completions.AsyncCompletions.create"): - with inference_recording(mode=InferenceMode.REPLAY, storage_dir=str(temp_storage_dir)): - client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test") - - with pytest.raises(RuntimeError, match="No recorded response found"): - await client.chat.completions.create( - model="llama3.2:3b", messages=[{"role": "user", "content": "This was never recorded"}] - ) - - async def test_embeddings_recording(self, temp_storage_dir, real_embeddings_response): - """Test recording and replay of embeddings calls.""" - - # baseline - mock works without recording - client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test") - client.embeddings._post = AsyncMock(return_value=real_embeddings_response) - response = await client.embeddings.create( - model=real_embeddings_response.model, - input=["Hello world", "Test embedding"], - encoding_format=NOT_GIVEN, - ) - assert len(response.data) == 2 - assert response.data[0].embedding == [0.1, 0.2, 0.3] - client.embeddings._post.assert_called_once() - - temp_storage_dir = temp_storage_dir / "test_embeddings_recording" - # Record - with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)): - client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test") - client.embeddings._post = AsyncMock(return_value=real_embeddings_response) - - response = await client.embeddings.create( - model=real_embeddings_response.model, - input=["Hello world", "Test embedding"], - encoding_format=NOT_GIVEN, - dimensions=NOT_GIVEN, - user=NOT_GIVEN, - ) - - assert len(response.data) == 2 - - # Replay - with inference_recording(mode=InferenceMode.REPLAY, storage_dir=str(temp_storage_dir)): - client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test") - client.embeddings._post = AsyncMock(return_value=real_embeddings_response) - - response = await client.embeddings.create( - model=real_embeddings_response.model, - input=["Hello world", "Test embedding"], - ) - - # Verify we got the recorded response - assert len(response.data) == 2 - assert response.data[0].embedding == [0.1, 0.2, 0.3] - - # Verify original method was not called - client.embeddings._post.assert_not_called() - - async def test_completions_recording(self, temp_storage_dir): - real_completions_response = OpenAICompletion( - id="test_completion", - object="text_completion", - created=1234567890, - model="llama3.2:3b", - choices=[ - { - "text": "Hello! I'm doing well, thank you for asking.", - "index": 0, - "logprobs": None, - "finish_reason": "stop", - } - ], - ) - - temp_storage_dir = temp_storage_dir / "test_completions_recording" - - # baseline - mock works without recording - client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test") - client.completions._post = AsyncMock(return_value=real_completions_response) - response = await client.completions.create( - model=real_completions_response.model, - prompt="Hello, how are you?", - temperature=0.7, - max_tokens=50, - user=NOT_GIVEN, - ) - assert response.choices[0].text == real_completions_response.choices[0].text - client.completions._post.assert_called_once() - - # Record - with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)): - client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test") - client.completions._post = AsyncMock(return_value=real_completions_response) - - response = await client.completions.create( - model=real_completions_response.model, - prompt="Hello, how are you?", - temperature=0.7, - max_tokens=50, - user=NOT_GIVEN, - ) - - assert response.choices[0].text == real_completions_response.choices[0].text - client.completions._post.assert_called_once() - - # Replay - with inference_recording(mode=InferenceMode.REPLAY, storage_dir=str(temp_storage_dir)): - client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test") - client.completions._post = AsyncMock(return_value=real_completions_response) - response = await client.completions.create( - model=real_completions_response.model, - prompt="Hello, how are you?", - temperature=0.7, - max_tokens=50, - ) - assert response.choices[0].text == real_completions_response.choices[0].text - client.completions._post.assert_not_called() - - async def test_live_mode(self, real_openai_chat_response): - """Test that live mode passes through to original methods.""" - - async def mock_create(*args, **kwargs): - return real_openai_chat_response - - with patch("openai.resources.chat.completions.AsyncCompletions.create", side_effect=mock_create): - with inference_recording(mode=InferenceMode.LIVE, storage_dir="foo"): - client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test") - - response = await client.chat.completions.create( - model="llama3.2:3b", messages=[{"role": "user", "content": "Hello"}] - ) - - # Verify the response was returned - assert response.choices[0].message.content == "Hello! I'm doing well, thank you for asking."