mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-14 23:22:35 +00:00
simpler ID generation to remove denormalization
This commit is contained in:
parent
39aa17f975
commit
b47bf340db
5 changed files with 190 additions and 202 deletions
44
llama_stack/core/id_generation.py
Normal file
44
llama_stack/core/id_generation.py
Normal file
|
|
@ -0,0 +1,44 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Callable
|
||||||
|
|
||||||
|
IdFactory = Callable[[], str]
|
||||||
|
IdOverride = Callable[[str, IdFactory], str]
|
||||||
|
|
||||||
|
_id_override: IdOverride | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def generate_object_id(kind: str, factory: IdFactory) -> str:
|
||||||
|
"""Generate an identifier for the given kind using the provided factory.
|
||||||
|
|
||||||
|
Allows tests to override ID generation deterministically by installing an
|
||||||
|
override callback via :func:`set_id_override`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
override = _id_override
|
||||||
|
if override is not None:
|
||||||
|
return override(kind, factory)
|
||||||
|
return factory()
|
||||||
|
|
||||||
|
|
||||||
|
def set_id_override(override: IdOverride) -> IdOverride | None:
|
||||||
|
"""Install an override used to generate deterministic identifiers."""
|
||||||
|
|
||||||
|
global _id_override
|
||||||
|
|
||||||
|
previous = _id_override
|
||||||
|
_id_override = override
|
||||||
|
return previous
|
||||||
|
|
||||||
|
|
||||||
|
def reset_id_override(token: IdOverride | None) -> None:
|
||||||
|
"""Restore the previous override returned by :func:`set_id_override`."""
|
||||||
|
|
||||||
|
global _id_override
|
||||||
|
_id_override = token
|
||||||
|
|
@ -22,6 +22,7 @@ from llama_stack.apis.files import (
|
||||||
OpenAIFilePurpose,
|
OpenAIFilePurpose,
|
||||||
)
|
)
|
||||||
from llama_stack.core.datatypes import AccessRule
|
from llama_stack.core.datatypes import AccessRule
|
||||||
|
from llama_stack.core.id_generation import generate_object_id
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.utils.files.form_data import parse_expires_after
|
from llama_stack.providers.utils.files.form_data import parse_expires_after
|
||||||
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
|
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
|
||||||
|
|
@ -65,7 +66,7 @@ class LocalfsFilesImpl(Files):
|
||||||
|
|
||||||
def _generate_file_id(self) -> str:
|
def _generate_file_id(self) -> str:
|
||||||
"""Generate a unique file ID for OpenAI API."""
|
"""Generate a unique file ID for OpenAI API."""
|
||||||
return f"file-{uuid.uuid4().hex}"
|
return generate_object_id("file", lambda: f"file-{uuid.uuid4().hex}")
|
||||||
|
|
||||||
def _get_file_path(self, file_id: str) -> Path:
|
def _get_file_path(self, file_id: str) -> Path:
|
||||||
"""Get the filesystem path for a file ID."""
|
"""Get the filesystem path for a file ID."""
|
||||||
|
|
|
||||||
|
|
@ -23,6 +23,7 @@ from llama_stack.apis.files import (
|
||||||
OpenAIFilePurpose,
|
OpenAIFilePurpose,
|
||||||
)
|
)
|
||||||
from llama_stack.core.datatypes import AccessRule
|
from llama_stack.core.datatypes import AccessRule
|
||||||
|
from llama_stack.core.id_generation import generate_object_id
|
||||||
from llama_stack.providers.utils.files.form_data import parse_expires_after
|
from llama_stack.providers.utils.files.form_data import parse_expires_after
|
||||||
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
|
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
|
||||||
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
||||||
|
|
@ -198,7 +199,7 @@ class S3FilesImpl(Files):
|
||||||
purpose: Annotated[OpenAIFilePurpose, Form()],
|
purpose: Annotated[OpenAIFilePurpose, Form()],
|
||||||
expires_after: Annotated[ExpiresAfter | None, Depends(parse_expires_after)] = None,
|
expires_after: Annotated[ExpiresAfter | None, Depends(parse_expires_after)] = None,
|
||||||
) -> OpenAIFileObject:
|
) -> OpenAIFileObject:
|
||||||
file_id = f"file-{uuid.uuid4().hex}"
|
file_id = generate_object_id("file", lambda: f"file-{uuid.uuid4().hex}")
|
||||||
|
|
||||||
filename = getattr(file, "filename", None) or "uploaded_file"
|
filename = getattr(file, "filename", None) or "uploaded_file"
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -40,6 +40,7 @@ from llama_stack.apis.vector_io import (
|
||||||
VectorStoreSearchResponse,
|
VectorStoreSearchResponse,
|
||||||
VectorStoreSearchResponsePage,
|
VectorStoreSearchResponsePage,
|
||||||
)
|
)
|
||||||
|
from llama_stack.core.id_generation import generate_object_id
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||||
from llama_stack.providers.utils.memory.vector_store import (
|
from llama_stack.providers.utils.memory.vector_store import (
|
||||||
|
|
@ -352,7 +353,7 @@ class OpenAIVectorStoreMixin(ABC):
|
||||||
"""Creates a vector store."""
|
"""Creates a vector store."""
|
||||||
created_at = int(time.time())
|
created_at = int(time.time())
|
||||||
# Derive the canonical vector_db_id (allow override, else generate)
|
# Derive the canonical vector_db_id (allow override, else generate)
|
||||||
vector_db_id = provider_vector_db_id or f"vs_{uuid.uuid4()}"
|
vector_db_id = provider_vector_db_id or generate_object_id("vector_store", lambda: f"vs_{uuid.uuid4()}")
|
||||||
|
|
||||||
if provider_id is None:
|
if provider_id is None:
|
||||||
raise ValueError("Provider ID is required")
|
raise ValueError("Provider ID is required")
|
||||||
|
|
@ -986,7 +987,7 @@ class OpenAIVectorStoreMixin(ABC):
|
||||||
chunking_strategy = chunking_strategy or VectorStoreChunkingStrategyAuto()
|
chunking_strategy = chunking_strategy or VectorStoreChunkingStrategyAuto()
|
||||||
|
|
||||||
created_at = int(time.time())
|
created_at = int(time.time())
|
||||||
batch_id = f"batch_{uuid.uuid4()}"
|
batch_id = generate_object_id("vector_store_file_batch", lambda: f"batch_{uuid.uuid4()}")
|
||||||
# File batches expire after 7 days
|
# File batches expire after 7 days
|
||||||
expires_at = created_at + (7 * 24 * 60 * 60)
|
expires_at = created_at + (7 * 24 * 60 * 60)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ import hashlib
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from collections.abc import Generator
|
from collections.abc import Callable, Generator
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
@ -18,6 +18,7 @@ from typing import Any, Literal, cast
|
||||||
|
|
||||||
from openai import NOT_GIVEN, OpenAI
|
from openai import NOT_GIVEN, OpenAI
|
||||||
|
|
||||||
|
from llama_stack.core.id_generation import reset_id_override, set_id_override
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
logger = get_logger(__name__, category="testing")
|
logger = get_logger(__name__, category="testing")
|
||||||
|
|
@ -30,9 +31,8 @@ _current_mode: str | None = None
|
||||||
_current_storage: ResponseStorage | None = None
|
_current_storage: ResponseStorage | None = None
|
||||||
_original_methods: dict[str, Any] = {}
|
_original_methods: dict[str, Any] = {}
|
||||||
|
|
||||||
# ID normalization state: maps test_id -> (id_type -> {original_id: normalized_id})
|
# Per-test deterministic ID counters (test_id -> id_kind -> counter)
|
||||||
_id_normalizers: dict[str, dict[str, dict[str, str]]] = {}
|
_id_counters: dict[str, dict[str, int]] = {}
|
||||||
_id_counters: dict[str, dict[str, int]] = {} # test_id -> (id_type -> counter)
|
|
||||||
|
|
||||||
# Test context uses ContextVar since it changes per-test and needs async isolation
|
# Test context uses ContextVar since it changes per-test and needs async isolation
|
||||||
from contextvars import ContextVar
|
from contextvars import ContextVar
|
||||||
|
|
@ -56,104 +56,82 @@ class APIRecordingMode(StrEnum):
|
||||||
RECORD_IF_MISSING = "record-if-missing"
|
RECORD_IF_MISSING = "record-if-missing"
|
||||||
|
|
||||||
|
|
||||||
def _get_normalized_id(original_id: str, id_type: str) -> str:
|
_ID_KIND_PREFIXES: dict[str, str] = {
|
||||||
"""Get a normalized ID using a test-specific counter.
|
"file": "file-",
|
||||||
|
"vector_store": "vs_",
|
||||||
|
"vector_store_file_batch": "batch_",
|
||||||
|
}
|
||||||
|
|
||||||
Each unique ID within a test gets assigned a sequential number (file-1, file-2, uuid-1, etc).
|
|
||||||
This ensures consistency across requests within the same test while keeping IDs human-readable.
|
def _allocate_test_scoped_id(kind: str) -> str | None:
|
||||||
"""
|
"""Return the next deterministic ID for the given kind within the current test."""
|
||||||
global _id_normalizers, _id_counters
|
|
||||||
|
global _id_counters
|
||||||
|
|
||||||
test_id = _test_context.get()
|
test_id = _test_context.get()
|
||||||
if not test_id:
|
prefix = _ID_KIND_PREFIXES.get(kind)
|
||||||
# No test context, return original ID
|
|
||||||
return original_id
|
|
||||||
|
|
||||||
# Initialize structures for this test if needed
|
if prefix is None:
|
||||||
if test_id not in _id_normalizers:
|
return None
|
||||||
_id_normalizers[test_id] = {}
|
|
||||||
_id_counters[test_id] = {}
|
|
||||||
|
|
||||||
if id_type not in _id_normalizers[test_id]:
|
key = test_id or "__global__"
|
||||||
_id_normalizers[test_id][id_type] = {}
|
|
||||||
_id_counters[test_id][id_type] = 0
|
|
||||||
|
|
||||||
# Check if we've seen this ID before
|
if key not in _id_counters:
|
||||||
if original_id in _id_normalizers[test_id][id_type]:
|
_id_counters[key] = {}
|
||||||
return _id_normalizers[test_id][id_type][original_id]
|
|
||||||
|
|
||||||
# New ID - assign next counter value
|
counter = _id_counters[key].get(kind, 0) + 1
|
||||||
_id_counters[test_id][id_type] += 1
|
_id_counters[key][kind] = counter
|
||||||
counter = _id_counters[test_id][id_type]
|
|
||||||
normalized_id = f"{id_type}-{counter}"
|
|
||||||
|
|
||||||
# Store mapping
|
return f"{prefix}{counter}"
|
||||||
_id_normalizers[test_id][id_type][original_id] = normalized_id
|
|
||||||
return normalized_id
|
|
||||||
|
|
||||||
|
|
||||||
def _normalize_file_ids(obj: Any) -> Any:
|
class _IdCanonicalizer:
|
||||||
"""Recursively replace file IDs and vector store IDs with test-specific normalized values.
|
PATTERN = re.compile(r"(file-[A-Za-z0-9_-]+|vs_[A-Za-z0-9_-]+|batch_[A-Za-z0-9_-]+)")
|
||||||
|
|
||||||
Each unique file ID or UUID gets a unique normalized value using a sequential counter
|
def __init__(self) -> None:
|
||||||
within each test (file-1, file-2, uuid-1, etc).
|
self._mappings: dict[str, dict[str, str]] = {kind: {} for kind in _ID_KIND_PREFIXES}
|
||||||
"""
|
self._counters: dict[str, int] = dict.fromkeys(_ID_KIND_PREFIXES, 0)
|
||||||
import re
|
|
||||||
|
|
||||||
if isinstance(obj, dict):
|
def canonicalize(self, obj: Any) -> Any:
|
||||||
result = {}
|
if isinstance(obj, dict):
|
||||||
for k, v in obj.items():
|
return {k: self._canonicalize_value(k, v) for k, v in obj.items()}
|
||||||
# Normalize file IDs in document_id fields
|
if isinstance(obj, list):
|
||||||
if k == "document_id" and isinstance(v, str) and v.startswith("file-"):
|
return [self.canonicalize(item) for item in obj]
|
||||||
result[k] = _get_normalized_id(v, "file")
|
if isinstance(obj, str):
|
||||||
# Normalize vector database/store IDs with UUID patterns
|
return self._canonicalize_string(obj)
|
||||||
elif k in ("vector_db_id", "vector_store_id", "bank_id") and isinstance(v, str):
|
|
||||||
# Replace UUIDs in the ID deterministically
|
|
||||||
def replace_uuid(match):
|
|
||||||
uuid_val = match.group(0)
|
|
||||||
return _get_normalized_id(uuid_val, "uuid")
|
|
||||||
|
|
||||||
normalized = re.sub(
|
|
||||||
r"[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}",
|
|
||||||
replace_uuid,
|
|
||||||
v,
|
|
||||||
)
|
|
||||||
result[k] = normalized
|
|
||||||
else:
|
|
||||||
result[k] = _normalize_file_ids(v)
|
|
||||||
return result
|
|
||||||
elif isinstance(obj, list):
|
|
||||||
return [_normalize_file_ids(item) for item in obj]
|
|
||||||
elif isinstance(obj, str):
|
|
||||||
# Replace file-<uuid> patterns in strings (like in text content)
|
|
||||||
def replace_file_id(match):
|
|
||||||
file_id = match.group(0)
|
|
||||||
return _get_normalized_id(file_id, "file")
|
|
||||||
|
|
||||||
return re.sub(r"file-[a-f0-9]{32}", replace_file_id, obj)
|
|
||||||
else:
|
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
|
def _canonicalize_value(self, key: str, value: Any) -> Any:
|
||||||
|
if key in {"vector_db_id", "vector_store_id", "bank_id"} and isinstance(value, str):
|
||||||
|
return self._canonicalize_string(value)
|
||||||
|
if key == "document_id" and isinstance(value, str) and value.startswith("file-"):
|
||||||
|
return self._canonicalize_string(value)
|
||||||
|
return self.canonicalize(value)
|
||||||
|
|
||||||
def _get_current_test_id_mappings() -> dict[str, str]:
|
def _canonicalize_string(self, value: str) -> str:
|
||||||
"""Get the current test's ID normalization mappings.
|
def replace(match: re.Match[str]) -> str:
|
||||||
|
token = match.group(0)
|
||||||
|
if token.startswith("file-"):
|
||||||
|
return self._mapped_value("file", token)
|
||||||
|
if token.startswith("vs_"):
|
||||||
|
return self._mapped_value("vector_store", token)
|
||||||
|
if token.startswith("batch_"):
|
||||||
|
return self._mapped_value("vector_store_file_batch", token)
|
||||||
|
return token
|
||||||
|
|
||||||
Returns a dict mapping normalized_id -> original_id (e.g., "file-1" -> "file-abc123...")
|
return self.PATTERN.sub(replace, value)
|
||||||
This is the inverse of what's stored in _id_normalizers.
|
|
||||||
"""
|
|
||||||
global _id_normalizers
|
|
||||||
|
|
||||||
test_id = _test_context.get()
|
def _mapped_value(self, kind: str, original: str) -> str:
|
||||||
if not test_id or test_id not in _id_normalizers:
|
mapping = self._mappings[kind]
|
||||||
return {}
|
if original not in mapping:
|
||||||
|
self._counters[kind] += 1
|
||||||
|
mapping[original] = f"{_ID_KIND_PREFIXES[kind]}{self._counters[kind]}"
|
||||||
|
return mapping[original]
|
||||||
|
|
||||||
# Invert the mapping: normalized_id -> original_id
|
|
||||||
result = {}
|
|
||||||
for id_type, mappings in _id_normalizers[test_id].items():
|
|
||||||
for original_id, normalized_id in mappings.items():
|
|
||||||
result[normalized_id] = original_id
|
|
||||||
|
|
||||||
return result
|
def _canonicalize_for_hashing(obj: Any) -> Any:
|
||||||
|
canonicalizer = _IdCanonicalizer()
|
||||||
|
return canonicalizer.canonicalize(obj)
|
||||||
|
|
||||||
|
|
||||||
def _chunk_text_content(chunk: Any) -> tuple[str | None, bool]:
|
def _chunk_text_content(chunk: Any) -> tuple[str | None, bool]:
|
||||||
|
|
@ -205,7 +183,7 @@ def _ends_with_partial_identifier(text: str) -> bool:
|
||||||
else:
|
else:
|
||||||
core = token
|
core = token
|
||||||
|
|
||||||
suffix = core[len("file-"):]
|
suffix = core[len("file-") :]
|
||||||
if len(suffix) < 16:
|
if len(suffix) < 16:
|
||||||
return True
|
return True
|
||||||
if not re.fullmatch(r"[A-Za-z0-9_-]+", suffix):
|
if not re.fullmatch(r"[A-Za-z0-9_-]+", suffix):
|
||||||
|
|
@ -260,67 +238,11 @@ def _coalesce_streaming_chunks(chunks: list[Any]) -> list[Any]:
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def _denormalize_response(obj: Any, recorded_mapping: dict[str, str], current_mapping: dict[str, str]) -> Any:
|
def _deterministic_id_override(kind: str, factory: Callable[[], str]) -> str:
|
||||||
"""Replace recorded IDs with current runtime IDs in a response object.
|
deterministic_id = _allocate_test_scoped_id(kind)
|
||||||
|
if deterministic_id is not None:
|
||||||
Args:
|
return deterministic_id
|
||||||
obj: The response object to denormalize
|
return factory()
|
||||||
recorded_mapping: normalized_id -> recorded_original_id (from recording file)
|
|
||||||
current_mapping: normalized_id -> current_original_id (from current runtime)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The response with all recorded IDs replaced with current IDs
|
|
||||||
"""
|
|
||||||
import re
|
|
||||||
|
|
||||||
# Build reverse mapping: recorded_original_id -> current_original_id
|
|
||||||
# via the normalized intermediary
|
|
||||||
id_translation = {}
|
|
||||||
for normalized_id, recorded_id in recorded_mapping.items():
|
|
||||||
if normalized_id in current_mapping:
|
|
||||||
current_id = current_mapping[normalized_id]
|
|
||||||
id_translation[recorded_id] = current_id
|
|
||||||
|
|
||||||
if isinstance(obj, dict):
|
|
||||||
result = {}
|
|
||||||
for k, v in obj.items():
|
|
||||||
# Translate document_id fields
|
|
||||||
if k == "document_id" and isinstance(v, str):
|
|
||||||
result[k] = id_translation.get(v, v)
|
|
||||||
# Translate vector database/store IDs
|
|
||||||
elif k in ("vector_db_id", "vector_store_id", "bank_id") and isinstance(v, str):
|
|
||||||
# Replace any recorded IDs in the value
|
|
||||||
translated = v
|
|
||||||
for recorded_id, current_id in id_translation.items():
|
|
||||||
translated = translated.replace(recorded_id, current_id)
|
|
||||||
result[k] = translated
|
|
||||||
else:
|
|
||||||
result[k] = _denormalize_response(v, recorded_mapping, current_mapping)
|
|
||||||
return result
|
|
||||||
elif isinstance(obj, list):
|
|
||||||
return [_denormalize_response(item, recorded_mapping, current_mapping) for item in obj]
|
|
||||||
elif hasattr(obj, "model_dump"):
|
|
||||||
# Handle Pydantic/BaseModel instances by denormalizing their dict form
|
|
||||||
data = obj.model_dump(mode="python")
|
|
||||||
denormalized = _denormalize_response(data, recorded_mapping, current_mapping)
|
|
||||||
|
|
||||||
cls = obj.__class__
|
|
||||||
try:
|
|
||||||
return cls.model_validate(denormalized)
|
|
||||||
except Exception:
|
|
||||||
try:
|
|
||||||
return cls.model_construct(**denormalized)
|
|
||||||
except Exception:
|
|
||||||
return denormalized
|
|
||||||
elif isinstance(obj, str):
|
|
||||||
# Replace file-<uuid> patterns in strings (like in text content and citations)
|
|
||||||
translated = obj
|
|
||||||
for recorded_id, current_id in id_translation.items():
|
|
||||||
# Handle both bare file IDs and citation format <|file-...|>
|
|
||||||
translated = translated.replace(recorded_id, current_id)
|
|
||||||
return translated
|
|
||||||
else:
|
|
||||||
return obj
|
|
||||||
|
|
||||||
|
|
||||||
def normalize_inference_request(method: str, url: str, headers: dict[str, Any], body: dict[str, Any]) -> str:
|
def normalize_inference_request(method: str, url: str, headers: dict[str, Any], body: dict[str, Any]) -> str:
|
||||||
|
|
@ -337,13 +259,10 @@ def normalize_inference_request(method: str, url: str, headers: dict[str, Any],
|
||||||
|
|
||||||
parsed = urlparse(url)
|
parsed = urlparse(url)
|
||||||
|
|
||||||
# Normalize file IDs in the body to ensure consistent hashing across test runs
|
|
||||||
normalized_body = _normalize_file_ids(body)
|
|
||||||
|
|
||||||
normalized: dict[str, Any] = {
|
normalized: dict[str, Any] = {
|
||||||
"method": method.upper(),
|
"method": method.upper(),
|
||||||
"endpoint": parsed.path,
|
"endpoint": parsed.path,
|
||||||
"body": normalized_body,
|
"body": _canonicalize_for_hashing(body),
|
||||||
}
|
}
|
||||||
|
|
||||||
# Include test_id for isolation, except for shared infrastructure endpoints
|
# Include test_id for isolation, except for shared infrastructure endpoints
|
||||||
|
|
@ -357,10 +276,11 @@ def normalize_inference_request(method: str, url: str, headers: dict[str, Any],
|
||||||
|
|
||||||
def normalize_tool_request(provider_name: str, tool_name: str, kwargs: dict[str, Any]) -> str:
|
def normalize_tool_request(provider_name: str, tool_name: str, kwargs: dict[str, Any]) -> str:
|
||||||
"""Create a normalized hash of the tool request for consistent matching."""
|
"""Create a normalized hash of the tool request for consistent matching."""
|
||||||
# Normalize file IDs and vector store IDs in kwargs
|
normalized = {
|
||||||
normalized_kwargs = _normalize_file_ids(kwargs)
|
"provider": provider_name,
|
||||||
|
"tool_name": tool_name,
|
||||||
normalized = {"provider": provider_name, "tool_name": tool_name, "kwargs": normalized_kwargs}
|
"kwargs": _canonicalize_for_hashing(kwargs),
|
||||||
|
}
|
||||||
|
|
||||||
# Create hash - sort_keys=True ensures deterministic ordering
|
# Create hash - sort_keys=True ensures deterministic ordering
|
||||||
normalized_json = json.dumps(normalized, sort_keys=True)
|
normalized_json = json.dumps(normalized, sort_keys=True)
|
||||||
|
|
@ -514,9 +434,6 @@ def _normalize_response(data: dict[str, Any], request_hash: str) -> dict[str, An
|
||||||
if "eval_duration" in data and data["eval_duration"] is not None:
|
if "eval_duration" in data and data["eval_duration"] is not None:
|
||||||
data["eval_duration"] = 0
|
data["eval_duration"] = 0
|
||||||
|
|
||||||
# Normalize file IDs and vector store IDs to ensure consistent hashing across replays
|
|
||||||
data = _normalize_file_ids(data)
|
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -565,6 +482,8 @@ class ResponseStorage:
|
||||||
def __init__(self, base_dir: Path):
|
def __init__(self, base_dir: Path):
|
||||||
self.base_dir = base_dir
|
self.base_dir = base_dir
|
||||||
# Don't create responses_dir here - determine it per-test at runtime
|
# Don't create responses_dir here - determine it per-test at runtime
|
||||||
|
self._legacy_index: dict[str, Path] = {}
|
||||||
|
self._scanned_dirs: set[Path] = set()
|
||||||
|
|
||||||
def _get_test_dir(self) -> Path:
|
def _get_test_dir(self) -> Path:
|
||||||
"""Get the recordings directory in the test file's parent directory.
|
"""Get the recordings directory in the test file's parent directory.
|
||||||
|
|
@ -627,7 +546,7 @@ class ResponseStorage:
|
||||||
"test_id": _test_context.get(),
|
"test_id": _test_context.get(),
|
||||||
"request": request,
|
"request": request,
|
||||||
"response": serialized_response,
|
"response": serialized_response,
|
||||||
"id_normalization_mapping": _get_current_test_id_mappings(),
|
"id_normalization_mapping": {},
|
||||||
},
|
},
|
||||||
f,
|
f,
|
||||||
indent=2,
|
indent=2,
|
||||||
|
|
@ -635,6 +554,8 @@ class ResponseStorage:
|
||||||
f.write("\n")
|
f.write("\n")
|
||||||
f.flush()
|
f.flush()
|
||||||
|
|
||||||
|
self._legacy_index[request_hash] = response_path
|
||||||
|
|
||||||
def find_recording(self, request_hash: str) -> dict[str, Any] | None:
|
def find_recording(self, request_hash: str) -> dict[str, Any] | None:
|
||||||
"""Find a recorded response by request hash.
|
"""Find a recorded response by request hash.
|
||||||
|
|
||||||
|
|
@ -658,6 +579,52 @@ class ResponseStorage:
|
||||||
if fallback_path.exists():
|
if fallback_path.exists():
|
||||||
return _recording_from_file(fallback_path)
|
return _recording_from_file(fallback_path)
|
||||||
|
|
||||||
|
return self._find_in_legacy_index(request_hash, [test_dir, fallback_dir])
|
||||||
|
|
||||||
|
def _find_in_legacy_index(self, request_hash: str, directories: list[Path]) -> dict[str, Any] | None:
|
||||||
|
for directory in directories:
|
||||||
|
if not directory.exists() or directory in self._scanned_dirs:
|
||||||
|
continue
|
||||||
|
|
||||||
|
for path in directory.glob("*.json"):
|
||||||
|
try:
|
||||||
|
with open(path) as f:
|
||||||
|
data = json.load(f)
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
|
||||||
|
request = data.get("request")
|
||||||
|
if not request:
|
||||||
|
continue
|
||||||
|
|
||||||
|
body = request.get("body")
|
||||||
|
canonical_body = _canonicalize_for_hashing(body) if isinstance(body, dict | list) else body
|
||||||
|
|
||||||
|
token = None
|
||||||
|
test_id = data.get("test_id")
|
||||||
|
if test_id:
|
||||||
|
token = _test_context.set(test_id)
|
||||||
|
|
||||||
|
try:
|
||||||
|
legacy_hash = normalize_inference_request(
|
||||||
|
request.get("method", ""),
|
||||||
|
request.get("url", ""),
|
||||||
|
request.get("headers", {}),
|
||||||
|
canonical_body,
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
if token is not None:
|
||||||
|
_test_context.reset(token)
|
||||||
|
|
||||||
|
if legacy_hash not in self._legacy_index:
|
||||||
|
self._legacy_index[legacy_hash] = path
|
||||||
|
|
||||||
|
self._scanned_dirs.add(directory)
|
||||||
|
|
||||||
|
legacy_path = self._legacy_index.get(request_hash)
|
||||||
|
if legacy_path and legacy_path.exists():
|
||||||
|
return _recording_from_file(legacy_path)
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _model_list_responses(self, request_hash: str) -> list[dict[str, Any]]:
|
def _model_list_responses(self, request_hash: str) -> list[dict[str, Any]]:
|
||||||
|
|
@ -685,6 +652,14 @@ def _recording_from_file(response_path) -> dict[str, Any]:
|
||||||
with open(response_path) as f:
|
with open(response_path) as f:
|
||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
|
|
||||||
|
mapping = data.get("id_normalization_mapping") or {}
|
||||||
|
if mapping:
|
||||||
|
serialized = json.dumps(data)
|
||||||
|
for normalized, original in mapping.items():
|
||||||
|
serialized = serialized.replace(original, normalized)
|
||||||
|
data = json.loads(serialized)
|
||||||
|
data["id_normalization_mapping"] = {}
|
||||||
|
|
||||||
# Deserialize response body if needed
|
# Deserialize response body if needed
|
||||||
if "response" in data and "body" in data["response"]:
|
if "response" in data and "body" in data["response"]:
|
||||||
if isinstance(data["response"]["body"], list):
|
if isinstance(data["response"]["body"], list):
|
||||||
|
|
@ -759,7 +734,7 @@ async def _patched_tool_invoke_method(
|
||||||
original_method, provider_name: str, self, tool_name: str, kwargs: dict[str, Any]
|
original_method, provider_name: str, self, tool_name: str, kwargs: dict[str, Any]
|
||||||
):
|
):
|
||||||
"""Patched version of tool runtime invoke_tool method for recording/replay."""
|
"""Patched version of tool runtime invoke_tool method for recording/replay."""
|
||||||
global _current_mode, _current_storage, _id_normalizers
|
global _current_mode, _current_storage
|
||||||
|
|
||||||
if _current_mode == APIRecordingMode.LIVE or _current_storage is None:
|
if _current_mode == APIRecordingMode.LIVE or _current_storage is None:
|
||||||
# Normal operation
|
# Normal operation
|
||||||
|
|
@ -808,7 +783,7 @@ async def _patched_tool_invoke_method(
|
||||||
|
|
||||||
|
|
||||||
async def _patched_inference_method(original_method, self, client_type, endpoint, *args, **kwargs):
|
async def _patched_inference_method(original_method, self, client_type, endpoint, *args, **kwargs):
|
||||||
global _current_mode, _current_storage, _id_normalizers
|
global _current_mode, _current_storage
|
||||||
|
|
||||||
mode = _current_mode
|
mode = _current_mode
|
||||||
storage = _current_storage
|
storage = _current_storage
|
||||||
|
|
@ -864,45 +839,6 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
|
||||||
if recording["response"].get("is_streaming", False) and isinstance(response_body, list):
|
if recording["response"].get("is_streaming", False) and isinstance(response_body, list):
|
||||||
response_body = _coalesce_streaming_chunks(response_body)
|
response_body = _coalesce_streaming_chunks(response_body)
|
||||||
|
|
||||||
recording["response"]["body"] = response_body
|
|
||||||
|
|
||||||
# Denormalize the response: replace recorded IDs with current runtime IDs
|
|
||||||
recorded_mapping = recording.get("id_normalization_mapping", {})
|
|
||||||
current_mapping = _get_current_test_id_mappings()
|
|
||||||
|
|
||||||
if recorded_mapping or current_mapping:
|
|
||||||
import sys
|
|
||||||
print(f"\n=== DENORM DEBUG ===", file=sys.stderr)
|
|
||||||
print(f"Recorded mapping: {recorded_mapping}", file=sys.stderr)
|
|
||||||
print(f"Current mapping: {current_mapping}", file=sys.stderr)
|
|
||||||
print(f"_id_normalizers before: {_id_normalizers.get(_test_context.get(), {})}", file=sys.stderr)
|
|
||||||
|
|
||||||
response_body = _denormalize_response(response_body, recorded_mapping, current_mapping)
|
|
||||||
|
|
||||||
# Update _id_normalizers to register the current IDs with their normalized values
|
|
||||||
# This ensures that if these IDs appear in future requests, they get the same
|
|
||||||
# normalized value as they had in the recording
|
|
||||||
test_id = _test_context.get()
|
|
||||||
if test_id and recorded_mapping:
|
|
||||||
for normalized_id, recorded_id in recorded_mapping.items():
|
|
||||||
if normalized_id in current_mapping:
|
|
||||||
current_id = current_mapping[normalized_id]
|
|
||||||
# Extract ID type from normalized_id (e.g., "file-1" -> "file")
|
|
||||||
id_type = normalized_id.rsplit("-", 1)[0]
|
|
||||||
|
|
||||||
# Ensure structures exist
|
|
||||||
if test_id not in _id_normalizers:
|
|
||||||
_id_normalizers[test_id] = {}
|
|
||||||
if id_type not in _id_normalizers[test_id]:
|
|
||||||
_id_normalizers[test_id][id_type] = {}
|
|
||||||
|
|
||||||
# Register: current_id -> normalized_id
|
|
||||||
_id_normalizers[test_id][id_type][current_id] = normalized_id
|
|
||||||
print(f"Registered {current_id} -> {normalized_id}", file=sys.stderr)
|
|
||||||
|
|
||||||
print(f"_id_normalizers after: {_id_normalizers.get(_test_context.get(), {})}", file=sys.stderr)
|
|
||||||
print(f"=== END DENORM DEBUG ===\n", file=sys.stderr)
|
|
||||||
|
|
||||||
if recording["response"].get("is_streaming", False):
|
if recording["response"].get("is_streaming", False):
|
||||||
|
|
||||||
async def replay_stream():
|
async def replay_stream():
|
||||||
|
|
@ -1125,6 +1061,7 @@ def api_recording(mode: str, storage_dir: str | Path | None = None) -> Generator
|
||||||
# Store previous state
|
# Store previous state
|
||||||
prev_mode = _current_mode
|
prev_mode = _current_mode
|
||||||
prev_storage = _current_storage
|
prev_storage = _current_storage
|
||||||
|
override_token = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
_current_mode = mode
|
_current_mode = mode
|
||||||
|
|
@ -1133,7 +1070,9 @@ def api_recording(mode: str, storage_dir: str | Path | None = None) -> Generator
|
||||||
if storage_dir is None:
|
if storage_dir is None:
|
||||||
raise ValueError("storage_dir is required for record, replay, and record-if-missing modes")
|
raise ValueError("storage_dir is required for record, replay, and record-if-missing modes")
|
||||||
_current_storage = ResponseStorage(Path(storage_dir))
|
_current_storage = ResponseStorage(Path(storage_dir))
|
||||||
|
_id_counters.clear()
|
||||||
patch_inference_clients()
|
patch_inference_clients()
|
||||||
|
override_token = set_id_override(_deterministic_id_override)
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
|
@ -1141,6 +1080,8 @@ def api_recording(mode: str, storage_dir: str | Path | None = None) -> Generator
|
||||||
# Restore previous state
|
# Restore previous state
|
||||||
if mode in ["record", "replay", "record-if-missing"]:
|
if mode in ["record", "replay", "record-if-missing"]:
|
||||||
unpatch_inference_clients()
|
unpatch_inference_clients()
|
||||||
|
if override_token is not None:
|
||||||
|
reset_id_override(override_token)
|
||||||
|
|
||||||
_current_mode = prev_mode
|
_current_mode = prev_mode
|
||||||
_current_storage = prev_storage
|
_current_storage = prev_storage
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue