mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-29 03:14:19 +00:00
test: recordable mocks use json only (#1443)
# Summary: removes the use of pickle # Test Plan: Run the following with `--record-responses` first, then another time without. LLAMA_STACK_CONFIG=fireworks pytest -s -v tests/integration/agents/test_agents.py --safety-shield meta-llama/Llama-Guard-3-8B --text-model meta-llama/Llama-3.1-8B-Instruct
This commit is contained in:
parent
564977c646
commit
3d71e5a036
6 changed files with 23792 additions and 26435 deletions
|
@ -59,7 +59,7 @@ def llama_stack_client_with_mocked_inference(llama_stack_client, request):
|
||||||
return llama_stack_client
|
return llama_stack_client
|
||||||
|
|
||||||
record_responses = request.config.getoption("--record-responses")
|
record_responses = request.config.getoption("--record-responses")
|
||||||
cache_dir = Path(__file__).parent / "fixtures" / "recorded_responses"
|
cache_dir = Path(__file__).parent / "recorded_responses"
|
||||||
|
|
||||||
# Create a shallow copy of the client to avoid modifying the original
|
# Create a shallow copy of the client to avoid modifying the original
|
||||||
client = copy.copy(llama_stack_client)
|
client = copy.copy(llama_stack_client)
|
||||||
|
|
|
@ -3,10 +3,12 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
import importlib
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import pickle
|
|
||||||
import re
|
import re
|
||||||
|
from datetime import datetime
|
||||||
|
from enum import Enum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
@ -15,18 +17,18 @@ class RecordableMock:
|
||||||
|
|
||||||
def __init__(self, real_func, cache_dir, func_name, record=False):
|
def __init__(self, real_func, cache_dir, func_name, record=False):
|
||||||
self.real_func = real_func
|
self.real_func = real_func
|
||||||
self.pickle_path = Path(cache_dir) / f"{func_name}.pickle"
|
|
||||||
self.json_path = Path(cache_dir) / f"{func_name}.json"
|
self.json_path = Path(cache_dir) / f"{func_name}.json"
|
||||||
self.record = record
|
self.record = record
|
||||||
self.cache = {}
|
self.cache = {}
|
||||||
|
|
||||||
# Load existing cache if available and not recording
|
# Load existing cache if available and not recording
|
||||||
if self.pickle_path.exists():
|
if self.json_path.exists():
|
||||||
try:
|
try:
|
||||||
with open(self.pickle_path, "rb") as f:
|
with open(self.json_path, "r") as f:
|
||||||
self.cache = pickle.load(f)
|
self.cache = json.load(f)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error loading cache from {self.pickle_path}: {e}")
|
print(f"Error loading cache from {self.json_path}: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
async def __call__(self, *args, **kwargs):
|
async def __call__(self, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
|
@ -98,23 +100,19 @@ class RecordableMock:
|
||||||
# Check if it's a value or chunks
|
# Check if it's a value or chunks
|
||||||
if cached_data.get("type") == "value":
|
if cached_data.get("type") == "value":
|
||||||
# It's a regular value
|
# It's a regular value
|
||||||
return cached_data["value"]
|
return self._reconstruct_object(cached_data["value"])
|
||||||
else:
|
else:
|
||||||
# It's chunks from an async generator
|
# It's chunks from an async generator
|
||||||
async def replay_generator():
|
async def replay_generator():
|
||||||
for chunk in cached_data["chunks"]:
|
for chunk in cached_data["chunks"]:
|
||||||
yield chunk
|
yield self._reconstruct_object(chunk)
|
||||||
|
|
||||||
return replay_generator()
|
return replay_generator()
|
||||||
|
|
||||||
def _create_cache_key(self, args, kwargs):
|
def _create_cache_key(self, args, kwargs):
|
||||||
"""Create a hashable key from the function arguments, ignoring auto-generated IDs."""
|
"""Create a hashable key from the function arguments, ignoring auto-generated IDs."""
|
||||||
# Convert args and kwargs to a string representation directly
|
# Convert to JSON strings with sorted keys
|
||||||
args_str = str(args)
|
key = json.dumps((args, kwargs), sort_keys=True, default=self._json_default)
|
||||||
kwargs_str = str(sorted([(k, kwargs[k]) for k in kwargs]))
|
|
||||||
|
|
||||||
# Combine into a single key
|
|
||||||
key = f"{args_str}_{kwargs_str}"
|
|
||||||
|
|
||||||
# Post-process the key with regex to replace IDs with placeholders
|
# Post-process the key with regex to replace IDs with placeholders
|
||||||
# Replace UUIDs and similar patterns
|
# Replace UUIDs and similar patterns
|
||||||
|
@ -126,83 +124,95 @@ class RecordableMock:
|
||||||
return key
|
return key
|
||||||
|
|
||||||
def _save_cache(self):
|
def _save_cache(self):
|
||||||
"""Save the cache to disk in both pickle and JSON formats."""
|
"""Save the cache to disk in JSON format."""
|
||||||
os.makedirs(self.pickle_path.parent, exist_ok=True)
|
os.makedirs(self.json_path.parent, exist_ok=True)
|
||||||
|
|
||||||
# Save as pickle for exact object preservation
|
# Write the JSON file with pretty formatting
|
||||||
with open(self.pickle_path, "wb") as f:
|
|
||||||
pickle.dump(self.cache, f)
|
|
||||||
|
|
||||||
# Also save as JSON for human readability and diffing
|
|
||||||
try:
|
try:
|
||||||
# Create a simplified version of the cache for JSON
|
|
||||||
json_cache = {}
|
|
||||||
for key, value in self.cache.items():
|
|
||||||
if value.get("type") == "generator":
|
|
||||||
# For generators, create a simplified representation of each chunk
|
|
||||||
chunks = []
|
|
||||||
for chunk in value["chunks"]:
|
|
||||||
chunk_dict = self._object_to_json_safe_dict(chunk)
|
|
||||||
chunks.append(chunk_dict)
|
|
||||||
json_cache[key] = {"type": "generator", "chunks": chunks}
|
|
||||||
else:
|
|
||||||
# For values, create a simplified representation
|
|
||||||
val = value["value"]
|
|
||||||
val_dict = self._object_to_json_safe_dict(val)
|
|
||||||
json_cache[key] = {"type": "value", "value": val_dict}
|
|
||||||
|
|
||||||
# Write the JSON file with pretty formatting
|
|
||||||
with open(self.json_path, "w") as f:
|
with open(self.json_path, "w") as f:
|
||||||
json.dump(json_cache, f, indent=2, sort_keys=True)
|
json.dump(self.cache, f, indent=2, sort_keys=True, default=self._json_default)
|
||||||
|
# write another empty line at the end of the file to make pre-commit happy
|
||||||
|
f.write("\n")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error saving JSON cache: {e}")
|
print(f"Error saving JSON cache: {e}")
|
||||||
|
|
||||||
def _object_to_json_safe_dict(self, obj):
|
def _json_default(self, obj):
|
||||||
"""Convert an object to a JSON-safe dictionary."""
|
"""Default function for JSON serialization of objects."""
|
||||||
# Handle enum types
|
|
||||||
if hasattr(obj, "value") and hasattr(obj.__class__, "__members__"):
|
if isinstance(obj, datetime):
|
||||||
return {"__enum__": obj.__class__.__name__, "value": obj.value}
|
return {
|
||||||
|
"__datetime__": obj.isoformat(),
|
||||||
|
"__module__": obj.__class__.__module__,
|
||||||
|
"__class__": obj.__class__.__name__,
|
||||||
|
}
|
||||||
|
|
||||||
|
if isinstance(obj, Enum):
|
||||||
|
return {
|
||||||
|
"__enum__": obj.__class__.__name__,
|
||||||
|
"value": obj.value,
|
||||||
|
"__module__": obj.__class__.__module__,
|
||||||
|
}
|
||||||
|
|
||||||
# Handle Pydantic models
|
# Handle Pydantic models
|
||||||
if hasattr(obj, "model_dump"):
|
if hasattr(obj, "model_dump"):
|
||||||
return self._process_dict(obj.model_dump())
|
model_data = obj.model_dump()
|
||||||
elif hasattr(obj, "dict"):
|
return {
|
||||||
return self._process_dict(obj.dict())
|
"__pydantic__": obj.__class__.__name__,
|
||||||
|
"__module__": obj.__class__.__module__,
|
||||||
|
"data": model_data,
|
||||||
|
}
|
||||||
|
|
||||||
# Handle regular objects with __dict__
|
def _reconstruct_object(self, data):
|
||||||
try:
|
"""Reconstruct an object from its JSON representation."""
|
||||||
return self._process_dict(vars(obj))
|
if isinstance(data, dict):
|
||||||
except Exception as e:
|
# Check if this is a serialized datetime
|
||||||
print(f"Error converting object to JSON-safe dict: {e}")
|
if "__datetime__" in data:
|
||||||
# If we can't get a dict, convert to string
|
try:
|
||||||
return str(obj)
|
module_name = data.get("__module__", "datetime")
|
||||||
|
class_name = data.get("__class__", "datetime")
|
||||||
|
|
||||||
def _process_dict(self, d):
|
# Try to import the specific datetime class
|
||||||
"""Process a dictionary to make all values JSON-safe."""
|
module = importlib.import_module(module_name)
|
||||||
if not isinstance(d, dict):
|
dt_class = getattr(module, class_name)
|
||||||
return d
|
|
||||||
|
|
||||||
result = {}
|
# Parse the ISO format string
|
||||||
for k, v in d.items():
|
dt = dt_class.fromisoformat(data["__datetime__"])
|
||||||
if isinstance(v, dict):
|
return dt
|
||||||
result[k] = self._process_dict(v)
|
except (ImportError, AttributeError, ValueError) as e:
|
||||||
elif isinstance(v, list):
|
print(f"Error reconstructing datetime: {e}")
|
||||||
result[k] = [
|
return data
|
||||||
self._process_dict(item)
|
|
||||||
if isinstance(item, dict)
|
|
||||||
else self._object_to_json_safe_dict(item)
|
|
||||||
if hasattr(item, "__dict__")
|
|
||||||
else item
|
|
||||||
for item in v
|
|
||||||
]
|
|
||||||
elif hasattr(v, "value") and hasattr(v.__class__, "__members__"):
|
|
||||||
# Handle enum
|
|
||||||
result[k] = {"__enum__": v.__class__.__name__, "value": v.value}
|
|
||||||
elif hasattr(v, "__dict__"):
|
|
||||||
# Handle nested objects
|
|
||||||
result[k] = self._object_to_json_safe_dict(v)
|
|
||||||
else:
|
|
||||||
# Basic types
|
|
||||||
result[k] = v
|
|
||||||
|
|
||||||
return result
|
# Check if this is a serialized enum
|
||||||
|
elif "__enum__" in data:
|
||||||
|
try:
|
||||||
|
module_name = data.get("__module__", "builtins")
|
||||||
|
enum_class = self._import_class(module_name, data["__enum__"])
|
||||||
|
return enum_class(data["value"])
|
||||||
|
except (ImportError, AttributeError) as e:
|
||||||
|
print(f"Error reconstructing enum: {e}")
|
||||||
|
return data
|
||||||
|
|
||||||
|
# Check if this is a serialized Pydantic model
|
||||||
|
elif "__pydantic__" in data:
|
||||||
|
try:
|
||||||
|
module_name = data.get("__module__", "builtins")
|
||||||
|
model_class = self._import_class(module_name, data["__pydantic__"])
|
||||||
|
return model_class(**self._reconstruct_object(data["data"]))
|
||||||
|
except (ImportError, AttributeError) as e:
|
||||||
|
print(f"Error reconstructing Pydantic model: {e}")
|
||||||
|
return data
|
||||||
|
|
||||||
|
# Regular dictionary
|
||||||
|
return {k: self._reconstruct_object(v) for k, v in data.items()}
|
||||||
|
|
||||||
|
# Handle lists
|
||||||
|
elif isinstance(data, list):
|
||||||
|
return [self._reconstruct_object(item) for item in data]
|
||||||
|
|
||||||
|
# Return primitive types as is
|
||||||
|
return data
|
||||||
|
|
||||||
|
def _import_class(self, module_name, class_name):
|
||||||
|
"""Import a class from a module."""
|
||||||
|
module = __import__(module_name, fromlist=[class_name])
|
||||||
|
return getattr(module, class_name)
|
||||||
|
|
File diff suppressed because one or more lines are too long
Binary file not shown.
File diff suppressed because one or more lines are too long
Binary file not shown.
Loading…
Add table
Add a link
Reference in a new issue