mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-12 20:12:33 +00:00
simpler ID generation to remove denormalization
This commit is contained in:
parent
39aa17f975
commit
b47bf340db
5 changed files with 190 additions and 202 deletions
44
llama_stack/core/id_generation.py
Normal file
44
llama_stack/core/id_generation.py
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
IdFactory = Callable[[], str]
|
||||
IdOverride = Callable[[str, IdFactory], str]
|
||||
|
||||
_id_override: IdOverride | None = None
|
||||
|
||||
|
||||
def generate_object_id(kind: str, factory: IdFactory) -> str:
|
||||
"""Generate an identifier for the given kind using the provided factory.
|
||||
|
||||
Allows tests to override ID generation deterministically by installing an
|
||||
override callback via :func:`set_id_override`.
|
||||
"""
|
||||
|
||||
override = _id_override
|
||||
if override is not None:
|
||||
return override(kind, factory)
|
||||
return factory()
|
||||
|
||||
|
||||
def set_id_override(override: IdOverride) -> IdOverride | None:
|
||||
"""Install an override used to generate deterministic identifiers."""
|
||||
|
||||
global _id_override
|
||||
|
||||
previous = _id_override
|
||||
_id_override = override
|
||||
return previous
|
||||
|
||||
|
||||
def reset_id_override(token: IdOverride | None) -> None:
|
||||
"""Restore the previous override returned by :func:`set_id_override`."""
|
||||
|
||||
global _id_override
|
||||
_id_override = token
|
||||
|
|
@ -22,6 +22,7 @@ from llama_stack.apis.files import (
|
|||
OpenAIFilePurpose,
|
||||
)
|
||||
from llama_stack.core.datatypes import AccessRule
|
||||
from llama_stack.core.id_generation import generate_object_id
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.files.form_data import parse_expires_after
|
||||
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
|
||||
|
|
@ -65,7 +66,7 @@ class LocalfsFilesImpl(Files):
|
|||
|
||||
def _generate_file_id(self) -> str:
|
||||
"""Generate a unique file ID for OpenAI API."""
|
||||
return f"file-{uuid.uuid4().hex}"
|
||||
return generate_object_id("file", lambda: f"file-{uuid.uuid4().hex}")
|
||||
|
||||
def _get_file_path(self, file_id: str) -> Path:
|
||||
"""Get the filesystem path for a file ID."""
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ from llama_stack.apis.files import (
|
|||
OpenAIFilePurpose,
|
||||
)
|
||||
from llama_stack.core.datatypes import AccessRule
|
||||
from llama_stack.core.id_generation import generate_object_id
|
||||
from llama_stack.providers.utils.files.form_data import parse_expires_after
|
||||
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
|
||||
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
||||
|
|
@ -198,7 +199,7 @@ class S3FilesImpl(Files):
|
|||
purpose: Annotated[OpenAIFilePurpose, Form()],
|
||||
expires_after: Annotated[ExpiresAfter | None, Depends(parse_expires_after)] = None,
|
||||
) -> OpenAIFileObject:
|
||||
file_id = f"file-{uuid.uuid4().hex}"
|
||||
file_id = generate_object_id("file", lambda: f"file-{uuid.uuid4().hex}")
|
||||
|
||||
filename = getattr(file, "filename", None) or "uploaded_file"
|
||||
|
||||
|
|
|
|||
|
|
@ -40,6 +40,7 @@ from llama_stack.apis.vector_io import (
|
|||
VectorStoreSearchResponse,
|
||||
VectorStoreSearchResponsePage,
|
||||
)
|
||||
from llama_stack.core.id_generation import generate_object_id
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
|
|
@ -352,7 +353,7 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
"""Creates a vector store."""
|
||||
created_at = int(time.time())
|
||||
# Derive the canonical vector_db_id (allow override, else generate)
|
||||
vector_db_id = provider_vector_db_id or f"vs_{uuid.uuid4()}"
|
||||
vector_db_id = provider_vector_db_id or generate_object_id("vector_store", lambda: f"vs_{uuid.uuid4()}")
|
||||
|
||||
if provider_id is None:
|
||||
raise ValueError("Provider ID is required")
|
||||
|
|
@ -986,7 +987,7 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
chunking_strategy = chunking_strategy or VectorStoreChunkingStrategyAuto()
|
||||
|
||||
created_at = int(time.time())
|
||||
batch_id = f"batch_{uuid.uuid4()}"
|
||||
batch_id = generate_object_id("vector_store_file_batch", lambda: f"batch_{uuid.uuid4()}")
|
||||
# File batches expire after 7 days
|
||||
expires_at = created_at + (7 * 24 * 60 * 60)
|
||||
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ import hashlib
|
|||
import json
|
||||
import os
|
||||
import re
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Callable, Generator
|
||||
from contextlib import contextmanager
|
||||
from enum import StrEnum
|
||||
from pathlib import Path
|
||||
|
|
@ -18,6 +18,7 @@ from typing import Any, Literal, cast
|
|||
|
||||
from openai import NOT_GIVEN, OpenAI
|
||||
|
||||
from llama_stack.core.id_generation import reset_id_override, set_id_override
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
logger = get_logger(__name__, category="testing")
|
||||
|
|
@ -30,9 +31,8 @@ _current_mode: str | None = None
|
|||
_current_storage: ResponseStorage | None = None
|
||||
_original_methods: dict[str, Any] = {}
|
||||
|
||||
# ID normalization state: maps test_id -> (id_type -> {original_id: normalized_id})
|
||||
_id_normalizers: dict[str, dict[str, dict[str, str]]] = {}
|
||||
_id_counters: dict[str, dict[str, int]] = {} # test_id -> (id_type -> counter)
|
||||
# Per-test deterministic ID counters (test_id -> id_kind -> counter)
|
||||
_id_counters: dict[str, dict[str, int]] = {}
|
||||
|
||||
# Test context uses ContextVar since it changes per-test and needs async isolation
|
||||
from contextvars import ContextVar
|
||||
|
|
@ -56,104 +56,82 @@ class APIRecordingMode(StrEnum):
|
|||
RECORD_IF_MISSING = "record-if-missing"
|
||||
|
||||
|
||||
def _get_normalized_id(original_id: str, id_type: str) -> str:
|
||||
"""Get a normalized ID using a test-specific counter.
|
||||
_ID_KIND_PREFIXES: dict[str, str] = {
|
||||
"file": "file-",
|
||||
"vector_store": "vs_",
|
||||
"vector_store_file_batch": "batch_",
|
||||
}
|
||||
|
||||
Each unique ID within a test gets assigned a sequential number (file-1, file-2, uuid-1, etc).
|
||||
This ensures consistency across requests within the same test while keeping IDs human-readable.
|
||||
"""
|
||||
global _id_normalizers, _id_counters
|
||||
|
||||
def _allocate_test_scoped_id(kind: str) -> str | None:
|
||||
"""Return the next deterministic ID for the given kind within the current test."""
|
||||
|
||||
global _id_counters
|
||||
|
||||
test_id = _test_context.get()
|
||||
if not test_id:
|
||||
# No test context, return original ID
|
||||
return original_id
|
||||
prefix = _ID_KIND_PREFIXES.get(kind)
|
||||
|
||||
# Initialize structures for this test if needed
|
||||
if test_id not in _id_normalizers:
|
||||
_id_normalizers[test_id] = {}
|
||||
_id_counters[test_id] = {}
|
||||
if prefix is None:
|
||||
return None
|
||||
|
||||
if id_type not in _id_normalizers[test_id]:
|
||||
_id_normalizers[test_id][id_type] = {}
|
||||
_id_counters[test_id][id_type] = 0
|
||||
key = test_id or "__global__"
|
||||
|
||||
# Check if we've seen this ID before
|
||||
if original_id in _id_normalizers[test_id][id_type]:
|
||||
return _id_normalizers[test_id][id_type][original_id]
|
||||
if key not in _id_counters:
|
||||
_id_counters[key] = {}
|
||||
|
||||
# New ID - assign next counter value
|
||||
_id_counters[test_id][id_type] += 1
|
||||
counter = _id_counters[test_id][id_type]
|
||||
normalized_id = f"{id_type}-{counter}"
|
||||
counter = _id_counters[key].get(kind, 0) + 1
|
||||
_id_counters[key][kind] = counter
|
||||
|
||||
# Store mapping
|
||||
_id_normalizers[test_id][id_type][original_id] = normalized_id
|
||||
return normalized_id
|
||||
return f"{prefix}{counter}"
|
||||
|
||||
|
||||
def _normalize_file_ids(obj: Any) -> Any:
|
||||
"""Recursively replace file IDs and vector store IDs with test-specific normalized values.
|
||||
class _IdCanonicalizer:
|
||||
PATTERN = re.compile(r"(file-[A-Za-z0-9_-]+|vs_[A-Za-z0-9_-]+|batch_[A-Za-z0-9_-]+)")
|
||||
|
||||
Each unique file ID or UUID gets a unique normalized value using a sequential counter
|
||||
within each test (file-1, file-2, uuid-1, etc).
|
||||
"""
|
||||
import re
|
||||
def __init__(self) -> None:
|
||||
self._mappings: dict[str, dict[str, str]] = {kind: {} for kind in _ID_KIND_PREFIXES}
|
||||
self._counters: dict[str, int] = dict.fromkeys(_ID_KIND_PREFIXES, 0)
|
||||
|
||||
if isinstance(obj, dict):
|
||||
result = {}
|
||||
for k, v in obj.items():
|
||||
# Normalize file IDs in document_id fields
|
||||
if k == "document_id" and isinstance(v, str) and v.startswith("file-"):
|
||||
result[k] = _get_normalized_id(v, "file")
|
||||
# Normalize vector database/store IDs with UUID patterns
|
||||
elif k in ("vector_db_id", "vector_store_id", "bank_id") and isinstance(v, str):
|
||||
# Replace UUIDs in the ID deterministically
|
||||
def replace_uuid(match):
|
||||
uuid_val = match.group(0)
|
||||
return _get_normalized_id(uuid_val, "uuid")
|
||||
|
||||
normalized = re.sub(
|
||||
r"[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}",
|
||||
replace_uuid,
|
||||
v,
|
||||
)
|
||||
result[k] = normalized
|
||||
else:
|
||||
result[k] = _normalize_file_ids(v)
|
||||
return result
|
||||
elif isinstance(obj, list):
|
||||
return [_normalize_file_ids(item) for item in obj]
|
||||
elif isinstance(obj, str):
|
||||
# Replace file-<uuid> patterns in strings (like in text content)
|
||||
def replace_file_id(match):
|
||||
file_id = match.group(0)
|
||||
return _get_normalized_id(file_id, "file")
|
||||
|
||||
return re.sub(r"file-[a-f0-9]{32}", replace_file_id, obj)
|
||||
else:
|
||||
def canonicalize(self, obj: Any) -> Any:
|
||||
if isinstance(obj, dict):
|
||||
return {k: self._canonicalize_value(k, v) for k, v in obj.items()}
|
||||
if isinstance(obj, list):
|
||||
return [self.canonicalize(item) for item in obj]
|
||||
if isinstance(obj, str):
|
||||
return self._canonicalize_string(obj)
|
||||
return obj
|
||||
|
||||
def _canonicalize_value(self, key: str, value: Any) -> Any:
|
||||
if key in {"vector_db_id", "vector_store_id", "bank_id"} and isinstance(value, str):
|
||||
return self._canonicalize_string(value)
|
||||
if key == "document_id" and isinstance(value, str) and value.startswith("file-"):
|
||||
return self._canonicalize_string(value)
|
||||
return self.canonicalize(value)
|
||||
|
||||
def _get_current_test_id_mappings() -> dict[str, str]:
|
||||
"""Get the current test's ID normalization mappings.
|
||||
def _canonicalize_string(self, value: str) -> str:
|
||||
def replace(match: re.Match[str]) -> str:
|
||||
token = match.group(0)
|
||||
if token.startswith("file-"):
|
||||
return self._mapped_value("file", token)
|
||||
if token.startswith("vs_"):
|
||||
return self._mapped_value("vector_store", token)
|
||||
if token.startswith("batch_"):
|
||||
return self._mapped_value("vector_store_file_batch", token)
|
||||
return token
|
||||
|
||||
Returns a dict mapping normalized_id -> original_id (e.g., "file-1" -> "file-abc123...")
|
||||
This is the inverse of what's stored in _id_normalizers.
|
||||
"""
|
||||
global _id_normalizers
|
||||
return self.PATTERN.sub(replace, value)
|
||||
|
||||
test_id = _test_context.get()
|
||||
if not test_id or test_id not in _id_normalizers:
|
||||
return {}
|
||||
def _mapped_value(self, kind: str, original: str) -> str:
|
||||
mapping = self._mappings[kind]
|
||||
if original not in mapping:
|
||||
self._counters[kind] += 1
|
||||
mapping[original] = f"{_ID_KIND_PREFIXES[kind]}{self._counters[kind]}"
|
||||
return mapping[original]
|
||||
|
||||
# Invert the mapping: normalized_id -> original_id
|
||||
result = {}
|
||||
for id_type, mappings in _id_normalizers[test_id].items():
|
||||
for original_id, normalized_id in mappings.items():
|
||||
result[normalized_id] = original_id
|
||||
|
||||
return result
|
||||
def _canonicalize_for_hashing(obj: Any) -> Any:
|
||||
canonicalizer = _IdCanonicalizer()
|
||||
return canonicalizer.canonicalize(obj)
|
||||
|
||||
|
||||
def _chunk_text_content(chunk: Any) -> tuple[str | None, bool]:
|
||||
|
|
@ -205,7 +183,7 @@ def _ends_with_partial_identifier(text: str) -> bool:
|
|||
else:
|
||||
core = token
|
||||
|
||||
suffix = core[len("file-"):]
|
||||
suffix = core[len("file-") :]
|
||||
if len(suffix) < 16:
|
||||
return True
|
||||
if not re.fullmatch(r"[A-Za-z0-9_-]+", suffix):
|
||||
|
|
@ -260,67 +238,11 @@ def _coalesce_streaming_chunks(chunks: list[Any]) -> list[Any]:
|
|||
return result
|
||||
|
||||
|
||||
def _denormalize_response(obj: Any, recorded_mapping: dict[str, str], current_mapping: dict[str, str]) -> Any:
|
||||
"""Replace recorded IDs with current runtime IDs in a response object.
|
||||
|
||||
Args:
|
||||
obj: The response object to denormalize
|
||||
recorded_mapping: normalized_id -> recorded_original_id (from recording file)
|
||||
current_mapping: normalized_id -> current_original_id (from current runtime)
|
||||
|
||||
Returns:
|
||||
The response with all recorded IDs replaced with current IDs
|
||||
"""
|
||||
import re
|
||||
|
||||
# Build reverse mapping: recorded_original_id -> current_original_id
|
||||
# via the normalized intermediary
|
||||
id_translation = {}
|
||||
for normalized_id, recorded_id in recorded_mapping.items():
|
||||
if normalized_id in current_mapping:
|
||||
current_id = current_mapping[normalized_id]
|
||||
id_translation[recorded_id] = current_id
|
||||
|
||||
if isinstance(obj, dict):
|
||||
result = {}
|
||||
for k, v in obj.items():
|
||||
# Translate document_id fields
|
||||
if k == "document_id" and isinstance(v, str):
|
||||
result[k] = id_translation.get(v, v)
|
||||
# Translate vector database/store IDs
|
||||
elif k in ("vector_db_id", "vector_store_id", "bank_id") and isinstance(v, str):
|
||||
# Replace any recorded IDs in the value
|
||||
translated = v
|
||||
for recorded_id, current_id in id_translation.items():
|
||||
translated = translated.replace(recorded_id, current_id)
|
||||
result[k] = translated
|
||||
else:
|
||||
result[k] = _denormalize_response(v, recorded_mapping, current_mapping)
|
||||
return result
|
||||
elif isinstance(obj, list):
|
||||
return [_denormalize_response(item, recorded_mapping, current_mapping) for item in obj]
|
||||
elif hasattr(obj, "model_dump"):
|
||||
# Handle Pydantic/BaseModel instances by denormalizing their dict form
|
||||
data = obj.model_dump(mode="python")
|
||||
denormalized = _denormalize_response(data, recorded_mapping, current_mapping)
|
||||
|
||||
cls = obj.__class__
|
||||
try:
|
||||
return cls.model_validate(denormalized)
|
||||
except Exception:
|
||||
try:
|
||||
return cls.model_construct(**denormalized)
|
||||
except Exception:
|
||||
return denormalized
|
||||
elif isinstance(obj, str):
|
||||
# Replace file-<uuid> patterns in strings (like in text content and citations)
|
||||
translated = obj
|
||||
for recorded_id, current_id in id_translation.items():
|
||||
# Handle both bare file IDs and citation format <|file-...|>
|
||||
translated = translated.replace(recorded_id, current_id)
|
||||
return translated
|
||||
else:
|
||||
return obj
|
||||
def _deterministic_id_override(kind: str, factory: Callable[[], str]) -> str:
|
||||
deterministic_id = _allocate_test_scoped_id(kind)
|
||||
if deterministic_id is not None:
|
||||
return deterministic_id
|
||||
return factory()
|
||||
|
||||
|
||||
def normalize_inference_request(method: str, url: str, headers: dict[str, Any], body: dict[str, Any]) -> str:
|
||||
|
|
@ -337,13 +259,10 @@ def normalize_inference_request(method: str, url: str, headers: dict[str, Any],
|
|||
|
||||
parsed = urlparse(url)
|
||||
|
||||
# Normalize file IDs in the body to ensure consistent hashing across test runs
|
||||
normalized_body = _normalize_file_ids(body)
|
||||
|
||||
normalized: dict[str, Any] = {
|
||||
"method": method.upper(),
|
||||
"endpoint": parsed.path,
|
||||
"body": normalized_body,
|
||||
"body": _canonicalize_for_hashing(body),
|
||||
}
|
||||
|
||||
# Include test_id for isolation, except for shared infrastructure endpoints
|
||||
|
|
@ -357,10 +276,11 @@ def normalize_inference_request(method: str, url: str, headers: dict[str, Any],
|
|||
|
||||
def normalize_tool_request(provider_name: str, tool_name: str, kwargs: dict[str, Any]) -> str:
|
||||
"""Create a normalized hash of the tool request for consistent matching."""
|
||||
# Normalize file IDs and vector store IDs in kwargs
|
||||
normalized_kwargs = _normalize_file_ids(kwargs)
|
||||
|
||||
normalized = {"provider": provider_name, "tool_name": tool_name, "kwargs": normalized_kwargs}
|
||||
normalized = {
|
||||
"provider": provider_name,
|
||||
"tool_name": tool_name,
|
||||
"kwargs": _canonicalize_for_hashing(kwargs),
|
||||
}
|
||||
|
||||
# Create hash - sort_keys=True ensures deterministic ordering
|
||||
normalized_json = json.dumps(normalized, sort_keys=True)
|
||||
|
|
@ -514,9 +434,6 @@ def _normalize_response(data: dict[str, Any], request_hash: str) -> dict[str, An
|
|||
if "eval_duration" in data and data["eval_duration"] is not None:
|
||||
data["eval_duration"] = 0
|
||||
|
||||
# Normalize file IDs and vector store IDs to ensure consistent hashing across replays
|
||||
data = _normalize_file_ids(data)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
|
|
@ -565,6 +482,8 @@ class ResponseStorage:
|
|||
def __init__(self, base_dir: Path):
|
||||
self.base_dir = base_dir
|
||||
# Don't create responses_dir here - determine it per-test at runtime
|
||||
self._legacy_index: dict[str, Path] = {}
|
||||
self._scanned_dirs: set[Path] = set()
|
||||
|
||||
def _get_test_dir(self) -> Path:
|
||||
"""Get the recordings directory in the test file's parent directory.
|
||||
|
|
@ -627,7 +546,7 @@ class ResponseStorage:
|
|||
"test_id": _test_context.get(),
|
||||
"request": request,
|
||||
"response": serialized_response,
|
||||
"id_normalization_mapping": _get_current_test_id_mappings(),
|
||||
"id_normalization_mapping": {},
|
||||
},
|
||||
f,
|
||||
indent=2,
|
||||
|
|
@ -635,6 +554,8 @@ class ResponseStorage:
|
|||
f.write("\n")
|
||||
f.flush()
|
||||
|
||||
self._legacy_index[request_hash] = response_path
|
||||
|
||||
def find_recording(self, request_hash: str) -> dict[str, Any] | None:
|
||||
"""Find a recorded response by request hash.
|
||||
|
||||
|
|
@ -658,6 +579,52 @@ class ResponseStorage:
|
|||
if fallback_path.exists():
|
||||
return _recording_from_file(fallback_path)
|
||||
|
||||
return self._find_in_legacy_index(request_hash, [test_dir, fallback_dir])
|
||||
|
||||
def _find_in_legacy_index(self, request_hash: str, directories: list[Path]) -> dict[str, Any] | None:
|
||||
for directory in directories:
|
||||
if not directory.exists() or directory in self._scanned_dirs:
|
||||
continue
|
||||
|
||||
for path in directory.glob("*.json"):
|
||||
try:
|
||||
with open(path) as f:
|
||||
data = json.load(f)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
request = data.get("request")
|
||||
if not request:
|
||||
continue
|
||||
|
||||
body = request.get("body")
|
||||
canonical_body = _canonicalize_for_hashing(body) if isinstance(body, dict | list) else body
|
||||
|
||||
token = None
|
||||
test_id = data.get("test_id")
|
||||
if test_id:
|
||||
token = _test_context.set(test_id)
|
||||
|
||||
try:
|
||||
legacy_hash = normalize_inference_request(
|
||||
request.get("method", ""),
|
||||
request.get("url", ""),
|
||||
request.get("headers", {}),
|
||||
canonical_body,
|
||||
)
|
||||
finally:
|
||||
if token is not None:
|
||||
_test_context.reset(token)
|
||||
|
||||
if legacy_hash not in self._legacy_index:
|
||||
self._legacy_index[legacy_hash] = path
|
||||
|
||||
self._scanned_dirs.add(directory)
|
||||
|
||||
legacy_path = self._legacy_index.get(request_hash)
|
||||
if legacy_path and legacy_path.exists():
|
||||
return _recording_from_file(legacy_path)
|
||||
|
||||
return None
|
||||
|
||||
def _model_list_responses(self, request_hash: str) -> list[dict[str, Any]]:
|
||||
|
|
@ -685,6 +652,14 @@ def _recording_from_file(response_path) -> dict[str, Any]:
|
|||
with open(response_path) as f:
|
||||
data = json.load(f)
|
||||
|
||||
mapping = data.get("id_normalization_mapping") or {}
|
||||
if mapping:
|
||||
serialized = json.dumps(data)
|
||||
for normalized, original in mapping.items():
|
||||
serialized = serialized.replace(original, normalized)
|
||||
data = json.loads(serialized)
|
||||
data["id_normalization_mapping"] = {}
|
||||
|
||||
# Deserialize response body if needed
|
||||
if "response" in data and "body" in data["response"]:
|
||||
if isinstance(data["response"]["body"], list):
|
||||
|
|
@ -759,7 +734,7 @@ async def _patched_tool_invoke_method(
|
|||
original_method, provider_name: str, self, tool_name: str, kwargs: dict[str, Any]
|
||||
):
|
||||
"""Patched version of tool runtime invoke_tool method for recording/replay."""
|
||||
global _current_mode, _current_storage, _id_normalizers
|
||||
global _current_mode, _current_storage
|
||||
|
||||
if _current_mode == APIRecordingMode.LIVE or _current_storage is None:
|
||||
# Normal operation
|
||||
|
|
@ -808,7 +783,7 @@ async def _patched_tool_invoke_method(
|
|||
|
||||
|
||||
async def _patched_inference_method(original_method, self, client_type, endpoint, *args, **kwargs):
|
||||
global _current_mode, _current_storage, _id_normalizers
|
||||
global _current_mode, _current_storage
|
||||
|
||||
mode = _current_mode
|
||||
storage = _current_storage
|
||||
|
|
@ -864,45 +839,6 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
|
|||
if recording["response"].get("is_streaming", False) and isinstance(response_body, list):
|
||||
response_body = _coalesce_streaming_chunks(response_body)
|
||||
|
||||
recording["response"]["body"] = response_body
|
||||
|
||||
# Denormalize the response: replace recorded IDs with current runtime IDs
|
||||
recorded_mapping = recording.get("id_normalization_mapping", {})
|
||||
current_mapping = _get_current_test_id_mappings()
|
||||
|
||||
if recorded_mapping or current_mapping:
|
||||
import sys
|
||||
print(f"\n=== DENORM DEBUG ===", file=sys.stderr)
|
||||
print(f"Recorded mapping: {recorded_mapping}", file=sys.stderr)
|
||||
print(f"Current mapping: {current_mapping}", file=sys.stderr)
|
||||
print(f"_id_normalizers before: {_id_normalizers.get(_test_context.get(), {})}", file=sys.stderr)
|
||||
|
||||
response_body = _denormalize_response(response_body, recorded_mapping, current_mapping)
|
||||
|
||||
# Update _id_normalizers to register the current IDs with their normalized values
|
||||
# This ensures that if these IDs appear in future requests, they get the same
|
||||
# normalized value as they had in the recording
|
||||
test_id = _test_context.get()
|
||||
if test_id and recorded_mapping:
|
||||
for normalized_id, recorded_id in recorded_mapping.items():
|
||||
if normalized_id in current_mapping:
|
||||
current_id = current_mapping[normalized_id]
|
||||
# Extract ID type from normalized_id (e.g., "file-1" -> "file")
|
||||
id_type = normalized_id.rsplit("-", 1)[0]
|
||||
|
||||
# Ensure structures exist
|
||||
if test_id not in _id_normalizers:
|
||||
_id_normalizers[test_id] = {}
|
||||
if id_type not in _id_normalizers[test_id]:
|
||||
_id_normalizers[test_id][id_type] = {}
|
||||
|
||||
# Register: current_id -> normalized_id
|
||||
_id_normalizers[test_id][id_type][current_id] = normalized_id
|
||||
print(f"Registered {current_id} -> {normalized_id}", file=sys.stderr)
|
||||
|
||||
print(f"_id_normalizers after: {_id_normalizers.get(_test_context.get(), {})}", file=sys.stderr)
|
||||
print(f"=== END DENORM DEBUG ===\n", file=sys.stderr)
|
||||
|
||||
if recording["response"].get("is_streaming", False):
|
||||
|
||||
async def replay_stream():
|
||||
|
|
@ -1125,6 +1061,7 @@ def api_recording(mode: str, storage_dir: str | Path | None = None) -> Generator
|
|||
# Store previous state
|
||||
prev_mode = _current_mode
|
||||
prev_storage = _current_storage
|
||||
override_token = None
|
||||
|
||||
try:
|
||||
_current_mode = mode
|
||||
|
|
@ -1133,7 +1070,9 @@ def api_recording(mode: str, storage_dir: str | Path | None = None) -> Generator
|
|||
if storage_dir is None:
|
||||
raise ValueError("storage_dir is required for record, replay, and record-if-missing modes")
|
||||
_current_storage = ResponseStorage(Path(storage_dir))
|
||||
_id_counters.clear()
|
||||
patch_inference_clients()
|
||||
override_token = set_id_override(_deterministic_id_override)
|
||||
|
||||
yield
|
||||
|
||||
|
|
@ -1141,6 +1080,8 @@ def api_recording(mode: str, storage_dir: str | Path | None = None) -> Generator
|
|||
# Restore previous state
|
||||
if mode in ["record", "replay", "record-if-missing"]:
|
||||
unpatch_inference_clients()
|
||||
if override_token is not None:
|
||||
reset_id_override(override_token)
|
||||
|
||||
_current_mode = prev_mode
|
||||
_current_storage = prev_storage
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue