mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-12 20:12:33 +00:00
redid api_recorder.py now, step 0
This commit is contained in:
parent
9205731cd6
commit
8414c30859
1 changed files with 241 additions and 368 deletions
|
|
@ -15,20 +15,19 @@ from enum import StrEnum
|
|||
from pathlib import Path
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
from openai import NOT_GIVEN
|
||||
from openai import NOT_GIVEN, OpenAI
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
logger = get_logger(__name__, category="testing")
|
||||
|
||||
# Global state for the API recording system
|
||||
# Global state for the recording system
|
||||
# Note: Using module globals instead of ContextVars because the session-scoped
|
||||
# client initialization happens in one async context, but tests run in different
|
||||
# contexts, and we need the mode/storage to persist across all contexts.
|
||||
_current_mode: str | None = None
|
||||
_current_storage: ResponseStorage | None = None
|
||||
_original_methods: dict[str, Any] = {}
|
||||
_memory_cache: dict[str, dict[str, Any]] = {}
|
||||
|
||||
# Test context uses ContextVar since it changes per-test and needs async isolation
|
||||
from contextvars import ContextVar
|
||||
|
|
@ -52,29 +51,162 @@ class APIRecordingMode(StrEnum):
|
|||
RECORD_IF_MISSING = "record-if-missing"
|
||||
|
||||
|
||||
def _normalize_file_ids(obj: Any) -> Any:
|
||||
"""Recursively replace file IDs with a canonical placeholder for consistent hashing."""
|
||||
import re
|
||||
def normalize_inference_request(method: str, url: str, headers: dict[str, Any], body: dict[str, Any]) -> str:
|
||||
"""Create a normalized hash of the request for consistent matching.
|
||||
|
||||
if isinstance(obj, dict):
|
||||
result = {}
|
||||
for k, v in obj.items():
|
||||
# Normalize file IDs in attribute dictionaries
|
||||
if k == "document_id" and isinstance(v, str) and v.startswith("file-"):
|
||||
result[k] = "file-NORMALIZED"
|
||||
Includes test_id from context to ensure test isolation - identical requests
|
||||
from different tests will have different hashes.
|
||||
|
||||
Exception: Model list endpoints (/v1/models, /api/tags) exclude test_id since
|
||||
they are infrastructure/shared and need to work across session setup and tests.
|
||||
"""
|
||||
# Extract just the endpoint path
|
||||
from urllib.parse import urlparse
|
||||
|
||||
parsed = urlparse(url)
|
||||
normalized: dict[str, Any] = {
|
||||
"method": method.upper(),
|
||||
"endpoint": parsed.path,
|
||||
"body": body,
|
||||
}
|
||||
|
||||
# Include test_id for isolation, except for shared infrastructure endpoints
|
||||
if parsed.path not in ("/api/tags", "/v1/models"):
|
||||
normalized["test_id"] = _test_context.get()
|
||||
|
||||
# Create hash - sort_keys=True ensures deterministic ordering
|
||||
normalized_json = json.dumps(normalized, sort_keys=True)
|
||||
return hashlib.sha256(normalized_json.encode()).hexdigest()
|
||||
|
||||
|
||||
def normalize_tool_request(provider_name: str, tool_name: str, kwargs: dict[str, Any]) -> str:
|
||||
"""Create a normalized hash of the tool request for consistent matching."""
|
||||
normalized = {"provider": provider_name, "tool_name": tool_name, "kwargs": kwargs}
|
||||
|
||||
# Create hash - sort_keys=True ensures deterministic ordering
|
||||
normalized_json = json.dumps(normalized, sort_keys=True)
|
||||
return hashlib.sha256(normalized_json.encode()).hexdigest()
|
||||
|
||||
|
||||
def _sync_test_context_from_provider_data():
|
||||
"""In server mode, sync test ID from provider_data to _test_context.
|
||||
|
||||
This ensures that storage operations (which read from _test_context) work correctly
|
||||
in server mode where the test ID arrives via HTTP header → provider_data.
|
||||
|
||||
Returns a token to reset _test_context, or None if no sync was needed.
|
||||
"""
|
||||
stack_config_type = os.environ.get("LLAMA_STACK_TEST_STACK_CONFIG_TYPE", "library_client")
|
||||
|
||||
if stack_config_type != "server":
|
||||
return None
|
||||
|
||||
try:
|
||||
from llama_stack.core.request_headers import PROVIDER_DATA_VAR
|
||||
|
||||
provider_data = PROVIDER_DATA_VAR.get()
|
||||
|
||||
if provider_data and "__test_id" in provider_data:
|
||||
test_id = provider_data["__test_id"]
|
||||
return _test_context.set(test_id)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def patch_httpx_for_test_id():
|
||||
"""Patch client _prepare_request methods to inject test ID into provider data header.
|
||||
|
||||
This is needed for server mode where the test ID must be transported from
|
||||
client to server via HTTP headers. In library_client mode, this patch is a no-op
|
||||
since everything runs in the same process.
|
||||
|
||||
We use the _prepare_request hook that Stainless clients provide for mutating
|
||||
requests after construction but before sending.
|
||||
"""
|
||||
from llama_stack_client import LlamaStackClient
|
||||
|
||||
if "llama_stack_client_prepare_request" in _original_methods:
|
||||
return
|
||||
|
||||
_original_methods["llama_stack_client_prepare_request"] = LlamaStackClient._prepare_request
|
||||
_original_methods["openai_prepare_request"] = OpenAI._prepare_request
|
||||
|
||||
def patched_prepare_request(self, request):
|
||||
# Call original first (it's a sync method that returns None)
|
||||
# Determine which original to call based on client type
|
||||
if "llama_stack_client" in self.__class__.__module__:
|
||||
_original_methods["llama_stack_client_prepare_request"](self, request)
|
||||
_original_methods["openai_prepare_request"](self, request)
|
||||
|
||||
# Only inject test ID in server mode
|
||||
stack_config_type = os.environ.get("LLAMA_STACK_TEST_STACK_CONFIG_TYPE", "library_client")
|
||||
test_id = _test_context.get()
|
||||
|
||||
if stack_config_type == "server" and test_id:
|
||||
provider_data_header = request.headers.get("X-LlamaStack-Provider-Data")
|
||||
|
||||
if provider_data_header:
|
||||
provider_data = json.loads(provider_data_header)
|
||||
else:
|
||||
result[k] = _normalize_file_ids(v)
|
||||
return result
|
||||
elif isinstance(obj, list):
|
||||
return [_normalize_file_ids(item) for item in obj]
|
||||
elif isinstance(obj, str):
|
||||
# Replace file-<uuid> patterns in strings (like in text content)
|
||||
return re.sub(r"file-[a-f0-9]{32}", "file-NORMALIZED", obj)
|
||||
else:
|
||||
return obj
|
||||
provider_data = {}
|
||||
|
||||
provider_data["__test_id"] = test_id
|
||||
request.headers["X-LlamaStack-Provider-Data"] = json.dumps(provider_data)
|
||||
|
||||
return None
|
||||
|
||||
LlamaStackClient._prepare_request = patched_prepare_request
|
||||
OpenAI._prepare_request = patched_prepare_request
|
||||
|
||||
|
||||
def _normalize_response_data(data: dict[str, Any], request_hash: str) -> dict[str, Any]:
|
||||
# currently, unpatch is never called
|
||||
def unpatch_httpx_for_test_id():
|
||||
"""Remove client _prepare_request patches for test ID injection."""
|
||||
if "llama_stack_client_prepare_request" not in _original_methods:
|
||||
return
|
||||
|
||||
from llama_stack_client import LlamaStackClient
|
||||
|
||||
LlamaStackClient._prepare_request = _original_methods["llama_stack_client_prepare_request"]
|
||||
del _original_methods["llama_stack_client_prepare_request"]
|
||||
OpenAI._prepare_request = _original_methods["openai_prepare_request"]
|
||||
del _original_methods["openai_prepare_request"]
|
||||
|
||||
|
||||
def get_api_recording_mode() -> APIRecordingMode:
|
||||
return APIRecordingMode(os.environ.get("LLAMA_STACK_TEST_INFERENCE_MODE", "replay").lower())
|
||||
|
||||
|
||||
def setup_api_recording():
|
||||
"""
|
||||
Returns a context manager that can be used to record or replay API requests (inference and tools).
|
||||
This is to be used in tests to increase their reliability and reduce reliance on expensive, external services.
|
||||
|
||||
Currently supports:
|
||||
- Inference: OpenAI and Ollama clients
|
||||
- Tools: Search providers (Tavily)
|
||||
|
||||
Two environment variables are supported:
|
||||
- LLAMA_STACK_TEST_INFERENCE_MODE: The mode to run in. Must be 'live', 'record', 'replay', or 'record-if-missing'. Default is 'replay'.
|
||||
- 'live': Make all requests live without recording
|
||||
- 'record': Record all requests (overwrites existing recordings)
|
||||
- 'replay': Use only recorded responses (fails if recording not found)
|
||||
- 'record-if-missing': Use recorded responses when available, record new ones when not found
|
||||
- LLAMA_STACK_TEST_RECORDING_DIR: The directory to store the recordings in. Default is 'tests/integration/recordings'.
|
||||
|
||||
The recordings are stored as JSON files.
|
||||
"""
|
||||
mode = get_api_recording_mode()
|
||||
if mode == APIRecordingMode.LIVE:
|
||||
return None
|
||||
|
||||
storage_dir = os.environ.get("LLAMA_STACK_TEST_RECORDING_DIR", DEFAULT_STORAGE_DIR)
|
||||
return api_recording(mode=mode, storage_dir=storage_dir)
|
||||
|
||||
|
||||
def _normalize_response(data: dict[str, Any], request_hash: str) -> dict[str, Any]:
|
||||
"""Normalize fields that change between recordings but don't affect functionality.
|
||||
|
||||
This reduces noise in git diffs by making IDs deterministic and timestamps constant.
|
||||
|
|
@ -106,184 +238,11 @@ def _normalize_response_data(data: dict[str, Any], request_hash: str) -> dict[st
|
|||
return data
|
||||
|
||||
|
||||
def normalize_request(method: str, url: str, headers: dict[str, Any], body: dict[str, Any]) -> str:
|
||||
"""Create a normalized hash of the request for consistent matching.
|
||||
|
||||
Includes test_id from context to ensure test isolation - identical requests
|
||||
from different tests will have different hashes.
|
||||
"""
|
||||
# Extract just the endpoint path
|
||||
from urllib.parse import urlparse
|
||||
|
||||
parsed = urlparse(url)
|
||||
|
||||
# Normalize file IDs in the body to ensure consistent hashing across test runs
|
||||
normalized_body = _normalize_file_ids(body)
|
||||
|
||||
normalized: dict[str, Any] = {"method": method.upper(), "endpoint": parsed.path, "body": normalized_body}
|
||||
|
||||
# Include test_id for isolation, except for shared infrastructure endpoints
|
||||
if parsed.path not in ("/api/tags", "/v1/models"):
|
||||
# Server mode: test ID was synced from provider_data to _test_context
|
||||
# by _sync_test_context_from_provider_data() at the start of the request.
|
||||
# We read from _test_context because it's available in all contexts (including
|
||||
# when making outgoing API calls), whereas PROVIDER_DATA_VAR is only set
|
||||
# for incoming HTTP requests.
|
||||
#
|
||||
# Library client mode: test ID in same-process ContextVar
|
||||
test_id = _test_context.get()
|
||||
normalized["test_id"] = test_id
|
||||
|
||||
# Create hash - sort_keys=True ensures deterministic ordering
|
||||
normalized_json = json.dumps(normalized, sort_keys=True)
|
||||
request_hash = hashlib.sha256(normalized_json.encode()).hexdigest()
|
||||
|
||||
return request_hash
|
||||
|
||||
|
||||
def normalize_tool_request(provider_name: str, tool_name: str, kwargs: dict[str, Any]) -> str:
|
||||
"""Create a normalized hash of the tool request for consistent matching."""
|
||||
normalized = {"provider": provider_name, "tool_name": tool_name, "kwargs": kwargs}
|
||||
|
||||
# Create hash - sort_keys=True ensures deterministic ordering
|
||||
normalized_json = json.dumps(normalized, sort_keys=True)
|
||||
return hashlib.sha256(normalized_json.encode()).hexdigest()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def set_test_context(test_id: str) -> Generator[None, None, None]:
|
||||
"""Set the test context for recording isolation.
|
||||
|
||||
Usage:
|
||||
with set_test_context("test_basic_completion"):
|
||||
# Make API calls that will be recorded with this test_id
|
||||
response = client.chat.completions.create(...)
|
||||
"""
|
||||
token = _test_context.set(test_id)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_test_context.reset(token)
|
||||
|
||||
|
||||
def patch_httpx_for_test_id():
|
||||
"""Patch client _prepare_request methods to inject test ID into provider data header.
|
||||
|
||||
Patches both LlamaStackClient and OpenAI client to ensure test ID is transported
|
||||
from client to server via HTTP headers in server mode.
|
||||
|
||||
This is needed for server mode where the test ID must be transported from
|
||||
client to server via HTTP headers. In library_client mode, this patch is a no-op
|
||||
since everything runs in the same process.
|
||||
|
||||
We use the _prepare_request hook that Stainless clients provide for mutating
|
||||
requests after construction but before sending.
|
||||
"""
|
||||
from llama_stack_client import LlamaStackClient
|
||||
|
||||
if "llama_stack_client_prepare_request" in _original_methods:
|
||||
# Already patched
|
||||
return
|
||||
|
||||
# Save original methods
|
||||
_original_methods["llama_stack_client_prepare_request"] = LlamaStackClient._prepare_request
|
||||
|
||||
# Also patch OpenAI client if available (used in compat tests)
|
||||
try:
|
||||
from openai import OpenAI
|
||||
|
||||
_original_methods["openai_prepare_request"] = OpenAI._prepare_request
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
def patched_prepare_request(self, request):
|
||||
# Call original first (it's a sync method that returns None)
|
||||
# Determine which original to call based on client type
|
||||
if "llama_stack_client" in self.__class__.__module__:
|
||||
_original_methods["llama_stack_client_prepare_request"](self, request)
|
||||
elif "openai_prepare_request" in _original_methods:
|
||||
_original_methods["openai_prepare_request"](self, request)
|
||||
|
||||
# Only inject test ID in server mode
|
||||
stack_config_type = os.environ.get("LLAMA_STACK_TEST_STACK_CONFIG_TYPE", "library_client")
|
||||
test_id = _test_context.get()
|
||||
|
||||
if stack_config_type == "server" and test_id:
|
||||
# Get existing provider data header or create new dict
|
||||
provider_data_header = request.headers.get("X-LlamaStack-Provider-Data")
|
||||
|
||||
if provider_data_header:
|
||||
provider_data = json.loads(provider_data_header)
|
||||
else:
|
||||
provider_data = {}
|
||||
|
||||
# Inject test ID
|
||||
provider_data["__test_id"] = test_id
|
||||
request.headers["X-LlamaStack-Provider-Data"] = json.dumps(provider_data)
|
||||
|
||||
# Sync version returns None
|
||||
return None
|
||||
|
||||
# Apply patches
|
||||
LlamaStackClient._prepare_request = patched_prepare_request
|
||||
if "openai_prepare_request" in _original_methods:
|
||||
from openai import OpenAI
|
||||
|
||||
OpenAI._prepare_request = patched_prepare_request
|
||||
|
||||
|
||||
def unpatch_httpx_for_test_id():
|
||||
"""Remove client _prepare_request patches for test ID injection."""
|
||||
if "llama_stack_client_prepare_request" not in _original_methods:
|
||||
return
|
||||
|
||||
from llama_stack_client import LlamaStackClient
|
||||
|
||||
LlamaStackClient._prepare_request = _original_methods["llama_stack_client_prepare_request"]
|
||||
del _original_methods["llama_stack_client_prepare_request"]
|
||||
|
||||
# Also restore OpenAI client if it was patched
|
||||
if "openai_prepare_request" in _original_methods:
|
||||
from openai import OpenAI
|
||||
|
||||
OpenAI._prepare_request = _original_methods["openai_prepare_request"]
|
||||
del _original_methods["openai_prepare_request"]
|
||||
|
||||
|
||||
def get_api_recording_mode() -> APIRecordingMode:
|
||||
return APIRecordingMode(os.environ.get("LLAMA_STACK_TEST_INFERENCE_MODE", "replay").lower())
|
||||
|
||||
|
||||
def setup_api_recording():
|
||||
"""
|
||||
Returns a context manager that can be used to record or replay API requests (inference and tools).
|
||||
This is to be used in tests to increase their reliability and reduce reliance on expensive, external services.
|
||||
|
||||
Currently supports:
|
||||
- Inference: OpenAI, Ollama, and LiteLLM clients
|
||||
- Tools: Search providers (Tavily for now)
|
||||
|
||||
Two environment variables are supported:
|
||||
- LLAMA_STACK_TEST_INFERENCE_MODE: The mode to run in. Options:
|
||||
- 'live': Make real API calls, no recording
|
||||
- 'record': Record all API interactions (overwrites existing)
|
||||
- 'replay': Use recorded responses only (default)
|
||||
- 'record-if-missing': Replay when possible, record when recording doesn't exist
|
||||
- LLAMA_STACK_TEST_RECORDING_DIR: The directory to store the recordings in. Default is 'tests/integration/recordings'.
|
||||
|
||||
The recordings are stored as JSON files.
|
||||
"""
|
||||
mode = get_api_recording_mode()
|
||||
if mode == APIRecordingMode.LIVE:
|
||||
return None
|
||||
|
||||
storage_dir = os.environ.get("LLAMA_STACK_TEST_RECORDING_DIR", DEFAULT_STORAGE_DIR)
|
||||
return api_recording(mode=mode, storage_dir=storage_dir)
|
||||
|
||||
|
||||
def _serialize_response(response: Any) -> Any:
|
||||
def _serialize_response(response: Any, request_hash: str = "") -> Any:
|
||||
if hasattr(response, "model_dump"):
|
||||
data = response.model_dump(mode="json")
|
||||
# Normalize fields to reduce noise
|
||||
data = _normalize_response(data, request_hash)
|
||||
return {
|
||||
"__type__": f"{response.__class__.__module__}.{response.__class__.__qualname__}",
|
||||
"__data__": data,
|
||||
|
|
@ -308,17 +267,22 @@ def _deserialize_response(data: dict[str, Any]) -> Any:
|
|||
|
||||
return cls.model_validate(data["__data__"])
|
||||
except (ImportError, AttributeError, TypeError, ValueError) as e:
|
||||
logger.warning(f"Failed to deserialize object of type {data['__type__']}: {e}")
|
||||
return data["__data__"]
|
||||
logger.warning(f"Failed to deserialize object of type {data['__type__']} with model_validate: {e}")
|
||||
try:
|
||||
return cls.model_construct(**data["__data__"])
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to deserialize object of type {data['__type__']} with model_construct: {e}")
|
||||
return data["__data__"]
|
||||
|
||||
return data
|
||||
|
||||
|
||||
class ResponseStorage:
|
||||
"""Handles storage/retrieval for API recordings (inference and tools)."""
|
||||
"""Handles SQLite index + JSON file storage/retrieval for inference recordings."""
|
||||
|
||||
def __init__(self, base_dir: Path):
|
||||
self.base_dir = base_dir
|
||||
# Don't create responses_dir here - determine it per-test at runtime
|
||||
|
||||
def _get_test_dir(self) -> Path:
|
||||
"""Get the recordings directory in the test file's parent directory.
|
||||
|
|
@ -327,7 +291,6 @@ class ResponseStorage:
|
|||
returns "tests/integration/inference/recordings/".
|
||||
"""
|
||||
test_id = _test_context.get()
|
||||
|
||||
if test_id:
|
||||
# Extract the directory path from the test nodeid
|
||||
# e.g., "tests/integration/inference/test_basic.py::test_foo[params]"
|
||||
|
|
@ -342,21 +305,17 @@ class ResponseStorage:
|
|||
# Fallback for non-test contexts
|
||||
return self.base_dir / "recordings"
|
||||
|
||||
def _ensure_directories(self) -> Path:
|
||||
def _ensure_directories(self):
|
||||
"""Ensure test-specific directories exist."""
|
||||
test_dir = self._get_test_dir()
|
||||
test_dir.mkdir(parents=True, exist_ok=True)
|
||||
return test_dir
|
||||
|
||||
def store_recording(self, request_hash: str, request: dict[str, Any], response: dict[str, Any]):
|
||||
"""Store a request/response pair both in memory cache and on disk."""
|
||||
global _memory_cache
|
||||
|
||||
# Store in memory cache first
|
||||
_memory_cache[request_hash] = {"request": request, "response": response}
|
||||
|
||||
"""Store a request/response pair."""
|
||||
responses_dir = self._ensure_directories()
|
||||
|
||||
# Generate unique response filename using full hash
|
||||
# Use FULL hash (not truncated)
|
||||
response_file = f"{request_hash}.json"
|
||||
|
||||
# Serialize response body if needed
|
||||
|
|
@ -364,32 +323,45 @@ class ResponseStorage:
|
|||
if "body" in serialized_response:
|
||||
if isinstance(serialized_response["body"], list):
|
||||
# Handle streaming responses (list of chunks)
|
||||
serialized_response["body"] = [_serialize_response(chunk) for chunk in serialized_response["body"]]
|
||||
serialized_response["body"] = [
|
||||
_serialize_response(chunk, request_hash) for chunk in serialized_response["body"]
|
||||
]
|
||||
else:
|
||||
# Handle single response
|
||||
serialized_response["body"] = _serialize_response(serialized_response["body"])
|
||||
serialized_response["body"] = _serialize_response(serialized_response["body"], request_hash)
|
||||
|
||||
# If this is a model-list endpoint recording, include models digest in filename to distinguish variants
|
||||
# For model-list endpoints, include digest in filename to distinguish different model sets
|
||||
endpoint = request.get("endpoint")
|
||||
test_id = _test_context.get()
|
||||
if endpoint in ("/api/tags", "/v1/models"):
|
||||
test_id = None
|
||||
digest = _model_identifiers_digest(endpoint, response)
|
||||
response_file = f"models-{request_hash}-{digest}.json"
|
||||
|
||||
response_path = responses_dir / response_file
|
||||
|
||||
# Save response to JSON file
|
||||
# Save response to JSON file with metadata
|
||||
with open(response_path, "w") as f:
|
||||
json.dump({"test_id": test_id, "request": request, "response": serialized_response}, f, indent=2)
|
||||
json.dump(
|
||||
{
|
||||
"test_id": _test_context.get(),
|
||||
"request": request,
|
||||
"response": serialized_response,
|
||||
},
|
||||
f,
|
||||
indent=2,
|
||||
)
|
||||
f.write("\n")
|
||||
f.flush()
|
||||
|
||||
def find_recording(self, request_hash: str) -> dict[str, Any] | None:
|
||||
"""Find a recorded response by request hash."""
|
||||
"""Find a recorded response by request hash.
|
||||
|
||||
Uses fallback: first checks test-specific dir, then falls back to base recordings dir.
|
||||
This handles cases where recordings happen during session setup (no test context) but
|
||||
are requested during tests (with test context).
|
||||
"""
|
||||
response_file = f"{request_hash}.json"
|
||||
|
||||
# Check test-specific directory first
|
||||
# Try test-specific directory first
|
||||
test_dir = self._get_test_dir()
|
||||
response_path = test_dir / response_file
|
||||
|
||||
|
|
@ -529,23 +501,18 @@ async def _patched_tool_invoke_method(
|
|||
# If RECORD_IF_MISSING and no recording found, fall through to record
|
||||
|
||||
if _current_mode in (APIRecordingMode.RECORD, APIRecordingMode.RECORD_IF_MISSING):
|
||||
# Check in-memory cache first (collision detection)
|
||||
global _memory_cache
|
||||
if request_hash in _memory_cache:
|
||||
# Return the cached response instead of making a new tool call
|
||||
return _memory_cache[request_hash]["response"]["body"]
|
||||
|
||||
# No cached response, make the tool call and record it
|
||||
# Make the tool call and record it
|
||||
result = await original_method(self, tool_name, kwargs)
|
||||
|
||||
request_data = {
|
||||
"test_id": _test_context.get(),
|
||||
"provider": provider_name,
|
||||
"tool_name": tool_name,
|
||||
"kwargs": kwargs,
|
||||
}
|
||||
response_data = {"body": result, "is_streaming": False}
|
||||
|
||||
# Store the recording (both in memory and on disk)
|
||||
# Store the recording
|
||||
_current_storage.store_recording(request_hash, request_data, response_data)
|
||||
return result
|
||||
|
||||
|
|
@ -557,40 +524,15 @@ async def _patched_tool_invoke_method(
|
|||
_test_context.reset(test_context_token)
|
||||
|
||||
|
||||
def _sync_test_context_from_provider_data():
|
||||
"""In server mode, sync test ID from provider_data to _test_context.
|
||||
|
||||
This ensures that storage operations (which read from _test_context) work correctly
|
||||
in server mode where the test ID arrives via HTTP header → provider_data.
|
||||
|
||||
Returns a token to reset _test_context, or None if no sync was needed.
|
||||
"""
|
||||
stack_config_type = os.environ.get("LLAMA_STACK_TEST_STACK_CONFIG_TYPE", "library_client")
|
||||
|
||||
if stack_config_type != "server":
|
||||
return None
|
||||
|
||||
try:
|
||||
from llama_stack.core.request_headers import PROVIDER_DATA_VAR
|
||||
|
||||
provider_data = PROVIDER_DATA_VAR.get()
|
||||
|
||||
if provider_data and "__test_id" in provider_data:
|
||||
test_id = provider_data["__test_id"]
|
||||
return _test_context.set(test_id)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def _patched_inference_method(original_method, self, client_type, endpoint, *args, **kwargs):
|
||||
global _current_mode, _current_storage
|
||||
|
||||
if _current_mode == APIRecordingMode.LIVE or _current_storage is None:
|
||||
# Normal operation
|
||||
if client_type == "litellm":
|
||||
return await original_method(*args, **kwargs)
|
||||
mode = _current_mode
|
||||
storage = _current_storage
|
||||
|
||||
if mode == APIRecordingMode.LIVE or storage is None:
|
||||
if endpoint == "/v1/models":
|
||||
return original_method(self, *args, **kwargs)
|
||||
else:
|
||||
return await original_method(self, *args, **kwargs)
|
||||
|
||||
|
|
@ -609,30 +551,34 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
|
|||
base_url = getattr(self, "host", "http://localhost:11434")
|
||||
if not base_url.startswith("http"):
|
||||
base_url = f"http://{base_url}"
|
||||
elif client_type == "litellm":
|
||||
# For LiteLLM, extract base URL from kwargs if available
|
||||
base_url = kwargs.get("api_base", "https://api.openai.com")
|
||||
else:
|
||||
raise ValueError(f"Unknown client type: {client_type}")
|
||||
|
||||
url = base_url.rstrip("/") + endpoint
|
||||
# Special handling for Databricks URLs to avoid leaking workspace info
|
||||
# e.g. https://adb-1234567890123456.7.cloud.databricks.com -> https://...cloud.databricks.com
|
||||
if "cloud.databricks.com" in url:
|
||||
url = "__databricks__" + url.split("cloud.databricks.com")[-1]
|
||||
method = "POST"
|
||||
headers = {}
|
||||
body = kwargs
|
||||
|
||||
request_hash = normalize_request(method, url, headers, body)
|
||||
request_hash = normalize_inference_request(method, url, headers, body)
|
||||
|
||||
if _current_mode in (APIRecordingMode.REPLAY, APIRecordingMode.RECORD_IF_MISSING):
|
||||
# Special handling for model-list endpoints: return union of all responses
|
||||
# Try to find existing recording for REPLAY or RECORD_IF_MISSING modes
|
||||
recording = None
|
||||
if mode == APIRecordingMode.REPLAY or mode == APIRecordingMode.RECORD_IF_MISSING:
|
||||
# Special handling for model-list endpoints: merge all recordings with this hash
|
||||
if endpoint in ("/api/tags", "/v1/models"):
|
||||
records = _current_storage._model_list_responses(request_hash)
|
||||
records = storage._model_list_responses(request_hash)
|
||||
recording = _combine_model_list_responses(endpoint, records)
|
||||
else:
|
||||
recording = _current_storage.find_recording(request_hash)
|
||||
recording = storage.find_recording(request_hash)
|
||||
|
||||
if recording:
|
||||
response_body = recording["response"]["body"]
|
||||
|
||||
if recording["response"].get("is_streaming", False) or recording["response"].get("is_paginated", False):
|
||||
if recording["response"].get("is_streaming", False):
|
||||
|
||||
async def replay_stream():
|
||||
for chunk in response_body:
|
||||
|
|
@ -641,41 +587,25 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
|
|||
return replay_stream()
|
||||
else:
|
||||
return response_body
|
||||
elif _current_mode == APIRecordingMode.REPLAY:
|
||||
elif mode == APIRecordingMode.REPLAY:
|
||||
# REPLAY mode requires recording to exist
|
||||
raise RuntimeError(
|
||||
f"No recorded response found for request hash: {request_hash}\n"
|
||||
f"Request: {method} {url} {body}\n"
|
||||
f"Model: {body.get('model', 'unknown')}\n"
|
||||
f"To record this response, run with LLAMA_STACK_TEST_INFERENCE_MODE=record"
|
||||
)
|
||||
# If RECORD_IF_MISSING and no recording found, fall through to record
|
||||
|
||||
if _current_mode in (APIRecordingMode.RECORD, APIRecordingMode.RECORD_IF_MISSING):
|
||||
# Check in-memory cache first (collision detection)
|
||||
global _memory_cache
|
||||
if request_hash in _memory_cache:
|
||||
# Return the cached response instead of making a new API call
|
||||
cached_recording = _memory_cache[request_hash]
|
||||
response_body = cached_recording["response"]["body"]
|
||||
|
||||
if cached_recording["response"].get("is_streaming", False) or cached_recording["response"].get(
|
||||
"is_paginated", False
|
||||
):
|
||||
|
||||
async def replay_cached_stream():
|
||||
for chunk in response_body:
|
||||
yield chunk
|
||||
|
||||
return replay_cached_stream()
|
||||
else:
|
||||
return response_body
|
||||
|
||||
# No cached response, make the API call and record it
|
||||
if client_type == "litellm":
|
||||
response = await original_method(*args, **kwargs)
|
||||
if mode == APIRecordingMode.RECORD or (mode == APIRecordingMode.RECORD_IF_MISSING and not recording):
|
||||
if endpoint == "/v1/models":
|
||||
response = original_method(self, *args, **kwargs)
|
||||
else:
|
||||
response = await original_method(self, *args, **kwargs)
|
||||
|
||||
# we want to store the result of the iterator, not the iterator itself
|
||||
if endpoint == "/v1/models":
|
||||
response = [m async for m in response]
|
||||
|
||||
request_data = {
|
||||
"method": method,
|
||||
"url": url,
|
||||
|
|
@ -688,20 +618,16 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
|
|||
# Determine if this is a streaming request based on request parameters
|
||||
is_streaming = body.get("stream", False)
|
||||
|
||||
# Special case: /v1/models is a paginated endpoint that returns an async iterator
|
||||
is_paginated = endpoint == "/v1/models"
|
||||
|
||||
if is_streaming or is_paginated:
|
||||
# For streaming/paginated responses, we need to collect all chunks immediately before yielding
|
||||
if is_streaming:
|
||||
# For streaming responses, we need to collect all chunks immediately before yielding
|
||||
# This ensures the recording is saved even if the generator isn't fully consumed
|
||||
chunks = []
|
||||
async for chunk in response:
|
||||
chunks.append(chunk)
|
||||
|
||||
# Store the recording immediately (both in memory and on disk)
|
||||
# For paginated endpoints, mark as paginated rather than streaming
|
||||
response_data = {"body": chunks, "is_streaming": is_streaming, "is_paginated": is_paginated}
|
||||
_current_storage.store_recording(request_hash, request_data, response_data)
|
||||
# Store the recording immediately
|
||||
response_data = {"body": chunks, "is_streaming": True}
|
||||
storage.store_recording(request_hash, request_data, response_data)
|
||||
|
||||
# Return a generator that replays the stored chunks
|
||||
async def replay_recorded_stream():
|
||||
|
|
@ -711,24 +637,20 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
|
|||
return replay_recorded_stream()
|
||||
else:
|
||||
response_data = {"body": response, "is_streaming": False}
|
||||
# Store the response (both in memory and on disk)
|
||||
_current_storage.store_recording(request_hash, request_data, response_data)
|
||||
storage.store_recording(request_hash, request_data, response_data)
|
||||
return response
|
||||
|
||||
else:
|
||||
raise AssertionError(f"Invalid mode: {_current_mode}")
|
||||
|
||||
raise AssertionError(f"Invalid mode: {mode}")
|
||||
finally:
|
||||
# Reset test context if we set it in server mode
|
||||
if test_context_token is not None:
|
||||
if test_context_token:
|
||||
_test_context.reset(test_context_token)
|
||||
|
||||
|
||||
def patch_api_clients():
|
||||
"""Install monkey patches for inference clients and tool runtime methods."""
|
||||
def patch_inference_clients():
|
||||
"""Install monkey patches for OpenAI client methods, Ollama AsyncClient methods, and tool runtime methods."""
|
||||
global _original_methods
|
||||
|
||||
import litellm
|
||||
from ollama import AsyncClient as OllamaAsyncClient
|
||||
from openai.resources.chat.completions import AsyncCompletions as AsyncChatCompletions
|
||||
from openai.resources.completions import AsyncCompletions
|
||||
|
|
@ -736,9 +658,8 @@ def patch_api_clients():
|
|||
from openai.resources.models import AsyncModels
|
||||
|
||||
from llama_stack.providers.remote.tool_runtime.tavily_search.tavily_search import TavilySearchToolRuntimeImpl
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||
|
||||
# Store original methods for OpenAI, Ollama, LiteLLM clients, and tool runtimes
|
||||
# Store original methods for OpenAI, Ollama clients, and tool runtimes
|
||||
_original_methods = {
|
||||
"chat_completions_create": AsyncChatCompletions.create,
|
||||
"completions_create": AsyncCompletions.create,
|
||||
|
|
@ -750,9 +671,6 @@ def patch_api_clients():
|
|||
"ollama_ps": OllamaAsyncClient.ps,
|
||||
"ollama_pull": OllamaAsyncClient.pull,
|
||||
"ollama_list": OllamaAsyncClient.list,
|
||||
"litellm_acompletion": litellm.acompletion,
|
||||
"litellm_atext_completion": litellm.atext_completion,
|
||||
"litellm_openai_mixin_get_api_key": LiteLLMOpenAIMixin.get_api_key,
|
||||
"tavily_invoke_tool": TavilySearchToolRuntimeImpl.invoke_tool,
|
||||
}
|
||||
|
||||
|
|
@ -774,18 +692,10 @@ def patch_api_clients():
|
|||
|
||||
def patched_models_list(self, *args, **kwargs):
|
||||
async def _iter():
|
||||
result = await _patched_inference_method(
|
||||
for item in await _patched_inference_method(
|
||||
_original_methods["models_list"], self, "openai", "/v1/models", *args, **kwargs
|
||||
)
|
||||
# The result is either an async generator (streaming/paginated) or a list
|
||||
# If it's an async generator, iterate through it
|
||||
if hasattr(result, "__aiter__"):
|
||||
async for item in result:
|
||||
yield item
|
||||
else:
|
||||
# It's a list, yield each item
|
||||
for item in result:
|
||||
yield item
|
||||
):
|
||||
yield item
|
||||
|
||||
return _iter()
|
||||
|
||||
|
|
@ -834,33 +744,6 @@ def patch_api_clients():
|
|||
OllamaAsyncClient.pull = patched_ollama_pull
|
||||
OllamaAsyncClient.list = patched_ollama_list
|
||||
|
||||
# Create patched methods for LiteLLM
|
||||
async def patched_litellm_acompletion(*args, **kwargs):
|
||||
return await _patched_inference_method(
|
||||
_original_methods["litellm_acompletion"], None, "litellm", "/chat/completions", *args, **kwargs
|
||||
)
|
||||
|
||||
async def patched_litellm_atext_completion(*args, **kwargs):
|
||||
return await _patched_inference_method(
|
||||
_original_methods["litellm_atext_completion"], None, "litellm", "/completions", *args, **kwargs
|
||||
)
|
||||
|
||||
# Apply LiteLLM patches
|
||||
litellm.acompletion = patched_litellm_acompletion
|
||||
litellm.atext_completion = patched_litellm_atext_completion
|
||||
|
||||
# Create patched method for LiteLLMOpenAIMixin.get_api_key
|
||||
def patched_litellm_get_api_key(self):
|
||||
global _current_mode
|
||||
if _current_mode != APIRecordingMode.REPLAY:
|
||||
return _original_methods["litellm_openai_mixin_get_api_key"](self)
|
||||
else:
|
||||
# For record/replay modes, return a fake API key to avoid exposing real credentials
|
||||
return "fake-api-key-for-testing"
|
||||
|
||||
# Apply LiteLLMOpenAIMixin patch
|
||||
LiteLLMOpenAIMixin.get_api_key = patched_litellm_get_api_key
|
||||
|
||||
# Create patched methods for tool runtimes
|
||||
async def patched_tavily_invoke_tool(self, tool_name: str, kwargs: dict[str, Any]):
|
||||
return await _patched_tool_invoke_method(
|
||||
|
|
@ -871,15 +754,14 @@ def patch_api_clients():
|
|||
TavilySearchToolRuntimeImpl.invoke_tool = patched_tavily_invoke_tool
|
||||
|
||||
|
||||
def unpatch_api_clients():
|
||||
"""Remove monkey patches and restore original client methods and tool runtimes."""
|
||||
global _original_methods, _memory_cache
|
||||
def unpatch_inference_clients():
|
||||
"""Remove monkey patches and restore original OpenAI, Ollama client, and tool runtime methods."""
|
||||
global _original_methods
|
||||
|
||||
if not _original_methods:
|
||||
return
|
||||
|
||||
# Import here to avoid circular imports
|
||||
import litellm
|
||||
from ollama import AsyncClient as OllamaAsyncClient
|
||||
from openai.resources.chat.completions import AsyncCompletions as AsyncChatCompletions
|
||||
from openai.resources.completions import AsyncCompletions
|
||||
|
|
@ -887,7 +769,6 @@ def unpatch_api_clients():
|
|||
from openai.resources.models import AsyncModels
|
||||
|
||||
from llama_stack.providers.remote.tool_runtime.tavily_search.tavily_search import TavilySearchToolRuntimeImpl
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||
|
||||
# Restore OpenAI client methods
|
||||
AsyncChatCompletions.create = _original_methods["chat_completions_create"]
|
||||
|
|
@ -903,19 +784,11 @@ def unpatch_api_clients():
|
|||
OllamaAsyncClient.pull = _original_methods["ollama_pull"]
|
||||
OllamaAsyncClient.list = _original_methods["ollama_list"]
|
||||
|
||||
# Restore LiteLLM methods
|
||||
litellm.acompletion = _original_methods["litellm_acompletion"]
|
||||
litellm.atext_completion = _original_methods["litellm_atext_completion"]
|
||||
LiteLLMOpenAIMixin.get_api_key = _original_methods["litellm_openai_mixin_get_api_key"]
|
||||
|
||||
# Restore tool runtime methods
|
||||
TavilySearchToolRuntimeImpl.invoke_tool = _original_methods["tavily_invoke_tool"]
|
||||
|
||||
_original_methods.clear()
|
||||
|
||||
# Clear memory cache to prevent memory leaks
|
||||
_memory_cache.clear()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def api_recording(mode: str, storage_dir: str | Path | None = None) -> Generator[None, None, None]:
|
||||
|
|
@ -933,14 +806,14 @@ def api_recording(mode: str, storage_dir: str | Path | None = None) -> Generator
|
|||
if storage_dir is None:
|
||||
raise ValueError("storage_dir is required for record, replay, and record-if-missing modes")
|
||||
_current_storage = ResponseStorage(Path(storage_dir))
|
||||
patch_api_clients()
|
||||
patch_inference_clients()
|
||||
|
||||
yield
|
||||
|
||||
finally:
|
||||
# Restore previous state
|
||||
if mode in ["record", "replay", "record-if-missing"]:
|
||||
unpatch_api_clients()
|
||||
unpatch_inference_clients()
|
||||
|
||||
_current_mode = prev_mode
|
||||
_current_storage = prev_storage
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue