From b47bf340db3d481517f72c05d89dc425700f652e Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Thu, 9 Oct 2025 09:22:47 -0700 Subject: [PATCH] simpler ID generation to remove denormalization --- llama_stack/core/id_generation.py | 44 +++ .../providers/inline/files/localfs/files.py | 3 +- .../providers/remote/files/s3/files.py | 3 +- .../utils/memory/openai_vector_store_mixin.py | 5 +- llama_stack/testing/api_recorder.py | 337 ++++++++---------- 5 files changed, 190 insertions(+), 202 deletions(-) create mode 100644 llama_stack/core/id_generation.py diff --git a/llama_stack/core/id_generation.py b/llama_stack/core/id_generation.py new file mode 100644 index 000000000..6459e5cac --- /dev/null +++ b/llama_stack/core/id_generation.py @@ -0,0 +1,44 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from __future__ import annotations + +from collections.abc import Callable + +IdFactory = Callable[[], str] +IdOverride = Callable[[str, IdFactory], str] + +_id_override: IdOverride | None = None + + +def generate_object_id(kind: str, factory: IdFactory) -> str: + """Generate an identifier for the given kind using the provided factory. + + Allows tests to override ID generation deterministically by installing an + override callback via :func:`set_id_override`. + """ + + override = _id_override + if override is not None: + return override(kind, factory) + return factory() + + +def set_id_override(override: IdOverride) -> IdOverride | None: + """Install an override used to generate deterministic identifiers.""" + + global _id_override + + previous = _id_override + _id_override = override + return previous + + +def reset_id_override(token: IdOverride | None) -> None: + """Restore the previous override returned by :func:`set_id_override`.""" + + global _id_override + _id_override = token diff --git a/llama_stack/providers/inline/files/localfs/files.py b/llama_stack/providers/inline/files/localfs/files.py index 77af94681..a76b982ce 100644 --- a/llama_stack/providers/inline/files/localfs/files.py +++ b/llama_stack/providers/inline/files/localfs/files.py @@ -22,6 +22,7 @@ from llama_stack.apis.files import ( OpenAIFilePurpose, ) from llama_stack.core.datatypes import AccessRule +from llama_stack.core.id_generation import generate_object_id from llama_stack.log import get_logger from llama_stack.providers.utils.files.form_data import parse_expires_after from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType @@ -65,7 +66,7 @@ class LocalfsFilesImpl(Files): def _generate_file_id(self) -> str: """Generate a unique file ID for OpenAI API.""" - return f"file-{uuid.uuid4().hex}" + return generate_object_id("file", lambda: f"file-{uuid.uuid4().hex}") def _get_file_path(self, file_id: str) -> Path: """Get the filesystem path for a file ID.""" diff --git a/llama_stack/providers/remote/files/s3/files.py b/llama_stack/providers/remote/files/s3/files.py index eb339b31e..c0e9f81d6 100644 --- a/llama_stack/providers/remote/files/s3/files.py +++ b/llama_stack/providers/remote/files/s3/files.py @@ -23,6 +23,7 @@ from llama_stack.apis.files import ( OpenAIFilePurpose, ) from llama_stack.core.datatypes import AccessRule +from llama_stack.core.id_generation import generate_object_id from llama_stack.providers.utils.files.form_data import parse_expires_after from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore @@ -198,7 +199,7 @@ class S3FilesImpl(Files): purpose: Annotated[OpenAIFilePurpose, Form()], expires_after: Annotated[ExpiresAfter | None, Depends(parse_expires_after)] = None, ) -> OpenAIFileObject: - file_id = f"file-{uuid.uuid4().hex}" + file_id = generate_object_id("file", lambda: f"file-{uuid.uuid4().hex}") filename = getattr(file, "filename", None) or "uploaded_file" diff --git a/llama_stack/providers/utils/memory/openai_vector_store_mixin.py b/llama_stack/providers/utils/memory/openai_vector_store_mixin.py index c179eba6c..ddfef9ba2 100644 --- a/llama_stack/providers/utils/memory/openai_vector_store_mixin.py +++ b/llama_stack/providers/utils/memory/openai_vector_store_mixin.py @@ -40,6 +40,7 @@ from llama_stack.apis.vector_io import ( VectorStoreSearchResponse, VectorStoreSearchResponsePage, ) +from llama_stack.core.id_generation import generate_object_id from llama_stack.log import get_logger from llama_stack.providers.utils.kvstore.api import KVStore from llama_stack.providers.utils.memory.vector_store import ( @@ -352,7 +353,7 @@ class OpenAIVectorStoreMixin(ABC): """Creates a vector store.""" created_at = int(time.time()) # Derive the canonical vector_db_id (allow override, else generate) - vector_db_id = provider_vector_db_id or f"vs_{uuid.uuid4()}" + vector_db_id = provider_vector_db_id or generate_object_id("vector_store", lambda: f"vs_{uuid.uuid4()}") if provider_id is None: raise ValueError("Provider ID is required") @@ -986,7 +987,7 @@ class OpenAIVectorStoreMixin(ABC): chunking_strategy = chunking_strategy or VectorStoreChunkingStrategyAuto() created_at = int(time.time()) - batch_id = f"batch_{uuid.uuid4()}" + batch_id = generate_object_id("vector_store_file_batch", lambda: f"batch_{uuid.uuid4()}") # File batches expire after 7 days expires_at = created_at + (7 * 24 * 60 * 60) diff --git a/llama_stack/testing/api_recorder.py b/llama_stack/testing/api_recorder.py index b1244f19f..a9fa4d8b0 100644 --- a/llama_stack/testing/api_recorder.py +++ b/llama_stack/testing/api_recorder.py @@ -10,7 +10,7 @@ import hashlib import json import os import re -from collections.abc import Generator +from collections.abc import Callable, Generator from contextlib import contextmanager from enum import StrEnum from pathlib import Path @@ -18,6 +18,7 @@ from typing import Any, Literal, cast from openai import NOT_GIVEN, OpenAI +from llama_stack.core.id_generation import reset_id_override, set_id_override from llama_stack.log import get_logger logger = get_logger(__name__, category="testing") @@ -30,9 +31,8 @@ _current_mode: str | None = None _current_storage: ResponseStorage | None = None _original_methods: dict[str, Any] = {} -# ID normalization state: maps test_id -> (id_type -> {original_id: normalized_id}) -_id_normalizers: dict[str, dict[str, dict[str, str]]] = {} -_id_counters: dict[str, dict[str, int]] = {} # test_id -> (id_type -> counter) +# Per-test deterministic ID counters (test_id -> id_kind -> counter) +_id_counters: dict[str, dict[str, int]] = {} # Test context uses ContextVar since it changes per-test and needs async isolation from contextvars import ContextVar @@ -56,104 +56,82 @@ class APIRecordingMode(StrEnum): RECORD_IF_MISSING = "record-if-missing" -def _get_normalized_id(original_id: str, id_type: str) -> str: - """Get a normalized ID using a test-specific counter. +_ID_KIND_PREFIXES: dict[str, str] = { + "file": "file-", + "vector_store": "vs_", + "vector_store_file_batch": "batch_", +} - Each unique ID within a test gets assigned a sequential number (file-1, file-2, uuid-1, etc). - This ensures consistency across requests within the same test while keeping IDs human-readable. - """ - global _id_normalizers, _id_counters + +def _allocate_test_scoped_id(kind: str) -> str | None: + """Return the next deterministic ID for the given kind within the current test.""" + + global _id_counters test_id = _test_context.get() - if not test_id: - # No test context, return original ID - return original_id + prefix = _ID_KIND_PREFIXES.get(kind) - # Initialize structures for this test if needed - if test_id not in _id_normalizers: - _id_normalizers[test_id] = {} - _id_counters[test_id] = {} + if prefix is None: + return None - if id_type not in _id_normalizers[test_id]: - _id_normalizers[test_id][id_type] = {} - _id_counters[test_id][id_type] = 0 + key = test_id or "__global__" - # Check if we've seen this ID before - if original_id in _id_normalizers[test_id][id_type]: - return _id_normalizers[test_id][id_type][original_id] + if key not in _id_counters: + _id_counters[key] = {} - # New ID - assign next counter value - _id_counters[test_id][id_type] += 1 - counter = _id_counters[test_id][id_type] - normalized_id = f"{id_type}-{counter}" + counter = _id_counters[key].get(kind, 0) + 1 + _id_counters[key][kind] = counter - # Store mapping - _id_normalizers[test_id][id_type][original_id] = normalized_id - return normalized_id + return f"{prefix}{counter}" -def _normalize_file_ids(obj: Any) -> Any: - """Recursively replace file IDs and vector store IDs with test-specific normalized values. +class _IdCanonicalizer: + PATTERN = re.compile(r"(file-[A-Za-z0-9_-]+|vs_[A-Za-z0-9_-]+|batch_[A-Za-z0-9_-]+)") - Each unique file ID or UUID gets a unique normalized value using a sequential counter - within each test (file-1, file-2, uuid-1, etc). - """ - import re + def __init__(self) -> None: + self._mappings: dict[str, dict[str, str]] = {kind: {} for kind in _ID_KIND_PREFIXES} + self._counters: dict[str, int] = dict.fromkeys(_ID_KIND_PREFIXES, 0) - if isinstance(obj, dict): - result = {} - for k, v in obj.items(): - # Normalize file IDs in document_id fields - if k == "document_id" and isinstance(v, str) and v.startswith("file-"): - result[k] = _get_normalized_id(v, "file") - # Normalize vector database/store IDs with UUID patterns - elif k in ("vector_db_id", "vector_store_id", "bank_id") and isinstance(v, str): - # Replace UUIDs in the ID deterministically - def replace_uuid(match): - uuid_val = match.group(0) - return _get_normalized_id(uuid_val, "uuid") - - normalized = re.sub( - r"[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}", - replace_uuid, - v, - ) - result[k] = normalized - else: - result[k] = _normalize_file_ids(v) - return result - elif isinstance(obj, list): - return [_normalize_file_ids(item) for item in obj] - elif isinstance(obj, str): - # Replace file- patterns in strings (like in text content) - def replace_file_id(match): - file_id = match.group(0) - return _get_normalized_id(file_id, "file") - - return re.sub(r"file-[a-f0-9]{32}", replace_file_id, obj) - else: + def canonicalize(self, obj: Any) -> Any: + if isinstance(obj, dict): + return {k: self._canonicalize_value(k, v) for k, v in obj.items()} + if isinstance(obj, list): + return [self.canonicalize(item) for item in obj] + if isinstance(obj, str): + return self._canonicalize_string(obj) return obj + def _canonicalize_value(self, key: str, value: Any) -> Any: + if key in {"vector_db_id", "vector_store_id", "bank_id"} and isinstance(value, str): + return self._canonicalize_string(value) + if key == "document_id" and isinstance(value, str) and value.startswith("file-"): + return self._canonicalize_string(value) + return self.canonicalize(value) -def _get_current_test_id_mappings() -> dict[str, str]: - """Get the current test's ID normalization mappings. + def _canonicalize_string(self, value: str) -> str: + def replace(match: re.Match[str]) -> str: + token = match.group(0) + if token.startswith("file-"): + return self._mapped_value("file", token) + if token.startswith("vs_"): + return self._mapped_value("vector_store", token) + if token.startswith("batch_"): + return self._mapped_value("vector_store_file_batch", token) + return token - Returns a dict mapping normalized_id -> original_id (e.g., "file-1" -> "file-abc123...") - This is the inverse of what's stored in _id_normalizers. - """ - global _id_normalizers + return self.PATTERN.sub(replace, value) - test_id = _test_context.get() - if not test_id or test_id not in _id_normalizers: - return {} + def _mapped_value(self, kind: str, original: str) -> str: + mapping = self._mappings[kind] + if original not in mapping: + self._counters[kind] += 1 + mapping[original] = f"{_ID_KIND_PREFIXES[kind]}{self._counters[kind]}" + return mapping[original] - # Invert the mapping: normalized_id -> original_id - result = {} - for id_type, mappings in _id_normalizers[test_id].items(): - for original_id, normalized_id in mappings.items(): - result[normalized_id] = original_id - return result +def _canonicalize_for_hashing(obj: Any) -> Any: + canonicalizer = _IdCanonicalizer() + return canonicalizer.canonicalize(obj) def _chunk_text_content(chunk: Any) -> tuple[str | None, bool]: @@ -205,7 +183,7 @@ def _ends_with_partial_identifier(text: str) -> bool: else: core = token - suffix = core[len("file-"):] + suffix = core[len("file-") :] if len(suffix) < 16: return True if not re.fullmatch(r"[A-Za-z0-9_-]+", suffix): @@ -260,67 +238,11 @@ def _coalesce_streaming_chunks(chunks: list[Any]) -> list[Any]: return result -def _denormalize_response(obj: Any, recorded_mapping: dict[str, str], current_mapping: dict[str, str]) -> Any: - """Replace recorded IDs with current runtime IDs in a response object. - - Args: - obj: The response object to denormalize - recorded_mapping: normalized_id -> recorded_original_id (from recording file) - current_mapping: normalized_id -> current_original_id (from current runtime) - - Returns: - The response with all recorded IDs replaced with current IDs - """ - import re - - # Build reverse mapping: recorded_original_id -> current_original_id - # via the normalized intermediary - id_translation = {} - for normalized_id, recorded_id in recorded_mapping.items(): - if normalized_id in current_mapping: - current_id = current_mapping[normalized_id] - id_translation[recorded_id] = current_id - - if isinstance(obj, dict): - result = {} - for k, v in obj.items(): - # Translate document_id fields - if k == "document_id" and isinstance(v, str): - result[k] = id_translation.get(v, v) - # Translate vector database/store IDs - elif k in ("vector_db_id", "vector_store_id", "bank_id") and isinstance(v, str): - # Replace any recorded IDs in the value - translated = v - for recorded_id, current_id in id_translation.items(): - translated = translated.replace(recorded_id, current_id) - result[k] = translated - else: - result[k] = _denormalize_response(v, recorded_mapping, current_mapping) - return result - elif isinstance(obj, list): - return [_denormalize_response(item, recorded_mapping, current_mapping) for item in obj] - elif hasattr(obj, "model_dump"): - # Handle Pydantic/BaseModel instances by denormalizing their dict form - data = obj.model_dump(mode="python") - denormalized = _denormalize_response(data, recorded_mapping, current_mapping) - - cls = obj.__class__ - try: - return cls.model_validate(denormalized) - except Exception: - try: - return cls.model_construct(**denormalized) - except Exception: - return denormalized - elif isinstance(obj, str): - # Replace file- patterns in strings (like in text content and citations) - translated = obj - for recorded_id, current_id in id_translation.items(): - # Handle both bare file IDs and citation format <|file-...|> - translated = translated.replace(recorded_id, current_id) - return translated - else: - return obj +def _deterministic_id_override(kind: str, factory: Callable[[], str]) -> str: + deterministic_id = _allocate_test_scoped_id(kind) + if deterministic_id is not None: + return deterministic_id + return factory() def normalize_inference_request(method: str, url: str, headers: dict[str, Any], body: dict[str, Any]) -> str: @@ -337,13 +259,10 @@ def normalize_inference_request(method: str, url: str, headers: dict[str, Any], parsed = urlparse(url) - # Normalize file IDs in the body to ensure consistent hashing across test runs - normalized_body = _normalize_file_ids(body) - normalized: dict[str, Any] = { "method": method.upper(), "endpoint": parsed.path, - "body": normalized_body, + "body": _canonicalize_for_hashing(body), } # Include test_id for isolation, except for shared infrastructure endpoints @@ -357,10 +276,11 @@ def normalize_inference_request(method: str, url: str, headers: dict[str, Any], def normalize_tool_request(provider_name: str, tool_name: str, kwargs: dict[str, Any]) -> str: """Create a normalized hash of the tool request for consistent matching.""" - # Normalize file IDs and vector store IDs in kwargs - normalized_kwargs = _normalize_file_ids(kwargs) - - normalized = {"provider": provider_name, "tool_name": tool_name, "kwargs": normalized_kwargs} + normalized = { + "provider": provider_name, + "tool_name": tool_name, + "kwargs": _canonicalize_for_hashing(kwargs), + } # Create hash - sort_keys=True ensures deterministic ordering normalized_json = json.dumps(normalized, sort_keys=True) @@ -514,9 +434,6 @@ def _normalize_response(data: dict[str, Any], request_hash: str) -> dict[str, An if "eval_duration" in data and data["eval_duration"] is not None: data["eval_duration"] = 0 - # Normalize file IDs and vector store IDs to ensure consistent hashing across replays - data = _normalize_file_ids(data) - return data @@ -565,6 +482,8 @@ class ResponseStorage: def __init__(self, base_dir: Path): self.base_dir = base_dir # Don't create responses_dir here - determine it per-test at runtime + self._legacy_index: dict[str, Path] = {} + self._scanned_dirs: set[Path] = set() def _get_test_dir(self) -> Path: """Get the recordings directory in the test file's parent directory. @@ -627,7 +546,7 @@ class ResponseStorage: "test_id": _test_context.get(), "request": request, "response": serialized_response, - "id_normalization_mapping": _get_current_test_id_mappings(), + "id_normalization_mapping": {}, }, f, indent=2, @@ -635,6 +554,8 @@ class ResponseStorage: f.write("\n") f.flush() + self._legacy_index[request_hash] = response_path + def find_recording(self, request_hash: str) -> dict[str, Any] | None: """Find a recorded response by request hash. @@ -658,6 +579,52 @@ class ResponseStorage: if fallback_path.exists(): return _recording_from_file(fallback_path) + return self._find_in_legacy_index(request_hash, [test_dir, fallback_dir]) + + def _find_in_legacy_index(self, request_hash: str, directories: list[Path]) -> dict[str, Any] | None: + for directory in directories: + if not directory.exists() or directory in self._scanned_dirs: + continue + + for path in directory.glob("*.json"): + try: + with open(path) as f: + data = json.load(f) + except Exception: + continue + + request = data.get("request") + if not request: + continue + + body = request.get("body") + canonical_body = _canonicalize_for_hashing(body) if isinstance(body, dict | list) else body + + token = None + test_id = data.get("test_id") + if test_id: + token = _test_context.set(test_id) + + try: + legacy_hash = normalize_inference_request( + request.get("method", ""), + request.get("url", ""), + request.get("headers", {}), + canonical_body, + ) + finally: + if token is not None: + _test_context.reset(token) + + if legacy_hash not in self._legacy_index: + self._legacy_index[legacy_hash] = path + + self._scanned_dirs.add(directory) + + legacy_path = self._legacy_index.get(request_hash) + if legacy_path and legacy_path.exists(): + return _recording_from_file(legacy_path) + return None def _model_list_responses(self, request_hash: str) -> list[dict[str, Any]]: @@ -685,6 +652,14 @@ def _recording_from_file(response_path) -> dict[str, Any]: with open(response_path) as f: data = json.load(f) + mapping = data.get("id_normalization_mapping") or {} + if mapping: + serialized = json.dumps(data) + for normalized, original in mapping.items(): + serialized = serialized.replace(original, normalized) + data = json.loads(serialized) + data["id_normalization_mapping"] = {} + # Deserialize response body if needed if "response" in data and "body" in data["response"]: if isinstance(data["response"]["body"], list): @@ -759,7 +734,7 @@ async def _patched_tool_invoke_method( original_method, provider_name: str, self, tool_name: str, kwargs: dict[str, Any] ): """Patched version of tool runtime invoke_tool method for recording/replay.""" - global _current_mode, _current_storage, _id_normalizers + global _current_mode, _current_storage if _current_mode == APIRecordingMode.LIVE or _current_storage is None: # Normal operation @@ -808,7 +783,7 @@ async def _patched_tool_invoke_method( async def _patched_inference_method(original_method, self, client_type, endpoint, *args, **kwargs): - global _current_mode, _current_storage, _id_normalizers + global _current_mode, _current_storage mode = _current_mode storage = _current_storage @@ -864,45 +839,6 @@ async def _patched_inference_method(original_method, self, client_type, endpoint if recording["response"].get("is_streaming", False) and isinstance(response_body, list): response_body = _coalesce_streaming_chunks(response_body) - recording["response"]["body"] = response_body - - # Denormalize the response: replace recorded IDs with current runtime IDs - recorded_mapping = recording.get("id_normalization_mapping", {}) - current_mapping = _get_current_test_id_mappings() - - if recorded_mapping or current_mapping: - import sys - print(f"\n=== DENORM DEBUG ===", file=sys.stderr) - print(f"Recorded mapping: {recorded_mapping}", file=sys.stderr) - print(f"Current mapping: {current_mapping}", file=sys.stderr) - print(f"_id_normalizers before: {_id_normalizers.get(_test_context.get(), {})}", file=sys.stderr) - - response_body = _denormalize_response(response_body, recorded_mapping, current_mapping) - - # Update _id_normalizers to register the current IDs with their normalized values - # This ensures that if these IDs appear in future requests, they get the same - # normalized value as they had in the recording - test_id = _test_context.get() - if test_id and recorded_mapping: - for normalized_id, recorded_id in recorded_mapping.items(): - if normalized_id in current_mapping: - current_id = current_mapping[normalized_id] - # Extract ID type from normalized_id (e.g., "file-1" -> "file") - id_type = normalized_id.rsplit("-", 1)[0] - - # Ensure structures exist - if test_id not in _id_normalizers: - _id_normalizers[test_id] = {} - if id_type not in _id_normalizers[test_id]: - _id_normalizers[test_id][id_type] = {} - - # Register: current_id -> normalized_id - _id_normalizers[test_id][id_type][current_id] = normalized_id - print(f"Registered {current_id} -> {normalized_id}", file=sys.stderr) - - print(f"_id_normalizers after: {_id_normalizers.get(_test_context.get(), {})}", file=sys.stderr) - print(f"=== END DENORM DEBUG ===\n", file=sys.stderr) - if recording["response"].get("is_streaming", False): async def replay_stream(): @@ -1125,6 +1061,7 @@ def api_recording(mode: str, storage_dir: str | Path | None = None) -> Generator # Store previous state prev_mode = _current_mode prev_storage = _current_storage + override_token = None try: _current_mode = mode @@ -1133,7 +1070,9 @@ def api_recording(mode: str, storage_dir: str | Path | None = None) -> Generator if storage_dir is None: raise ValueError("storage_dir is required for record, replay, and record-if-missing modes") _current_storage = ResponseStorage(Path(storage_dir)) + _id_counters.clear() patch_inference_clients() + override_token = set_id_override(_deterministic_id_override) yield @@ -1141,6 +1080,8 @@ def api_recording(mode: str, storage_dir: str | Path | None = None) -> Generator # Restore previous state if mode in ["record", "replay", "record-if-missing"]: unpatch_inference_clients() + if override_token is not None: + reset_id_override(override_token) _current_mode = prev_mode _current_storage = prev_storage