a lot of simplification finally. all works

This commit is contained in:
Ashwin Bharambe 2025-10-09 11:15:00 -07:00
parent b47bf340db
commit 00fd27be1f
39 changed files with 16027 additions and 1969 deletions

View file

@ -232,14 +232,25 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
await log_request_pre_validation(request)
test_context_token = None
# Use context manager with both provider data and auth attributes
with request_provider_data_context(request.headers, user):
if os.environ.get("LLAMA_STACK_TEST_INFERENCE_MODE"):
from llama_stack.core.testing_context import (
TEST_CONTEXT,
reset_test_context,
sync_test_context_from_provider_data,
)
test_context_token = sync_test_context_from_provider_data()
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
try:
if is_streaming:
gen = preserve_contexts_async_generator(
sse_generator(func(**kwargs)), [CURRENT_TRACE_CONTEXT, PROVIDER_DATA_VAR]
sse_generator(func(**kwargs)), [CURRENT_TRACE_CONTEXT, PROVIDER_DATA_VAR, TEST_CONTEXT]
)
return StreamingResponse(gen, media_type="text/event-stream")
else:
@ -258,6 +269,9 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
else:
logger.error(f"Error executing endpoint {route=} {method=}: {str(e)}")
raise translate_exception(e) from e
finally:
if test_context_token is not None:
reset_test_context(test_context_token)
sig = inspect.signature(func)

View file

@ -0,0 +1,44 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
from contextvars import ContextVar
from llama_stack.core.request_headers import PROVIDER_DATA_VAR
TEST_CONTEXT: ContextVar[str | None] = ContextVar("llama_stack_test_context", default=None)
def get_test_context() -> str | None:
return TEST_CONTEXT.get()
def set_test_context(value: str | None):
return TEST_CONTEXT.set(value)
def reset_test_context(token) -> None:
TEST_CONTEXT.reset(token)
def sync_test_context_from_provider_data():
"""Sync test context from provider data when running in server test mode."""
if "LLAMA_STACK_TEST_INFERENCE_MODE" not in os.environ:
return None
stack_config_type = os.environ.get("LLAMA_STACK_TEST_STACK_CONFIG_TYPE", "library_client")
if stack_config_type != "server":
return None
try:
provider_data = PROVIDER_DATA_VAR.get()
except LookupError:
provider_data = None
if provider_data and "__test_id" in provider_data:
return TEST_CONTEXT.set(provider_data["__test_id"])
return None

View file

@ -9,7 +9,6 @@ from __future__ import annotations # for forward references
import hashlib
import json
import os
import re
from collections.abc import Callable, Generator
from contextlib import contextmanager
from enum import StrEnum
@ -35,12 +34,10 @@ _original_methods: dict[str, Any] = {}
_id_counters: dict[str, dict[str, int]] = {}
# Test context uses ContextVar since it changes per-test and needs async isolation
from contextvars import ContextVar
_test_context: ContextVar[str | None] = ContextVar("_test_context", default=None)
from openai.types.completion_choice import CompletionChoice
from llama_stack.core.testing_context import get_test_context
# update the "finish_reason" field, since its type definition is wrong (no None is accepted)
CompletionChoice.model_fields["finish_reason"].annotation = Literal["stop", "length", "content_filter"] | None
CompletionChoice.model_rebuild()
@ -60,6 +57,7 @@ _ID_KIND_PREFIXES: dict[str, str] = {
"file": "file-",
"vector_store": "vs_",
"vector_store_file_batch": "batch_",
"tool_call": "call_",
}
@ -68,176 +66,32 @@ def _allocate_test_scoped_id(kind: str) -> str | None:
global _id_counters
test_id = _test_context.get()
test_id = get_test_context()
prefix = _ID_KIND_PREFIXES.get(kind)
if prefix is None:
return None
key = test_id or "__global__"
if not test_id:
raise ValueError(f"Test ID is required for {kind} ID allocation")
key = test_id
if key not in _id_counters:
_id_counters[key] = {}
counter = _id_counters[key].get(kind, 0) + 1
# each test should get a contiguous block of IDs otherwise we will get
# collisions between tests inside other systems (like file storage) which
# expect IDs to be unique
test_hash = hashlib.sha256(test_id.encode()).hexdigest()
test_hash_int = int(test_hash, 16)
counter = test_hash_int % 1000000000000
counter = _id_counters[key].get(kind, counter) + 1
_id_counters[key][kind] = counter
return f"{prefix}{counter}"
class _IdCanonicalizer:
PATTERN = re.compile(r"(file-[A-Za-z0-9_-]+|vs_[A-Za-z0-9_-]+|batch_[A-Za-z0-9_-]+)")
def __init__(self) -> None:
self._mappings: dict[str, dict[str, str]] = {kind: {} for kind in _ID_KIND_PREFIXES}
self._counters: dict[str, int] = dict.fromkeys(_ID_KIND_PREFIXES, 0)
def canonicalize(self, obj: Any) -> Any:
if isinstance(obj, dict):
return {k: self._canonicalize_value(k, v) for k, v in obj.items()}
if isinstance(obj, list):
return [self.canonicalize(item) for item in obj]
if isinstance(obj, str):
return self._canonicalize_string(obj)
return obj
def _canonicalize_value(self, key: str, value: Any) -> Any:
if key in {"vector_db_id", "vector_store_id", "bank_id"} and isinstance(value, str):
return self._canonicalize_string(value)
if key == "document_id" and isinstance(value, str) and value.startswith("file-"):
return self._canonicalize_string(value)
return self.canonicalize(value)
def _canonicalize_string(self, value: str) -> str:
def replace(match: re.Match[str]) -> str:
token = match.group(0)
if token.startswith("file-"):
return self._mapped_value("file", token)
if token.startswith("vs_"):
return self._mapped_value("vector_store", token)
if token.startswith("batch_"):
return self._mapped_value("vector_store_file_batch", token)
return token
return self.PATTERN.sub(replace, value)
def _mapped_value(self, kind: str, original: str) -> str:
mapping = self._mappings[kind]
if original not in mapping:
self._counters[kind] += 1
mapping[original] = f"{_ID_KIND_PREFIXES[kind]}{self._counters[kind]}"
return mapping[original]
def _canonicalize_for_hashing(obj: Any) -> Any:
canonicalizer = _IdCanonicalizer()
return canonicalizer.canonicalize(obj)
def _chunk_text_content(chunk: Any) -> tuple[str | None, bool]:
"""Return (content, has_structured_fields) for OpenAI chat completion chunks."""
choices = getattr(chunk, "choices", None)
if not choices:
return None, False
delta = choices[0].delta
content = getattr(delta, "content", None)
if not content:
return None, False
has_structured = bool(getattr(delta, "tool_calls", None) or getattr(delta, "function_call", None))
return content, has_structured
def _chunk_with_content(chunk: Any, content: str) -> Any:
"""Return a copy of the chunk with delta.content replaced by the provided string."""
choices = getattr(chunk, "choices", None)
if not choices:
return chunk
updated_choices = []
for choice in choices:
delta = choice.delta
if getattr(delta, "content", None) is not None:
new_delta = delta.model_copy(update={"content": content})
updated_choices.append(choice.model_copy(update={"delta": new_delta}))
else:
updated_choices.append(choice)
return chunk.model_copy(update={"choices": updated_choices})
def _ends_with_partial_identifier(text: str) -> bool:
"""Return True if text ends in an incomplete file identifier."""
match = re.search(r"(?:<\|)?file-[A-Za-z0-9_-]*$", text)
if not match:
return False
token = match.group()
enclosed = token.startswith("<|")
if enclosed and not token.endswith("|>"):
return True
if enclosed:
core = token[2:-2] if token.endswith("|>") else token[2:]
else:
core = token
suffix = core[len("file-") :]
if len(suffix) < 16:
return True
if not re.fullmatch(r"[A-Za-z0-9_-]+", suffix):
return True
return False
def _has_safe_boundary(text: str) -> bool:
if not text:
return False
last_char = text[-1]
if last_char.isspace():
return True
return last_char in ".,?!;:)]}>\"'"
def _coalesce_streaming_chunks(chunks: list[Any]) -> list[Any]:
"""Merge adjacent text chunks to avoid breaking identifiers across boundaries."""
result: list[Any] = []
pending_chunk: Any | None = None
pending_content = ""
for chunk in chunks:
content, has_structured = _chunk_text_content(chunk)
if content is None or has_structured:
if pending_chunk is not None:
result.append(_chunk_with_content(pending_chunk, pending_content))
pending_chunk = None
pending_content = ""
result.append(chunk)
continue
if pending_chunk is None:
pending_chunk = chunk
pending_content = content
else:
pending_content += content
if (not _ends_with_partial_identifier(pending_content)) and _has_safe_boundary(pending_content):
result.append(_chunk_with_content(pending_chunk, pending_content))
pending_chunk = None
pending_content = ""
if pending_chunk is not None:
result.append(_chunk_with_content(pending_chunk, pending_content))
return result
def _deterministic_id_override(kind: str, factory: Callable[[], str]) -> str:
deterministic_id = _allocate_test_scoped_id(kind)
if deterministic_id is not None:
@ -262,12 +116,12 @@ def normalize_inference_request(method: str, url: str, headers: dict[str, Any],
normalized: dict[str, Any] = {
"method": method.upper(),
"endpoint": parsed.path,
"body": _canonicalize_for_hashing(body),
"body": body,
}
# Include test_id for isolation, except for shared infrastructure endpoints
if parsed.path not in ("/api/tags", "/v1/models"):
normalized["test_id"] = _test_context.get()
normalized["test_id"] = get_test_context()
# Create hash - sort_keys=True ensures deterministic ordering
normalized_json = json.dumps(normalized, sort_keys=True)
@ -279,7 +133,7 @@ def normalize_tool_request(provider_name: str, tool_name: str, kwargs: dict[str,
normalized = {
"provider": provider_name,
"tool_name": tool_name,
"kwargs": _canonicalize_for_hashing(kwargs),
"kwargs": kwargs,
}
# Create hash - sort_keys=True ensures deterministic ordering
@ -287,33 +141,6 @@ def normalize_tool_request(provider_name: str, tool_name: str, kwargs: dict[str,
return hashlib.sha256(normalized_json.encode()).hexdigest()
def _sync_test_context_from_provider_data():
"""In server mode, sync test ID from provider_data to _test_context.
This ensures that storage operations (which read from _test_context) work correctly
in server mode where the test ID arrives via HTTP header provider_data.
Returns a token to reset _test_context, or None if no sync was needed.
"""
stack_config_type = os.environ.get("LLAMA_STACK_TEST_STACK_CONFIG_TYPE", "library_client")
if stack_config_type != "server":
return None
try:
from llama_stack.core.request_headers import PROVIDER_DATA_VAR
provider_data = PROVIDER_DATA_VAR.get()
if provider_data and "__test_id" in provider_data:
test_id = provider_data["__test_id"]
return _test_context.set(test_id)
except ImportError:
pass
return None
def patch_httpx_for_test_id():
"""Patch client _prepare_request methods to inject test ID into provider data header.
@ -335,13 +162,12 @@ def patch_httpx_for_test_id():
def patched_prepare_request(self, request):
# Call original first (it's a sync method that returns None)
# Determine which original to call based on client type
if "llama_stack_client" in self.__class__.__module__:
_original_methods["llama_stack_client_prepare_request"](self, request)
_original_methods["openai_prepare_request"](self, request)
_original_methods["llama_stack_client_prepare_request"](self, request)
_original_methods["openai_prepare_request"](self, request)
# Only inject test ID in server mode
stack_config_type = os.environ.get("LLAMA_STACK_TEST_STACK_CONFIG_TYPE", "library_client")
test_id = _test_context.get()
test_id = get_test_context()
if stack_config_type == "server" and test_id:
provider_data_header = request.headers.get("X-LlamaStack-Provider-Data")
@ -482,8 +308,6 @@ class ResponseStorage:
def __init__(self, base_dir: Path):
self.base_dir = base_dir
# Don't create responses_dir here - determine it per-test at runtime
self._legacy_index: dict[str, Path] = {}
self._scanned_dirs: set[Path] = set()
def _get_test_dir(self) -> Path:
"""Get the recordings directory in the test file's parent directory.
@ -491,7 +315,7 @@ class ResponseStorage:
For test at "tests/integration/inference/test_foo.py::test_bar",
returns "tests/integration/inference/recordings/".
"""
test_id = _test_context.get()
test_id = get_test_context()
if test_id:
# Extract the directory path from the test nodeid
# e.g., "tests/integration/inference/test_basic.py::test_foo[params]"
@ -506,7 +330,7 @@ class ResponseStorage:
# Fallback for non-test contexts
return self.base_dir / "recordings"
def _ensure_directories(self):
def _ensure_directory(self):
"""Ensure test-specific directories exist."""
test_dir = self._get_test_dir()
test_dir.mkdir(parents=True, exist_ok=True)
@ -514,7 +338,7 @@ class ResponseStorage:
def store_recording(self, request_hash: str, request: dict[str, Any], response: dict[str, Any]):
"""Store a request/response pair."""
responses_dir = self._ensure_directories()
responses_dir = self._ensure_directory()
# Use FULL hash (not truncated)
response_file = f"{request_hash}.json"
@ -543,7 +367,7 @@ class ResponseStorage:
with open(response_path, "w") as f:
json.dump(
{
"test_id": _test_context.get(),
"test_id": get_test_context(),
"request": request,
"response": serialized_response,
"id_normalization_mapping": {},
@ -554,8 +378,6 @@ class ResponseStorage:
f.write("\n")
f.flush()
self._legacy_index[request_hash] = response_path
def find_recording(self, request_hash: str) -> dict[str, Any] | None:
"""Find a recorded response by request hash.
@ -579,52 +401,6 @@ class ResponseStorage:
if fallback_path.exists():
return _recording_from_file(fallback_path)
return self._find_in_legacy_index(request_hash, [test_dir, fallback_dir])
def _find_in_legacy_index(self, request_hash: str, directories: list[Path]) -> dict[str, Any] | None:
for directory in directories:
if not directory.exists() or directory in self._scanned_dirs:
continue
for path in directory.glob("*.json"):
try:
with open(path) as f:
data = json.load(f)
except Exception:
continue
request = data.get("request")
if not request:
continue
body = request.get("body")
canonical_body = _canonicalize_for_hashing(body) if isinstance(body, dict | list) else body
token = None
test_id = data.get("test_id")
if test_id:
token = _test_context.set(test_id)
try:
legacy_hash = normalize_inference_request(
request.get("method", ""),
request.get("url", ""),
request.get("headers", {}),
canonical_body,
)
finally:
if token is not None:
_test_context.reset(token)
if legacy_hash not in self._legacy_index:
self._legacy_index[legacy_hash] = path
self._scanned_dirs.add(directory)
legacy_path = self._legacy_index.get(request_hash)
if legacy_path and legacy_path.exists():
return _recording_from_file(legacy_path)
return None
def _model_list_responses(self, request_hash: str) -> list[dict[str, Any]]:
@ -740,46 +516,38 @@ async def _patched_tool_invoke_method(
# Normal operation
return await original_method(self, tool_name, kwargs)
# In server mode, sync test ID from provider_data to _test_context for storage operations
test_context_token = _sync_test_context_from_provider_data()
request_hash = normalize_tool_request(provider_name, tool_name, kwargs)
try:
request_hash = normalize_tool_request(provider_name, tool_name, kwargs)
if _current_mode in (APIRecordingMode.REPLAY, APIRecordingMode.RECORD_IF_MISSING):
recording = _current_storage.find_recording(request_hash)
if recording:
return recording["response"]["body"]
elif _current_mode == APIRecordingMode.REPLAY:
raise RuntimeError(
f"No recorded tool result found for {provider_name}.{tool_name}\n"
f"Request: {kwargs}\n"
f"To record this response, run with LLAMA_STACK_TEST_INFERENCE_MODE=record"
)
# If RECORD_IF_MISSING and no recording found, fall through to record
if _current_mode in (APIRecordingMode.REPLAY, APIRecordingMode.RECORD_IF_MISSING):
recording = _current_storage.find_recording(request_hash)
if recording:
return recording["response"]["body"]
elif _current_mode == APIRecordingMode.REPLAY:
raise RuntimeError(
f"No recorded tool result found for {provider_name}.{tool_name}\n"
f"Request: {kwargs}\n"
f"To record this response, run with LLAMA_STACK_TEST_INFERENCE_MODE=record"
)
# If RECORD_IF_MISSING and no recording found, fall through to record
if _current_mode in (APIRecordingMode.RECORD, APIRecordingMode.RECORD_IF_MISSING):
# Make the tool call and record it
result = await original_method(self, tool_name, kwargs)
if _current_mode in (APIRecordingMode.RECORD, APIRecordingMode.RECORD_IF_MISSING):
# Make the tool call and record it
result = await original_method(self, tool_name, kwargs)
request_data = {
"test_id": get_test_context(),
"provider": provider_name,
"tool_name": tool_name,
"kwargs": kwargs,
}
response_data = {"body": result, "is_streaming": False}
request_data = {
"test_id": _test_context.get(),
"provider": provider_name,
"tool_name": tool_name,
"kwargs": kwargs,
}
response_data = {"body": result, "is_streaming": False}
# Store the recording
_current_storage.store_recording(request_hash, request_data, response_data)
return result
# Store the recording
_current_storage.store_recording(request_hash, request_data, response_data)
return result
else:
raise AssertionError(f"Invalid mode: {_current_mode}")
finally:
# Reset test context if we set it in server mode
if test_context_token is not None:
_test_context.reset(test_context_token)
else:
raise AssertionError(f"Invalid mode: {_current_mode}")
async def _patched_inference_method(original_method, self, client_type, endpoint, *args, **kwargs):
@ -794,120 +562,108 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
else:
return await original_method(self, *args, **kwargs)
# In server mode, sync test ID from provider_data to _test_context for storage operations
test_context_token = _sync_test_context_from_provider_data()
# Get base URL based on client type
if client_type == "openai":
base_url = str(self._client.base_url)
try:
# Get base URL based on client type
if client_type == "openai":
base_url = str(self._client.base_url)
# the OpenAI client methods may pass NOT_GIVEN for unset parameters; filter these out
kwargs = {k: v for k, v in kwargs.items() if v is not NOT_GIVEN}
elif client_type == "ollama":
# Get base URL from the client (Ollama client uses host attribute)
base_url = getattr(self, "host", "http://localhost:11434")
if not base_url.startswith("http"):
base_url = f"http://{base_url}"
else:
raise ValueError(f"Unknown client type: {client_type}")
# the OpenAI client methods may pass NOT_GIVEN for unset parameters; filter these out
kwargs = {k: v for k, v in kwargs.items() if v is not NOT_GIVEN}
elif client_type == "ollama":
# Get base URL from the client (Ollama client uses host attribute)
base_url = getattr(self, "host", "http://localhost:11434")
if not base_url.startswith("http"):
base_url = f"http://{base_url}"
url = base_url.rstrip("/") + endpoint
# Special handling for Databricks URLs to avoid leaking workspace info
# e.g. https://adb-1234567890123456.7.cloud.databricks.com -> https://...cloud.databricks.com
if "cloud.databricks.com" in url:
url = "__databricks__" + url.split("cloud.databricks.com")[-1]
method = "POST"
headers = {}
body = kwargs
request_hash = normalize_inference_request(method, url, headers, body)
# Try to find existing recording for REPLAY or RECORD_IF_MISSING modes
recording = None
if mode == APIRecordingMode.REPLAY or mode == APIRecordingMode.RECORD_IF_MISSING:
# Special handling for model-list endpoints: merge all recordings with this hash
if endpoint in ("/api/tags", "/v1/models"):
records = storage._model_list_responses(request_hash)
recording = _combine_model_list_responses(endpoint, records)
else:
raise ValueError(f"Unknown client type: {client_type}")
recording = storage.find_recording(request_hash)
url = base_url.rstrip("/") + endpoint
# Special handling for Databricks URLs to avoid leaking workspace info
# e.g. https://adb-1234567890123456.7.cloud.databricks.com -> https://...cloud.databricks.com
if "cloud.databricks.com" in url:
url = "__databricks__" + url.split("cloud.databricks.com")[-1]
method = "POST"
headers = {}
body = kwargs
if recording:
response_body = recording["response"]["body"]
request_hash = normalize_inference_request(method, url, headers, body)
if recording["response"].get("is_streaming", False):
# Try to find existing recording for REPLAY or RECORD_IF_MISSING modes
recording = None
if mode == APIRecordingMode.REPLAY or mode == APIRecordingMode.RECORD_IF_MISSING:
# Special handling for model-list endpoints: merge all recordings with this hash
if endpoint in ("/api/tags", "/v1/models"):
records = storage._model_list_responses(request_hash)
recording = _combine_model_list_responses(endpoint, records)
else:
recording = storage.find_recording(request_hash)
if recording:
response_body = recording["response"]["body"]
if recording["response"].get("is_streaming", False) and isinstance(response_body, list):
response_body = _coalesce_streaming_chunks(response_body)
if recording["response"].get("is_streaming", False):
async def replay_stream():
for chunk in response_body:
yield chunk
return replay_stream()
else:
return response_body
elif mode == APIRecordingMode.REPLAY:
# REPLAY mode requires recording to exist
raise RuntimeError(
f"No recorded response found for request hash: {request_hash}\n"
f"Request: {method} {url} {body}\n"
f"Model: {body.get('model', 'unknown')}\n"
f"To record this response, run with LLAMA_STACK_TEST_INFERENCE_MODE=record"
)
if mode == APIRecordingMode.RECORD or (mode == APIRecordingMode.RECORD_IF_MISSING and not recording):
if endpoint == "/v1/models":
response = original_method(self, *args, **kwargs)
else:
response = await original_method(self, *args, **kwargs)
# we want to store the result of the iterator, not the iterator itself
if endpoint == "/v1/models":
response = [m async for m in response]
request_data = {
"method": method,
"url": url,
"headers": headers,
"body": body,
"endpoint": endpoint,
"model": body.get("model", ""),
}
# Determine if this is a streaming request based on request parameters
is_streaming = body.get("stream", False)
if is_streaming:
# For streaming responses, we need to collect all chunks immediately before yielding
# This ensures the recording is saved even if the generator isn't fully consumed
raw_chunks: list[Any] = []
async for chunk in response:
raw_chunks.append(chunk)
chunks = _coalesce_streaming_chunks(raw_chunks)
# Store the recording immediately
response_data = {"body": chunks, "is_streaming": True}
storage.store_recording(request_hash, request_data, response_data)
# Return a generator that replays the stored chunks
async def replay_recorded_stream():
for chunk in chunks:
async def replay_stream():
for chunk in response_body:
yield chunk
return replay_recorded_stream()
return replay_stream()
else:
response_data = {"body": response, "is_streaming": False}
storage.store_recording(request_hash, request_data, response_data)
return response
return response_body
elif mode == APIRecordingMode.REPLAY:
# REPLAY mode requires recording to exist
raise RuntimeError(
f"No recorded response found for request hash: {request_hash}\n"
f"Request: {method} {url} {body}\n"
f"Model: {body.get('model', 'unknown')}\n"
f"To record this response, run with LLAMA_STACK_TEST_INFERENCE_MODE=record"
)
if mode == APIRecordingMode.RECORD or (mode == APIRecordingMode.RECORD_IF_MISSING and not recording):
if endpoint == "/v1/models":
response = original_method(self, *args, **kwargs)
else:
raise AssertionError(f"Invalid mode: {mode}")
finally:
if test_context_token:
_test_context.reset(test_context_token)
response = await original_method(self, *args, **kwargs)
# we want to store the result of the iterator, not the iterator itself
if endpoint == "/v1/models":
response = [m async for m in response]
request_data = {
"method": method,
"url": url,
"headers": headers,
"body": body,
"endpoint": endpoint,
"model": body.get("model", ""),
}
# Determine if this is a streaming request based on request parameters
is_streaming = body.get("stream", False)
if is_streaming:
# For streaming responses, we need to collect all chunks immediately before yielding
# This ensures the recording is saved even if the generator isn't fully consumed
chunks: list[Any] = []
async for chunk in response:
chunks.append(chunk)
# Store the recording immediately
response_data = {"body": chunks, "is_streaming": True}
storage.store_recording(request_hash, request_data, response_data)
# Return a generator that replays the stored chunks
async def replay_recorded_stream():
for chunk in chunks:
yield chunk
return replay_recorded_stream()
else:
response_data = {"body": response, "is_streaming": False}
storage.store_recording(request_hash, request_data, response_data)
return response
else:
raise AssertionError(f"Invalid mode: {mode}")
def patch_inference_clients():