simpler ID generation to remove denormalization

This commit is contained in:
Ashwin Bharambe 2025-10-09 09:22:47 -07:00
parent 39aa17f975
commit b47bf340db
5 changed files with 190 additions and 202 deletions

View file

@ -0,0 +1,44 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from __future__ import annotations
from collections.abc import Callable
IdFactory = Callable[[], str]
IdOverride = Callable[[str, IdFactory], str]
_id_override: IdOverride | None = None
def generate_object_id(kind: str, factory: IdFactory) -> str:
"""Generate an identifier for the given kind using the provided factory.
Allows tests to override ID generation deterministically by installing an
override callback via :func:`set_id_override`.
"""
override = _id_override
if override is not None:
return override(kind, factory)
return factory()
def set_id_override(override: IdOverride) -> IdOverride | None:
"""Install an override used to generate deterministic identifiers."""
global _id_override
previous = _id_override
_id_override = override
return previous
def reset_id_override(token: IdOverride | None) -> None:
"""Restore the previous override returned by :func:`set_id_override`."""
global _id_override
_id_override = token

View file

@ -22,6 +22,7 @@ from llama_stack.apis.files import (
OpenAIFilePurpose,
)
from llama_stack.core.datatypes import AccessRule
from llama_stack.core.id_generation import generate_object_id
from llama_stack.log import get_logger
from llama_stack.providers.utils.files.form_data import parse_expires_after
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
@ -65,7 +66,7 @@ class LocalfsFilesImpl(Files):
def _generate_file_id(self) -> str:
"""Generate a unique file ID for OpenAI API."""
return f"file-{uuid.uuid4().hex}"
return generate_object_id("file", lambda: f"file-{uuid.uuid4().hex}")
def _get_file_path(self, file_id: str) -> Path:
"""Get the filesystem path for a file ID."""

View file

@ -23,6 +23,7 @@ from llama_stack.apis.files import (
OpenAIFilePurpose,
)
from llama_stack.core.datatypes import AccessRule
from llama_stack.core.id_generation import generate_object_id
from llama_stack.providers.utils.files.form_data import parse_expires_after
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
@ -198,7 +199,7 @@ class S3FilesImpl(Files):
purpose: Annotated[OpenAIFilePurpose, Form()],
expires_after: Annotated[ExpiresAfter | None, Depends(parse_expires_after)] = None,
) -> OpenAIFileObject:
file_id = f"file-{uuid.uuid4().hex}"
file_id = generate_object_id("file", lambda: f"file-{uuid.uuid4().hex}")
filename = getattr(file, "filename", None) or "uploaded_file"

View file

@ -40,6 +40,7 @@ from llama_stack.apis.vector_io import (
VectorStoreSearchResponse,
VectorStoreSearchResponsePage,
)
from llama_stack.core.id_generation import generate_object_id
from llama_stack.log import get_logger
from llama_stack.providers.utils.kvstore.api import KVStore
from llama_stack.providers.utils.memory.vector_store import (
@ -352,7 +353,7 @@ class OpenAIVectorStoreMixin(ABC):
"""Creates a vector store."""
created_at = int(time.time())
# Derive the canonical vector_db_id (allow override, else generate)
vector_db_id = provider_vector_db_id or f"vs_{uuid.uuid4()}"
vector_db_id = provider_vector_db_id or generate_object_id("vector_store", lambda: f"vs_{uuid.uuid4()}")
if provider_id is None:
raise ValueError("Provider ID is required")
@ -986,7 +987,7 @@ class OpenAIVectorStoreMixin(ABC):
chunking_strategy = chunking_strategy or VectorStoreChunkingStrategyAuto()
created_at = int(time.time())
batch_id = f"batch_{uuid.uuid4()}"
batch_id = generate_object_id("vector_store_file_batch", lambda: f"batch_{uuid.uuid4()}")
# File batches expire after 7 days
expires_at = created_at + (7 * 24 * 60 * 60)

View file

@ -10,7 +10,7 @@ import hashlib
import json
import os
import re
from collections.abc import Generator
from collections.abc import Callable, Generator
from contextlib import contextmanager
from enum import StrEnum
from pathlib import Path
@ -18,6 +18,7 @@ from typing import Any, Literal, cast
from openai import NOT_GIVEN, OpenAI
from llama_stack.core.id_generation import reset_id_override, set_id_override
from llama_stack.log import get_logger
logger = get_logger(__name__, category="testing")
@ -30,9 +31,8 @@ _current_mode: str | None = None
_current_storage: ResponseStorage | None = None
_original_methods: dict[str, Any] = {}
# ID normalization state: maps test_id -> (id_type -> {original_id: normalized_id})
_id_normalizers: dict[str, dict[str, dict[str, str]]] = {}
_id_counters: dict[str, dict[str, int]] = {} # test_id -> (id_type -> counter)
# Per-test deterministic ID counters (test_id -> id_kind -> counter)
_id_counters: dict[str, dict[str, int]] = {}
# Test context uses ContextVar since it changes per-test and needs async isolation
from contextvars import ContextVar
@ -56,104 +56,82 @@ class APIRecordingMode(StrEnum):
RECORD_IF_MISSING = "record-if-missing"
def _get_normalized_id(original_id: str, id_type: str) -> str:
"""Get a normalized ID using a test-specific counter.
_ID_KIND_PREFIXES: dict[str, str] = {
"file": "file-",
"vector_store": "vs_",
"vector_store_file_batch": "batch_",
}
Each unique ID within a test gets assigned a sequential number (file-1, file-2, uuid-1, etc).
This ensures consistency across requests within the same test while keeping IDs human-readable.
"""
global _id_normalizers, _id_counters
def _allocate_test_scoped_id(kind: str) -> str | None:
"""Return the next deterministic ID for the given kind within the current test."""
global _id_counters
test_id = _test_context.get()
if not test_id:
# No test context, return original ID
return original_id
prefix = _ID_KIND_PREFIXES.get(kind)
# Initialize structures for this test if needed
if test_id not in _id_normalizers:
_id_normalizers[test_id] = {}
_id_counters[test_id] = {}
if prefix is None:
return None
if id_type not in _id_normalizers[test_id]:
_id_normalizers[test_id][id_type] = {}
_id_counters[test_id][id_type] = 0
key = test_id or "__global__"
# Check if we've seen this ID before
if original_id in _id_normalizers[test_id][id_type]:
return _id_normalizers[test_id][id_type][original_id]
if key not in _id_counters:
_id_counters[key] = {}
# New ID - assign next counter value
_id_counters[test_id][id_type] += 1
counter = _id_counters[test_id][id_type]
normalized_id = f"{id_type}-{counter}"
counter = _id_counters[key].get(kind, 0) + 1
_id_counters[key][kind] = counter
# Store mapping
_id_normalizers[test_id][id_type][original_id] = normalized_id
return normalized_id
return f"{prefix}{counter}"
def _normalize_file_ids(obj: Any) -> Any:
"""Recursively replace file IDs and vector store IDs with test-specific normalized values.
class _IdCanonicalizer:
PATTERN = re.compile(r"(file-[A-Za-z0-9_-]+|vs_[A-Za-z0-9_-]+|batch_[A-Za-z0-9_-]+)")
Each unique file ID or UUID gets a unique normalized value using a sequential counter
within each test (file-1, file-2, uuid-1, etc).
"""
import re
def __init__(self) -> None:
self._mappings: dict[str, dict[str, str]] = {kind: {} for kind in _ID_KIND_PREFIXES}
self._counters: dict[str, int] = dict.fromkeys(_ID_KIND_PREFIXES, 0)
if isinstance(obj, dict):
result = {}
for k, v in obj.items():
# Normalize file IDs in document_id fields
if k == "document_id" and isinstance(v, str) and v.startswith("file-"):
result[k] = _get_normalized_id(v, "file")
# Normalize vector database/store IDs with UUID patterns
elif k in ("vector_db_id", "vector_store_id", "bank_id") and isinstance(v, str):
# Replace UUIDs in the ID deterministically
def replace_uuid(match):
uuid_val = match.group(0)
return _get_normalized_id(uuid_val, "uuid")
normalized = re.sub(
r"[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}",
replace_uuid,
v,
)
result[k] = normalized
else:
result[k] = _normalize_file_ids(v)
return result
elif isinstance(obj, list):
return [_normalize_file_ids(item) for item in obj]
elif isinstance(obj, str):
# Replace file-<uuid> patterns in strings (like in text content)
def replace_file_id(match):
file_id = match.group(0)
return _get_normalized_id(file_id, "file")
return re.sub(r"file-[a-f0-9]{32}", replace_file_id, obj)
else:
def canonicalize(self, obj: Any) -> Any:
if isinstance(obj, dict):
return {k: self._canonicalize_value(k, v) for k, v in obj.items()}
if isinstance(obj, list):
return [self.canonicalize(item) for item in obj]
if isinstance(obj, str):
return self._canonicalize_string(obj)
return obj
def _canonicalize_value(self, key: str, value: Any) -> Any:
if key in {"vector_db_id", "vector_store_id", "bank_id"} and isinstance(value, str):
return self._canonicalize_string(value)
if key == "document_id" and isinstance(value, str) and value.startswith("file-"):
return self._canonicalize_string(value)
return self.canonicalize(value)
def _get_current_test_id_mappings() -> dict[str, str]:
"""Get the current test's ID normalization mappings.
def _canonicalize_string(self, value: str) -> str:
def replace(match: re.Match[str]) -> str:
token = match.group(0)
if token.startswith("file-"):
return self._mapped_value("file", token)
if token.startswith("vs_"):
return self._mapped_value("vector_store", token)
if token.startswith("batch_"):
return self._mapped_value("vector_store_file_batch", token)
return token
Returns a dict mapping normalized_id -> original_id (e.g., "file-1" -> "file-abc123...")
This is the inverse of what's stored in _id_normalizers.
"""
global _id_normalizers
return self.PATTERN.sub(replace, value)
test_id = _test_context.get()
if not test_id or test_id not in _id_normalizers:
return {}
def _mapped_value(self, kind: str, original: str) -> str:
mapping = self._mappings[kind]
if original not in mapping:
self._counters[kind] += 1
mapping[original] = f"{_ID_KIND_PREFIXES[kind]}{self._counters[kind]}"
return mapping[original]
# Invert the mapping: normalized_id -> original_id
result = {}
for id_type, mappings in _id_normalizers[test_id].items():
for original_id, normalized_id in mappings.items():
result[normalized_id] = original_id
return result
def _canonicalize_for_hashing(obj: Any) -> Any:
canonicalizer = _IdCanonicalizer()
return canonicalizer.canonicalize(obj)
def _chunk_text_content(chunk: Any) -> tuple[str | None, bool]:
@ -205,7 +183,7 @@ def _ends_with_partial_identifier(text: str) -> bool:
else:
core = token
suffix = core[len("file-"):]
suffix = core[len("file-") :]
if len(suffix) < 16:
return True
if not re.fullmatch(r"[A-Za-z0-9_-]+", suffix):
@ -260,67 +238,11 @@ def _coalesce_streaming_chunks(chunks: list[Any]) -> list[Any]:
return result
def _denormalize_response(obj: Any, recorded_mapping: dict[str, str], current_mapping: dict[str, str]) -> Any:
"""Replace recorded IDs with current runtime IDs in a response object.
Args:
obj: The response object to denormalize
recorded_mapping: normalized_id -> recorded_original_id (from recording file)
current_mapping: normalized_id -> current_original_id (from current runtime)
Returns:
The response with all recorded IDs replaced with current IDs
"""
import re
# Build reverse mapping: recorded_original_id -> current_original_id
# via the normalized intermediary
id_translation = {}
for normalized_id, recorded_id in recorded_mapping.items():
if normalized_id in current_mapping:
current_id = current_mapping[normalized_id]
id_translation[recorded_id] = current_id
if isinstance(obj, dict):
result = {}
for k, v in obj.items():
# Translate document_id fields
if k == "document_id" and isinstance(v, str):
result[k] = id_translation.get(v, v)
# Translate vector database/store IDs
elif k in ("vector_db_id", "vector_store_id", "bank_id") and isinstance(v, str):
# Replace any recorded IDs in the value
translated = v
for recorded_id, current_id in id_translation.items():
translated = translated.replace(recorded_id, current_id)
result[k] = translated
else:
result[k] = _denormalize_response(v, recorded_mapping, current_mapping)
return result
elif isinstance(obj, list):
return [_denormalize_response(item, recorded_mapping, current_mapping) for item in obj]
elif hasattr(obj, "model_dump"):
# Handle Pydantic/BaseModel instances by denormalizing their dict form
data = obj.model_dump(mode="python")
denormalized = _denormalize_response(data, recorded_mapping, current_mapping)
cls = obj.__class__
try:
return cls.model_validate(denormalized)
except Exception:
try:
return cls.model_construct(**denormalized)
except Exception:
return denormalized
elif isinstance(obj, str):
# Replace file-<uuid> patterns in strings (like in text content and citations)
translated = obj
for recorded_id, current_id in id_translation.items():
# Handle both bare file IDs and citation format <|file-...|>
translated = translated.replace(recorded_id, current_id)
return translated
else:
return obj
def _deterministic_id_override(kind: str, factory: Callable[[], str]) -> str:
deterministic_id = _allocate_test_scoped_id(kind)
if deterministic_id is not None:
return deterministic_id
return factory()
def normalize_inference_request(method: str, url: str, headers: dict[str, Any], body: dict[str, Any]) -> str:
@ -337,13 +259,10 @@ def normalize_inference_request(method: str, url: str, headers: dict[str, Any],
parsed = urlparse(url)
# Normalize file IDs in the body to ensure consistent hashing across test runs
normalized_body = _normalize_file_ids(body)
normalized: dict[str, Any] = {
"method": method.upper(),
"endpoint": parsed.path,
"body": normalized_body,
"body": _canonicalize_for_hashing(body),
}
# Include test_id for isolation, except for shared infrastructure endpoints
@ -357,10 +276,11 @@ def normalize_inference_request(method: str, url: str, headers: dict[str, Any],
def normalize_tool_request(provider_name: str, tool_name: str, kwargs: dict[str, Any]) -> str:
"""Create a normalized hash of the tool request for consistent matching."""
# Normalize file IDs and vector store IDs in kwargs
normalized_kwargs = _normalize_file_ids(kwargs)
normalized = {"provider": provider_name, "tool_name": tool_name, "kwargs": normalized_kwargs}
normalized = {
"provider": provider_name,
"tool_name": tool_name,
"kwargs": _canonicalize_for_hashing(kwargs),
}
# Create hash - sort_keys=True ensures deterministic ordering
normalized_json = json.dumps(normalized, sort_keys=True)
@ -514,9 +434,6 @@ def _normalize_response(data: dict[str, Any], request_hash: str) -> dict[str, An
if "eval_duration" in data and data["eval_duration"] is not None:
data["eval_duration"] = 0
# Normalize file IDs and vector store IDs to ensure consistent hashing across replays
data = _normalize_file_ids(data)
return data
@ -565,6 +482,8 @@ class ResponseStorage:
def __init__(self, base_dir: Path):
self.base_dir = base_dir
# Don't create responses_dir here - determine it per-test at runtime
self._legacy_index: dict[str, Path] = {}
self._scanned_dirs: set[Path] = set()
def _get_test_dir(self) -> Path:
"""Get the recordings directory in the test file's parent directory.
@ -627,7 +546,7 @@ class ResponseStorage:
"test_id": _test_context.get(),
"request": request,
"response": serialized_response,
"id_normalization_mapping": _get_current_test_id_mappings(),
"id_normalization_mapping": {},
},
f,
indent=2,
@ -635,6 +554,8 @@ class ResponseStorage:
f.write("\n")
f.flush()
self._legacy_index[request_hash] = response_path
def find_recording(self, request_hash: str) -> dict[str, Any] | None:
"""Find a recorded response by request hash.
@ -658,6 +579,52 @@ class ResponseStorage:
if fallback_path.exists():
return _recording_from_file(fallback_path)
return self._find_in_legacy_index(request_hash, [test_dir, fallback_dir])
def _find_in_legacy_index(self, request_hash: str, directories: list[Path]) -> dict[str, Any] | None:
for directory in directories:
if not directory.exists() or directory in self._scanned_dirs:
continue
for path in directory.glob("*.json"):
try:
with open(path) as f:
data = json.load(f)
except Exception:
continue
request = data.get("request")
if not request:
continue
body = request.get("body")
canonical_body = _canonicalize_for_hashing(body) if isinstance(body, dict | list) else body
token = None
test_id = data.get("test_id")
if test_id:
token = _test_context.set(test_id)
try:
legacy_hash = normalize_inference_request(
request.get("method", ""),
request.get("url", ""),
request.get("headers", {}),
canonical_body,
)
finally:
if token is not None:
_test_context.reset(token)
if legacy_hash not in self._legacy_index:
self._legacy_index[legacy_hash] = path
self._scanned_dirs.add(directory)
legacy_path = self._legacy_index.get(request_hash)
if legacy_path and legacy_path.exists():
return _recording_from_file(legacy_path)
return None
def _model_list_responses(self, request_hash: str) -> list[dict[str, Any]]:
@ -685,6 +652,14 @@ def _recording_from_file(response_path) -> dict[str, Any]:
with open(response_path) as f:
data = json.load(f)
mapping = data.get("id_normalization_mapping") or {}
if mapping:
serialized = json.dumps(data)
for normalized, original in mapping.items():
serialized = serialized.replace(original, normalized)
data = json.loads(serialized)
data["id_normalization_mapping"] = {}
# Deserialize response body if needed
if "response" in data and "body" in data["response"]:
if isinstance(data["response"]["body"], list):
@ -759,7 +734,7 @@ async def _patched_tool_invoke_method(
original_method, provider_name: str, self, tool_name: str, kwargs: dict[str, Any]
):
"""Patched version of tool runtime invoke_tool method for recording/replay."""
global _current_mode, _current_storage, _id_normalizers
global _current_mode, _current_storage
if _current_mode == APIRecordingMode.LIVE or _current_storage is None:
# Normal operation
@ -808,7 +783,7 @@ async def _patched_tool_invoke_method(
async def _patched_inference_method(original_method, self, client_type, endpoint, *args, **kwargs):
global _current_mode, _current_storage, _id_normalizers
global _current_mode, _current_storage
mode = _current_mode
storage = _current_storage
@ -864,45 +839,6 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
if recording["response"].get("is_streaming", False) and isinstance(response_body, list):
response_body = _coalesce_streaming_chunks(response_body)
recording["response"]["body"] = response_body
# Denormalize the response: replace recorded IDs with current runtime IDs
recorded_mapping = recording.get("id_normalization_mapping", {})
current_mapping = _get_current_test_id_mappings()
if recorded_mapping or current_mapping:
import sys
print(f"\n=== DENORM DEBUG ===", file=sys.stderr)
print(f"Recorded mapping: {recorded_mapping}", file=sys.stderr)
print(f"Current mapping: {current_mapping}", file=sys.stderr)
print(f"_id_normalizers before: {_id_normalizers.get(_test_context.get(), {})}", file=sys.stderr)
response_body = _denormalize_response(response_body, recorded_mapping, current_mapping)
# Update _id_normalizers to register the current IDs with their normalized values
# This ensures that if these IDs appear in future requests, they get the same
# normalized value as they had in the recording
test_id = _test_context.get()
if test_id and recorded_mapping:
for normalized_id, recorded_id in recorded_mapping.items():
if normalized_id in current_mapping:
current_id = current_mapping[normalized_id]
# Extract ID type from normalized_id (e.g., "file-1" -> "file")
id_type = normalized_id.rsplit("-", 1)[0]
# Ensure structures exist
if test_id not in _id_normalizers:
_id_normalizers[test_id] = {}
if id_type not in _id_normalizers[test_id]:
_id_normalizers[test_id][id_type] = {}
# Register: current_id -> normalized_id
_id_normalizers[test_id][id_type][current_id] = normalized_id
print(f"Registered {current_id} -> {normalized_id}", file=sys.stderr)
print(f"_id_normalizers after: {_id_normalizers.get(_test_context.get(), {})}", file=sys.stderr)
print(f"=== END DENORM DEBUG ===\n", file=sys.stderr)
if recording["response"].get("is_streaming", False):
async def replay_stream():
@ -1125,6 +1061,7 @@ def api_recording(mode: str, storage_dir: str | Path | None = None) -> Generator
# Store previous state
prev_mode = _current_mode
prev_storage = _current_storage
override_token = None
try:
_current_mode = mode
@ -1133,7 +1070,9 @@ def api_recording(mode: str, storage_dir: str | Path | None = None) -> Generator
if storage_dir is None:
raise ValueError("storage_dir is required for record, replay, and record-if-missing modes")
_current_storage = ResponseStorage(Path(storage_dir))
_id_counters.clear()
patch_inference_clients()
override_token = set_id_override(_deterministic_id_override)
yield
@ -1141,6 +1080,8 @@ def api_recording(mode: str, storage_dir: str | Path | None = None) -> Generator
# Restore previous state
if mode in ["record", "replay", "record-if-missing"]:
unpatch_inference_clients()
if override_token is not None:
reset_id_override(override_token)
_current_mode = prev_mode
_current_storage = prev_storage