mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-16 23:49:28 +00:00
score normalization
This commit is contained in:
parent
e0b32b8bdf
commit
aeb91f67ac
24 changed files with 34629 additions and 3 deletions
|
|
@ -9,6 +9,7 @@ from __future__ import annotations # for forward references
|
|||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from collections.abc import Callable, Generator
|
||||
from contextlib import contextmanager
|
||||
from enum import StrEnum
|
||||
|
|
@ -61,6 +62,35 @@ _ID_KIND_PREFIXES: dict[str, str] = {
|
|||
}
|
||||
|
||||
|
||||
_FLOAT_IN_STRING_PATTERN = re.compile(r"(-?\d+\.\d{4,})")
|
||||
|
||||
|
||||
def _normalize_numeric_literal_strings(value: str) -> str:
|
||||
"""Round any long decimal literals embedded in strings for stable hashing."""
|
||||
|
||||
def _replace(match: re.Match[str]) -> str:
|
||||
number = float(match.group(0))
|
||||
return f"{number:.6f}"
|
||||
|
||||
return _FLOAT_IN_STRING_PATTERN.sub(_replace, value)
|
||||
|
||||
|
||||
def _normalize_body_for_hash(value: Any) -> Any:
|
||||
"""Recursively normalize a JSON-like value to improve hash stability."""
|
||||
|
||||
if isinstance(value, dict):
|
||||
return {key: _normalize_body_for_hash(item) for key, item in value.items()}
|
||||
if isinstance(value, list):
|
||||
return [_normalize_body_for_hash(item) for item in value]
|
||||
if isinstance(value, tuple):
|
||||
return tuple(_normalize_body_for_hash(item) for item in value)
|
||||
if isinstance(value, float):
|
||||
return round(value, 6)
|
||||
if isinstance(value, str):
|
||||
return _normalize_numeric_literal_strings(value)
|
||||
return value
|
||||
|
||||
|
||||
def _allocate_test_scoped_id(kind: str) -> str | None:
|
||||
"""Return the next deterministic ID for the given kind within the current test."""
|
||||
|
||||
|
|
@ -108,22 +138,24 @@ def normalize_inference_request(method: str, url: str, headers: dict[str, Any],
|
|||
Exception: Model list endpoints (/v1/models, /api/tags) exclude test_id since
|
||||
they are infrastructure/shared and need to work across session setup and tests.
|
||||
"""
|
||||
|
||||
# Extract just the endpoint path
|
||||
from urllib.parse import urlparse
|
||||
|
||||
parsed = urlparse(url)
|
||||
|
||||
body_for_hash = _normalize_body_for_hash(body)
|
||||
|
||||
normalized: dict[str, Any] = {
|
||||
"method": method.upper(),
|
||||
"endpoint": parsed.path,
|
||||
"body": body,
|
||||
"body": body_for_hash,
|
||||
}
|
||||
|
||||
# Include test_id for isolation, except for shared infrastructure endpoints
|
||||
if parsed.path not in ("/api/tags", "/v1/models"):
|
||||
normalized["test_id"] = get_test_context()
|
||||
|
||||
# Create hash - sort_keys=True ensures deterministic ordering
|
||||
normalized_json = json.dumps(normalized, sort_keys=True)
|
||||
return hashlib.sha256(normalized_json.encode()).hexdigest()
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue