test: recordable mocks use json only (#1443)

# Summary:
removes the use of pickle

# Test Plan:
Run the following with `--record-responses` first, then another time
without.

LLAMA_STACK_CONFIG=fireworks pytest -s -v
tests/integration/agents/test_agents.py --safety-shield
meta-llama/Llama-Guard-3-8B --text-model
meta-llama/Llama-3.1-8B-Instruct
This commit is contained in:
ehhuang 2025-03-06 14:46:29 -08:00 committed by GitHub
parent 564977c646
commit 3d71e5a036
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 23792 additions and 26435 deletions

View file

@ -59,7 +59,7 @@ def llama_stack_client_with_mocked_inference(llama_stack_client, request):
return llama_stack_client return llama_stack_client
record_responses = request.config.getoption("--record-responses") record_responses = request.config.getoption("--record-responses")
cache_dir = Path(__file__).parent / "fixtures" / "recorded_responses" cache_dir = Path(__file__).parent / "recorded_responses"
# Create a shallow copy of the client to avoid modifying the original # Create a shallow copy of the client to avoid modifying the original
client = copy.copy(llama_stack_client) client = copy.copy(llama_stack_client)

View file

@ -3,10 +3,12 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import importlib
import json import json
import os import os
import pickle
import re import re
from datetime import datetime
from enum import Enum
from pathlib import Path from pathlib import Path
@ -15,18 +17,18 @@ class RecordableMock:
def __init__(self, real_func, cache_dir, func_name, record=False): def __init__(self, real_func, cache_dir, func_name, record=False):
self.real_func = real_func self.real_func = real_func
self.pickle_path = Path(cache_dir) / f"{func_name}.pickle"
self.json_path = Path(cache_dir) / f"{func_name}.json" self.json_path = Path(cache_dir) / f"{func_name}.json"
self.record = record self.record = record
self.cache = {} self.cache = {}
# Load existing cache if available and not recording # Load existing cache if available and not recording
if self.pickle_path.exists(): if self.json_path.exists():
try: try:
with open(self.pickle_path, "rb") as f: with open(self.json_path, "r") as f:
self.cache = pickle.load(f) self.cache = json.load(f)
except Exception as e: except Exception as e:
print(f"Error loading cache from {self.pickle_path}: {e}") print(f"Error loading cache from {self.json_path}: {e}")
raise
async def __call__(self, *args, **kwargs): async def __call__(self, *args, **kwargs):
""" """
@ -98,23 +100,19 @@ class RecordableMock:
# Check if it's a value or chunks # Check if it's a value or chunks
if cached_data.get("type") == "value": if cached_data.get("type") == "value":
# It's a regular value # It's a regular value
return cached_data["value"] return self._reconstruct_object(cached_data["value"])
else: else:
# It's chunks from an async generator # It's chunks from an async generator
async def replay_generator(): async def replay_generator():
for chunk in cached_data["chunks"]: for chunk in cached_data["chunks"]:
yield chunk yield self._reconstruct_object(chunk)
return replay_generator() return replay_generator()
def _create_cache_key(self, args, kwargs): def _create_cache_key(self, args, kwargs):
"""Create a hashable key from the function arguments, ignoring auto-generated IDs.""" """Create a hashable key from the function arguments, ignoring auto-generated IDs."""
# Convert args and kwargs to a string representation directly # Convert to JSON strings with sorted keys
args_str = str(args) key = json.dumps((args, kwargs), sort_keys=True, default=self._json_default)
kwargs_str = str(sorted([(k, kwargs[k]) for k in kwargs]))
# Combine into a single key
key = f"{args_str}_{kwargs_str}"
# Post-process the key with regex to replace IDs with placeholders # Post-process the key with regex to replace IDs with placeholders
# Replace UUIDs and similar patterns # Replace UUIDs and similar patterns
@ -126,83 +124,95 @@ class RecordableMock:
return key return key
def _save_cache(self): def _save_cache(self):
"""Save the cache to disk in both pickle and JSON formats.""" """Save the cache to disk in JSON format."""
os.makedirs(self.pickle_path.parent, exist_ok=True) os.makedirs(self.json_path.parent, exist_ok=True)
# Save as pickle for exact object preservation
with open(self.pickle_path, "wb") as f:
pickle.dump(self.cache, f)
# Also save as JSON for human readability and diffing
try:
# Create a simplified version of the cache for JSON
json_cache = {}
for key, value in self.cache.items():
if value.get("type") == "generator":
# For generators, create a simplified representation of each chunk
chunks = []
for chunk in value["chunks"]:
chunk_dict = self._object_to_json_safe_dict(chunk)
chunks.append(chunk_dict)
json_cache[key] = {"type": "generator", "chunks": chunks}
else:
# For values, create a simplified representation
val = value["value"]
val_dict = self._object_to_json_safe_dict(val)
json_cache[key] = {"type": "value", "value": val_dict}
# Write the JSON file with pretty formatting # Write the JSON file with pretty formatting
try:
with open(self.json_path, "w") as f: with open(self.json_path, "w") as f:
json.dump(json_cache, f, indent=2, sort_keys=True) json.dump(self.cache, f, indent=2, sort_keys=True, default=self._json_default)
# write another empty line at the end of the file to make pre-commit happy
f.write("\n")
except Exception as e: except Exception as e:
print(f"Error saving JSON cache: {e}") print(f"Error saving JSON cache: {e}")
def _object_to_json_safe_dict(self, obj): def _json_default(self, obj):
"""Convert an object to a JSON-safe dictionary.""" """Default function for JSON serialization of objects."""
# Handle enum types
if hasattr(obj, "value") and hasattr(obj.__class__, "__members__"): if isinstance(obj, datetime):
return {"__enum__": obj.__class__.__name__, "value": obj.value} return {
"__datetime__": obj.isoformat(),
"__module__": obj.__class__.__module__,
"__class__": obj.__class__.__name__,
}
if isinstance(obj, Enum):
return {
"__enum__": obj.__class__.__name__,
"value": obj.value,
"__module__": obj.__class__.__module__,
}
# Handle Pydantic models # Handle Pydantic models
if hasattr(obj, "model_dump"): if hasattr(obj, "model_dump"):
return self._process_dict(obj.model_dump()) model_data = obj.model_dump()
elif hasattr(obj, "dict"): return {
return self._process_dict(obj.dict()) "__pydantic__": obj.__class__.__name__,
"__module__": obj.__class__.__module__,
"data": model_data,
}
# Handle regular objects with __dict__ def _reconstruct_object(self, data):
"""Reconstruct an object from its JSON representation."""
if isinstance(data, dict):
# Check if this is a serialized datetime
if "__datetime__" in data:
try: try:
return self._process_dict(vars(obj)) module_name = data.get("__module__", "datetime")
except Exception as e: class_name = data.get("__class__", "datetime")
print(f"Error converting object to JSON-safe dict: {e}")
# If we can't get a dict, convert to string
return str(obj)
def _process_dict(self, d): # Try to import the specific datetime class
"""Process a dictionary to make all values JSON-safe.""" module = importlib.import_module(module_name)
if not isinstance(d, dict): dt_class = getattr(module, class_name)
return d
result = {} # Parse the ISO format string
for k, v in d.items(): dt = dt_class.fromisoformat(data["__datetime__"])
if isinstance(v, dict): return dt
result[k] = self._process_dict(v) except (ImportError, AttributeError, ValueError) as e:
elif isinstance(v, list): print(f"Error reconstructing datetime: {e}")
result[k] = [ return data
self._process_dict(item)
if isinstance(item, dict)
else self._object_to_json_safe_dict(item)
if hasattr(item, "__dict__")
else item
for item in v
]
elif hasattr(v, "value") and hasattr(v.__class__, "__members__"):
# Handle enum
result[k] = {"__enum__": v.__class__.__name__, "value": v.value}
elif hasattr(v, "__dict__"):
# Handle nested objects
result[k] = self._object_to_json_safe_dict(v)
else:
# Basic types
result[k] = v
return result # Check if this is a serialized enum
elif "__enum__" in data:
try:
module_name = data.get("__module__", "builtins")
enum_class = self._import_class(module_name, data["__enum__"])
return enum_class(data["value"])
except (ImportError, AttributeError) as e:
print(f"Error reconstructing enum: {e}")
return data
# Check if this is a serialized Pydantic model
elif "__pydantic__" in data:
try:
module_name = data.get("__module__", "builtins")
model_class = self._import_class(module_name, data["__pydantic__"])
return model_class(**self._reconstruct_object(data["data"]))
except (ImportError, AttributeError) as e:
print(f"Error reconstructing Pydantic model: {e}")
return data
# Regular dictionary
return {k: self._reconstruct_object(v) for k, v in data.items()}
# Handle lists
elif isinstance(data, list):
return [self._reconstruct_object(item) for item in data]
# Return primitive types as is
return data
def _import_class(self, module_name, class_name):
"""Import a class from a module."""
module = __import__(module_name, fromlist=[class_name])
return getattr(module, class_name)

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long