mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-16 19:59:26 +00:00
a lot of simplification finally. all works
This commit is contained in:
parent
b47bf340db
commit
00fd27be1f
39 changed files with 16027 additions and 1969 deletions
|
|
@ -9,7 +9,6 @@ from __future__ import annotations # for forward references
|
|||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from collections.abc import Callable, Generator
|
||||
from contextlib import contextmanager
|
||||
from enum import StrEnum
|
||||
|
|
@ -35,12 +34,10 @@ _original_methods: dict[str, Any] = {}
|
|||
_id_counters: dict[str, dict[str, int]] = {}
|
||||
|
||||
# Test context uses ContextVar since it changes per-test and needs async isolation
|
||||
from contextvars import ContextVar
|
||||
|
||||
_test_context: ContextVar[str | None] = ContextVar("_test_context", default=None)
|
||||
|
||||
from openai.types.completion_choice import CompletionChoice
|
||||
|
||||
from llama_stack.core.testing_context import get_test_context
|
||||
|
||||
# update the "finish_reason" field, since its type definition is wrong (no None is accepted)
|
||||
CompletionChoice.model_fields["finish_reason"].annotation = Literal["stop", "length", "content_filter"] | None
|
||||
CompletionChoice.model_rebuild()
|
||||
|
|
@ -60,6 +57,7 @@ _ID_KIND_PREFIXES: dict[str, str] = {
|
|||
"file": "file-",
|
||||
"vector_store": "vs_",
|
||||
"vector_store_file_batch": "batch_",
|
||||
"tool_call": "call_",
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -68,176 +66,32 @@ def _allocate_test_scoped_id(kind: str) -> str | None:
|
|||
|
||||
global _id_counters
|
||||
|
||||
test_id = _test_context.get()
|
||||
test_id = get_test_context()
|
||||
prefix = _ID_KIND_PREFIXES.get(kind)
|
||||
|
||||
if prefix is None:
|
||||
return None
|
||||
|
||||
key = test_id or "__global__"
|
||||
if not test_id:
|
||||
raise ValueError(f"Test ID is required for {kind} ID allocation")
|
||||
|
||||
key = test_id
|
||||
if key not in _id_counters:
|
||||
_id_counters[key] = {}
|
||||
|
||||
counter = _id_counters[key].get(kind, 0) + 1
|
||||
# each test should get a contiguous block of IDs otherwise we will get
|
||||
# collisions between tests inside other systems (like file storage) which
|
||||
# expect IDs to be unique
|
||||
test_hash = hashlib.sha256(test_id.encode()).hexdigest()
|
||||
test_hash_int = int(test_hash, 16)
|
||||
counter = test_hash_int % 1000000000000
|
||||
|
||||
counter = _id_counters[key].get(kind, counter) + 1
|
||||
_id_counters[key][kind] = counter
|
||||
|
||||
return f"{prefix}{counter}"
|
||||
|
||||
|
||||
class _IdCanonicalizer:
|
||||
PATTERN = re.compile(r"(file-[A-Za-z0-9_-]+|vs_[A-Za-z0-9_-]+|batch_[A-Za-z0-9_-]+)")
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._mappings: dict[str, dict[str, str]] = {kind: {} for kind in _ID_KIND_PREFIXES}
|
||||
self._counters: dict[str, int] = dict.fromkeys(_ID_KIND_PREFIXES, 0)
|
||||
|
||||
def canonicalize(self, obj: Any) -> Any:
|
||||
if isinstance(obj, dict):
|
||||
return {k: self._canonicalize_value(k, v) for k, v in obj.items()}
|
||||
if isinstance(obj, list):
|
||||
return [self.canonicalize(item) for item in obj]
|
||||
if isinstance(obj, str):
|
||||
return self._canonicalize_string(obj)
|
||||
return obj
|
||||
|
||||
def _canonicalize_value(self, key: str, value: Any) -> Any:
|
||||
if key in {"vector_db_id", "vector_store_id", "bank_id"} and isinstance(value, str):
|
||||
return self._canonicalize_string(value)
|
||||
if key == "document_id" and isinstance(value, str) and value.startswith("file-"):
|
||||
return self._canonicalize_string(value)
|
||||
return self.canonicalize(value)
|
||||
|
||||
def _canonicalize_string(self, value: str) -> str:
|
||||
def replace(match: re.Match[str]) -> str:
|
||||
token = match.group(0)
|
||||
if token.startswith("file-"):
|
||||
return self._mapped_value("file", token)
|
||||
if token.startswith("vs_"):
|
||||
return self._mapped_value("vector_store", token)
|
||||
if token.startswith("batch_"):
|
||||
return self._mapped_value("vector_store_file_batch", token)
|
||||
return token
|
||||
|
||||
return self.PATTERN.sub(replace, value)
|
||||
|
||||
def _mapped_value(self, kind: str, original: str) -> str:
|
||||
mapping = self._mappings[kind]
|
||||
if original not in mapping:
|
||||
self._counters[kind] += 1
|
||||
mapping[original] = f"{_ID_KIND_PREFIXES[kind]}{self._counters[kind]}"
|
||||
return mapping[original]
|
||||
|
||||
|
||||
def _canonicalize_for_hashing(obj: Any) -> Any:
|
||||
canonicalizer = _IdCanonicalizer()
|
||||
return canonicalizer.canonicalize(obj)
|
||||
|
||||
|
||||
def _chunk_text_content(chunk: Any) -> tuple[str | None, bool]:
|
||||
"""Return (content, has_structured_fields) for OpenAI chat completion chunks."""
|
||||
choices = getattr(chunk, "choices", None)
|
||||
if not choices:
|
||||
return None, False
|
||||
|
||||
delta = choices[0].delta
|
||||
content = getattr(delta, "content", None)
|
||||
if not content:
|
||||
return None, False
|
||||
|
||||
has_structured = bool(getattr(delta, "tool_calls", None) or getattr(delta, "function_call", None))
|
||||
return content, has_structured
|
||||
|
||||
|
||||
def _chunk_with_content(chunk: Any, content: str) -> Any:
|
||||
"""Return a copy of the chunk with delta.content replaced by the provided string."""
|
||||
choices = getattr(chunk, "choices", None)
|
||||
if not choices:
|
||||
return chunk
|
||||
|
||||
updated_choices = []
|
||||
for choice in choices:
|
||||
delta = choice.delta
|
||||
if getattr(delta, "content", None) is not None:
|
||||
new_delta = delta.model_copy(update={"content": content})
|
||||
updated_choices.append(choice.model_copy(update={"delta": new_delta}))
|
||||
else:
|
||||
updated_choices.append(choice)
|
||||
|
||||
return chunk.model_copy(update={"choices": updated_choices})
|
||||
|
||||
|
||||
def _ends_with_partial_identifier(text: str) -> bool:
|
||||
"""Return True if text ends in an incomplete file identifier."""
|
||||
match = re.search(r"(?:<\|)?file-[A-Za-z0-9_-]*$", text)
|
||||
if not match:
|
||||
return False
|
||||
|
||||
token = match.group()
|
||||
enclosed = token.startswith("<|")
|
||||
if enclosed and not token.endswith("|>"):
|
||||
return True
|
||||
|
||||
if enclosed:
|
||||
core = token[2:-2] if token.endswith("|>") else token[2:]
|
||||
else:
|
||||
core = token
|
||||
|
||||
suffix = core[len("file-") :]
|
||||
if len(suffix) < 16:
|
||||
return True
|
||||
if not re.fullmatch(r"[A-Za-z0-9_-]+", suffix):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _has_safe_boundary(text: str) -> bool:
|
||||
if not text:
|
||||
return False
|
||||
|
||||
last_char = text[-1]
|
||||
if last_char.isspace():
|
||||
return True
|
||||
|
||||
return last_char in ".,?!;:)]}>\"'"
|
||||
|
||||
|
||||
def _coalesce_streaming_chunks(chunks: list[Any]) -> list[Any]:
|
||||
"""Merge adjacent text chunks to avoid breaking identifiers across boundaries."""
|
||||
result: list[Any] = []
|
||||
pending_chunk: Any | None = None
|
||||
pending_content = ""
|
||||
|
||||
for chunk in chunks:
|
||||
content, has_structured = _chunk_text_content(chunk)
|
||||
|
||||
if content is None or has_structured:
|
||||
if pending_chunk is not None:
|
||||
result.append(_chunk_with_content(pending_chunk, pending_content))
|
||||
pending_chunk = None
|
||||
pending_content = ""
|
||||
|
||||
result.append(chunk)
|
||||
continue
|
||||
|
||||
if pending_chunk is None:
|
||||
pending_chunk = chunk
|
||||
pending_content = content
|
||||
else:
|
||||
pending_content += content
|
||||
|
||||
if (not _ends_with_partial_identifier(pending_content)) and _has_safe_boundary(pending_content):
|
||||
result.append(_chunk_with_content(pending_chunk, pending_content))
|
||||
pending_chunk = None
|
||||
pending_content = ""
|
||||
|
||||
if pending_chunk is not None:
|
||||
result.append(_chunk_with_content(pending_chunk, pending_content))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _deterministic_id_override(kind: str, factory: Callable[[], str]) -> str:
|
||||
deterministic_id = _allocate_test_scoped_id(kind)
|
||||
if deterministic_id is not None:
|
||||
|
|
@ -262,12 +116,12 @@ def normalize_inference_request(method: str, url: str, headers: dict[str, Any],
|
|||
normalized: dict[str, Any] = {
|
||||
"method": method.upper(),
|
||||
"endpoint": parsed.path,
|
||||
"body": _canonicalize_for_hashing(body),
|
||||
"body": body,
|
||||
}
|
||||
|
||||
# Include test_id for isolation, except for shared infrastructure endpoints
|
||||
if parsed.path not in ("/api/tags", "/v1/models"):
|
||||
normalized["test_id"] = _test_context.get()
|
||||
normalized["test_id"] = get_test_context()
|
||||
|
||||
# Create hash - sort_keys=True ensures deterministic ordering
|
||||
normalized_json = json.dumps(normalized, sort_keys=True)
|
||||
|
|
@ -279,7 +133,7 @@ def normalize_tool_request(provider_name: str, tool_name: str, kwargs: dict[str,
|
|||
normalized = {
|
||||
"provider": provider_name,
|
||||
"tool_name": tool_name,
|
||||
"kwargs": _canonicalize_for_hashing(kwargs),
|
||||
"kwargs": kwargs,
|
||||
}
|
||||
|
||||
# Create hash - sort_keys=True ensures deterministic ordering
|
||||
|
|
@ -287,33 +141,6 @@ def normalize_tool_request(provider_name: str, tool_name: str, kwargs: dict[str,
|
|||
return hashlib.sha256(normalized_json.encode()).hexdigest()
|
||||
|
||||
|
||||
def _sync_test_context_from_provider_data():
|
||||
"""In server mode, sync test ID from provider_data to _test_context.
|
||||
|
||||
This ensures that storage operations (which read from _test_context) work correctly
|
||||
in server mode where the test ID arrives via HTTP header → provider_data.
|
||||
|
||||
Returns a token to reset _test_context, or None if no sync was needed.
|
||||
"""
|
||||
stack_config_type = os.environ.get("LLAMA_STACK_TEST_STACK_CONFIG_TYPE", "library_client")
|
||||
|
||||
if stack_config_type != "server":
|
||||
return None
|
||||
|
||||
try:
|
||||
from llama_stack.core.request_headers import PROVIDER_DATA_VAR
|
||||
|
||||
provider_data = PROVIDER_DATA_VAR.get()
|
||||
|
||||
if provider_data and "__test_id" in provider_data:
|
||||
test_id = provider_data["__test_id"]
|
||||
return _test_context.set(test_id)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def patch_httpx_for_test_id():
|
||||
"""Patch client _prepare_request methods to inject test ID into provider data header.
|
||||
|
||||
|
|
@ -335,13 +162,12 @@ def patch_httpx_for_test_id():
|
|||
def patched_prepare_request(self, request):
|
||||
# Call original first (it's a sync method that returns None)
|
||||
# Determine which original to call based on client type
|
||||
if "llama_stack_client" in self.__class__.__module__:
|
||||
_original_methods["llama_stack_client_prepare_request"](self, request)
|
||||
_original_methods["openai_prepare_request"](self, request)
|
||||
_original_methods["llama_stack_client_prepare_request"](self, request)
|
||||
_original_methods["openai_prepare_request"](self, request)
|
||||
|
||||
# Only inject test ID in server mode
|
||||
stack_config_type = os.environ.get("LLAMA_STACK_TEST_STACK_CONFIG_TYPE", "library_client")
|
||||
test_id = _test_context.get()
|
||||
test_id = get_test_context()
|
||||
|
||||
if stack_config_type == "server" and test_id:
|
||||
provider_data_header = request.headers.get("X-LlamaStack-Provider-Data")
|
||||
|
|
@ -482,8 +308,6 @@ class ResponseStorage:
|
|||
def __init__(self, base_dir: Path):
|
||||
self.base_dir = base_dir
|
||||
# Don't create responses_dir here - determine it per-test at runtime
|
||||
self._legacy_index: dict[str, Path] = {}
|
||||
self._scanned_dirs: set[Path] = set()
|
||||
|
||||
def _get_test_dir(self) -> Path:
|
||||
"""Get the recordings directory in the test file's parent directory.
|
||||
|
|
@ -491,7 +315,7 @@ class ResponseStorage:
|
|||
For test at "tests/integration/inference/test_foo.py::test_bar",
|
||||
returns "tests/integration/inference/recordings/".
|
||||
"""
|
||||
test_id = _test_context.get()
|
||||
test_id = get_test_context()
|
||||
if test_id:
|
||||
# Extract the directory path from the test nodeid
|
||||
# e.g., "tests/integration/inference/test_basic.py::test_foo[params]"
|
||||
|
|
@ -506,7 +330,7 @@ class ResponseStorage:
|
|||
# Fallback for non-test contexts
|
||||
return self.base_dir / "recordings"
|
||||
|
||||
def _ensure_directories(self):
|
||||
def _ensure_directory(self):
|
||||
"""Ensure test-specific directories exist."""
|
||||
test_dir = self._get_test_dir()
|
||||
test_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
|
@ -514,7 +338,7 @@ class ResponseStorage:
|
|||
|
||||
def store_recording(self, request_hash: str, request: dict[str, Any], response: dict[str, Any]):
|
||||
"""Store a request/response pair."""
|
||||
responses_dir = self._ensure_directories()
|
||||
responses_dir = self._ensure_directory()
|
||||
|
||||
# Use FULL hash (not truncated)
|
||||
response_file = f"{request_hash}.json"
|
||||
|
|
@ -543,7 +367,7 @@ class ResponseStorage:
|
|||
with open(response_path, "w") as f:
|
||||
json.dump(
|
||||
{
|
||||
"test_id": _test_context.get(),
|
||||
"test_id": get_test_context(),
|
||||
"request": request,
|
||||
"response": serialized_response,
|
||||
"id_normalization_mapping": {},
|
||||
|
|
@ -554,8 +378,6 @@ class ResponseStorage:
|
|||
f.write("\n")
|
||||
f.flush()
|
||||
|
||||
self._legacy_index[request_hash] = response_path
|
||||
|
||||
def find_recording(self, request_hash: str) -> dict[str, Any] | None:
|
||||
"""Find a recorded response by request hash.
|
||||
|
||||
|
|
@ -579,52 +401,6 @@ class ResponseStorage:
|
|||
if fallback_path.exists():
|
||||
return _recording_from_file(fallback_path)
|
||||
|
||||
return self._find_in_legacy_index(request_hash, [test_dir, fallback_dir])
|
||||
|
||||
def _find_in_legacy_index(self, request_hash: str, directories: list[Path]) -> dict[str, Any] | None:
|
||||
for directory in directories:
|
||||
if not directory.exists() or directory in self._scanned_dirs:
|
||||
continue
|
||||
|
||||
for path in directory.glob("*.json"):
|
||||
try:
|
||||
with open(path) as f:
|
||||
data = json.load(f)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
request = data.get("request")
|
||||
if not request:
|
||||
continue
|
||||
|
||||
body = request.get("body")
|
||||
canonical_body = _canonicalize_for_hashing(body) if isinstance(body, dict | list) else body
|
||||
|
||||
token = None
|
||||
test_id = data.get("test_id")
|
||||
if test_id:
|
||||
token = _test_context.set(test_id)
|
||||
|
||||
try:
|
||||
legacy_hash = normalize_inference_request(
|
||||
request.get("method", ""),
|
||||
request.get("url", ""),
|
||||
request.get("headers", {}),
|
||||
canonical_body,
|
||||
)
|
||||
finally:
|
||||
if token is not None:
|
||||
_test_context.reset(token)
|
||||
|
||||
if legacy_hash not in self._legacy_index:
|
||||
self._legacy_index[legacy_hash] = path
|
||||
|
||||
self._scanned_dirs.add(directory)
|
||||
|
||||
legacy_path = self._legacy_index.get(request_hash)
|
||||
if legacy_path and legacy_path.exists():
|
||||
return _recording_from_file(legacy_path)
|
||||
|
||||
return None
|
||||
|
||||
def _model_list_responses(self, request_hash: str) -> list[dict[str, Any]]:
|
||||
|
|
@ -740,46 +516,38 @@ async def _patched_tool_invoke_method(
|
|||
# Normal operation
|
||||
return await original_method(self, tool_name, kwargs)
|
||||
|
||||
# In server mode, sync test ID from provider_data to _test_context for storage operations
|
||||
test_context_token = _sync_test_context_from_provider_data()
|
||||
request_hash = normalize_tool_request(provider_name, tool_name, kwargs)
|
||||
|
||||
try:
|
||||
request_hash = normalize_tool_request(provider_name, tool_name, kwargs)
|
||||
if _current_mode in (APIRecordingMode.REPLAY, APIRecordingMode.RECORD_IF_MISSING):
|
||||
recording = _current_storage.find_recording(request_hash)
|
||||
if recording:
|
||||
return recording["response"]["body"]
|
||||
elif _current_mode == APIRecordingMode.REPLAY:
|
||||
raise RuntimeError(
|
||||
f"No recorded tool result found for {provider_name}.{tool_name}\n"
|
||||
f"Request: {kwargs}\n"
|
||||
f"To record this response, run with LLAMA_STACK_TEST_INFERENCE_MODE=record"
|
||||
)
|
||||
# If RECORD_IF_MISSING and no recording found, fall through to record
|
||||
|
||||
if _current_mode in (APIRecordingMode.REPLAY, APIRecordingMode.RECORD_IF_MISSING):
|
||||
recording = _current_storage.find_recording(request_hash)
|
||||
if recording:
|
||||
return recording["response"]["body"]
|
||||
elif _current_mode == APIRecordingMode.REPLAY:
|
||||
raise RuntimeError(
|
||||
f"No recorded tool result found for {provider_name}.{tool_name}\n"
|
||||
f"Request: {kwargs}\n"
|
||||
f"To record this response, run with LLAMA_STACK_TEST_INFERENCE_MODE=record"
|
||||
)
|
||||
# If RECORD_IF_MISSING and no recording found, fall through to record
|
||||
if _current_mode in (APIRecordingMode.RECORD, APIRecordingMode.RECORD_IF_MISSING):
|
||||
# Make the tool call and record it
|
||||
result = await original_method(self, tool_name, kwargs)
|
||||
|
||||
if _current_mode in (APIRecordingMode.RECORD, APIRecordingMode.RECORD_IF_MISSING):
|
||||
# Make the tool call and record it
|
||||
result = await original_method(self, tool_name, kwargs)
|
||||
request_data = {
|
||||
"test_id": get_test_context(),
|
||||
"provider": provider_name,
|
||||
"tool_name": tool_name,
|
||||
"kwargs": kwargs,
|
||||
}
|
||||
response_data = {"body": result, "is_streaming": False}
|
||||
|
||||
request_data = {
|
||||
"test_id": _test_context.get(),
|
||||
"provider": provider_name,
|
||||
"tool_name": tool_name,
|
||||
"kwargs": kwargs,
|
||||
}
|
||||
response_data = {"body": result, "is_streaming": False}
|
||||
# Store the recording
|
||||
_current_storage.store_recording(request_hash, request_data, response_data)
|
||||
return result
|
||||
|
||||
# Store the recording
|
||||
_current_storage.store_recording(request_hash, request_data, response_data)
|
||||
return result
|
||||
|
||||
else:
|
||||
raise AssertionError(f"Invalid mode: {_current_mode}")
|
||||
finally:
|
||||
# Reset test context if we set it in server mode
|
||||
if test_context_token is not None:
|
||||
_test_context.reset(test_context_token)
|
||||
else:
|
||||
raise AssertionError(f"Invalid mode: {_current_mode}")
|
||||
|
||||
|
||||
async def _patched_inference_method(original_method, self, client_type, endpoint, *args, **kwargs):
|
||||
|
|
@ -794,120 +562,108 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
|
|||
else:
|
||||
return await original_method(self, *args, **kwargs)
|
||||
|
||||
# In server mode, sync test ID from provider_data to _test_context for storage operations
|
||||
test_context_token = _sync_test_context_from_provider_data()
|
||||
# Get base URL based on client type
|
||||
if client_type == "openai":
|
||||
base_url = str(self._client.base_url)
|
||||
|
||||
try:
|
||||
# Get base URL based on client type
|
||||
if client_type == "openai":
|
||||
base_url = str(self._client.base_url)
|
||||
# the OpenAI client methods may pass NOT_GIVEN for unset parameters; filter these out
|
||||
kwargs = {k: v for k, v in kwargs.items() if v is not NOT_GIVEN}
|
||||
elif client_type == "ollama":
|
||||
# Get base URL from the client (Ollama client uses host attribute)
|
||||
base_url = getattr(self, "host", "http://localhost:11434")
|
||||
if not base_url.startswith("http"):
|
||||
base_url = f"http://{base_url}"
|
||||
else:
|
||||
raise ValueError(f"Unknown client type: {client_type}")
|
||||
|
||||
# the OpenAI client methods may pass NOT_GIVEN for unset parameters; filter these out
|
||||
kwargs = {k: v for k, v in kwargs.items() if v is not NOT_GIVEN}
|
||||
elif client_type == "ollama":
|
||||
# Get base URL from the client (Ollama client uses host attribute)
|
||||
base_url = getattr(self, "host", "http://localhost:11434")
|
||||
if not base_url.startswith("http"):
|
||||
base_url = f"http://{base_url}"
|
||||
url = base_url.rstrip("/") + endpoint
|
||||
# Special handling for Databricks URLs to avoid leaking workspace info
|
||||
# e.g. https://adb-1234567890123456.7.cloud.databricks.com -> https://...cloud.databricks.com
|
||||
if "cloud.databricks.com" in url:
|
||||
url = "__databricks__" + url.split("cloud.databricks.com")[-1]
|
||||
method = "POST"
|
||||
headers = {}
|
||||
body = kwargs
|
||||
|
||||
request_hash = normalize_inference_request(method, url, headers, body)
|
||||
|
||||
# Try to find existing recording for REPLAY or RECORD_IF_MISSING modes
|
||||
recording = None
|
||||
if mode == APIRecordingMode.REPLAY or mode == APIRecordingMode.RECORD_IF_MISSING:
|
||||
# Special handling for model-list endpoints: merge all recordings with this hash
|
||||
if endpoint in ("/api/tags", "/v1/models"):
|
||||
records = storage._model_list_responses(request_hash)
|
||||
recording = _combine_model_list_responses(endpoint, records)
|
||||
else:
|
||||
raise ValueError(f"Unknown client type: {client_type}")
|
||||
recording = storage.find_recording(request_hash)
|
||||
|
||||
url = base_url.rstrip("/") + endpoint
|
||||
# Special handling for Databricks URLs to avoid leaking workspace info
|
||||
# e.g. https://adb-1234567890123456.7.cloud.databricks.com -> https://...cloud.databricks.com
|
||||
if "cloud.databricks.com" in url:
|
||||
url = "__databricks__" + url.split("cloud.databricks.com")[-1]
|
||||
method = "POST"
|
||||
headers = {}
|
||||
body = kwargs
|
||||
if recording:
|
||||
response_body = recording["response"]["body"]
|
||||
|
||||
request_hash = normalize_inference_request(method, url, headers, body)
|
||||
if recording["response"].get("is_streaming", False):
|
||||
|
||||
# Try to find existing recording for REPLAY or RECORD_IF_MISSING modes
|
||||
recording = None
|
||||
if mode == APIRecordingMode.REPLAY or mode == APIRecordingMode.RECORD_IF_MISSING:
|
||||
# Special handling for model-list endpoints: merge all recordings with this hash
|
||||
if endpoint in ("/api/tags", "/v1/models"):
|
||||
records = storage._model_list_responses(request_hash)
|
||||
recording = _combine_model_list_responses(endpoint, records)
|
||||
else:
|
||||
recording = storage.find_recording(request_hash)
|
||||
|
||||
if recording:
|
||||
response_body = recording["response"]["body"]
|
||||
|
||||
if recording["response"].get("is_streaming", False) and isinstance(response_body, list):
|
||||
response_body = _coalesce_streaming_chunks(response_body)
|
||||
|
||||
if recording["response"].get("is_streaming", False):
|
||||
|
||||
async def replay_stream():
|
||||
for chunk in response_body:
|
||||
yield chunk
|
||||
|
||||
return replay_stream()
|
||||
else:
|
||||
return response_body
|
||||
elif mode == APIRecordingMode.REPLAY:
|
||||
# REPLAY mode requires recording to exist
|
||||
raise RuntimeError(
|
||||
f"No recorded response found for request hash: {request_hash}\n"
|
||||
f"Request: {method} {url} {body}\n"
|
||||
f"Model: {body.get('model', 'unknown')}\n"
|
||||
f"To record this response, run with LLAMA_STACK_TEST_INFERENCE_MODE=record"
|
||||
)
|
||||
|
||||
if mode == APIRecordingMode.RECORD or (mode == APIRecordingMode.RECORD_IF_MISSING and not recording):
|
||||
if endpoint == "/v1/models":
|
||||
response = original_method(self, *args, **kwargs)
|
||||
else:
|
||||
response = await original_method(self, *args, **kwargs)
|
||||
|
||||
# we want to store the result of the iterator, not the iterator itself
|
||||
if endpoint == "/v1/models":
|
||||
response = [m async for m in response]
|
||||
|
||||
request_data = {
|
||||
"method": method,
|
||||
"url": url,
|
||||
"headers": headers,
|
||||
"body": body,
|
||||
"endpoint": endpoint,
|
||||
"model": body.get("model", ""),
|
||||
}
|
||||
|
||||
# Determine if this is a streaming request based on request parameters
|
||||
is_streaming = body.get("stream", False)
|
||||
|
||||
if is_streaming:
|
||||
# For streaming responses, we need to collect all chunks immediately before yielding
|
||||
# This ensures the recording is saved even if the generator isn't fully consumed
|
||||
raw_chunks: list[Any] = []
|
||||
async for chunk in response:
|
||||
raw_chunks.append(chunk)
|
||||
|
||||
chunks = _coalesce_streaming_chunks(raw_chunks)
|
||||
|
||||
# Store the recording immediately
|
||||
response_data = {"body": chunks, "is_streaming": True}
|
||||
storage.store_recording(request_hash, request_data, response_data)
|
||||
|
||||
# Return a generator that replays the stored chunks
|
||||
async def replay_recorded_stream():
|
||||
for chunk in chunks:
|
||||
async def replay_stream():
|
||||
for chunk in response_body:
|
||||
yield chunk
|
||||
|
||||
return replay_recorded_stream()
|
||||
return replay_stream()
|
||||
else:
|
||||
response_data = {"body": response, "is_streaming": False}
|
||||
storage.store_recording(request_hash, request_data, response_data)
|
||||
return response
|
||||
return response_body
|
||||
elif mode == APIRecordingMode.REPLAY:
|
||||
# REPLAY mode requires recording to exist
|
||||
raise RuntimeError(
|
||||
f"No recorded response found for request hash: {request_hash}\n"
|
||||
f"Request: {method} {url} {body}\n"
|
||||
f"Model: {body.get('model', 'unknown')}\n"
|
||||
f"To record this response, run with LLAMA_STACK_TEST_INFERENCE_MODE=record"
|
||||
)
|
||||
|
||||
if mode == APIRecordingMode.RECORD or (mode == APIRecordingMode.RECORD_IF_MISSING and not recording):
|
||||
if endpoint == "/v1/models":
|
||||
response = original_method(self, *args, **kwargs)
|
||||
else:
|
||||
raise AssertionError(f"Invalid mode: {mode}")
|
||||
finally:
|
||||
if test_context_token:
|
||||
_test_context.reset(test_context_token)
|
||||
response = await original_method(self, *args, **kwargs)
|
||||
|
||||
# we want to store the result of the iterator, not the iterator itself
|
||||
if endpoint == "/v1/models":
|
||||
response = [m async for m in response]
|
||||
|
||||
request_data = {
|
||||
"method": method,
|
||||
"url": url,
|
||||
"headers": headers,
|
||||
"body": body,
|
||||
"endpoint": endpoint,
|
||||
"model": body.get("model", ""),
|
||||
}
|
||||
|
||||
# Determine if this is a streaming request based on request parameters
|
||||
is_streaming = body.get("stream", False)
|
||||
|
||||
if is_streaming:
|
||||
# For streaming responses, we need to collect all chunks immediately before yielding
|
||||
# This ensures the recording is saved even if the generator isn't fully consumed
|
||||
chunks: list[Any] = []
|
||||
async for chunk in response:
|
||||
chunks.append(chunk)
|
||||
|
||||
# Store the recording immediately
|
||||
response_data = {"body": chunks, "is_streaming": True}
|
||||
storage.store_recording(request_hash, request_data, response_data)
|
||||
|
||||
# Return a generator that replays the stored chunks
|
||||
async def replay_recorded_stream():
|
||||
for chunk in chunks:
|
||||
yield chunk
|
||||
|
||||
return replay_recorded_stream()
|
||||
else:
|
||||
response_data = {"body": response, "is_streaming": False}
|
||||
storage.store_recording(request_hash, request_data, response_data)
|
||||
return response
|
||||
|
||||
else:
|
||||
raise AssertionError(f"Invalid mode: {mode}")
|
||||
|
||||
|
||||
def patch_inference_clients():
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue