test: recordable mocks use json only (#1443)

# Summary:
removes the use of pickle

# Test Plan:
Run the following with `--record-responses` first, then another time
without.

LLAMA_STACK_CONFIG=fireworks pytest -s -v
tests/integration/agents/test_agents.py --safety-shield
meta-llama/Llama-Guard-3-8B --text-model
meta-llama/Llama-3.1-8B-Instruct
This commit is contained in:
ehhuang 2025-03-06 14:46:29 -08:00 committed by GitHub
parent 564977c646
commit 3d71e5a036
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 23792 additions and 26435 deletions

View file

@ -59,7 +59,7 @@ def llama_stack_client_with_mocked_inference(llama_stack_client, request):
return llama_stack_client
record_responses = request.config.getoption("--record-responses")
cache_dir = Path(__file__).parent / "fixtures" / "recorded_responses"
cache_dir = Path(__file__).parent / "recorded_responses"
# Create a shallow copy of the client to avoid modifying the original
client = copy.copy(llama_stack_client)

View file

@ -3,10 +3,12 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import importlib
import json
import os
import pickle
import re
from datetime import datetime
from enum import Enum
from pathlib import Path
@ -15,18 +17,18 @@ class RecordableMock:
def __init__(self, real_func, cache_dir, func_name, record=False):
self.real_func = real_func
self.pickle_path = Path(cache_dir) / f"{func_name}.pickle"
self.json_path = Path(cache_dir) / f"{func_name}.json"
self.record = record
self.cache = {}
# Load existing cache if available and not recording
if self.pickle_path.exists():
if self.json_path.exists():
try:
with open(self.pickle_path, "rb") as f:
self.cache = pickle.load(f)
with open(self.json_path, "r") as f:
self.cache = json.load(f)
except Exception as e:
print(f"Error loading cache from {self.pickle_path}: {e}")
print(f"Error loading cache from {self.json_path}: {e}")
raise
async def __call__(self, *args, **kwargs):
"""
@ -98,23 +100,19 @@ class RecordableMock:
# Check if it's a value or chunks
if cached_data.get("type") == "value":
# It's a regular value
return cached_data["value"]
return self._reconstruct_object(cached_data["value"])
else:
# It's chunks from an async generator
async def replay_generator():
for chunk in cached_data["chunks"]:
yield chunk
yield self._reconstruct_object(chunk)
return replay_generator()
def _create_cache_key(self, args, kwargs):
"""Create a hashable key from the function arguments, ignoring auto-generated IDs."""
# Convert args and kwargs to a string representation directly
args_str = str(args)
kwargs_str = str(sorted([(k, kwargs[k]) for k in kwargs]))
# Combine into a single key
key = f"{args_str}_{kwargs_str}"
# Convert to JSON strings with sorted keys
key = json.dumps((args, kwargs), sort_keys=True, default=self._json_default)
# Post-process the key with regex to replace IDs with placeholders
# Replace UUIDs and similar patterns
@ -126,83 +124,95 @@ class RecordableMock:
return key
def _save_cache(self):
"""Save the cache to disk in both pickle and JSON formats."""
os.makedirs(self.pickle_path.parent, exist_ok=True)
"""Save the cache to disk in JSON format."""
os.makedirs(self.json_path.parent, exist_ok=True)
# Save as pickle for exact object preservation
with open(self.pickle_path, "wb") as f:
pickle.dump(self.cache, f)
# Also save as JSON for human readability and diffing
# Write the JSON file with pretty formatting
try:
# Create a simplified version of the cache for JSON
json_cache = {}
for key, value in self.cache.items():
if value.get("type") == "generator":
# For generators, create a simplified representation of each chunk
chunks = []
for chunk in value["chunks"]:
chunk_dict = self._object_to_json_safe_dict(chunk)
chunks.append(chunk_dict)
json_cache[key] = {"type": "generator", "chunks": chunks}
else:
# For values, create a simplified representation
val = value["value"]
val_dict = self._object_to_json_safe_dict(val)
json_cache[key] = {"type": "value", "value": val_dict}
# Write the JSON file with pretty formatting
with open(self.json_path, "w") as f:
json.dump(json_cache, f, indent=2, sort_keys=True)
json.dump(self.cache, f, indent=2, sort_keys=True, default=self._json_default)
# write another empty line at the end of the file to make pre-commit happy
f.write("\n")
except Exception as e:
print(f"Error saving JSON cache: {e}")
def _object_to_json_safe_dict(self, obj):
"""Convert an object to a JSON-safe dictionary."""
# Handle enum types
if hasattr(obj, "value") and hasattr(obj.__class__, "__members__"):
return {"__enum__": obj.__class__.__name__, "value": obj.value}
def _json_default(self, obj):
"""Default function for JSON serialization of objects."""
if isinstance(obj, datetime):
return {
"__datetime__": obj.isoformat(),
"__module__": obj.__class__.__module__,
"__class__": obj.__class__.__name__,
}
if isinstance(obj, Enum):
return {
"__enum__": obj.__class__.__name__,
"value": obj.value,
"__module__": obj.__class__.__module__,
}
# Handle Pydantic models
if hasattr(obj, "model_dump"):
return self._process_dict(obj.model_dump())
elif hasattr(obj, "dict"):
return self._process_dict(obj.dict())
model_data = obj.model_dump()
return {
"__pydantic__": obj.__class__.__name__,
"__module__": obj.__class__.__module__,
"data": model_data,
}
# Handle regular objects with __dict__
try:
return self._process_dict(vars(obj))
except Exception as e:
print(f"Error converting object to JSON-safe dict: {e}")
# If we can't get a dict, convert to string
return str(obj)
def _reconstruct_object(self, data):
"""Reconstruct an object from its JSON representation."""
if isinstance(data, dict):
# Check if this is a serialized datetime
if "__datetime__" in data:
try:
module_name = data.get("__module__", "datetime")
class_name = data.get("__class__", "datetime")
def _process_dict(self, d):
"""Process a dictionary to make all values JSON-safe."""
if not isinstance(d, dict):
return d
# Try to import the specific datetime class
module = importlib.import_module(module_name)
dt_class = getattr(module, class_name)
result = {}
for k, v in d.items():
if isinstance(v, dict):
result[k] = self._process_dict(v)
elif isinstance(v, list):
result[k] = [
self._process_dict(item)
if isinstance(item, dict)
else self._object_to_json_safe_dict(item)
if hasattr(item, "__dict__")
else item
for item in v
]
elif hasattr(v, "value") and hasattr(v.__class__, "__members__"):
# Handle enum
result[k] = {"__enum__": v.__class__.__name__, "value": v.value}
elif hasattr(v, "__dict__"):
# Handle nested objects
result[k] = self._object_to_json_safe_dict(v)
else:
# Basic types
result[k] = v
# Parse the ISO format string
dt = dt_class.fromisoformat(data["__datetime__"])
return dt
except (ImportError, AttributeError, ValueError) as e:
print(f"Error reconstructing datetime: {e}")
return data
return result
# Check if this is a serialized enum
elif "__enum__" in data:
try:
module_name = data.get("__module__", "builtins")
enum_class = self._import_class(module_name, data["__enum__"])
return enum_class(data["value"])
except (ImportError, AttributeError) as e:
print(f"Error reconstructing enum: {e}")
return data
# Check if this is a serialized Pydantic model
elif "__pydantic__" in data:
try:
module_name = data.get("__module__", "builtins")
model_class = self._import_class(module_name, data["__pydantic__"])
return model_class(**self._reconstruct_object(data["data"]))
except (ImportError, AttributeError) as e:
print(f"Error reconstructing Pydantic model: {e}")
return data
# Regular dictionary
return {k: self._reconstruct_object(v) for k, v in data.items()}
# Handle lists
elif isinstance(data, list):
return [self._reconstruct_object(item) for item in data]
# Return primitive types as is
return data
def _import_class(self, module_name, class_name):
"""Import a class from a module."""
module = __import__(module_name, fromlist=[class_name])
return getattr(module, class_name)

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long