This commit is contained in:
Ashwin Bharambe 2025-10-03 17:17:13 -07:00
parent 4e512aec8a
commit 2bb8055646
2 changed files with 0 additions and 65 deletions

View file

@ -80,31 +80,6 @@ def setup_inference_recording():
return inference_recording(mode=mode, storage_dir=storage_dir)
def _normalize_tool_call_ids(obj: Any, request_hash: str, counter: dict[str, int]) -> Any:
"""Recursively normalize tool call IDs in an object structure."""
if isinstance(obj, dict):
# Normalize tool_calls array
if "tool_calls" in obj and isinstance(obj["tool_calls"], list):
for tool_call in obj["tool_calls"]:
if isinstance(tool_call, dict) and "id" in tool_call:
# Generate deterministic tool call ID
tool_call["id"] = f"toolcall-{request_hash[:8]}-{counter['count']}"
counter["count"] += 1
# Recurse into nested structures
_normalize_tool_call_ids(tool_call, request_hash, counter)
# Recurse into all dict values
for key, value in obj.items():
if key != "tool_calls": # Already handled above
obj[key] = _normalize_tool_call_ids(value, request_hash, counter)
elif isinstance(obj, list):
# Recurse into list items
return [_normalize_tool_call_ids(item, request_hash, counter) for item in obj]
return obj
def _normalize_response_data(data: dict[str, Any], request_hash: str) -> dict[str, Any]:
"""Normalize fields that change between recordings but don't affect functionality.
@ -134,10 +109,6 @@ def _normalize_response_data(data: dict[str, Any], request_hash: str) -> dict[st
if "eval_duration" in data and data["eval_duration"] is not None:
data["eval_duration"] = 0
# Normalize tool call IDs in delta/choices (for streaming responses)
counter = {"count": 0}
_normalize_tool_call_ids(data, request_hash, counter)
return data

View file

@ -21,38 +21,6 @@ import json
from pathlib import Path
def normalize_tool_call_ids(obj, request_hash: str, counter: dict) -> None:
"""Recursively normalize tool call IDs in an object structure."""
if isinstance(obj, dict):
# Normalize tool_calls array
if "tool_calls" in obj and isinstance(obj["tool_calls"], list):
for tool_call in obj["tool_calls"]:
if isinstance(tool_call, dict) and "id" in tool_call:
# Generate deterministic tool call ID
tool_call["id"] = f"toolcall-{request_hash[:8]}-{counter['count']}"
counter["count"] += 1
# Recurse into nested structures
normalize_tool_call_ids(tool_call, request_hash, counter)
# Normalize tool_call_id field (used in tool response messages)
if "tool_call_id" in obj:
# We need to map this to the same ID that was used in tool_calls
# For post-facto cleanup, we'll use a simple approach: extract index if possible
# Otherwise just normalize with counter
obj["tool_call_id"] = f"toolcall-{request_hash[:8]}-{counter['tool_response_count']}"
counter["tool_response_count"] += 1
# Recurse into all dict values
for key, value in obj.items():
if key not in ("tool_calls", "tool_call_id"): # Already handled above
normalize_tool_call_ids(value, request_hash, counter)
elif isinstance(obj, list):
# Recurse into list items
for item in obj:
normalize_tool_call_ids(item, request_hash, counter)
def normalize_response_data(data: dict, request_hash: str) -> dict:
"""Normalize fields that change between recordings but don't affect functionality."""
# Only normalize ID for completion/chat responses, not for model objects
@ -79,10 +47,6 @@ def normalize_response_data(data: dict, request_hash: str) -> dict:
if "eval_duration" in data and data["eval_duration"] is not None:
data["eval_duration"] = 0
# Normalize tool call IDs
counter = {"count": 0, "tool_response_count": 0}
normalize_tool_call_ids(data, request_hash, counter)
return data