more involved id mappings

This commit is contained in:
Ashwin Bharambe 2025-10-08 18:35:45 -07:00
parent ecea7ccdb8
commit 39aa17f975
15 changed files with 16151 additions and 4 deletions

View file

@ -9,6 +9,7 @@ from __future__ import annotations # for forward references
import hashlib
import json
import os
import re
from collections.abc import Generator
from contextlib import contextmanager
from enum import StrEnum
@ -134,6 +135,194 @@ def _normalize_file_ids(obj: Any) -> Any:
return obj
def _get_current_test_id_mappings() -> dict[str, str]:
"""Get the current test's ID normalization mappings.
Returns a dict mapping normalized_id -> original_id (e.g., "file-1" -> "file-abc123...")
This is the inverse of what's stored in _id_normalizers.
"""
global _id_normalizers
test_id = _test_context.get()
if not test_id or test_id not in _id_normalizers:
return {}
# Invert the mapping: normalized_id -> original_id
result = {}
for id_type, mappings in _id_normalizers[test_id].items():
for original_id, normalized_id in mappings.items():
result[normalized_id] = original_id
return result
def _chunk_text_content(chunk: Any) -> tuple[str | None, bool]:
"""Return (content, has_structured_fields) for OpenAI chat completion chunks."""
choices = getattr(chunk, "choices", None)
if not choices:
return None, False
delta = choices[0].delta
content = getattr(delta, "content", None)
if not content:
return None, False
has_structured = bool(getattr(delta, "tool_calls", None) or getattr(delta, "function_call", None))
return content, has_structured
def _chunk_with_content(chunk: Any, content: str) -> Any:
"""Return a copy of the chunk with delta.content replaced by the provided string."""
choices = getattr(chunk, "choices", None)
if not choices:
return chunk
updated_choices = []
for choice in choices:
delta = choice.delta
if getattr(delta, "content", None) is not None:
new_delta = delta.model_copy(update={"content": content})
updated_choices.append(choice.model_copy(update={"delta": new_delta}))
else:
updated_choices.append(choice)
return chunk.model_copy(update={"choices": updated_choices})
def _ends_with_partial_identifier(text: str) -> bool:
"""Return True if text ends in an incomplete file identifier."""
match = re.search(r"(?:<\|)?file-[A-Za-z0-9_-]*$", text)
if not match:
return False
token = match.group()
enclosed = token.startswith("<|")
if enclosed and not token.endswith("|>"):
return True
if enclosed:
core = token[2:-2] if token.endswith("|>") else token[2:]
else:
core = token
suffix = core[len("file-"):]
if len(suffix) < 16:
return True
if not re.fullmatch(r"[A-Za-z0-9_-]+", suffix):
return True
return False
def _has_safe_boundary(text: str) -> bool:
if not text:
return False
last_char = text[-1]
if last_char.isspace():
return True
return last_char in ".,?!;:)]}>\"'"
def _coalesce_streaming_chunks(chunks: list[Any]) -> list[Any]:
"""Merge adjacent text chunks to avoid breaking identifiers across boundaries."""
result: list[Any] = []
pending_chunk: Any | None = None
pending_content = ""
for chunk in chunks:
content, has_structured = _chunk_text_content(chunk)
if content is None or has_structured:
if pending_chunk is not None:
result.append(_chunk_with_content(pending_chunk, pending_content))
pending_chunk = None
pending_content = ""
result.append(chunk)
continue
if pending_chunk is None:
pending_chunk = chunk
pending_content = content
else:
pending_content += content
if (not _ends_with_partial_identifier(pending_content)) and _has_safe_boundary(pending_content):
result.append(_chunk_with_content(pending_chunk, pending_content))
pending_chunk = None
pending_content = ""
if pending_chunk is not None:
result.append(_chunk_with_content(pending_chunk, pending_content))
return result
def _denormalize_response(obj: Any, recorded_mapping: dict[str, str], current_mapping: dict[str, str]) -> Any:
"""Replace recorded IDs with current runtime IDs in a response object.
Args:
obj: The response object to denormalize
recorded_mapping: normalized_id -> recorded_original_id (from recording file)
current_mapping: normalized_id -> current_original_id (from current runtime)
Returns:
The response with all recorded IDs replaced with current IDs
"""
import re
# Build reverse mapping: recorded_original_id -> current_original_id
# via the normalized intermediary
id_translation = {}
for normalized_id, recorded_id in recorded_mapping.items():
if normalized_id in current_mapping:
current_id = current_mapping[normalized_id]
id_translation[recorded_id] = current_id
if isinstance(obj, dict):
result = {}
for k, v in obj.items():
# Translate document_id fields
if k == "document_id" and isinstance(v, str):
result[k] = id_translation.get(v, v)
# Translate vector database/store IDs
elif k in ("vector_db_id", "vector_store_id", "bank_id") and isinstance(v, str):
# Replace any recorded IDs in the value
translated = v
for recorded_id, current_id in id_translation.items():
translated = translated.replace(recorded_id, current_id)
result[k] = translated
else:
result[k] = _denormalize_response(v, recorded_mapping, current_mapping)
return result
elif isinstance(obj, list):
return [_denormalize_response(item, recorded_mapping, current_mapping) for item in obj]
elif hasattr(obj, "model_dump"):
# Handle Pydantic/BaseModel instances by denormalizing their dict form
data = obj.model_dump(mode="python")
denormalized = _denormalize_response(data, recorded_mapping, current_mapping)
cls = obj.__class__
try:
return cls.model_validate(denormalized)
except Exception:
try:
return cls.model_construct(**denormalized)
except Exception:
return denormalized
elif isinstance(obj, str):
# Replace file-<uuid> patterns in strings (like in text content and citations)
translated = obj
for recorded_id, current_id in id_translation.items():
# Handle both bare file IDs and citation format <|file-...|>
translated = translated.replace(recorded_id, current_id)
return translated
else:
return obj
def normalize_inference_request(method: str, url: str, headers: dict[str, Any], body: dict[str, Any]) -> str:
"""Create a normalized hash of the request for consistent matching.
@ -438,6 +627,7 @@ class ResponseStorage:
"test_id": _test_context.get(),
"request": request,
"response": serialized_response,
"id_normalization_mapping": _get_current_test_id_mappings(),
},
f,
indent=2,
@ -569,7 +759,7 @@ async def _patched_tool_invoke_method(
original_method, provider_name: str, self, tool_name: str, kwargs: dict[str, Any]
):
"""Patched version of tool runtime invoke_tool method for recording/replay."""
global _current_mode, _current_storage
global _current_mode, _current_storage, _id_normalizers
if _current_mode == APIRecordingMode.LIVE or _current_storage is None:
# Normal operation
@ -618,7 +808,7 @@ async def _patched_tool_invoke_method(
async def _patched_inference_method(original_method, self, client_type, endpoint, *args, **kwargs):
global _current_mode, _current_storage
global _current_mode, _current_storage, _id_normalizers
mode = _current_mode
storage = _current_storage
@ -671,6 +861,48 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
if recording:
response_body = recording["response"]["body"]
if recording["response"].get("is_streaming", False) and isinstance(response_body, list):
response_body = _coalesce_streaming_chunks(response_body)
recording["response"]["body"] = response_body
# Denormalize the response: replace recorded IDs with current runtime IDs
recorded_mapping = recording.get("id_normalization_mapping", {})
current_mapping = _get_current_test_id_mappings()
if recorded_mapping or current_mapping:
import sys
print(f"\n=== DENORM DEBUG ===", file=sys.stderr)
print(f"Recorded mapping: {recorded_mapping}", file=sys.stderr)
print(f"Current mapping: {current_mapping}", file=sys.stderr)
print(f"_id_normalizers before: {_id_normalizers.get(_test_context.get(), {})}", file=sys.stderr)
response_body = _denormalize_response(response_body, recorded_mapping, current_mapping)
# Update _id_normalizers to register the current IDs with their normalized values
# This ensures that if these IDs appear in future requests, they get the same
# normalized value as they had in the recording
test_id = _test_context.get()
if test_id and recorded_mapping:
for normalized_id, recorded_id in recorded_mapping.items():
if normalized_id in current_mapping:
current_id = current_mapping[normalized_id]
# Extract ID type from normalized_id (e.g., "file-1" -> "file")
id_type = normalized_id.rsplit("-", 1)[0]
# Ensure structures exist
if test_id not in _id_normalizers:
_id_normalizers[test_id] = {}
if id_type not in _id_normalizers[test_id]:
_id_normalizers[test_id][id_type] = {}
# Register: current_id -> normalized_id
_id_normalizers[test_id][id_type][current_id] = normalized_id
print(f"Registered {current_id} -> {normalized_id}", file=sys.stderr)
print(f"_id_normalizers after: {_id_normalizers.get(_test_context.get(), {})}", file=sys.stderr)
print(f"=== END DENORM DEBUG ===\n", file=sys.stderr)
if recording["response"].get("is_streaming", False):
async def replay_stream():
@ -714,9 +946,11 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
if is_streaming:
# For streaming responses, we need to collect all chunks immediately before yielding
# This ensures the recording is saved even if the generator isn't fully consumed
chunks = []
raw_chunks: list[Any] = []
async for chunk in response:
chunks.append(chunk)
raw_chunks.append(chunk)
chunks = _coalesce_streaming_chunks(raw_chunks)
# Store the recording immediately
response_data = {"body": chunks, "is_streaming": True}