mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-22 20:40:00 +00:00
feat(tests): introduce inference record/replay to increase test reliability
Implements a comprehensive recording and replay system for inference API calls that eliminates dependency on online inference providers during testing. The system treats inference as deterministic by recording real API responses and replaying them in subsequent test runs. Applies to OpenAI clients (which should cover many inference requests) as well as OpenAI AsyncClient. For storing, we use a hybrid system: Sqlite for fast lookups and JSON files for easy greppability / debuggability. As expected, tests become much much faster (more than 3x in just inference testing.) ```bash LLAMA_STACK_INFERENCE_MODE=record uv run pytest -s -v tests/integration/inference \ --stack-config=starter \ -k "not( builtin_tool or safety_with_image or code_interpreter or test_rag )" \ --text-model="ollama/llama3.2:3b-instruct-fp16" \ --embedding-model=sentence-transformers/all-MiniLM-L6-v2 ``` ```bash LLAMA_STACK_INFERENCE_MODE=replay LLAMA_STACK_TEST_ID=<test_id> uv run pytest -s -v tests/integration/inference \ --stack-config=starter \ -k "not( builtin_tool or safety_with_image or code_interpreter or test_rag )" \ --text-model="ollama/llama3.2:3b-instruct-fp16" \ --embedding-model=sentence-transformers/all-MiniLM-L6-v2 ``` - `LLAMA_STACK_INFERENCE_MODE`: `live` (default), `record`, or `replay` - `LLAMA_STACK_TEST_ID`: Specify recording directory (auto-generated if not set) - `LLAMA_STACK_RECORDING_DIR`: Storage location (default: ~/.llama/recordings/)
This commit is contained in:
parent
fee365b71e
commit
e59c13f2b8
4 changed files with 997 additions and 58 deletions
5
llama_stack/testing/__init__.py
Normal file
5
llama_stack/testing/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
609
llama_stack/testing/inference_recorder.py
Normal file
609
llama_stack/testing/inference_recorder.py
Normal file
|
|
@ -0,0 +1,609 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sqlite3
|
||||||
|
import uuid
|
||||||
|
from collections.abc import Generator
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, cast
|
||||||
|
|
||||||
|
# Global state for the recording system
|
||||||
|
_current_mode: str | None = None
|
||||||
|
_current_storage: ResponseStorage | None = None
|
||||||
|
_original_methods: dict[str, Any] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_request(method: str, url: str, headers: dict[str, Any], body: dict[str, Any]) -> str:
|
||||||
|
"""Create a normalized hash of the request for consistent matching."""
|
||||||
|
# Extract just the endpoint path
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
parsed = urlparse(url)
|
||||||
|
endpoint = parsed.path
|
||||||
|
|
||||||
|
# Create normalized request dict
|
||||||
|
normalized: dict[str, Any] = {
|
||||||
|
"method": method.upper(),
|
||||||
|
"endpoint": endpoint,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Normalize body parameters
|
||||||
|
if body:
|
||||||
|
# Handle model parameter
|
||||||
|
if "model" in body:
|
||||||
|
normalized["model"] = body["model"]
|
||||||
|
|
||||||
|
# Handle messages (normalize whitespace)
|
||||||
|
if "messages" in body:
|
||||||
|
normalized_messages = []
|
||||||
|
for msg in body["messages"]:
|
||||||
|
normalized_msg = dict(msg)
|
||||||
|
if "content" in normalized_msg and isinstance(normalized_msg["content"], str):
|
||||||
|
# Normalize whitespace
|
||||||
|
normalized_msg["content"] = " ".join(normalized_msg["content"].split())
|
||||||
|
normalized_messages.append(normalized_msg)
|
||||||
|
normalized["messages"] = normalized_messages
|
||||||
|
|
||||||
|
# Handle other parameters (sort for consistency)
|
||||||
|
other_params = {}
|
||||||
|
for key, value in body.items():
|
||||||
|
if key not in ["model", "messages"]:
|
||||||
|
if isinstance(value, float):
|
||||||
|
# Round floats to 6 decimal places
|
||||||
|
other_params[key] = round(value, 6)
|
||||||
|
else:
|
||||||
|
other_params[key] = value
|
||||||
|
|
||||||
|
if other_params:
|
||||||
|
# Sort dictionary keys for consistent hashing
|
||||||
|
normalized["parameters"] = dict(sorted(other_params.items()))
|
||||||
|
|
||||||
|
# Create hash
|
||||||
|
normalized_json = json.dumps(normalized, sort_keys=True)
|
||||||
|
return hashlib.sha256(normalized_json.encode()).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
def get_current_test_id() -> str:
|
||||||
|
"""Extract test ID from pytest context or fall back to environment/generated ID."""
|
||||||
|
# Try to get from pytest context
|
||||||
|
try:
|
||||||
|
import _pytest.fixtures
|
||||||
|
|
||||||
|
if hasattr(_pytest.fixtures, "_current_request") and _pytest.fixtures._current_request:
|
||||||
|
request = _pytest.fixtures._current_request
|
||||||
|
if hasattr(request, "node"):
|
||||||
|
# Use the test node ID as our test identifier
|
||||||
|
node_id: str = request.node.nodeid
|
||||||
|
# Clean up the node ID to be filesystem-safe
|
||||||
|
test_id = node_id.replace("/", "_").replace("::", "_").replace(".py", "")
|
||||||
|
return test_id
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Fall back to environment-based or generated ID
|
||||||
|
return os.environ.get("LLAMA_STACK_TEST_ID", f"test_{uuid.uuid4().hex[:8]}")
|
||||||
|
|
||||||
|
|
||||||
|
def get_inference_mode() -> str:
|
||||||
|
"""Get the inference recording mode from environment variables."""
|
||||||
|
return os.environ.get("LLAMA_STACK_INFERENCE_MODE", "live").lower()
|
||||||
|
|
||||||
|
|
||||||
|
def setup_inference_recording():
|
||||||
|
"""Convenience function to set up inference recording based on environment variables."""
|
||||||
|
mode = get_inference_mode()
|
||||||
|
|
||||||
|
if mode not in ["live", "record", "replay"]:
|
||||||
|
raise ValueError(f"Invalid LLAMA_STACK_INFERENCE_MODE: {mode}. Must be 'live', 'record', or 'replay'")
|
||||||
|
|
||||||
|
if mode == "live":
|
||||||
|
# Return a no-op context manager for live mode
|
||||||
|
@contextmanager
|
||||||
|
def live_mode():
|
||||||
|
yield
|
||||||
|
|
||||||
|
return live_mode()
|
||||||
|
|
||||||
|
test_id = get_current_test_id()
|
||||||
|
storage_dir = os.environ.get("LLAMA_STACK_RECORDING_DIR", str(Path.home() / ".llama" / "recordings"))
|
||||||
|
|
||||||
|
return inference_recording(mode=mode, test_id=test_id, storage_dir=storage_dir)
|
||||||
|
|
||||||
|
|
||||||
|
def _serialize_response(response: Any) -> Any:
|
||||||
|
"""Serialize OpenAI response objects to JSON-compatible format."""
|
||||||
|
if hasattr(response, "model_dump"):
|
||||||
|
return response.model_dump()
|
||||||
|
elif hasattr(response, "__dict__"):
|
||||||
|
return dict(response.__dict__)
|
||||||
|
else:
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
def _deserialize_response(data: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""Deserialize response data back to a dict format."""
|
||||||
|
# For simplicity, just return the dict - this preserves all the data
|
||||||
|
# The original response structure is sufficient for replaying
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
class ResponseStorage:
|
||||||
|
"""Handles SQLite index + JSON file storage/retrieval for inference recordings."""
|
||||||
|
|
||||||
|
def __init__(self, base_dir: Path, test_id: str):
|
||||||
|
self.base_dir = base_dir
|
||||||
|
self.test_id = test_id
|
||||||
|
self.test_dir = base_dir / test_id
|
||||||
|
self.responses_dir = self.test_dir / "responses"
|
||||||
|
self.db_path = self.test_dir / "index.sqlite"
|
||||||
|
|
||||||
|
self._ensure_directories()
|
||||||
|
self._init_database()
|
||||||
|
|
||||||
|
def _ensure_directories(self):
|
||||||
|
self.test_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
self.responses_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
def _init_database(self):
|
||||||
|
with sqlite3.connect(self.db_path) as conn:
|
||||||
|
conn.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS recordings (
|
||||||
|
request_hash TEXT PRIMARY KEY,
|
||||||
|
response_file TEXT,
|
||||||
|
endpoint TEXT,
|
||||||
|
model TEXT,
|
||||||
|
timestamp TEXT,
|
||||||
|
is_streaming BOOLEAN
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
|
||||||
|
def store_recording(self, request_hash: str, request: dict[str, Any], response: dict[str, Any]):
|
||||||
|
"""Store a request/response pair."""
|
||||||
|
# Generate unique response filename
|
||||||
|
response_file = f"{request_hash[:12]}.json"
|
||||||
|
response_path = self.responses_dir / response_file
|
||||||
|
|
||||||
|
# Serialize response body if needed
|
||||||
|
serialized_response = dict(response)
|
||||||
|
if "body" in serialized_response:
|
||||||
|
if isinstance(serialized_response["body"], list):
|
||||||
|
# Handle streaming responses (list of chunks)
|
||||||
|
serialized_response["body"] = [_serialize_response(chunk) for chunk in serialized_response["body"]]
|
||||||
|
else:
|
||||||
|
# Handle single response
|
||||||
|
serialized_response["body"] = _serialize_response(serialized_response["body"])
|
||||||
|
|
||||||
|
# Save response to JSON file
|
||||||
|
with open(response_path, "w") as f:
|
||||||
|
json.dump({"request": request, "response": serialized_response}, f, indent=2)
|
||||||
|
|
||||||
|
# Update SQLite index
|
||||||
|
with sqlite3.connect(self.db_path) as conn:
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
INSERT OR REPLACE INTO recordings
|
||||||
|
(request_hash, response_file, endpoint, model, timestamp, is_streaming)
|
||||||
|
VALUES (?, ?, ?, ?, datetime('now'), ?)
|
||||||
|
""",
|
||||||
|
(
|
||||||
|
request_hash,
|
||||||
|
response_file,
|
||||||
|
request.get("endpoint", ""),
|
||||||
|
request.get("model", ""),
|
||||||
|
response.get("is_streaming", False),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def find_recording(self, request_hash: str) -> dict[str, Any] | None:
|
||||||
|
"""Find a recorded response by request hash."""
|
||||||
|
with sqlite3.connect(self.db_path) as conn:
|
||||||
|
result = conn.execute(
|
||||||
|
"SELECT response_file FROM recordings WHERE request_hash = ?", (request_hash,)
|
||||||
|
).fetchone()
|
||||||
|
|
||||||
|
if not result:
|
||||||
|
return None
|
||||||
|
|
||||||
|
response_file = result[0]
|
||||||
|
response_path = self.responses_dir / response_file
|
||||||
|
|
||||||
|
if not response_path.exists():
|
||||||
|
return None
|
||||||
|
|
||||||
|
with open(response_path) as f:
|
||||||
|
data = json.load(f)
|
||||||
|
|
||||||
|
# Deserialize response body if needed
|
||||||
|
if "response" in data and "body" in data["response"]:
|
||||||
|
if isinstance(data["response"]["body"], list):
|
||||||
|
# Handle streaming responses
|
||||||
|
data["response"]["body"] = [_deserialize_response(chunk) for chunk in data["response"]["body"]]
|
||||||
|
else:
|
||||||
|
# Handle single response
|
||||||
|
data["response"]["body"] = _deserialize_response(data["response"]["body"])
|
||||||
|
|
||||||
|
return cast(dict[str, Any], data)
|
||||||
|
|
||||||
|
|
||||||
|
async def _patched_create_method(original_method, self, **kwargs):
|
||||||
|
"""Patched version of OpenAI client create methods."""
|
||||||
|
global _current_mode, _current_storage
|
||||||
|
|
||||||
|
if _current_mode == "live" or _current_storage is None:
|
||||||
|
# Normal operation
|
||||||
|
return await original_method(self, **kwargs)
|
||||||
|
|
||||||
|
# Get base URL from the client
|
||||||
|
base_url = str(self._client.base_url)
|
||||||
|
|
||||||
|
# Determine endpoint based on the method's module/class path
|
||||||
|
method_str = str(original_method)
|
||||||
|
if "chat.completions" in method_str:
|
||||||
|
endpoint = "/v1/chat/completions"
|
||||||
|
elif "embeddings" in method_str:
|
||||||
|
endpoint = "/v1/embeddings"
|
||||||
|
elif "completions" in method_str:
|
||||||
|
endpoint = "/v1/completions"
|
||||||
|
else:
|
||||||
|
# Fallback - try to guess from the self object
|
||||||
|
if hasattr(self, "_resource") and hasattr(self._resource, "_resource"):
|
||||||
|
resource_name = getattr(self._resource._resource, "_resource", "unknown")
|
||||||
|
if "chat" in str(resource_name):
|
||||||
|
endpoint = "/v1/chat/completions"
|
||||||
|
elif "embeddings" in str(resource_name):
|
||||||
|
endpoint = "/v1/embeddings"
|
||||||
|
else:
|
||||||
|
endpoint = "/v1/completions"
|
||||||
|
else:
|
||||||
|
endpoint = "/v1/completions"
|
||||||
|
|
||||||
|
url = base_url.rstrip("/") + endpoint
|
||||||
|
|
||||||
|
# Normalize request for matching
|
||||||
|
method = "POST"
|
||||||
|
headers = {}
|
||||||
|
body = kwargs
|
||||||
|
|
||||||
|
request_hash = normalize_request(method, url, headers, body)
|
||||||
|
|
||||||
|
if _current_mode == "replay":
|
||||||
|
# Try to find recorded response
|
||||||
|
recording = _current_storage.find_recording(request_hash)
|
||||||
|
if recording:
|
||||||
|
# Return recorded response
|
||||||
|
response_body = recording["response"]["body"]
|
||||||
|
|
||||||
|
# Handle streaming responses
|
||||||
|
if recording["response"].get("is_streaming", False):
|
||||||
|
# For streaming, we need to return an async iterator
|
||||||
|
async def replay_stream():
|
||||||
|
for chunk in response_body:
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
return replay_stream()
|
||||||
|
else:
|
||||||
|
return response_body
|
||||||
|
else:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"No recorded response found for request hash: {request_hash}\n"
|
||||||
|
f"Endpoint: {endpoint}\n"
|
||||||
|
f"Model: {body.get('model', 'unknown')}\n"
|
||||||
|
f"To record this response, run with LLAMA_STACK_INFERENCE_MODE=record"
|
||||||
|
)
|
||||||
|
|
||||||
|
elif _current_mode == "record":
|
||||||
|
# Make real request and record it
|
||||||
|
response = await original_method(self, **kwargs)
|
||||||
|
|
||||||
|
# Store the recording
|
||||||
|
request_data = {
|
||||||
|
"method": method,
|
||||||
|
"url": url,
|
||||||
|
"headers": headers,
|
||||||
|
"body": body,
|
||||||
|
"endpoint": endpoint,
|
||||||
|
"model": body.get("model", ""),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Determine if this is a streaming request based on request parameters
|
||||||
|
is_streaming = body.get("stream", False)
|
||||||
|
|
||||||
|
if is_streaming:
|
||||||
|
# For streaming responses, we need to collect all chunks immediately before yielding
|
||||||
|
# This ensures the recording is saved even if the generator isn't fully consumed
|
||||||
|
chunks = []
|
||||||
|
async for chunk in response:
|
||||||
|
chunks.append(chunk)
|
||||||
|
|
||||||
|
# Store the recording immediately
|
||||||
|
response_data = {"body": chunks, "is_streaming": True}
|
||||||
|
_current_storage.store_recording(request_hash, request_data, response_data)
|
||||||
|
|
||||||
|
# Return a generator that replays the stored chunks
|
||||||
|
async def replay_recorded_stream():
|
||||||
|
for chunk in chunks:
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
return replay_recorded_stream()
|
||||||
|
else:
|
||||||
|
response_data = {"body": response, "is_streaming": False}
|
||||||
|
_current_storage.store_recording(request_hash, request_data, response_data)
|
||||||
|
return response
|
||||||
|
|
||||||
|
else:
|
||||||
|
return await original_method(self, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
async def _patched_ollama_method(original_method, self, method_name, **kwargs):
|
||||||
|
"""Patched version of Ollama AsyncClient methods."""
|
||||||
|
global _current_mode, _current_storage
|
||||||
|
|
||||||
|
if _current_mode == "live" or _current_storage is None:
|
||||||
|
# Normal operation
|
||||||
|
return await original_method(self, **kwargs)
|
||||||
|
|
||||||
|
# Get base URL from the client (Ollama client uses host attribute)
|
||||||
|
base_url = getattr(self, "host", "http://localhost:11434")
|
||||||
|
if not base_url.startswith("http"):
|
||||||
|
base_url = f"http://{base_url}"
|
||||||
|
|
||||||
|
# Determine endpoint based on method name
|
||||||
|
if method_name == "generate":
|
||||||
|
endpoint = "/api/generate"
|
||||||
|
elif method_name == "chat":
|
||||||
|
endpoint = "/api/chat"
|
||||||
|
elif method_name == "embed":
|
||||||
|
endpoint = "/api/embeddings"
|
||||||
|
else:
|
||||||
|
endpoint = f"/api/{method_name}"
|
||||||
|
|
||||||
|
url = base_url.rstrip("/") + endpoint
|
||||||
|
|
||||||
|
# Normalize request for matching
|
||||||
|
method = "POST"
|
||||||
|
headers = {}
|
||||||
|
body = kwargs
|
||||||
|
|
||||||
|
request_hash = normalize_request(method, url, headers, body)
|
||||||
|
|
||||||
|
if _current_mode == "replay":
|
||||||
|
# Try to find recorded response
|
||||||
|
recording = _current_storage.find_recording(request_hash)
|
||||||
|
if recording:
|
||||||
|
# Return recorded response
|
||||||
|
response_body = recording["response"]["body"]
|
||||||
|
|
||||||
|
# Handle streaming responses for Ollama
|
||||||
|
if recording["response"].get("is_streaming", False):
|
||||||
|
# For streaming, we need to return an async iterator
|
||||||
|
async def replay_ollama_stream():
|
||||||
|
for chunk in response_body:
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
return replay_ollama_stream()
|
||||||
|
else:
|
||||||
|
return response_body
|
||||||
|
else:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"No recorded response found for request hash: {request_hash}\n"
|
||||||
|
f"Endpoint: {endpoint}\n"
|
||||||
|
f"Model: {body.get('model', 'unknown')}\n"
|
||||||
|
f"To record this response, run with LLAMA_STACK_INFERENCE_MODE=record"
|
||||||
|
)
|
||||||
|
|
||||||
|
elif _current_mode == "record":
|
||||||
|
# Make real request and record it
|
||||||
|
response = await original_method(self, **kwargs)
|
||||||
|
|
||||||
|
# Store the recording
|
||||||
|
request_data = {
|
||||||
|
"method": method,
|
||||||
|
"url": url,
|
||||||
|
"headers": headers,
|
||||||
|
"body": body,
|
||||||
|
"endpoint": endpoint,
|
||||||
|
"model": body.get("model", ""),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Determine if this is a streaming request based on request parameters
|
||||||
|
is_streaming = body.get("stream", False)
|
||||||
|
|
||||||
|
if is_streaming:
|
||||||
|
# For streaming responses, we need to collect all chunks immediately before yielding
|
||||||
|
# This ensures the recording is saved even if the generator isn't fully consumed
|
||||||
|
chunks = []
|
||||||
|
async for chunk in response:
|
||||||
|
chunks.append(chunk)
|
||||||
|
|
||||||
|
# Store the recording immediately
|
||||||
|
response_data = {"body": chunks, "is_streaming": True}
|
||||||
|
_current_storage.store_recording(request_hash, request_data, response_data)
|
||||||
|
|
||||||
|
# Return a generator that replays the stored chunks
|
||||||
|
async def replay_recorded_stream():
|
||||||
|
for chunk in chunks:
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
return replay_recorded_stream()
|
||||||
|
else:
|
||||||
|
response_data = {"body": response, "is_streaming": False}
|
||||||
|
_current_storage.store_recording(request_hash, request_data, response_data)
|
||||||
|
return response
|
||||||
|
|
||||||
|
else:
|
||||||
|
return await original_method(self, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def patch_inference_clients():
|
||||||
|
"""Install monkey patches for OpenAI client methods and Ollama AsyncClient methods."""
|
||||||
|
global _original_methods
|
||||||
|
|
||||||
|
# Import here to avoid circular imports
|
||||||
|
from openai import AsyncOpenAI
|
||||||
|
from openai.resources.chat.completions import AsyncCompletions as AsyncChatCompletions
|
||||||
|
from openai.resources.completions import AsyncCompletions
|
||||||
|
from openai.resources.embeddings import AsyncEmbeddings
|
||||||
|
|
||||||
|
# Also import Ollama AsyncClient
|
||||||
|
try:
|
||||||
|
from ollama import AsyncClient as OllamaAsyncClient
|
||||||
|
except ImportError:
|
||||||
|
ollama_async_client = None
|
||||||
|
else:
|
||||||
|
ollama_async_client = OllamaAsyncClient
|
||||||
|
|
||||||
|
# Store original methods for both OpenAI and Ollama clients
|
||||||
|
_original_methods = {
|
||||||
|
"chat_completions_create": AsyncChatCompletions.create,
|
||||||
|
"completions_create": AsyncCompletions.create,
|
||||||
|
"embeddings_create": AsyncEmbeddings.create,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add Ollama client methods if available
|
||||||
|
if ollama_async_client:
|
||||||
|
_original_methods.update(
|
||||||
|
{
|
||||||
|
"ollama_generate": ollama_async_client.generate,
|
||||||
|
"ollama_chat": ollama_async_client.chat,
|
||||||
|
"ollama_embed": ollama_async_client.embed,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create patched methods for OpenAI client
|
||||||
|
async def patched_chat_completions_create(self, **kwargs):
|
||||||
|
return await _patched_create_method(_original_methods["chat_completions_create"], self, **kwargs)
|
||||||
|
|
||||||
|
async def patched_completions_create(self, **kwargs):
|
||||||
|
return await _patched_create_method(_original_methods["completions_create"], self, **kwargs)
|
||||||
|
|
||||||
|
async def patched_embeddings_create(self, **kwargs):
|
||||||
|
return await _patched_create_method(_original_methods["embeddings_create"], self, **kwargs)
|
||||||
|
|
||||||
|
# Apply OpenAI patches
|
||||||
|
AsyncChatCompletions.create = patched_chat_completions_create
|
||||||
|
AsyncCompletions.create = patched_completions_create
|
||||||
|
AsyncEmbeddings.create = patched_embeddings_create
|
||||||
|
|
||||||
|
# Create patched methods for Ollama client
|
||||||
|
if ollama_async_client:
|
||||||
|
|
||||||
|
async def patched_ollama_generate(self, **kwargs):
|
||||||
|
return await _patched_ollama_method(_original_methods["ollama_generate"], self, "generate", **kwargs)
|
||||||
|
|
||||||
|
async def patched_ollama_chat(self, **kwargs):
|
||||||
|
return await _patched_ollama_method(_original_methods["ollama_chat"], self, "chat", **kwargs)
|
||||||
|
|
||||||
|
async def patched_ollama_embed(self, **kwargs):
|
||||||
|
return await _patched_ollama_method(_original_methods["ollama_embed"], self, "embed", **kwargs)
|
||||||
|
|
||||||
|
# Apply Ollama patches
|
||||||
|
ollama_async_client.generate = patched_ollama_generate
|
||||||
|
ollama_async_client.chat = patched_ollama_chat
|
||||||
|
ollama_async_client.embed = patched_ollama_embed
|
||||||
|
|
||||||
|
# Also try to patch the AsyncOpenAI __init__ to trace client creation
|
||||||
|
original_openai_init = AsyncOpenAI.__init__
|
||||||
|
|
||||||
|
def patched_openai_init(self, *args, **kwargs):
|
||||||
|
result = original_openai_init(self, *args, **kwargs)
|
||||||
|
|
||||||
|
# After client is created, try to re-patch its methods
|
||||||
|
if hasattr(self, "chat") and hasattr(self.chat, "completions"):
|
||||||
|
original_chat_create = self.chat.completions.create
|
||||||
|
|
||||||
|
async def instance_patched_chat_create(**kwargs):
|
||||||
|
return await _patched_create_method(original_chat_create, self.chat.completions, **kwargs)
|
||||||
|
|
||||||
|
self.chat.completions.create = instance_patched_chat_create
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
AsyncOpenAI.__init__ = patched_openai_init
|
||||||
|
|
||||||
|
|
||||||
|
def unpatch_inference_clients():
|
||||||
|
"""Remove monkey patches and restore original OpenAI and Ollama client methods."""
|
||||||
|
global _original_methods
|
||||||
|
|
||||||
|
if not _original_methods:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Import here to avoid circular imports
|
||||||
|
from openai.resources.chat.completions import AsyncCompletions as AsyncChatCompletions
|
||||||
|
from openai.resources.completions import AsyncCompletions
|
||||||
|
from openai.resources.embeddings import AsyncEmbeddings
|
||||||
|
|
||||||
|
# Restore OpenAI client methods
|
||||||
|
if "chat_completions_create" in _original_methods:
|
||||||
|
AsyncChatCompletions.create = _original_methods["chat_completions_create"]
|
||||||
|
|
||||||
|
if "completions_create" in _original_methods:
|
||||||
|
AsyncCompletions.create = _original_methods["completions_create"]
|
||||||
|
|
||||||
|
if "embeddings_create" in _original_methods:
|
||||||
|
AsyncEmbeddings.create = _original_methods["embeddings_create"]
|
||||||
|
|
||||||
|
# Restore Ollama client methods if they were patched
|
||||||
|
try:
|
||||||
|
from ollama import AsyncClient as OllamaAsyncClient
|
||||||
|
|
||||||
|
if "ollama_generate" in _original_methods:
|
||||||
|
OllamaAsyncClient.generate = _original_methods["ollama_generate"]
|
||||||
|
|
||||||
|
if "ollama_chat" in _original_methods:
|
||||||
|
OllamaAsyncClient.chat = _original_methods["ollama_chat"]
|
||||||
|
|
||||||
|
if "ollama_embed" in _original_methods:
|
||||||
|
OllamaAsyncClient.embed = _original_methods["ollama_embed"]
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
_original_methods.clear()
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def inference_recording(
|
||||||
|
mode: str = "live", test_id: str | None = None, storage_dir: str | Path | None = None
|
||||||
|
) -> Generator[None, None, None]:
|
||||||
|
"""Context manager for inference recording/replaying."""
|
||||||
|
global _current_mode, _current_storage
|
||||||
|
|
||||||
|
# Set defaults
|
||||||
|
if storage_dir is None:
|
||||||
|
storage_dir_path = Path.home() / ".llama" / "recordings"
|
||||||
|
else:
|
||||||
|
storage_dir_path = Path(storage_dir)
|
||||||
|
|
||||||
|
if test_id is None:
|
||||||
|
test_id = f"test_{uuid.uuid4().hex[:8]}"
|
||||||
|
|
||||||
|
# Store previous state
|
||||||
|
prev_mode = _current_mode
|
||||||
|
prev_storage = _current_storage
|
||||||
|
|
||||||
|
try:
|
||||||
|
_current_mode = mode
|
||||||
|
|
||||||
|
if mode in ["record", "replay"]:
|
||||||
|
_current_storage = ResponseStorage(storage_dir_path, test_id)
|
||||||
|
patch_inference_clients()
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Restore previous state
|
||||||
|
if mode in ["record", "replay"]:
|
||||||
|
unpatch_inference_clients()
|
||||||
|
|
||||||
|
_current_mode = prev_mode
|
||||||
|
_current_storage = prev_storage
|
||||||
|
|
@ -185,70 +185,95 @@ def llama_stack_client(request, provider_data):
|
||||||
if not config:
|
if not config:
|
||||||
raise ValueError("You must specify either --stack-config or LLAMA_STACK_CONFIG")
|
raise ValueError("You must specify either --stack-config or LLAMA_STACK_CONFIG")
|
||||||
|
|
||||||
# Handle server:<config_name> format or server:<config_name>:<port>
|
# Set up inference recording if enabled
|
||||||
if config.startswith("server:"):
|
inference_mode = os.environ.get("LLAMA_STACK_INFERENCE_MODE", "live").lower()
|
||||||
parts = config.split(":")
|
recording_context = None
|
||||||
config_name = parts[1]
|
|
||||||
port = int(parts[2]) if len(parts) > 2 else int(os.environ.get("LLAMA_STACK_PORT", DEFAULT_PORT))
|
|
||||||
base_url = f"http://localhost:{port}"
|
|
||||||
|
|
||||||
# Check if port is available
|
if inference_mode in ["record", "replay"]:
|
||||||
if is_port_available(port):
|
from llama_stack.testing.inference_recorder import setup_inference_recording
|
||||||
print(f"Starting llama stack server with config '{config_name}' on port {port}...")
|
|
||||||
|
|
||||||
# Start server
|
recording_context = setup_inference_recording()
|
||||||
server_process = start_llama_stack_server(config_name)
|
recording_context.__enter__()
|
||||||
|
print(f"Inference recording enabled: mode={inference_mode}")
|
||||||
|
|
||||||
# Wait for server to be ready
|
|
||||||
if not wait_for_server_ready(base_url, timeout=120, process=server_process):
|
|
||||||
print("Server failed to start within timeout")
|
|
||||||
server_process.terminate()
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Server failed to start within timeout. Check that config '{config_name}' exists and is valid. "
|
|
||||||
f"See server.log for details."
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"Server is ready at {base_url}")
|
|
||||||
|
|
||||||
# Store process for potential cleanup (pytest will handle termination at session end)
|
|
||||||
request.session._llama_stack_server_process = server_process
|
|
||||||
else:
|
|
||||||
print(f"Port {port} is already in use, assuming server is already running...")
|
|
||||||
|
|
||||||
return LlamaStackClient(
|
|
||||||
base_url=base_url,
|
|
||||||
provider_data=provider_data,
|
|
||||||
timeout=int(os.environ.get("LLAMA_STACK_CLIENT_TIMEOUT", "30")),
|
|
||||||
)
|
|
||||||
|
|
||||||
# check if this looks like a URL using proper URL parsing
|
|
||||||
try:
|
try:
|
||||||
parsed_url = urlparse(config)
|
# Handle server:<config_name> format or server:<config_name>:<port>
|
||||||
if parsed_url.scheme and parsed_url.netloc:
|
if config.startswith("server:"):
|
||||||
return LlamaStackClient(
|
parts = config.split(":")
|
||||||
base_url=config,
|
config_name = parts[1]
|
||||||
|
port = int(parts[2]) if len(parts) > 2 else int(os.environ.get("LLAMA_STACK_PORT", DEFAULT_PORT))
|
||||||
|
base_url = f"http://localhost:{port}"
|
||||||
|
|
||||||
|
# Check if port is available
|
||||||
|
if is_port_available(port):
|
||||||
|
print(f"Starting llama stack server with config '{config_name}' on port {port}...")
|
||||||
|
|
||||||
|
# Start server
|
||||||
|
server_process = start_llama_stack_server(config_name)
|
||||||
|
|
||||||
|
# Wait for server to be ready
|
||||||
|
if not wait_for_server_ready(base_url, timeout=120, process=server_process):
|
||||||
|
print("Server failed to start within timeout")
|
||||||
|
server_process.terminate()
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Server failed to start within timeout. Check that config '{config_name}' exists and is valid. "
|
||||||
|
f"See server.log for details."
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Server is ready at {base_url}")
|
||||||
|
|
||||||
|
# Store process for potential cleanup (pytest will handle termination at session end)
|
||||||
|
request.session._llama_stack_server_process = server_process
|
||||||
|
else:
|
||||||
|
print(f"Port {port} is already in use, assuming server is already running...")
|
||||||
|
|
||||||
|
client = LlamaStackClient(
|
||||||
|
base_url=base_url,
|
||||||
provider_data=provider_data,
|
provider_data=provider_data,
|
||||||
|
timeout=int(os.environ.get("LLAMA_STACK_CLIENT_TIMEOUT", "30")),
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
# check if this looks like a URL using proper URL parsing
|
||||||
|
try:
|
||||||
|
parsed_url = urlparse(config)
|
||||||
|
if parsed_url.scheme and parsed_url.netloc:
|
||||||
|
client = LlamaStackClient(
|
||||||
|
base_url=config,
|
||||||
|
provider_data=provider_data,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError("Not a URL")
|
||||||
|
except Exception as e:
|
||||||
|
# If URL parsing fails, treat as library config
|
||||||
|
if "=" in config:
|
||||||
|
run_config = run_config_from_adhoc_config_spec(config)
|
||||||
|
run_config_file = tempfile.NamedTemporaryFile(delete=False, suffix=".yaml")
|
||||||
|
with open(run_config_file.name, "w") as f:
|
||||||
|
yaml.dump(run_config.model_dump(), f)
|
||||||
|
config = run_config_file.name
|
||||||
|
|
||||||
|
client = LlamaStackAsLibraryClient(
|
||||||
|
config,
|
||||||
|
provider_data=provider_data,
|
||||||
|
skip_logger_removal=True,
|
||||||
|
)
|
||||||
|
if not client.initialize():
|
||||||
|
raise RuntimeError("Initialization failed") from e
|
||||||
|
|
||||||
|
# Store recording context for cleanup
|
||||||
|
if recording_context:
|
||||||
|
request.session._inference_recording_context = recording_context
|
||||||
|
|
||||||
|
return client
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
# If URL parsing fails, treat as non-URL config
|
# Clean up recording context on error
|
||||||
pass
|
if recording_context:
|
||||||
|
try:
|
||||||
if "=" in config:
|
recording_context.__exit__(None, None, None)
|
||||||
run_config = run_config_from_adhoc_config_spec(config)
|
except Exception as cleanup_error:
|
||||||
run_config_file = tempfile.NamedTemporaryFile(delete=False, suffix=".yaml")
|
print(f"Warning: Error cleaning up recording context: {cleanup_error}")
|
||||||
with open(run_config_file.name, "w") as f:
|
raise
|
||||||
yaml.dump(run_config.model_dump(), f)
|
|
||||||
config = run_config_file.name
|
|
||||||
|
|
||||||
client = LlamaStackAsLibraryClient(
|
|
||||||
config,
|
|
||||||
provider_data=provider_data,
|
|
||||||
skip_logger_removal=True,
|
|
||||||
)
|
|
||||||
if not client.initialize():
|
|
||||||
raise RuntimeError("Initialization failed")
|
|
||||||
|
|
||||||
return client
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
|
|
@ -264,9 +289,20 @@ def compat_client(request):
|
||||||
|
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
def cleanup_server_process(request):
|
def cleanup_server_process(request):
|
||||||
"""Cleanup server process at the end of the test session."""
|
"""Cleanup server process and inference recording at the end of the test session."""
|
||||||
yield # Run tests
|
yield # Run tests
|
||||||
|
|
||||||
|
# Clean up inference recording context
|
||||||
|
if hasattr(request.session, "_inference_recording_context"):
|
||||||
|
recording_context = request.session._inference_recording_context
|
||||||
|
if recording_context:
|
||||||
|
try:
|
||||||
|
print("Cleaning up inference recording context...")
|
||||||
|
recording_context.__exit__(None, None, None)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error during inference recording cleanup: {e}")
|
||||||
|
|
||||||
|
# Clean up server process
|
||||||
if hasattr(request.session, "_llama_stack_server_process"):
|
if hasattr(request.session, "_llama_stack_server_process"):
|
||||||
server_process = request.session._llama_stack_server_process
|
server_process = request.session._llama_stack_server_process
|
||||||
if server_process:
|
if server_process:
|
||||||
|
|
|
||||||
289
tests/integration/test_inference_recordings.py
Normal file
289
tests/integration/test_inference_recordings.py
Normal file
|
|
@ -0,0 +1,289 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import sqlite3
|
||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from openai import AsyncOpenAI
|
||||||
|
|
||||||
|
from llama_stack.testing.inference_recorder import (
|
||||||
|
ResponseStorage,
|
||||||
|
inference_recording,
|
||||||
|
normalize_request,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def temp_storage_dir():
|
||||||
|
"""Create a temporary directory for test recordings."""
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
yield Path(temp_dir)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_openai_response():
|
||||||
|
"""Mock OpenAI response object."""
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.choices = [Mock()]
|
||||||
|
mock_response.choices[0].message.content = "Hello! I'm doing well, thank you for asking."
|
||||||
|
mock_response.model_dump.return_value = {
|
||||||
|
"id": "chatcmpl-test123",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"message": {"role": "assistant", "content": "Hello! I'm doing well, thank you for asking."},
|
||||||
|
"finish_reason": "stop",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"model": "llama3.2:3b",
|
||||||
|
"usage": {"prompt_tokens": 10, "completion_tokens": 15, "total_tokens": 25},
|
||||||
|
}
|
||||||
|
|
||||||
|
return mock_response
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_embeddings_response():
|
||||||
|
"""Mock OpenAI embeddings response object."""
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.data = [Mock(embedding=[0.1, 0.2, 0.3]), Mock(embedding=[0.4, 0.5, 0.6])]
|
||||||
|
mock_response.model_dump.return_value = {
|
||||||
|
"object": "list",
|
||||||
|
"data": [
|
||||||
|
{"object": "embedding", "embedding": [0.1, 0.2, 0.3], "index": 0},
|
||||||
|
{"object": "embedding", "embedding": [0.4, 0.5, 0.6], "index": 1},
|
||||||
|
],
|
||||||
|
"model": "nomic-embed-text",
|
||||||
|
"usage": {"prompt_tokens": 6, "total_tokens": 6},
|
||||||
|
}
|
||||||
|
|
||||||
|
return mock_response
|
||||||
|
|
||||||
|
|
||||||
|
class TestInferenceRecording:
|
||||||
|
"""Test the inference recording system."""
|
||||||
|
|
||||||
|
def test_request_normalization(self):
|
||||||
|
"""Test that request normalization produces consistent hashes."""
|
||||||
|
# Test basic normalization
|
||||||
|
hash1 = normalize_request(
|
||||||
|
"POST",
|
||||||
|
"http://localhost:11434/v1/chat/completions",
|
||||||
|
{},
|
||||||
|
{"model": "llama3.2:3b", "messages": [{"role": "user", "content": "Hello world"}], "temperature": 0.7},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Same request should produce same hash
|
||||||
|
hash2 = normalize_request(
|
||||||
|
"POST",
|
||||||
|
"http://localhost:11434/v1/chat/completions",
|
||||||
|
{},
|
||||||
|
{"model": "llama3.2:3b", "messages": [{"role": "user", "content": "Hello world"}], "temperature": 0.7},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert hash1 == hash2
|
||||||
|
|
||||||
|
# Different content should produce different hash
|
||||||
|
hash3 = normalize_request(
|
||||||
|
"POST",
|
||||||
|
"http://localhost:11434/v1/chat/completions",
|
||||||
|
{},
|
||||||
|
{
|
||||||
|
"model": "llama3.2:3b",
|
||||||
|
"messages": [{"role": "user", "content": "Different message"}],
|
||||||
|
"temperature": 0.7,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert hash1 != hash3
|
||||||
|
|
||||||
|
def test_request_normalization_edge_cases(self):
|
||||||
|
"""Test request normalization handles edge cases correctly."""
|
||||||
|
# Test whitespace normalization
|
||||||
|
hash1 = normalize_request(
|
||||||
|
"POST",
|
||||||
|
"http://test/v1/chat/completions",
|
||||||
|
{},
|
||||||
|
{"messages": [{"role": "user", "content": "Hello world\n\n"}]},
|
||||||
|
)
|
||||||
|
hash2 = normalize_request(
|
||||||
|
"POST", "http://test/v1/chat/completions", {}, {"messages": [{"role": "user", "content": "Hello world"}]}
|
||||||
|
)
|
||||||
|
assert hash1 == hash2
|
||||||
|
|
||||||
|
# Test float precision normalization
|
||||||
|
hash3 = normalize_request("POST", "http://test/v1/chat/completions", {}, {"temperature": 0.7000001})
|
||||||
|
hash4 = normalize_request("POST", "http://test/v1/chat/completions", {}, {"temperature": 0.7})
|
||||||
|
assert hash3 == hash4
|
||||||
|
|
||||||
|
def test_response_storage(self, temp_storage_dir):
|
||||||
|
"""Test the ResponseStorage class."""
|
||||||
|
storage = ResponseStorage(temp_storage_dir, "test_storage")
|
||||||
|
|
||||||
|
# Test directory creation
|
||||||
|
assert storage.test_dir.exists()
|
||||||
|
assert storage.responses_dir.exists()
|
||||||
|
assert storage.db_path.exists()
|
||||||
|
|
||||||
|
# Test storing and retrieving a recording
|
||||||
|
request_hash = "test_hash_123"
|
||||||
|
request_data = {
|
||||||
|
"method": "POST",
|
||||||
|
"url": "http://localhost:11434/v1/chat/completions",
|
||||||
|
"endpoint": "/v1/chat/completions",
|
||||||
|
"model": "llama3.2:3b",
|
||||||
|
}
|
||||||
|
response_data = {"body": {"content": "test response"}, "is_streaming": False}
|
||||||
|
|
||||||
|
storage.store_recording(request_hash, request_data, response_data)
|
||||||
|
|
||||||
|
# Verify SQLite record
|
||||||
|
with sqlite3.connect(storage.db_path) as conn:
|
||||||
|
result = conn.execute("SELECT * FROM recordings WHERE request_hash = ?", (request_hash,)).fetchone()
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result[0] == request_hash # request_hash
|
||||||
|
assert result[2] == "/v1/chat/completions" # endpoint
|
||||||
|
assert result[3] == "llama3.2:3b" # model
|
||||||
|
|
||||||
|
# Verify file storage and retrieval
|
||||||
|
retrieved = storage.find_recording(request_hash)
|
||||||
|
assert retrieved is not None
|
||||||
|
assert retrieved["request"]["model"] == "llama3.2:3b"
|
||||||
|
assert retrieved["response"]["body"]["content"] == "test response"
|
||||||
|
|
||||||
|
async def test_recording_mode(self, temp_storage_dir, mock_openai_response):
|
||||||
|
"""Test that recording mode captures and stores responses."""
|
||||||
|
test_id = "test_recording_mode"
|
||||||
|
|
||||||
|
async def mock_create(*args, **kwargs):
|
||||||
|
return mock_openai_response
|
||||||
|
|
||||||
|
with patch("openai.resources.chat.completions.AsyncCompletions.create", side_effect=mock_create):
|
||||||
|
with inference_recording(mode="record", test_id=test_id, storage_dir=str(temp_storage_dir)):
|
||||||
|
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||||
|
|
||||||
|
response = await client.chat.completions.create(
|
||||||
|
model="llama3.2:3b",
|
||||||
|
messages=[{"role": "user", "content": "Hello, how are you?"}],
|
||||||
|
temperature=0.7,
|
||||||
|
max_tokens=50,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the response was returned correctly
|
||||||
|
assert response.choices[0].message.content == "Hello! I'm doing well, thank you for asking."
|
||||||
|
|
||||||
|
# Verify recording was stored
|
||||||
|
storage = ResponseStorage(temp_storage_dir, test_id)
|
||||||
|
with sqlite3.connect(storage.db_path) as conn:
|
||||||
|
recordings = conn.execute("SELECT COUNT(*) FROM recordings").fetchone()[0]
|
||||||
|
|
||||||
|
assert recordings == 1
|
||||||
|
|
||||||
|
async def test_replay_mode(self, temp_storage_dir, mock_openai_response):
|
||||||
|
"""Test that replay mode returns stored responses without making real calls."""
|
||||||
|
test_id = "test_replay_mode"
|
||||||
|
|
||||||
|
async def mock_create(*args, **kwargs):
|
||||||
|
return mock_openai_response
|
||||||
|
|
||||||
|
# First, record a response
|
||||||
|
with patch("openai.resources.chat.completions.AsyncCompletions.create", side_effect=mock_create):
|
||||||
|
with inference_recording(mode="record", test_id=test_id, storage_dir=str(temp_storage_dir)):
|
||||||
|
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||||
|
|
||||||
|
await client.chat.completions.create(
|
||||||
|
model="llama3.2:3b",
|
||||||
|
messages=[{"role": "user", "content": "Hello, how are you?"}],
|
||||||
|
temperature=0.7,
|
||||||
|
max_tokens=50,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Now test replay mode - should not call the original method
|
||||||
|
with patch("openai.resources.chat.completions.AsyncCompletions.create") as mock_create_patch:
|
||||||
|
with inference_recording(mode="replay", test_id=test_id, storage_dir=str(temp_storage_dir)):
|
||||||
|
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||||
|
|
||||||
|
response = await client.chat.completions.create(
|
||||||
|
model="llama3.2:3b",
|
||||||
|
messages=[{"role": "user", "content": "Hello, how are you?"}],
|
||||||
|
temperature=0.7,
|
||||||
|
max_tokens=50,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify we got the recorded response
|
||||||
|
assert response["choices"][0]["message"]["content"] == "Hello! I'm doing well, thank you for asking."
|
||||||
|
|
||||||
|
# Verify the original method was NOT called
|
||||||
|
mock_create_patch.assert_not_called()
|
||||||
|
|
||||||
|
async def test_replay_missing_recording(self, temp_storage_dir):
|
||||||
|
"""Test that replay mode fails when no recording is found."""
|
||||||
|
test_id = "test_missing_recording"
|
||||||
|
|
||||||
|
with patch("openai.resources.chat.completions.AsyncCompletions.create"):
|
||||||
|
with inference_recording(mode="replay", test_id=test_id, storage_dir=str(temp_storage_dir)):
|
||||||
|
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError, match="No recorded response found"):
|
||||||
|
await client.chat.completions.create(
|
||||||
|
model="llama3.2:3b", messages=[{"role": "user", "content": "This was never recorded"}]
|
||||||
|
)
|
||||||
|
|
||||||
|
async def test_embeddings_recording(self, temp_storage_dir, mock_embeddings_response):
|
||||||
|
"""Test recording and replay of embeddings calls."""
|
||||||
|
test_id = "test_embeddings"
|
||||||
|
|
||||||
|
async def mock_create(*args, **kwargs):
|
||||||
|
return mock_embeddings_response
|
||||||
|
|
||||||
|
# Record
|
||||||
|
with patch("openai.resources.embeddings.AsyncEmbeddings.create", side_effect=mock_create):
|
||||||
|
with inference_recording(mode="record", test_id=test_id, storage_dir=str(temp_storage_dir)):
|
||||||
|
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||||
|
|
||||||
|
response = await client.embeddings.create(
|
||||||
|
model="nomic-embed-text", input=["Hello world", "Test embedding"]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(response.data) == 2
|
||||||
|
|
||||||
|
# Replay
|
||||||
|
with patch("openai.resources.embeddings.AsyncEmbeddings.create") as mock_create_patch:
|
||||||
|
with inference_recording(mode="replay", test_id=test_id, storage_dir=str(temp_storage_dir)):
|
||||||
|
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||||
|
|
||||||
|
response = await client.embeddings.create(
|
||||||
|
model="nomic-embed-text", input=["Hello world", "Test embedding"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify we got the recorded response
|
||||||
|
assert len(response["data"]) == 2
|
||||||
|
assert response["data"][0]["embedding"] == [0.1, 0.2, 0.3]
|
||||||
|
|
||||||
|
# Verify original method was not called
|
||||||
|
mock_create_patch.assert_not_called()
|
||||||
|
|
||||||
|
async def test_live_mode(self, mock_openai_response):
|
||||||
|
"""Test that live mode passes through to original methods."""
|
||||||
|
|
||||||
|
async def mock_create(*args, **kwargs):
|
||||||
|
return mock_openai_response
|
||||||
|
|
||||||
|
with patch("openai.resources.chat.completions.AsyncCompletions.create", side_effect=mock_create):
|
||||||
|
with inference_recording(mode="live"):
|
||||||
|
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||||
|
|
||||||
|
response = await client.chat.completions.create(
|
||||||
|
model="llama3.2:3b", messages=[{"role": "user", "content": "Hello"}]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the response was returned
|
||||||
|
assert response.choices[0].message.content == "Hello! I'm doing well, thank you for asking."
|
||||||
Loading…
Add table
Add a link
Reference in a new issue