chore: remove recordable mock (#2088)

# What does this PR do?
We've disabled it for a while given that this hasn't worked as well as
expected given the frequent changes of llama_stack_client and how this
requires both repos to be in sync.

## Test Plan

Co-authored-by: Ashwin Bharambe <ashwin.bharambe@gmail.com>
This commit is contained in:
ehhuang 2025-05-05 10:08:55 -07:00 committed by GitHub
parent a5d151e912
commit 4597145011
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 36 additions and 57965 deletions

View file

@ -56,8 +56,8 @@ def get_boiling_point_with_metadata(liquid_name: str, celcius: bool = True) -> d
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def agent_config(llama_stack_client_with_mocked_inference, text_model_id): def agent_config(llama_stack_client, text_model_id):
available_shields = [shield.identifier for shield in llama_stack_client_with_mocked_inference.shields.list()] available_shields = [shield.identifier for shield in llama_stack_client.shields.list()]
available_shields = available_shields[:1] available_shields = available_shields[:1]
agent_config = dict( agent_config = dict(
model=text_model_id, model=text_model_id,
@ -77,8 +77,8 @@ def agent_config(llama_stack_client_with_mocked_inference, text_model_id):
return agent_config return agent_config
def test_agent_simple(llama_stack_client_with_mocked_inference, agent_config): def test_agent_simple(llama_stack_client, agent_config):
agent = Agent(llama_stack_client_with_mocked_inference, **agent_config) agent = Agent(llama_stack_client, **agent_config)
session_id = agent.create_session(f"test-session-{uuid4()}") session_id = agent.create_session(f"test-session-{uuid4()}")
simple_hello = agent.create_turn( simple_hello = agent.create_turn(
@ -179,7 +179,7 @@ def test_agent_name(llama_stack_client, text_model_id):
assert "hello" in agent_logs[0]["output"].lower() assert "hello" in agent_logs[0]["output"].lower()
def test_tool_config(llama_stack_client_with_mocked_inference, agent_config): def test_tool_config(agent_config):
common_params = dict( common_params = dict(
model="meta-llama/Llama-3.2-3B-Instruct", model="meta-llama/Llama-3.2-3B-Instruct",
instructions="You are a helpful assistant", instructions="You are a helpful assistant",
@ -235,7 +235,7 @@ def test_tool_config(llama_stack_client_with_mocked_inference, agent_config):
Server__AgentConfig(**agent_config) Server__AgentConfig(**agent_config)
def test_builtin_tool_web_search(llama_stack_client_with_mocked_inference, agent_config): def test_builtin_tool_web_search(llama_stack_client, agent_config):
agent_config = { agent_config = {
**agent_config, **agent_config,
"instructions": "You are a helpful assistant that can use web search to answer questions.", "instructions": "You are a helpful assistant that can use web search to answer questions.",
@ -243,7 +243,7 @@ def test_builtin_tool_web_search(llama_stack_client_with_mocked_inference, agent
"builtin::websearch", "builtin::websearch",
], ],
} }
agent = Agent(llama_stack_client_with_mocked_inference, **agent_config) agent = Agent(llama_stack_client, **agent_config)
session_id = agent.create_session(f"test-session-{uuid4()}") session_id = agent.create_session(f"test-session-{uuid4()}")
response = agent.create_turn( response = agent.create_turn(
@ -266,14 +266,14 @@ def test_builtin_tool_web_search(llama_stack_client_with_mocked_inference, agent
assert found_tool_execution assert found_tool_execution
def test_builtin_tool_code_execution(llama_stack_client_with_mocked_inference, agent_config): def test_builtin_tool_code_execution(llama_stack_client, agent_config):
agent_config = { agent_config = {
**agent_config, **agent_config,
"tools": [ "tools": [
"builtin::code_interpreter", "builtin::code_interpreter",
], ],
} }
agent = Agent(llama_stack_client_with_mocked_inference, **agent_config) agent = Agent(llama_stack_client, **agent_config)
session_id = agent.create_session(f"test-session-{uuid4()}") session_id = agent.create_session(f"test-session-{uuid4()}")
response = agent.create_turn( response = agent.create_turn(
@ -296,7 +296,7 @@ def test_builtin_tool_code_execution(llama_stack_client_with_mocked_inference, a
# server, this means the _server_ must have `bwrap` available. If you are using library client, then # server, this means the _server_ must have `bwrap` available. If you are using library client, then
# you must have `bwrap` available in test's environment. # you must have `bwrap` available in test's environment.
@pytest.mark.skip(reason="Code interpreter is currently disabled in the Stack") @pytest.mark.skip(reason="Code interpreter is currently disabled in the Stack")
def test_code_interpreter_for_attachments(llama_stack_client_with_mocked_inference, agent_config): def test_code_interpreter_for_attachments(llama_stack_client, agent_config):
agent_config = { agent_config = {
**agent_config, **agent_config,
"tools": [ "tools": [
@ -304,7 +304,7 @@ def test_code_interpreter_for_attachments(llama_stack_client_with_mocked_inferen
], ],
} }
codex_agent = Agent(llama_stack_client_with_mocked_inference, **agent_config) codex_agent = Agent(llama_stack_client, **agent_config)
session_id = codex_agent.create_session(f"test-session-{uuid4()}") session_id = codex_agent.create_session(f"test-session-{uuid4()}")
inflation_doc = Document( inflation_doc = Document(
content="https://raw.githubusercontent.com/meta-llama/llama-stack-apps/main/examples/resources/inflation.csv", content="https://raw.githubusercontent.com/meta-llama/llama-stack-apps/main/examples/resources/inflation.csv",
@ -332,14 +332,14 @@ def test_code_interpreter_for_attachments(llama_stack_client_with_mocked_inferen
assert "Tool:code_interpreter" in logs_str assert "Tool:code_interpreter" in logs_str
def test_custom_tool(llama_stack_client_with_mocked_inference, agent_config): def test_custom_tool(llama_stack_client, agent_config):
client_tool = get_boiling_point client_tool = get_boiling_point
agent_config = { agent_config = {
**agent_config, **agent_config,
"tools": [client_tool], "tools": [client_tool],
} }
agent = Agent(llama_stack_client_with_mocked_inference, **agent_config) agent = Agent(llama_stack_client, **agent_config)
session_id = agent.create_session(f"test-session-{uuid4()}") session_id = agent.create_session(f"test-session-{uuid4()}")
response = agent.create_turn( response = agent.create_turn(
@ -358,7 +358,7 @@ def test_custom_tool(llama_stack_client_with_mocked_inference, agent_config):
assert "get_boiling_point" in logs_str assert "get_boiling_point" in logs_str
def test_custom_tool_infinite_loop(llama_stack_client_with_mocked_inference, agent_config): def test_custom_tool_infinite_loop(llama_stack_client, agent_config):
client_tool = get_boiling_point client_tool = get_boiling_point
agent_config = { agent_config = {
**agent_config, **agent_config,
@ -367,7 +367,7 @@ def test_custom_tool_infinite_loop(llama_stack_client_with_mocked_inference, age
"max_infer_iters": 5, "max_infer_iters": 5,
} }
agent = Agent(llama_stack_client_with_mocked_inference, **agent_config) agent = Agent(llama_stack_client, **agent_config)
session_id = agent.create_session(f"test-session-{uuid4()}") session_id = agent.create_session(f"test-session-{uuid4()}")
response = agent.create_turn( response = agent.create_turn(
@ -385,25 +385,21 @@ def test_custom_tool_infinite_loop(llama_stack_client_with_mocked_inference, age
assert num_tool_calls <= 5 assert num_tool_calls <= 5
def test_tool_choice_required(llama_stack_client_with_mocked_inference, agent_config): def test_tool_choice_required(llama_stack_client, agent_config):
tool_execution_steps = run_agent_with_tool_choice( tool_execution_steps = run_agent_with_tool_choice(llama_stack_client, agent_config, "required")
llama_stack_client_with_mocked_inference, agent_config, "required"
)
assert len(tool_execution_steps) > 0 assert len(tool_execution_steps) > 0
def test_tool_choice_none(llama_stack_client_with_mocked_inference, agent_config): def test_tool_choice_none(llama_stack_client, agent_config):
tool_execution_steps = run_agent_with_tool_choice(llama_stack_client_with_mocked_inference, agent_config, "none") tool_execution_steps = run_agent_with_tool_choice(llama_stack_client, agent_config, "none")
assert len(tool_execution_steps) == 0 assert len(tool_execution_steps) == 0
def test_tool_choice_get_boiling_point(llama_stack_client_with_mocked_inference, agent_config): def test_tool_choice_get_boiling_point(llama_stack_client, agent_config):
if "llama" not in agent_config["model"].lower(): if "llama" not in agent_config["model"].lower():
pytest.xfail("NotImplemented for non-llama models") pytest.xfail("NotImplemented for non-llama models")
tool_execution_steps = run_agent_with_tool_choice( tool_execution_steps = run_agent_with_tool_choice(llama_stack_client, agent_config, "get_boiling_point")
llama_stack_client_with_mocked_inference, agent_config, "get_boiling_point"
)
assert len(tool_execution_steps) >= 1 and tool_execution_steps[0].tool_calls[0].tool_name == "get_boiling_point" assert len(tool_execution_steps) >= 1 and tool_execution_steps[0].tool_calls[0].tool_name == "get_boiling_point"
@ -435,7 +431,7 @@ def run_agent_with_tool_choice(client, agent_config, tool_choice):
@pytest.mark.parametrize("rag_tool_name", ["builtin::rag/knowledge_search", "builtin::rag"]) @pytest.mark.parametrize("rag_tool_name", ["builtin::rag/knowledge_search", "builtin::rag"])
def test_rag_agent(llama_stack_client_with_mocked_inference, agent_config, rag_tool_name): def test_rag_agent(llama_stack_client, agent_config, rag_tool_name):
urls = ["chat.rst", "llama3.rst", "memory_optimizations.rst", "lora_finetune.rst"] urls = ["chat.rst", "llama3.rst", "memory_optimizations.rst", "lora_finetune.rst"]
documents = [ documents = [
Document( Document(
@ -447,12 +443,12 @@ def test_rag_agent(llama_stack_client_with_mocked_inference, agent_config, rag_t
for i, url in enumerate(urls) for i, url in enumerate(urls)
] ]
vector_db_id = f"test-vector-db-{uuid4()}" vector_db_id = f"test-vector-db-{uuid4()}"
llama_stack_client_with_mocked_inference.vector_dbs.register( llama_stack_client.vector_dbs.register(
vector_db_id=vector_db_id, vector_db_id=vector_db_id,
embedding_model="all-MiniLM-L6-v2", embedding_model="all-MiniLM-L6-v2",
embedding_dimension=384, embedding_dimension=384,
) )
llama_stack_client_with_mocked_inference.tool_runtime.rag_tool.insert( llama_stack_client.tool_runtime.rag_tool.insert(
documents=documents, documents=documents,
vector_db_id=vector_db_id, vector_db_id=vector_db_id,
# small chunks help to get specific info out of the docs # small chunks help to get specific info out of the docs
@ -469,7 +465,7 @@ def test_rag_agent(llama_stack_client_with_mocked_inference, agent_config, rag_t
) )
], ],
} }
rag_agent = Agent(llama_stack_client_with_mocked_inference, **agent_config) rag_agent = Agent(llama_stack_client, **agent_config)
session_id = rag_agent.create_session(f"test-session-{uuid4()}") session_id = rag_agent.create_session(f"test-session-{uuid4()}")
user_prompts = [ user_prompts = [
( (
@ -494,7 +490,7 @@ def test_rag_agent(llama_stack_client_with_mocked_inference, agent_config, rag_t
assert expected_kw in response.output_message.content.lower() assert expected_kw in response.output_message.content.lower()
def test_rag_agent_with_attachments(llama_stack_client_with_mocked_inference, agent_config): def test_rag_agent_with_attachments(llama_stack_client, agent_config):
urls = ["llama3.rst", "lora_finetune.rst"] urls = ["llama3.rst", "lora_finetune.rst"]
documents = [ documents = [
# passign as url # passign as url
@ -517,7 +513,7 @@ def test_rag_agent_with_attachments(llama_stack_client_with_mocked_inference, ag
metadata={}, metadata={},
), ),
] ]
rag_agent = Agent(llama_stack_client_with_mocked_inference, **agent_config) rag_agent = Agent(llama_stack_client, **agent_config)
session_id = rag_agent.create_session(f"test-session-{uuid4()}") session_id = rag_agent.create_session(f"test-session-{uuid4()}")
user_prompts = [ user_prompts = [
( (
@ -553,7 +549,7 @@ def test_rag_agent_with_attachments(llama_stack_client_with_mocked_inference, ag
@pytest.mark.skip(reason="Code interpreter is currently disabled in the Stack") @pytest.mark.skip(reason="Code interpreter is currently disabled in the Stack")
def test_rag_and_code_agent(llama_stack_client_with_mocked_inference, agent_config): def test_rag_and_code_agent(llama_stack_client, agent_config):
if "llama-4" in agent_config["model"].lower(): if "llama-4" in agent_config["model"].lower():
pytest.xfail("Not working for llama4") pytest.xfail("Not working for llama4")
@ -578,12 +574,12 @@ def test_rag_and_code_agent(llama_stack_client_with_mocked_inference, agent_conf
) )
) )
vector_db_id = f"test-vector-db-{uuid4()}" vector_db_id = f"test-vector-db-{uuid4()}"
llama_stack_client_with_mocked_inference.vector_dbs.register( llama_stack_client.vector_dbs.register(
vector_db_id=vector_db_id, vector_db_id=vector_db_id,
embedding_model="all-MiniLM-L6-v2", embedding_model="all-MiniLM-L6-v2",
embedding_dimension=384, embedding_dimension=384,
) )
llama_stack_client_with_mocked_inference.tool_runtime.rag_tool.insert( llama_stack_client.tool_runtime.rag_tool.insert(
documents=documents, documents=documents,
vector_db_id=vector_db_id, vector_db_id=vector_db_id,
chunk_size_in_tokens=128, chunk_size_in_tokens=128,
@ -598,7 +594,7 @@ def test_rag_and_code_agent(llama_stack_client_with_mocked_inference, agent_conf
"builtin::code_interpreter", "builtin::code_interpreter",
], ],
} }
agent = Agent(llama_stack_client_with_mocked_inference, **agent_config) agent = Agent(llama_stack_client, **agent_config)
user_prompts = [ user_prompts = [
( (
"when was Perplexity the company founded?", "when was Perplexity the company founded?",
@ -632,7 +628,7 @@ def test_rag_and_code_agent(llama_stack_client_with_mocked_inference, agent_conf
"client_tools", "client_tools",
[(get_boiling_point, False), (get_boiling_point_with_metadata, True)], [(get_boiling_point, False), (get_boiling_point_with_metadata, True)],
) )
def test_create_turn_response(llama_stack_client_with_mocked_inference, agent_config, client_tools): def test_create_turn_response(llama_stack_client, agent_config, client_tools):
client_tool, expects_metadata = client_tools client_tool, expects_metadata = client_tools
agent_config = { agent_config = {
**agent_config, **agent_config,
@ -641,7 +637,7 @@ def test_create_turn_response(llama_stack_client_with_mocked_inference, agent_co
"tools": [client_tool], "tools": [client_tool],
} }
agent = Agent(llama_stack_client_with_mocked_inference, **agent_config) agent = Agent(llama_stack_client, **agent_config)
session_id = agent.create_session(f"test-session-{uuid4()}") session_id = agent.create_session(f"test-session-{uuid4()}")
input_prompt = f"Call {client_tools[0].__name__} tool and answer What is the boiling point of polyjuice?" input_prompt = f"Call {client_tools[0].__name__} tool and answer What is the boiling point of polyjuice?"
@ -677,7 +673,7 @@ def test_create_turn_response(llama_stack_client_with_mocked_inference, agent_co
last_step_completed_at = step.completed_at last_step_completed_at = step.completed_at
def test_multi_tool_calls(llama_stack_client_with_mocked_inference, agent_config): def test_multi_tool_calls(llama_stack_client, agent_config):
if "gpt" not in agent_config["model"]: if "gpt" not in agent_config["model"]:
pytest.xfail("Only tested on GPT models") pytest.xfail("Only tested on GPT models")
@ -686,7 +682,7 @@ def test_multi_tool_calls(llama_stack_client_with_mocked_inference, agent_config
"tools": [get_boiling_point], "tools": [get_boiling_point],
} }
agent = Agent(llama_stack_client_with_mocked_inference, **agent_config) agent = Agent(llama_stack_client, **agent_config)
session_id = agent.create_session(f"test-session-{uuid4()}") session_id = agent.create_session(f"test-session-{uuid4()}")
response = agent.create_turn( response = agent.create_turn(

View file

@ -4,12 +4,9 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import copy
import inspect import inspect
import logging
import os import os
import tempfile import tempfile
from pathlib import Path
import pytest import pytest
import yaml import yaml
@ -17,12 +14,9 @@ from llama_stack_client import LlamaStackClient
from openai import OpenAI from openai import OpenAI
from llama_stack import LlamaStackAsLibraryClient from llama_stack import LlamaStackAsLibraryClient
from llama_stack.apis.datatypes import Api
from llama_stack.distribution.stack import run_config_from_adhoc_config_spec from llama_stack.distribution.stack import run_config_from_adhoc_config_spec
from llama_stack.env import get_env_or_fail from llama_stack.env import get_env_or_fail
from .recordable_mock import RecordableMock
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def provider_data(): def provider_data():
@ -46,63 +40,6 @@ def provider_data():
return provider_data return provider_data
@pytest.fixture(scope="session")
def llama_stack_client_with_mocked_inference(llama_stack_client, request):
"""
Returns a client with mocked inference APIs and tool runtime APIs that use recorded responses by default.
If --record-responses is passed, it will call the real APIs and record the responses.
"""
# TODO: will rework this to be more stable
return llama_stack_client
if not isinstance(llama_stack_client, LlamaStackAsLibraryClient):
logging.warning(
"llama_stack_client_with_mocked_inference is not supported for this client, returning original client without mocking"
)
return llama_stack_client
record_responses = request.config.getoption("--record-responses")
cache_dir = Path(__file__).parent / "recorded_responses"
# Create a shallow copy of the client to avoid modifying the original
client = copy.copy(llama_stack_client)
# Get the inference API used by the agents implementation
agents_impl = client.async_client.impls[Api.agents]
original_inference = agents_impl.inference_api
# Create a new inference object with the same attributes
inference_mock = copy.copy(original_inference)
# Replace the methods with recordable mocks
inference_mock.chat_completion = RecordableMock(
original_inference.chat_completion, cache_dir, "chat_completion", record=record_responses
)
inference_mock.completion = RecordableMock(
original_inference.completion, cache_dir, "text_completion", record=record_responses
)
inference_mock.embeddings = RecordableMock(
original_inference.embeddings, cache_dir, "embeddings", record=record_responses
)
# Replace the inference API in the agents implementation
agents_impl.inference_api = inference_mock
original_tool_runtime_api = agents_impl.tool_runtime_api
tool_runtime_mock = copy.copy(original_tool_runtime_api)
# Replace the methods with recordable mocks
tool_runtime_mock.invoke_tool = RecordableMock(
original_tool_runtime_api.invoke_tool, cache_dir, "invoke_tool", record=record_responses
)
agents_impl.tool_runtime_api = tool_runtime_mock
# Also update the client.inference for consistency
client.inference = inference_mock
return client
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def inference_provider_type(llama_stack_client): def inference_provider_type(llama_stack_client):
providers = llama_stack_client.providers.list() providers = llama_stack_client.providers.list()
@ -177,7 +114,7 @@ def skip_if_no_model(request):
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def llama_stack_client(request, provider_data, text_model_id): def llama_stack_client(request, provider_data):
config = request.config.getoption("--stack-config") config = request.config.getoption("--stack-config")
if not config: if not config:
config = get_env_or_fail("LLAMA_STACK_CONFIG") config = get_env_or_fail("LLAMA_STACK_CONFIG")

View file

@ -1,221 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import importlib
import json
import os
import re
from datetime import datetime
from enum import Enum
from pathlib import Path
class RecordableMock:
"""A mock that can record and replay API responses."""
def __init__(self, real_func, cache_dir, func_name, record=False):
self.real_func = real_func
self.json_path = Path(cache_dir) / f"{func_name}.json"
self.record = record
self.cache = {}
# Load existing cache if available and not recording
if self.json_path.exists():
try:
with open(self.json_path) as f:
self.cache = json.load(f)
except Exception as e:
print(f"Error loading cache from {self.json_path}: {e}")
raise
async def __call__(self, *args, **kwargs):
"""
Returns a coroutine that when awaited returns the result or an async generator,
matching the behavior of the original function.
"""
# Create a cache key from the arguments
key = self._create_cache_key(args, kwargs)
if self.record:
# In record mode, always call the real function
real_result = self.real_func(*args, **kwargs)
# If it's a coroutine, we need to create a wrapper coroutine
if hasattr(real_result, "__await__"):
# Define a coroutine function that will record the result
async def record_coroutine():
try:
# Await the real coroutine
result = await real_result
# Check if the result is an async generator
if hasattr(result, "__aiter__"):
# It's an async generator, so we need to record its chunks
chunks = []
# Create and return a new async generator that records chunks
async def recording_generator():
nonlocal chunks
async for chunk in result:
chunks.append(chunk)
yield chunk
# After all chunks are yielded, save to cache
self.cache[key] = {"type": "generator", "chunks": chunks}
self._save_cache()
return recording_generator()
else:
# It's a regular result, save it to cache
self.cache[key] = {"type": "value", "value": result}
self._save_cache()
return result
except Exception as e:
print(f"Error in recording mode: {e}")
raise
return await record_coroutine()
else:
# It's already an async generator, so we need to record its chunks
async def record_generator():
chunks = []
async for chunk in real_result:
chunks.append(chunk)
yield chunk
# After all chunks are yielded, save to cache
self.cache[key] = {"type": "generator", "chunks": chunks}
self._save_cache()
return record_generator()
elif key not in self.cache:
# In replay mode, if the key is not in the cache, throw an error
raise KeyError(
f"No cached response found for key: {key}\nRun with --record-responses to record this response."
)
else:
# In replay mode with a cached response
cached_data = self.cache[key]
# Check if it's a value or chunks
if cached_data.get("type") == "value":
# It's a regular value
return self._reconstruct_object(cached_data["value"])
else:
# It's chunks from an async generator
async def replay_generator():
for chunk in cached_data["chunks"]:
yield self._reconstruct_object(chunk)
return replay_generator()
def _create_cache_key(self, args, kwargs):
"""Create a hashable key from the function arguments, ignoring auto-generated IDs."""
# Convert to JSON strings with sorted keys
key = json.dumps((args, kwargs), sort_keys=True, default=self._json_default)
# Post-process the key with regex to replace IDs with placeholders
# Replace UUIDs and similar patterns
key = re.sub(r"[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}", "<UUID>", key)
# Replace temporary file paths created by tempfile.mkdtemp()
key = re.sub(r"/var/folders/[^,'\"\s]+", "<TEMP_FILE>", key)
# Replace /tmp/ paths which are also commonly used for temporary files
key = re.sub(r"/tmp/[^,'\"\s]+", "<TEMP_FILE>", key)
return key
def _save_cache(self):
"""Save the cache to disk in JSON format."""
os.makedirs(self.json_path.parent, exist_ok=True)
# Write the JSON file with pretty formatting
try:
with open(self.json_path, "w") as f:
json.dump(self.cache, f, indent=2, sort_keys=True, default=self._json_default)
# write another empty line at the end of the file to make pre-commit happy
f.write("\n")
except Exception as e:
print(f"Error saving JSON cache: {e}")
def _json_default(self, obj):
"""Default function for JSON serialization of objects."""
if isinstance(obj, datetime):
return {
"__datetime__": obj.isoformat(),
"__module__": obj.__class__.__module__,
"__class__": obj.__class__.__name__,
}
if isinstance(obj, Enum):
return {
"__enum__": obj.__class__.__name__,
"value": obj.value,
"__module__": obj.__class__.__module__,
}
# Handle Pydantic models
if hasattr(obj, "model_dump"):
model_data = obj.model_dump()
return {
"__pydantic__": obj.__class__.__name__,
"__module__": obj.__class__.__module__,
"data": model_data,
}
def _reconstruct_object(self, data):
"""Reconstruct an object from its JSON representation."""
if isinstance(data, dict):
# Check if this is a serialized datetime
if "__datetime__" in data:
try:
module_name = data.get("__module__", "datetime")
class_name = data.get("__class__", "datetime")
# Try to import the specific datetime class
module = importlib.import_module(module_name)
dt_class = getattr(module, class_name)
# Parse the ISO format string
dt = dt_class.fromisoformat(data["__datetime__"])
return dt
except (ImportError, AttributeError, ValueError) as e:
print(f"Error reconstructing datetime: {e}")
return data
# Check if this is a serialized enum
elif "__enum__" in data:
try:
module_name = data.get("__module__", "builtins")
enum_class = self._import_class(module_name, data["__enum__"])
return enum_class(data["value"])
except (ImportError, AttributeError) as e:
print(f"Error reconstructing enum: {e}")
return data
# Check if this is a serialized Pydantic model
elif "__pydantic__" in data:
try:
module_name = data.get("__module__", "builtins")
model_class = self._import_class(module_name, data["__pydantic__"])
return model_class(**self._reconstruct_object(data["data"]))
except (ImportError, AttributeError) as e:
print(f"Error reconstructing Pydantic model: {e}")
return data
# Regular dictionary
return {k: self._reconstruct_object(v) for k, v in data.items()}
# Handle lists
elif isinstance(data, list):
return [self._reconstruct_object(item) for item in data]
# Return primitive types as is
return data
def _import_class(self, module_name, class_name):
"""Import a class from a module."""
module = __import__(module_name, fromlist=[class_name])
return getattr(module, class_name)

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long