diff --git a/llama_stack/testing/__init__.py b/llama_stack/testing/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/llama_stack/testing/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/llama_stack/testing/inference_recorder.py b/llama_stack/testing/inference_recorder.py new file mode 100644 index 000000000..31607a312 --- /dev/null +++ b/llama_stack/testing/inference_recorder.py @@ -0,0 +1,609 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from __future__ import annotations + +import hashlib +import json +import os +import sqlite3 +import uuid +from collections.abc import Generator +from contextlib import contextmanager +from pathlib import Path +from typing import Any, cast + +# Global state for the recording system +_current_mode: str | None = None +_current_storage: ResponseStorage | None = None +_original_methods: dict[str, Any] = {} + + +def normalize_request(method: str, url: str, headers: dict[str, Any], body: dict[str, Any]) -> str: + """Create a normalized hash of the request for consistent matching.""" + # Extract just the endpoint path + from urllib.parse import urlparse + + parsed = urlparse(url) + endpoint = parsed.path + + # Create normalized request dict + normalized: dict[str, Any] = { + "method": method.upper(), + "endpoint": endpoint, + } + + # Normalize body parameters + if body: + # Handle model parameter + if "model" in body: + normalized["model"] = body["model"] + + # Handle messages (normalize whitespace) + if "messages" in body: + normalized_messages = [] + for msg in body["messages"]: + normalized_msg = dict(msg) + if "content" in normalized_msg and isinstance(normalized_msg["content"], str): + # Normalize whitespace + normalized_msg["content"] = " ".join(normalized_msg["content"].split()) + normalized_messages.append(normalized_msg) + normalized["messages"] = normalized_messages + + # Handle other parameters (sort for consistency) + other_params = {} + for key, value in body.items(): + if key not in ["model", "messages"]: + if isinstance(value, float): + # Round floats to 6 decimal places + other_params[key] = round(value, 6) + else: + other_params[key] = value + + if other_params: + # Sort dictionary keys for consistent hashing + normalized["parameters"] = dict(sorted(other_params.items())) + + # Create hash + normalized_json = json.dumps(normalized, sort_keys=True) + return hashlib.sha256(normalized_json.encode()).hexdigest() + + +def get_current_test_id() -> str: + """Extract test ID from pytest context or fall back to environment/generated ID.""" + # Try to get from pytest context + try: + import _pytest.fixtures + + if hasattr(_pytest.fixtures, "_current_request") and _pytest.fixtures._current_request: + request = _pytest.fixtures._current_request + if hasattr(request, "node"): + # Use the test node ID as our test identifier + node_id: str = request.node.nodeid + # Clean up the node ID to be filesystem-safe + test_id = node_id.replace("/", "_").replace("::", "_").replace(".py", "") + return test_id + except AttributeError: + pass + + # Fall back to environment-based or generated ID + return os.environ.get("LLAMA_STACK_TEST_ID", f"test_{uuid.uuid4().hex[:8]}") + + +def get_inference_mode() -> str: + """Get the inference recording mode from environment variables.""" + return os.environ.get("LLAMA_STACK_INFERENCE_MODE", "live").lower() + + +def setup_inference_recording(): + """Convenience function to set up inference recording based on environment variables.""" + mode = get_inference_mode() + + if mode not in ["live", "record", "replay"]: + raise ValueError(f"Invalid LLAMA_STACK_INFERENCE_MODE: {mode}. Must be 'live', 'record', or 'replay'") + + if mode == "live": + # Return a no-op context manager for live mode + @contextmanager + def live_mode(): + yield + + return live_mode() + + test_id = get_current_test_id() + storage_dir = os.environ.get("LLAMA_STACK_RECORDING_DIR", str(Path.home() / ".llama" / "recordings")) + + return inference_recording(mode=mode, test_id=test_id, storage_dir=storage_dir) + + +def _serialize_response(response: Any) -> Any: + """Serialize OpenAI response objects to JSON-compatible format.""" + if hasattr(response, "model_dump"): + return response.model_dump() + elif hasattr(response, "__dict__"): + return dict(response.__dict__) + else: + return response + + +def _deserialize_response(data: dict[str, Any]) -> dict[str, Any]: + """Deserialize response data back to a dict format.""" + # For simplicity, just return the dict - this preserves all the data + # The original response structure is sufficient for replaying + return data + + +class ResponseStorage: + """Handles SQLite index + JSON file storage/retrieval for inference recordings.""" + + def __init__(self, base_dir: Path, test_id: str): + self.base_dir = base_dir + self.test_id = test_id + self.test_dir = base_dir / test_id + self.responses_dir = self.test_dir / "responses" + self.db_path = self.test_dir / "index.sqlite" + + self._ensure_directories() + self._init_database() + + def _ensure_directories(self): + self.test_dir.mkdir(parents=True, exist_ok=True) + self.responses_dir.mkdir(exist_ok=True) + + def _init_database(self): + with sqlite3.connect(self.db_path) as conn: + conn.execute(""" + CREATE TABLE IF NOT EXISTS recordings ( + request_hash TEXT PRIMARY KEY, + response_file TEXT, + endpoint TEXT, + model TEXT, + timestamp TEXT, + is_streaming BOOLEAN + ) + """) + + def store_recording(self, request_hash: str, request: dict[str, Any], response: dict[str, Any]): + """Store a request/response pair.""" + # Generate unique response filename + response_file = f"{request_hash[:12]}.json" + response_path = self.responses_dir / response_file + + # Serialize response body if needed + serialized_response = dict(response) + if "body" in serialized_response: + if isinstance(serialized_response["body"], list): + # Handle streaming responses (list of chunks) + serialized_response["body"] = [_serialize_response(chunk) for chunk in serialized_response["body"]] + else: + # Handle single response + serialized_response["body"] = _serialize_response(serialized_response["body"]) + + # Save response to JSON file + with open(response_path, "w") as f: + json.dump({"request": request, "response": serialized_response}, f, indent=2) + + # Update SQLite index + with sqlite3.connect(self.db_path) as conn: + conn.execute( + """ + INSERT OR REPLACE INTO recordings + (request_hash, response_file, endpoint, model, timestamp, is_streaming) + VALUES (?, ?, ?, ?, datetime('now'), ?) + """, + ( + request_hash, + response_file, + request.get("endpoint", ""), + request.get("model", ""), + response.get("is_streaming", False), + ), + ) + + def find_recording(self, request_hash: str) -> dict[str, Any] | None: + """Find a recorded response by request hash.""" + with sqlite3.connect(self.db_path) as conn: + result = conn.execute( + "SELECT response_file FROM recordings WHERE request_hash = ?", (request_hash,) + ).fetchone() + + if not result: + return None + + response_file = result[0] + response_path = self.responses_dir / response_file + + if not response_path.exists(): + return None + + with open(response_path) as f: + data = json.load(f) + + # Deserialize response body if needed + if "response" in data and "body" in data["response"]: + if isinstance(data["response"]["body"], list): + # Handle streaming responses + data["response"]["body"] = [_deserialize_response(chunk) for chunk in data["response"]["body"]] + else: + # Handle single response + data["response"]["body"] = _deserialize_response(data["response"]["body"]) + + return cast(dict[str, Any], data) + + +async def _patched_create_method(original_method, self, **kwargs): + """Patched version of OpenAI client create methods.""" + global _current_mode, _current_storage + + if _current_mode == "live" or _current_storage is None: + # Normal operation + return await original_method(self, **kwargs) + + # Get base URL from the client + base_url = str(self._client.base_url) + + # Determine endpoint based on the method's module/class path + method_str = str(original_method) + if "chat.completions" in method_str: + endpoint = "/v1/chat/completions" + elif "embeddings" in method_str: + endpoint = "/v1/embeddings" + elif "completions" in method_str: + endpoint = "/v1/completions" + else: + # Fallback - try to guess from the self object + if hasattr(self, "_resource") and hasattr(self._resource, "_resource"): + resource_name = getattr(self._resource._resource, "_resource", "unknown") + if "chat" in str(resource_name): + endpoint = "/v1/chat/completions" + elif "embeddings" in str(resource_name): + endpoint = "/v1/embeddings" + else: + endpoint = "/v1/completions" + else: + endpoint = "/v1/completions" + + url = base_url.rstrip("/") + endpoint + + # Normalize request for matching + method = "POST" + headers = {} + body = kwargs + + request_hash = normalize_request(method, url, headers, body) + + if _current_mode == "replay": + # Try to find recorded response + recording = _current_storage.find_recording(request_hash) + if recording: + # Return recorded response + response_body = recording["response"]["body"] + + # Handle streaming responses + if recording["response"].get("is_streaming", False): + # For streaming, we need to return an async iterator + async def replay_stream(): + for chunk in response_body: + yield chunk + + return replay_stream() + else: + return response_body + else: + raise RuntimeError( + f"No recorded response found for request hash: {request_hash}\n" + f"Endpoint: {endpoint}\n" + f"Model: {body.get('model', 'unknown')}\n" + f"To record this response, run with LLAMA_STACK_INFERENCE_MODE=record" + ) + + elif _current_mode == "record": + # Make real request and record it + response = await original_method(self, **kwargs) + + # Store the recording + request_data = { + "method": method, + "url": url, + "headers": headers, + "body": body, + "endpoint": endpoint, + "model": body.get("model", ""), + } + + # Determine if this is a streaming request based on request parameters + is_streaming = body.get("stream", False) + + if is_streaming: + # For streaming responses, we need to collect all chunks immediately before yielding + # This ensures the recording is saved even if the generator isn't fully consumed + chunks = [] + async for chunk in response: + chunks.append(chunk) + + # Store the recording immediately + response_data = {"body": chunks, "is_streaming": True} + _current_storage.store_recording(request_hash, request_data, response_data) + + # Return a generator that replays the stored chunks + async def replay_recorded_stream(): + for chunk in chunks: + yield chunk + + return replay_recorded_stream() + else: + response_data = {"body": response, "is_streaming": False} + _current_storage.store_recording(request_hash, request_data, response_data) + return response + + else: + return await original_method(self, **kwargs) + + +async def _patched_ollama_method(original_method, self, method_name, **kwargs): + """Patched version of Ollama AsyncClient methods.""" + global _current_mode, _current_storage + + if _current_mode == "live" or _current_storage is None: + # Normal operation + return await original_method(self, **kwargs) + + # Get base URL from the client (Ollama client uses host attribute) + base_url = getattr(self, "host", "http://localhost:11434") + if not base_url.startswith("http"): + base_url = f"http://{base_url}" + + # Determine endpoint based on method name + if method_name == "generate": + endpoint = "/api/generate" + elif method_name == "chat": + endpoint = "/api/chat" + elif method_name == "embed": + endpoint = "/api/embeddings" + else: + endpoint = f"/api/{method_name}" + + url = base_url.rstrip("/") + endpoint + + # Normalize request for matching + method = "POST" + headers = {} + body = kwargs + + request_hash = normalize_request(method, url, headers, body) + + if _current_mode == "replay": + # Try to find recorded response + recording = _current_storage.find_recording(request_hash) + if recording: + # Return recorded response + response_body = recording["response"]["body"] + + # Handle streaming responses for Ollama + if recording["response"].get("is_streaming", False): + # For streaming, we need to return an async iterator + async def replay_ollama_stream(): + for chunk in response_body: + yield chunk + + return replay_ollama_stream() + else: + return response_body + else: + raise RuntimeError( + f"No recorded response found for request hash: {request_hash}\n" + f"Endpoint: {endpoint}\n" + f"Model: {body.get('model', 'unknown')}\n" + f"To record this response, run with LLAMA_STACK_INFERENCE_MODE=record" + ) + + elif _current_mode == "record": + # Make real request and record it + response = await original_method(self, **kwargs) + + # Store the recording + request_data = { + "method": method, + "url": url, + "headers": headers, + "body": body, + "endpoint": endpoint, + "model": body.get("model", ""), + } + + # Determine if this is a streaming request based on request parameters + is_streaming = body.get("stream", False) + + if is_streaming: + # For streaming responses, we need to collect all chunks immediately before yielding + # This ensures the recording is saved even if the generator isn't fully consumed + chunks = [] + async for chunk in response: + chunks.append(chunk) + + # Store the recording immediately + response_data = {"body": chunks, "is_streaming": True} + _current_storage.store_recording(request_hash, request_data, response_data) + + # Return a generator that replays the stored chunks + async def replay_recorded_stream(): + for chunk in chunks: + yield chunk + + return replay_recorded_stream() + else: + response_data = {"body": response, "is_streaming": False} + _current_storage.store_recording(request_hash, request_data, response_data) + return response + + else: + return await original_method(self, **kwargs) + + +def patch_inference_clients(): + """Install monkey patches for OpenAI client methods and Ollama AsyncClient methods.""" + global _original_methods + + # Import here to avoid circular imports + from openai import AsyncOpenAI + from openai.resources.chat.completions import AsyncCompletions as AsyncChatCompletions + from openai.resources.completions import AsyncCompletions + from openai.resources.embeddings import AsyncEmbeddings + + # Also import Ollama AsyncClient + try: + from ollama import AsyncClient as OllamaAsyncClient + except ImportError: + ollama_async_client = None + else: + ollama_async_client = OllamaAsyncClient + + # Store original methods for both OpenAI and Ollama clients + _original_methods = { + "chat_completions_create": AsyncChatCompletions.create, + "completions_create": AsyncCompletions.create, + "embeddings_create": AsyncEmbeddings.create, + } + + # Add Ollama client methods if available + if ollama_async_client: + _original_methods.update( + { + "ollama_generate": ollama_async_client.generate, + "ollama_chat": ollama_async_client.chat, + "ollama_embed": ollama_async_client.embed, + } + ) + + # Create patched methods for OpenAI client + async def patched_chat_completions_create(self, **kwargs): + return await _patched_create_method(_original_methods["chat_completions_create"], self, **kwargs) + + async def patched_completions_create(self, **kwargs): + return await _patched_create_method(_original_methods["completions_create"], self, **kwargs) + + async def patched_embeddings_create(self, **kwargs): + return await _patched_create_method(_original_methods["embeddings_create"], self, **kwargs) + + # Apply OpenAI patches + AsyncChatCompletions.create = patched_chat_completions_create + AsyncCompletions.create = patched_completions_create + AsyncEmbeddings.create = patched_embeddings_create + + # Create patched methods for Ollama client + if ollama_async_client: + + async def patched_ollama_generate(self, **kwargs): + return await _patched_ollama_method(_original_methods["ollama_generate"], self, "generate", **kwargs) + + async def patched_ollama_chat(self, **kwargs): + return await _patched_ollama_method(_original_methods["ollama_chat"], self, "chat", **kwargs) + + async def patched_ollama_embed(self, **kwargs): + return await _patched_ollama_method(_original_methods["ollama_embed"], self, "embed", **kwargs) + + # Apply Ollama patches + ollama_async_client.generate = patched_ollama_generate + ollama_async_client.chat = patched_ollama_chat + ollama_async_client.embed = patched_ollama_embed + + # Also try to patch the AsyncOpenAI __init__ to trace client creation + original_openai_init = AsyncOpenAI.__init__ + + def patched_openai_init(self, *args, **kwargs): + result = original_openai_init(self, *args, **kwargs) + + # After client is created, try to re-patch its methods + if hasattr(self, "chat") and hasattr(self.chat, "completions"): + original_chat_create = self.chat.completions.create + + async def instance_patched_chat_create(**kwargs): + return await _patched_create_method(original_chat_create, self.chat.completions, **kwargs) + + self.chat.completions.create = instance_patched_chat_create + + return result + + AsyncOpenAI.__init__ = patched_openai_init + + +def unpatch_inference_clients(): + """Remove monkey patches and restore original OpenAI and Ollama client methods.""" + global _original_methods + + if not _original_methods: + return + + # Import here to avoid circular imports + from openai.resources.chat.completions import AsyncCompletions as AsyncChatCompletions + from openai.resources.completions import AsyncCompletions + from openai.resources.embeddings import AsyncEmbeddings + + # Restore OpenAI client methods + if "chat_completions_create" in _original_methods: + AsyncChatCompletions.create = _original_methods["chat_completions_create"] + + if "completions_create" in _original_methods: + AsyncCompletions.create = _original_methods["completions_create"] + + if "embeddings_create" in _original_methods: + AsyncEmbeddings.create = _original_methods["embeddings_create"] + + # Restore Ollama client methods if they were patched + try: + from ollama import AsyncClient as OllamaAsyncClient + + if "ollama_generate" in _original_methods: + OllamaAsyncClient.generate = _original_methods["ollama_generate"] + + if "ollama_chat" in _original_methods: + OllamaAsyncClient.chat = _original_methods["ollama_chat"] + + if "ollama_embed" in _original_methods: + OllamaAsyncClient.embed = _original_methods["ollama_embed"] + + except ImportError: + pass + + _original_methods.clear() + + +@contextmanager +def inference_recording( + mode: str = "live", test_id: str | None = None, storage_dir: str | Path | None = None +) -> Generator[None, None, None]: + """Context manager for inference recording/replaying.""" + global _current_mode, _current_storage + + # Set defaults + if storage_dir is None: + storage_dir_path = Path.home() / ".llama" / "recordings" + else: + storage_dir_path = Path(storage_dir) + + if test_id is None: + test_id = f"test_{uuid.uuid4().hex[:8]}" + + # Store previous state + prev_mode = _current_mode + prev_storage = _current_storage + + try: + _current_mode = mode + + if mode in ["record", "replay"]: + _current_storage = ResponseStorage(storage_dir_path, test_id) + patch_inference_clients() + + yield + + finally: + # Restore previous state + if mode in ["record", "replay"]: + unpatch_inference_clients() + + _current_mode = prev_mode + _current_storage = prev_storage diff --git a/tests/integration/fixtures/common.py b/tests/integration/fixtures/common.py index f6b5b3026..4cffa9333 100644 --- a/tests/integration/fixtures/common.py +++ b/tests/integration/fixtures/common.py @@ -185,70 +185,95 @@ def llama_stack_client(request, provider_data): if not config: raise ValueError("You must specify either --stack-config or LLAMA_STACK_CONFIG") - # Handle server: format or server:: - if config.startswith("server:"): - parts = config.split(":") - config_name = parts[1] - port = int(parts[2]) if len(parts) > 2 else int(os.environ.get("LLAMA_STACK_PORT", DEFAULT_PORT)) - base_url = f"http://localhost:{port}" + # Set up inference recording if enabled + inference_mode = os.environ.get("LLAMA_STACK_INFERENCE_MODE", "live").lower() + recording_context = None - # Check if port is available - if is_port_available(port): - print(f"Starting llama stack server with config '{config_name}' on port {port}...") + if inference_mode in ["record", "replay"]: + from llama_stack.testing.inference_recorder import setup_inference_recording - # Start server - server_process = start_llama_stack_server(config_name) + recording_context = setup_inference_recording() + recording_context.__enter__() + print(f"Inference recording enabled: mode={inference_mode}") - # Wait for server to be ready - if not wait_for_server_ready(base_url, timeout=120, process=server_process): - print("Server failed to start within timeout") - server_process.terminate() - raise RuntimeError( - f"Server failed to start within timeout. Check that config '{config_name}' exists and is valid. " - f"See server.log for details." - ) - - print(f"Server is ready at {base_url}") - - # Store process for potential cleanup (pytest will handle termination at session end) - request.session._llama_stack_server_process = server_process - else: - print(f"Port {port} is already in use, assuming server is already running...") - - return LlamaStackClient( - base_url=base_url, - provider_data=provider_data, - timeout=int(os.environ.get("LLAMA_STACK_CLIENT_TIMEOUT", "30")), - ) - - # check if this looks like a URL using proper URL parsing try: - parsed_url = urlparse(config) - if parsed_url.scheme and parsed_url.netloc: - return LlamaStackClient( - base_url=config, + # Handle server: format or server:: + if config.startswith("server:"): + parts = config.split(":") + config_name = parts[1] + port = int(parts[2]) if len(parts) > 2 else int(os.environ.get("LLAMA_STACK_PORT", DEFAULT_PORT)) + base_url = f"http://localhost:{port}" + + # Check if port is available + if is_port_available(port): + print(f"Starting llama stack server with config '{config_name}' on port {port}...") + + # Start server + server_process = start_llama_stack_server(config_name) + + # Wait for server to be ready + if not wait_for_server_ready(base_url, timeout=120, process=server_process): + print("Server failed to start within timeout") + server_process.terminate() + raise RuntimeError( + f"Server failed to start within timeout. Check that config '{config_name}' exists and is valid. " + f"See server.log for details." + ) + + print(f"Server is ready at {base_url}") + + # Store process for potential cleanup (pytest will handle termination at session end) + request.session._llama_stack_server_process = server_process + else: + print(f"Port {port} is already in use, assuming server is already running...") + + client = LlamaStackClient( + base_url=base_url, provider_data=provider_data, + timeout=int(os.environ.get("LLAMA_STACK_CLIENT_TIMEOUT", "30")), ) + else: + # check if this looks like a URL using proper URL parsing + try: + parsed_url = urlparse(config) + if parsed_url.scheme and parsed_url.netloc: + client = LlamaStackClient( + base_url=config, + provider_data=provider_data, + ) + else: + raise ValueError("Not a URL") + except Exception as e: + # If URL parsing fails, treat as library config + if "=" in config: + run_config = run_config_from_adhoc_config_spec(config) + run_config_file = tempfile.NamedTemporaryFile(delete=False, suffix=".yaml") + with open(run_config_file.name, "w") as f: + yaml.dump(run_config.model_dump(), f) + config = run_config_file.name + + client = LlamaStackAsLibraryClient( + config, + provider_data=provider_data, + skip_logger_removal=True, + ) + if not client.initialize(): + raise RuntimeError("Initialization failed") from e + + # Store recording context for cleanup + if recording_context: + request.session._inference_recording_context = recording_context + + return client + except Exception: - # If URL parsing fails, treat as non-URL config - pass - - if "=" in config: - run_config = run_config_from_adhoc_config_spec(config) - run_config_file = tempfile.NamedTemporaryFile(delete=False, suffix=".yaml") - with open(run_config_file.name, "w") as f: - yaml.dump(run_config.model_dump(), f) - config = run_config_file.name - - client = LlamaStackAsLibraryClient( - config, - provider_data=provider_data, - skip_logger_removal=True, - ) - if not client.initialize(): - raise RuntimeError("Initialization failed") - - return client + # Clean up recording context on error + if recording_context: + try: + recording_context.__exit__(None, None, None) + except Exception as cleanup_error: + print(f"Warning: Error cleaning up recording context: {cleanup_error}") + raise @pytest.fixture(scope="session") @@ -264,9 +289,20 @@ def compat_client(request): @pytest.fixture(scope="session", autouse=True) def cleanup_server_process(request): - """Cleanup server process at the end of the test session.""" + """Cleanup server process and inference recording at the end of the test session.""" yield # Run tests + # Clean up inference recording context + if hasattr(request.session, "_inference_recording_context"): + recording_context = request.session._inference_recording_context + if recording_context: + try: + print("Cleaning up inference recording context...") + recording_context.__exit__(None, None, None) + except Exception as e: + print(f"Error during inference recording cleanup: {e}") + + # Clean up server process if hasattr(request.session, "_llama_stack_server_process"): server_process = request.session._llama_stack_server_process if server_process: diff --git a/tests/integration/test_inference_recordings.py b/tests/integration/test_inference_recordings.py new file mode 100644 index 000000000..a27059959 --- /dev/null +++ b/tests/integration/test_inference_recordings.py @@ -0,0 +1,289 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import sqlite3 +import tempfile +from pathlib import Path +from unittest.mock import Mock, patch + +import pytest +from openai import AsyncOpenAI + +from llama_stack.testing.inference_recorder import ( + ResponseStorage, + inference_recording, + normalize_request, +) + + +@pytest.fixture +def temp_storage_dir(): + """Create a temporary directory for test recordings.""" + with tempfile.TemporaryDirectory() as temp_dir: + yield Path(temp_dir) + + +@pytest.fixture +def mock_openai_response(): + """Mock OpenAI response object.""" + mock_response = Mock() + mock_response.choices = [Mock()] + mock_response.choices[0].message.content = "Hello! I'm doing well, thank you for asking." + mock_response.model_dump.return_value = { + "id": "chatcmpl-test123", + "object": "chat.completion", + "choices": [ + { + "index": 0, + "message": {"role": "assistant", "content": "Hello! I'm doing well, thank you for asking."}, + "finish_reason": "stop", + } + ], + "model": "llama3.2:3b", + "usage": {"prompt_tokens": 10, "completion_tokens": 15, "total_tokens": 25}, + } + + return mock_response + + +@pytest.fixture +def mock_embeddings_response(): + """Mock OpenAI embeddings response object.""" + mock_response = Mock() + mock_response.data = [Mock(embedding=[0.1, 0.2, 0.3]), Mock(embedding=[0.4, 0.5, 0.6])] + mock_response.model_dump.return_value = { + "object": "list", + "data": [ + {"object": "embedding", "embedding": [0.1, 0.2, 0.3], "index": 0}, + {"object": "embedding", "embedding": [0.4, 0.5, 0.6], "index": 1}, + ], + "model": "nomic-embed-text", + "usage": {"prompt_tokens": 6, "total_tokens": 6}, + } + + return mock_response + + +class TestInferenceRecording: + """Test the inference recording system.""" + + def test_request_normalization(self): + """Test that request normalization produces consistent hashes.""" + # Test basic normalization + hash1 = normalize_request( + "POST", + "http://localhost:11434/v1/chat/completions", + {}, + {"model": "llama3.2:3b", "messages": [{"role": "user", "content": "Hello world"}], "temperature": 0.7}, + ) + + # Same request should produce same hash + hash2 = normalize_request( + "POST", + "http://localhost:11434/v1/chat/completions", + {}, + {"model": "llama3.2:3b", "messages": [{"role": "user", "content": "Hello world"}], "temperature": 0.7}, + ) + + assert hash1 == hash2 + + # Different content should produce different hash + hash3 = normalize_request( + "POST", + "http://localhost:11434/v1/chat/completions", + {}, + { + "model": "llama3.2:3b", + "messages": [{"role": "user", "content": "Different message"}], + "temperature": 0.7, + }, + ) + + assert hash1 != hash3 + + def test_request_normalization_edge_cases(self): + """Test request normalization handles edge cases correctly.""" + # Test whitespace normalization + hash1 = normalize_request( + "POST", + "http://test/v1/chat/completions", + {}, + {"messages": [{"role": "user", "content": "Hello world\n\n"}]}, + ) + hash2 = normalize_request( + "POST", "http://test/v1/chat/completions", {}, {"messages": [{"role": "user", "content": "Hello world"}]} + ) + assert hash1 == hash2 + + # Test float precision normalization + hash3 = normalize_request("POST", "http://test/v1/chat/completions", {}, {"temperature": 0.7000001}) + hash4 = normalize_request("POST", "http://test/v1/chat/completions", {}, {"temperature": 0.7}) + assert hash3 == hash4 + + def test_response_storage(self, temp_storage_dir): + """Test the ResponseStorage class.""" + storage = ResponseStorage(temp_storage_dir, "test_storage") + + # Test directory creation + assert storage.test_dir.exists() + assert storage.responses_dir.exists() + assert storage.db_path.exists() + + # Test storing and retrieving a recording + request_hash = "test_hash_123" + request_data = { + "method": "POST", + "url": "http://localhost:11434/v1/chat/completions", + "endpoint": "/v1/chat/completions", + "model": "llama3.2:3b", + } + response_data = {"body": {"content": "test response"}, "is_streaming": False} + + storage.store_recording(request_hash, request_data, response_data) + + # Verify SQLite record + with sqlite3.connect(storage.db_path) as conn: + result = conn.execute("SELECT * FROM recordings WHERE request_hash = ?", (request_hash,)).fetchone() + + assert result is not None + assert result[0] == request_hash # request_hash + assert result[2] == "/v1/chat/completions" # endpoint + assert result[3] == "llama3.2:3b" # model + + # Verify file storage and retrieval + retrieved = storage.find_recording(request_hash) + assert retrieved is not None + assert retrieved["request"]["model"] == "llama3.2:3b" + assert retrieved["response"]["body"]["content"] == "test response" + + async def test_recording_mode(self, temp_storage_dir, mock_openai_response): + """Test that recording mode captures and stores responses.""" + test_id = "test_recording_mode" + + async def mock_create(*args, **kwargs): + return mock_openai_response + + with patch("openai.resources.chat.completions.AsyncCompletions.create", side_effect=mock_create): + with inference_recording(mode="record", test_id=test_id, storage_dir=str(temp_storage_dir)): + client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test") + + response = await client.chat.completions.create( + model="llama3.2:3b", + messages=[{"role": "user", "content": "Hello, how are you?"}], + temperature=0.7, + max_tokens=50, + ) + + # Verify the response was returned correctly + assert response.choices[0].message.content == "Hello! I'm doing well, thank you for asking." + + # Verify recording was stored + storage = ResponseStorage(temp_storage_dir, test_id) + with sqlite3.connect(storage.db_path) as conn: + recordings = conn.execute("SELECT COUNT(*) FROM recordings").fetchone()[0] + + assert recordings == 1 + + async def test_replay_mode(self, temp_storage_dir, mock_openai_response): + """Test that replay mode returns stored responses without making real calls.""" + test_id = "test_replay_mode" + + async def mock_create(*args, **kwargs): + return mock_openai_response + + # First, record a response + with patch("openai.resources.chat.completions.AsyncCompletions.create", side_effect=mock_create): + with inference_recording(mode="record", test_id=test_id, storage_dir=str(temp_storage_dir)): + client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test") + + await client.chat.completions.create( + model="llama3.2:3b", + messages=[{"role": "user", "content": "Hello, how are you?"}], + temperature=0.7, + max_tokens=50, + ) + + # Now test replay mode - should not call the original method + with patch("openai.resources.chat.completions.AsyncCompletions.create") as mock_create_patch: + with inference_recording(mode="replay", test_id=test_id, storage_dir=str(temp_storage_dir)): + client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test") + + response = await client.chat.completions.create( + model="llama3.2:3b", + messages=[{"role": "user", "content": "Hello, how are you?"}], + temperature=0.7, + max_tokens=50, + ) + + # Verify we got the recorded response + assert response["choices"][0]["message"]["content"] == "Hello! I'm doing well, thank you for asking." + + # Verify the original method was NOT called + mock_create_patch.assert_not_called() + + async def test_replay_missing_recording(self, temp_storage_dir): + """Test that replay mode fails when no recording is found.""" + test_id = "test_missing_recording" + + with patch("openai.resources.chat.completions.AsyncCompletions.create"): + with inference_recording(mode="replay", test_id=test_id, storage_dir=str(temp_storage_dir)): + client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test") + + with pytest.raises(RuntimeError, match="No recorded response found"): + await client.chat.completions.create( + model="llama3.2:3b", messages=[{"role": "user", "content": "This was never recorded"}] + ) + + async def test_embeddings_recording(self, temp_storage_dir, mock_embeddings_response): + """Test recording and replay of embeddings calls.""" + test_id = "test_embeddings" + + async def mock_create(*args, **kwargs): + return mock_embeddings_response + + # Record + with patch("openai.resources.embeddings.AsyncEmbeddings.create", side_effect=mock_create): + with inference_recording(mode="record", test_id=test_id, storage_dir=str(temp_storage_dir)): + client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test") + + response = await client.embeddings.create( + model="nomic-embed-text", input=["Hello world", "Test embedding"] + ) + + assert len(response.data) == 2 + + # Replay + with patch("openai.resources.embeddings.AsyncEmbeddings.create") as mock_create_patch: + with inference_recording(mode="replay", test_id=test_id, storage_dir=str(temp_storage_dir)): + client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test") + + response = await client.embeddings.create( + model="nomic-embed-text", input=["Hello world", "Test embedding"] + ) + + # Verify we got the recorded response + assert len(response["data"]) == 2 + assert response["data"][0]["embedding"] == [0.1, 0.2, 0.3] + + # Verify original method was not called + mock_create_patch.assert_not_called() + + async def test_live_mode(self, mock_openai_response): + """Test that live mode passes through to original methods.""" + + async def mock_create(*args, **kwargs): + return mock_openai_response + + with patch("openai.resources.chat.completions.AsyncCompletions.create", side_effect=mock_create): + with inference_recording(mode="live"): + client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test") + + response = await client.chat.completions.create( + model="llama3.2:3b", messages=[{"role": "user", "content": "Hello"}] + ) + + # Verify the response was returned + assert response.choices[0].message.content == "Hello! I'm doing well, thank you for asking."