simpler ID generation to remove denormalization

This commit is contained in:
Ashwin Bharambe 2025-10-09 09:22:47 -07:00
parent 39aa17f975
commit b47bf340db
5 changed files with 190 additions and 202 deletions

View file

@ -0,0 +1,44 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from __future__ import annotations
from collections.abc import Callable
IdFactory = Callable[[], str]
IdOverride = Callable[[str, IdFactory], str]
_id_override: IdOverride | None = None
def generate_object_id(kind: str, factory: IdFactory) -> str:
"""Generate an identifier for the given kind using the provided factory.
Allows tests to override ID generation deterministically by installing an
override callback via :func:`set_id_override`.
"""
override = _id_override
if override is not None:
return override(kind, factory)
return factory()
def set_id_override(override: IdOverride) -> IdOverride | None:
"""Install an override used to generate deterministic identifiers."""
global _id_override
previous = _id_override
_id_override = override
return previous
def reset_id_override(token: IdOverride | None) -> None:
"""Restore the previous override returned by :func:`set_id_override`."""
global _id_override
_id_override = token

View file

@ -22,6 +22,7 @@ from llama_stack.apis.files import (
OpenAIFilePurpose, OpenAIFilePurpose,
) )
from llama_stack.core.datatypes import AccessRule from llama_stack.core.datatypes import AccessRule
from llama_stack.core.id_generation import generate_object_id
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.utils.files.form_data import parse_expires_after from llama_stack.providers.utils.files.form_data import parse_expires_after
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
@ -65,7 +66,7 @@ class LocalfsFilesImpl(Files):
def _generate_file_id(self) -> str: def _generate_file_id(self) -> str:
"""Generate a unique file ID for OpenAI API.""" """Generate a unique file ID for OpenAI API."""
return f"file-{uuid.uuid4().hex}" return generate_object_id("file", lambda: f"file-{uuid.uuid4().hex}")
def _get_file_path(self, file_id: str) -> Path: def _get_file_path(self, file_id: str) -> Path:
"""Get the filesystem path for a file ID.""" """Get the filesystem path for a file ID."""

View file

@ -23,6 +23,7 @@ from llama_stack.apis.files import (
OpenAIFilePurpose, OpenAIFilePurpose,
) )
from llama_stack.core.datatypes import AccessRule from llama_stack.core.datatypes import AccessRule
from llama_stack.core.id_generation import generate_object_id
from llama_stack.providers.utils.files.form_data import parse_expires_after from llama_stack.providers.utils.files.form_data import parse_expires_after
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
@ -198,7 +199,7 @@ class S3FilesImpl(Files):
purpose: Annotated[OpenAIFilePurpose, Form()], purpose: Annotated[OpenAIFilePurpose, Form()],
expires_after: Annotated[ExpiresAfter | None, Depends(parse_expires_after)] = None, expires_after: Annotated[ExpiresAfter | None, Depends(parse_expires_after)] = None,
) -> OpenAIFileObject: ) -> OpenAIFileObject:
file_id = f"file-{uuid.uuid4().hex}" file_id = generate_object_id("file", lambda: f"file-{uuid.uuid4().hex}")
filename = getattr(file, "filename", None) or "uploaded_file" filename = getattr(file, "filename", None) or "uploaded_file"

View file

@ -40,6 +40,7 @@ from llama_stack.apis.vector_io import (
VectorStoreSearchResponse, VectorStoreSearchResponse,
VectorStoreSearchResponsePage, VectorStoreSearchResponsePage,
) )
from llama_stack.core.id_generation import generate_object_id
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.utils.kvstore.api import KVStore from llama_stack.providers.utils.kvstore.api import KVStore
from llama_stack.providers.utils.memory.vector_store import ( from llama_stack.providers.utils.memory.vector_store import (
@ -352,7 +353,7 @@ class OpenAIVectorStoreMixin(ABC):
"""Creates a vector store.""" """Creates a vector store."""
created_at = int(time.time()) created_at = int(time.time())
# Derive the canonical vector_db_id (allow override, else generate) # Derive the canonical vector_db_id (allow override, else generate)
vector_db_id = provider_vector_db_id or f"vs_{uuid.uuid4()}" vector_db_id = provider_vector_db_id or generate_object_id("vector_store", lambda: f"vs_{uuid.uuid4()}")
if provider_id is None: if provider_id is None:
raise ValueError("Provider ID is required") raise ValueError("Provider ID is required")
@ -986,7 +987,7 @@ class OpenAIVectorStoreMixin(ABC):
chunking_strategy = chunking_strategy or VectorStoreChunkingStrategyAuto() chunking_strategy = chunking_strategy or VectorStoreChunkingStrategyAuto()
created_at = int(time.time()) created_at = int(time.time())
batch_id = f"batch_{uuid.uuid4()}" batch_id = generate_object_id("vector_store_file_batch", lambda: f"batch_{uuid.uuid4()}")
# File batches expire after 7 days # File batches expire after 7 days
expires_at = created_at + (7 * 24 * 60 * 60) expires_at = created_at + (7 * 24 * 60 * 60)

View file

@ -10,7 +10,7 @@ import hashlib
import json import json
import os import os
import re import re
from collections.abc import Generator from collections.abc import Callable, Generator
from contextlib import contextmanager from contextlib import contextmanager
from enum import StrEnum from enum import StrEnum
from pathlib import Path from pathlib import Path
@ -18,6 +18,7 @@ from typing import Any, Literal, cast
from openai import NOT_GIVEN, OpenAI from openai import NOT_GIVEN, OpenAI
from llama_stack.core.id_generation import reset_id_override, set_id_override
from llama_stack.log import get_logger from llama_stack.log import get_logger
logger = get_logger(__name__, category="testing") logger = get_logger(__name__, category="testing")
@ -30,9 +31,8 @@ _current_mode: str | None = None
_current_storage: ResponseStorage | None = None _current_storage: ResponseStorage | None = None
_original_methods: dict[str, Any] = {} _original_methods: dict[str, Any] = {}
# ID normalization state: maps test_id -> (id_type -> {original_id: normalized_id}) # Per-test deterministic ID counters (test_id -> id_kind -> counter)
_id_normalizers: dict[str, dict[str, dict[str, str]]] = {} _id_counters: dict[str, dict[str, int]] = {}
_id_counters: dict[str, dict[str, int]] = {} # test_id -> (id_type -> counter)
# Test context uses ContextVar since it changes per-test and needs async isolation # Test context uses ContextVar since it changes per-test and needs async isolation
from contextvars import ContextVar from contextvars import ContextVar
@ -56,104 +56,82 @@ class APIRecordingMode(StrEnum):
RECORD_IF_MISSING = "record-if-missing" RECORD_IF_MISSING = "record-if-missing"
def _get_normalized_id(original_id: str, id_type: str) -> str: _ID_KIND_PREFIXES: dict[str, str] = {
"""Get a normalized ID using a test-specific counter. "file": "file-",
"vector_store": "vs_",
"vector_store_file_batch": "batch_",
}
Each unique ID within a test gets assigned a sequential number (file-1, file-2, uuid-1, etc).
This ensures consistency across requests within the same test while keeping IDs human-readable. def _allocate_test_scoped_id(kind: str) -> str | None:
""" """Return the next deterministic ID for the given kind within the current test."""
global _id_normalizers, _id_counters
global _id_counters
test_id = _test_context.get() test_id = _test_context.get()
if not test_id: prefix = _ID_KIND_PREFIXES.get(kind)
# No test context, return original ID
return original_id
# Initialize structures for this test if needed if prefix is None:
if test_id not in _id_normalizers: return None
_id_normalizers[test_id] = {}
_id_counters[test_id] = {}
if id_type not in _id_normalizers[test_id]: key = test_id or "__global__"
_id_normalizers[test_id][id_type] = {}
_id_counters[test_id][id_type] = 0
# Check if we've seen this ID before if key not in _id_counters:
if original_id in _id_normalizers[test_id][id_type]: _id_counters[key] = {}
return _id_normalizers[test_id][id_type][original_id]
# New ID - assign next counter value counter = _id_counters[key].get(kind, 0) + 1
_id_counters[test_id][id_type] += 1 _id_counters[key][kind] = counter
counter = _id_counters[test_id][id_type]
normalized_id = f"{id_type}-{counter}"
# Store mapping return f"{prefix}{counter}"
_id_normalizers[test_id][id_type][original_id] = normalized_id
return normalized_id
def _normalize_file_ids(obj: Any) -> Any: class _IdCanonicalizer:
"""Recursively replace file IDs and vector store IDs with test-specific normalized values. PATTERN = re.compile(r"(file-[A-Za-z0-9_-]+|vs_[A-Za-z0-9_-]+|batch_[A-Za-z0-9_-]+)")
Each unique file ID or UUID gets a unique normalized value using a sequential counter def __init__(self) -> None:
within each test (file-1, file-2, uuid-1, etc). self._mappings: dict[str, dict[str, str]] = {kind: {} for kind in _ID_KIND_PREFIXES}
""" self._counters: dict[str, int] = dict.fromkeys(_ID_KIND_PREFIXES, 0)
import re
if isinstance(obj, dict): def canonicalize(self, obj: Any) -> Any:
result = {} if isinstance(obj, dict):
for k, v in obj.items(): return {k: self._canonicalize_value(k, v) for k, v in obj.items()}
# Normalize file IDs in document_id fields if isinstance(obj, list):
if k == "document_id" and isinstance(v, str) and v.startswith("file-"): return [self.canonicalize(item) for item in obj]
result[k] = _get_normalized_id(v, "file") if isinstance(obj, str):
# Normalize vector database/store IDs with UUID patterns return self._canonicalize_string(obj)
elif k in ("vector_db_id", "vector_store_id", "bank_id") and isinstance(v, str):
# Replace UUIDs in the ID deterministically
def replace_uuid(match):
uuid_val = match.group(0)
return _get_normalized_id(uuid_val, "uuid")
normalized = re.sub(
r"[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}",
replace_uuid,
v,
)
result[k] = normalized
else:
result[k] = _normalize_file_ids(v)
return result
elif isinstance(obj, list):
return [_normalize_file_ids(item) for item in obj]
elif isinstance(obj, str):
# Replace file-<uuid> patterns in strings (like in text content)
def replace_file_id(match):
file_id = match.group(0)
return _get_normalized_id(file_id, "file")
return re.sub(r"file-[a-f0-9]{32}", replace_file_id, obj)
else:
return obj return obj
def _canonicalize_value(self, key: str, value: Any) -> Any:
if key in {"vector_db_id", "vector_store_id", "bank_id"} and isinstance(value, str):
return self._canonicalize_string(value)
if key == "document_id" and isinstance(value, str) and value.startswith("file-"):
return self._canonicalize_string(value)
return self.canonicalize(value)
def _get_current_test_id_mappings() -> dict[str, str]: def _canonicalize_string(self, value: str) -> str:
"""Get the current test's ID normalization mappings. def replace(match: re.Match[str]) -> str:
token = match.group(0)
if token.startswith("file-"):
return self._mapped_value("file", token)
if token.startswith("vs_"):
return self._mapped_value("vector_store", token)
if token.startswith("batch_"):
return self._mapped_value("vector_store_file_batch", token)
return token
Returns a dict mapping normalized_id -> original_id (e.g., "file-1" -> "file-abc123...") return self.PATTERN.sub(replace, value)
This is the inverse of what's stored in _id_normalizers.
"""
global _id_normalizers
test_id = _test_context.get() def _mapped_value(self, kind: str, original: str) -> str:
if not test_id or test_id not in _id_normalizers: mapping = self._mappings[kind]
return {} if original not in mapping:
self._counters[kind] += 1
mapping[original] = f"{_ID_KIND_PREFIXES[kind]}{self._counters[kind]}"
return mapping[original]
# Invert the mapping: normalized_id -> original_id
result = {}
for id_type, mappings in _id_normalizers[test_id].items():
for original_id, normalized_id in mappings.items():
result[normalized_id] = original_id
return result def _canonicalize_for_hashing(obj: Any) -> Any:
canonicalizer = _IdCanonicalizer()
return canonicalizer.canonicalize(obj)
def _chunk_text_content(chunk: Any) -> tuple[str | None, bool]: def _chunk_text_content(chunk: Any) -> tuple[str | None, bool]:
@ -205,7 +183,7 @@ def _ends_with_partial_identifier(text: str) -> bool:
else: else:
core = token core = token
suffix = core[len("file-"):] suffix = core[len("file-") :]
if len(suffix) < 16: if len(suffix) < 16:
return True return True
if not re.fullmatch(r"[A-Za-z0-9_-]+", suffix): if not re.fullmatch(r"[A-Za-z0-9_-]+", suffix):
@ -260,67 +238,11 @@ def _coalesce_streaming_chunks(chunks: list[Any]) -> list[Any]:
return result return result
def _denormalize_response(obj: Any, recorded_mapping: dict[str, str], current_mapping: dict[str, str]) -> Any: def _deterministic_id_override(kind: str, factory: Callable[[], str]) -> str:
"""Replace recorded IDs with current runtime IDs in a response object. deterministic_id = _allocate_test_scoped_id(kind)
if deterministic_id is not None:
Args: return deterministic_id
obj: The response object to denormalize return factory()
recorded_mapping: normalized_id -> recorded_original_id (from recording file)
current_mapping: normalized_id -> current_original_id (from current runtime)
Returns:
The response with all recorded IDs replaced with current IDs
"""
import re
# Build reverse mapping: recorded_original_id -> current_original_id
# via the normalized intermediary
id_translation = {}
for normalized_id, recorded_id in recorded_mapping.items():
if normalized_id in current_mapping:
current_id = current_mapping[normalized_id]
id_translation[recorded_id] = current_id
if isinstance(obj, dict):
result = {}
for k, v in obj.items():
# Translate document_id fields
if k == "document_id" and isinstance(v, str):
result[k] = id_translation.get(v, v)
# Translate vector database/store IDs
elif k in ("vector_db_id", "vector_store_id", "bank_id") and isinstance(v, str):
# Replace any recorded IDs in the value
translated = v
for recorded_id, current_id in id_translation.items():
translated = translated.replace(recorded_id, current_id)
result[k] = translated
else:
result[k] = _denormalize_response(v, recorded_mapping, current_mapping)
return result
elif isinstance(obj, list):
return [_denormalize_response(item, recorded_mapping, current_mapping) for item in obj]
elif hasattr(obj, "model_dump"):
# Handle Pydantic/BaseModel instances by denormalizing their dict form
data = obj.model_dump(mode="python")
denormalized = _denormalize_response(data, recorded_mapping, current_mapping)
cls = obj.__class__
try:
return cls.model_validate(denormalized)
except Exception:
try:
return cls.model_construct(**denormalized)
except Exception:
return denormalized
elif isinstance(obj, str):
# Replace file-<uuid> patterns in strings (like in text content and citations)
translated = obj
for recorded_id, current_id in id_translation.items():
# Handle both bare file IDs and citation format <|file-...|>
translated = translated.replace(recorded_id, current_id)
return translated
else:
return obj
def normalize_inference_request(method: str, url: str, headers: dict[str, Any], body: dict[str, Any]) -> str: def normalize_inference_request(method: str, url: str, headers: dict[str, Any], body: dict[str, Any]) -> str:
@ -337,13 +259,10 @@ def normalize_inference_request(method: str, url: str, headers: dict[str, Any],
parsed = urlparse(url) parsed = urlparse(url)
# Normalize file IDs in the body to ensure consistent hashing across test runs
normalized_body = _normalize_file_ids(body)
normalized: dict[str, Any] = { normalized: dict[str, Any] = {
"method": method.upper(), "method": method.upper(),
"endpoint": parsed.path, "endpoint": parsed.path,
"body": normalized_body, "body": _canonicalize_for_hashing(body),
} }
# Include test_id for isolation, except for shared infrastructure endpoints # Include test_id for isolation, except for shared infrastructure endpoints
@ -357,10 +276,11 @@ def normalize_inference_request(method: str, url: str, headers: dict[str, Any],
def normalize_tool_request(provider_name: str, tool_name: str, kwargs: dict[str, Any]) -> str: def normalize_tool_request(provider_name: str, tool_name: str, kwargs: dict[str, Any]) -> str:
"""Create a normalized hash of the tool request for consistent matching.""" """Create a normalized hash of the tool request for consistent matching."""
# Normalize file IDs and vector store IDs in kwargs normalized = {
normalized_kwargs = _normalize_file_ids(kwargs) "provider": provider_name,
"tool_name": tool_name,
normalized = {"provider": provider_name, "tool_name": tool_name, "kwargs": normalized_kwargs} "kwargs": _canonicalize_for_hashing(kwargs),
}
# Create hash - sort_keys=True ensures deterministic ordering # Create hash - sort_keys=True ensures deterministic ordering
normalized_json = json.dumps(normalized, sort_keys=True) normalized_json = json.dumps(normalized, sort_keys=True)
@ -514,9 +434,6 @@ def _normalize_response(data: dict[str, Any], request_hash: str) -> dict[str, An
if "eval_duration" in data and data["eval_duration"] is not None: if "eval_duration" in data and data["eval_duration"] is not None:
data["eval_duration"] = 0 data["eval_duration"] = 0
# Normalize file IDs and vector store IDs to ensure consistent hashing across replays
data = _normalize_file_ids(data)
return data return data
@ -565,6 +482,8 @@ class ResponseStorage:
def __init__(self, base_dir: Path): def __init__(self, base_dir: Path):
self.base_dir = base_dir self.base_dir = base_dir
# Don't create responses_dir here - determine it per-test at runtime # Don't create responses_dir here - determine it per-test at runtime
self._legacy_index: dict[str, Path] = {}
self._scanned_dirs: set[Path] = set()
def _get_test_dir(self) -> Path: def _get_test_dir(self) -> Path:
"""Get the recordings directory in the test file's parent directory. """Get the recordings directory in the test file's parent directory.
@ -627,7 +546,7 @@ class ResponseStorage:
"test_id": _test_context.get(), "test_id": _test_context.get(),
"request": request, "request": request,
"response": serialized_response, "response": serialized_response,
"id_normalization_mapping": _get_current_test_id_mappings(), "id_normalization_mapping": {},
}, },
f, f,
indent=2, indent=2,
@ -635,6 +554,8 @@ class ResponseStorage:
f.write("\n") f.write("\n")
f.flush() f.flush()
self._legacy_index[request_hash] = response_path
def find_recording(self, request_hash: str) -> dict[str, Any] | None: def find_recording(self, request_hash: str) -> dict[str, Any] | None:
"""Find a recorded response by request hash. """Find a recorded response by request hash.
@ -658,6 +579,52 @@ class ResponseStorage:
if fallback_path.exists(): if fallback_path.exists():
return _recording_from_file(fallback_path) return _recording_from_file(fallback_path)
return self._find_in_legacy_index(request_hash, [test_dir, fallback_dir])
def _find_in_legacy_index(self, request_hash: str, directories: list[Path]) -> dict[str, Any] | None:
for directory in directories:
if not directory.exists() or directory in self._scanned_dirs:
continue
for path in directory.glob("*.json"):
try:
with open(path) as f:
data = json.load(f)
except Exception:
continue
request = data.get("request")
if not request:
continue
body = request.get("body")
canonical_body = _canonicalize_for_hashing(body) if isinstance(body, dict | list) else body
token = None
test_id = data.get("test_id")
if test_id:
token = _test_context.set(test_id)
try:
legacy_hash = normalize_inference_request(
request.get("method", ""),
request.get("url", ""),
request.get("headers", {}),
canonical_body,
)
finally:
if token is not None:
_test_context.reset(token)
if legacy_hash not in self._legacy_index:
self._legacy_index[legacy_hash] = path
self._scanned_dirs.add(directory)
legacy_path = self._legacy_index.get(request_hash)
if legacy_path and legacy_path.exists():
return _recording_from_file(legacy_path)
return None return None
def _model_list_responses(self, request_hash: str) -> list[dict[str, Any]]: def _model_list_responses(self, request_hash: str) -> list[dict[str, Any]]:
@ -685,6 +652,14 @@ def _recording_from_file(response_path) -> dict[str, Any]:
with open(response_path) as f: with open(response_path) as f:
data = json.load(f) data = json.load(f)
mapping = data.get("id_normalization_mapping") or {}
if mapping:
serialized = json.dumps(data)
for normalized, original in mapping.items():
serialized = serialized.replace(original, normalized)
data = json.loads(serialized)
data["id_normalization_mapping"] = {}
# Deserialize response body if needed # Deserialize response body if needed
if "response" in data and "body" in data["response"]: if "response" in data and "body" in data["response"]:
if isinstance(data["response"]["body"], list): if isinstance(data["response"]["body"], list):
@ -759,7 +734,7 @@ async def _patched_tool_invoke_method(
original_method, provider_name: str, self, tool_name: str, kwargs: dict[str, Any] original_method, provider_name: str, self, tool_name: str, kwargs: dict[str, Any]
): ):
"""Patched version of tool runtime invoke_tool method for recording/replay.""" """Patched version of tool runtime invoke_tool method for recording/replay."""
global _current_mode, _current_storage, _id_normalizers global _current_mode, _current_storage
if _current_mode == APIRecordingMode.LIVE or _current_storage is None: if _current_mode == APIRecordingMode.LIVE or _current_storage is None:
# Normal operation # Normal operation
@ -808,7 +783,7 @@ async def _patched_tool_invoke_method(
async def _patched_inference_method(original_method, self, client_type, endpoint, *args, **kwargs): async def _patched_inference_method(original_method, self, client_type, endpoint, *args, **kwargs):
global _current_mode, _current_storage, _id_normalizers global _current_mode, _current_storage
mode = _current_mode mode = _current_mode
storage = _current_storage storage = _current_storage
@ -864,45 +839,6 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
if recording["response"].get("is_streaming", False) and isinstance(response_body, list): if recording["response"].get("is_streaming", False) and isinstance(response_body, list):
response_body = _coalesce_streaming_chunks(response_body) response_body = _coalesce_streaming_chunks(response_body)
recording["response"]["body"] = response_body
# Denormalize the response: replace recorded IDs with current runtime IDs
recorded_mapping = recording.get("id_normalization_mapping", {})
current_mapping = _get_current_test_id_mappings()
if recorded_mapping or current_mapping:
import sys
print(f"\n=== DENORM DEBUG ===", file=sys.stderr)
print(f"Recorded mapping: {recorded_mapping}", file=sys.stderr)
print(f"Current mapping: {current_mapping}", file=sys.stderr)
print(f"_id_normalizers before: {_id_normalizers.get(_test_context.get(), {})}", file=sys.stderr)
response_body = _denormalize_response(response_body, recorded_mapping, current_mapping)
# Update _id_normalizers to register the current IDs with their normalized values
# This ensures that if these IDs appear in future requests, they get the same
# normalized value as they had in the recording
test_id = _test_context.get()
if test_id and recorded_mapping:
for normalized_id, recorded_id in recorded_mapping.items():
if normalized_id in current_mapping:
current_id = current_mapping[normalized_id]
# Extract ID type from normalized_id (e.g., "file-1" -> "file")
id_type = normalized_id.rsplit("-", 1)[0]
# Ensure structures exist
if test_id not in _id_normalizers:
_id_normalizers[test_id] = {}
if id_type not in _id_normalizers[test_id]:
_id_normalizers[test_id][id_type] = {}
# Register: current_id -> normalized_id
_id_normalizers[test_id][id_type][current_id] = normalized_id
print(f"Registered {current_id} -> {normalized_id}", file=sys.stderr)
print(f"_id_normalizers after: {_id_normalizers.get(_test_context.get(), {})}", file=sys.stderr)
print(f"=== END DENORM DEBUG ===\n", file=sys.stderr)
if recording["response"].get("is_streaming", False): if recording["response"].get("is_streaming", False):
async def replay_stream(): async def replay_stream():
@ -1125,6 +1061,7 @@ def api_recording(mode: str, storage_dir: str | Path | None = None) -> Generator
# Store previous state # Store previous state
prev_mode = _current_mode prev_mode = _current_mode
prev_storage = _current_storage prev_storage = _current_storage
override_token = None
try: try:
_current_mode = mode _current_mode = mode
@ -1133,7 +1070,9 @@ def api_recording(mode: str, storage_dir: str | Path | None = None) -> Generator
if storage_dir is None: if storage_dir is None:
raise ValueError("storage_dir is required for record, replay, and record-if-missing modes") raise ValueError("storage_dir is required for record, replay, and record-if-missing modes")
_current_storage = ResponseStorage(Path(storage_dir)) _current_storage = ResponseStorage(Path(storage_dir))
_id_counters.clear()
patch_inference_clients() patch_inference_clients()
override_token = set_id_override(_deterministic_id_override)
yield yield
@ -1141,6 +1080,8 @@ def api_recording(mode: str, storage_dir: str | Path | None = None) -> Generator
# Restore previous state # Restore previous state
if mode in ["record", "replay", "record-if-missing"]: if mode in ["record", "replay", "record-if-missing"]:
unpatch_inference_clients() unpatch_inference_clients()
if override_token is not None:
reset_id_override(override_token)
_current_mode = prev_mode _current_mode = prev_mode
_current_storage = prev_storage _current_storage = prev_storage