From 8414c3085979bf43799ae7738a076dfacf16a152 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Wed, 8 Oct 2025 12:43:58 -0700 Subject: [PATCH] redid api_recorder.py now, step 0 --- llama_stack/testing/api_recorder.py | 609 +++++++++++----------------- 1 file changed, 241 insertions(+), 368 deletions(-) diff --git a/llama_stack/testing/api_recorder.py b/llama_stack/testing/api_recorder.py index cc7c65465..522cf3282 100644 --- a/llama_stack/testing/api_recorder.py +++ b/llama_stack/testing/api_recorder.py @@ -15,20 +15,19 @@ from enum import StrEnum from pathlib import Path from typing import Any, Literal, cast -from openai import NOT_GIVEN +from openai import NOT_GIVEN, OpenAI from llama_stack.log import get_logger logger = get_logger(__name__, category="testing") -# Global state for the API recording system +# Global state for the recording system # Note: Using module globals instead of ContextVars because the session-scoped # client initialization happens in one async context, but tests run in different # contexts, and we need the mode/storage to persist across all contexts. _current_mode: str | None = None _current_storage: ResponseStorage | None = None _original_methods: dict[str, Any] = {} -_memory_cache: dict[str, dict[str, Any]] = {} # Test context uses ContextVar since it changes per-test and needs async isolation from contextvars import ContextVar @@ -52,29 +51,162 @@ class APIRecordingMode(StrEnum): RECORD_IF_MISSING = "record-if-missing" -def _normalize_file_ids(obj: Any) -> Any: - """Recursively replace file IDs with a canonical placeholder for consistent hashing.""" - import re +def normalize_inference_request(method: str, url: str, headers: dict[str, Any], body: dict[str, Any]) -> str: + """Create a normalized hash of the request for consistent matching. - if isinstance(obj, dict): - result = {} - for k, v in obj.items(): - # Normalize file IDs in attribute dictionaries - if k == "document_id" and isinstance(v, str) and v.startswith("file-"): - result[k] = "file-NORMALIZED" + Includes test_id from context to ensure test isolation - identical requests + from different tests will have different hashes. + + Exception: Model list endpoints (/v1/models, /api/tags) exclude test_id since + they are infrastructure/shared and need to work across session setup and tests. + """ + # Extract just the endpoint path + from urllib.parse import urlparse + + parsed = urlparse(url) + normalized: dict[str, Any] = { + "method": method.upper(), + "endpoint": parsed.path, + "body": body, + } + + # Include test_id for isolation, except for shared infrastructure endpoints + if parsed.path not in ("/api/tags", "/v1/models"): + normalized["test_id"] = _test_context.get() + + # Create hash - sort_keys=True ensures deterministic ordering + normalized_json = json.dumps(normalized, sort_keys=True) + return hashlib.sha256(normalized_json.encode()).hexdigest() + + +def normalize_tool_request(provider_name: str, tool_name: str, kwargs: dict[str, Any]) -> str: + """Create a normalized hash of the tool request for consistent matching.""" + normalized = {"provider": provider_name, "tool_name": tool_name, "kwargs": kwargs} + + # Create hash - sort_keys=True ensures deterministic ordering + normalized_json = json.dumps(normalized, sort_keys=True) + return hashlib.sha256(normalized_json.encode()).hexdigest() + + +def _sync_test_context_from_provider_data(): + """In server mode, sync test ID from provider_data to _test_context. + + This ensures that storage operations (which read from _test_context) work correctly + in server mode where the test ID arrives via HTTP header → provider_data. + + Returns a token to reset _test_context, or None if no sync was needed. + """ + stack_config_type = os.environ.get("LLAMA_STACK_TEST_STACK_CONFIG_TYPE", "library_client") + + if stack_config_type != "server": + return None + + try: + from llama_stack.core.request_headers import PROVIDER_DATA_VAR + + provider_data = PROVIDER_DATA_VAR.get() + + if provider_data and "__test_id" in provider_data: + test_id = provider_data["__test_id"] + return _test_context.set(test_id) + except ImportError: + pass + + return None + + +def patch_httpx_for_test_id(): + """Patch client _prepare_request methods to inject test ID into provider data header. + + This is needed for server mode where the test ID must be transported from + client to server via HTTP headers. In library_client mode, this patch is a no-op + since everything runs in the same process. + + We use the _prepare_request hook that Stainless clients provide for mutating + requests after construction but before sending. + """ + from llama_stack_client import LlamaStackClient + + if "llama_stack_client_prepare_request" in _original_methods: + return + + _original_methods["llama_stack_client_prepare_request"] = LlamaStackClient._prepare_request + _original_methods["openai_prepare_request"] = OpenAI._prepare_request + + def patched_prepare_request(self, request): + # Call original first (it's a sync method that returns None) + # Determine which original to call based on client type + if "llama_stack_client" in self.__class__.__module__: + _original_methods["llama_stack_client_prepare_request"](self, request) + _original_methods["openai_prepare_request"](self, request) + + # Only inject test ID in server mode + stack_config_type = os.environ.get("LLAMA_STACK_TEST_STACK_CONFIG_TYPE", "library_client") + test_id = _test_context.get() + + if stack_config_type == "server" and test_id: + provider_data_header = request.headers.get("X-LlamaStack-Provider-Data") + + if provider_data_header: + provider_data = json.loads(provider_data_header) else: - result[k] = _normalize_file_ids(v) - return result - elif isinstance(obj, list): - return [_normalize_file_ids(item) for item in obj] - elif isinstance(obj, str): - # Replace file- patterns in strings (like in text content) - return re.sub(r"file-[a-f0-9]{32}", "file-NORMALIZED", obj) - else: - return obj + provider_data = {} + + provider_data["__test_id"] = test_id + request.headers["X-LlamaStack-Provider-Data"] = json.dumps(provider_data) + + return None + + LlamaStackClient._prepare_request = patched_prepare_request + OpenAI._prepare_request = patched_prepare_request -def _normalize_response_data(data: dict[str, Any], request_hash: str) -> dict[str, Any]: +# currently, unpatch is never called +def unpatch_httpx_for_test_id(): + """Remove client _prepare_request patches for test ID injection.""" + if "llama_stack_client_prepare_request" not in _original_methods: + return + + from llama_stack_client import LlamaStackClient + + LlamaStackClient._prepare_request = _original_methods["llama_stack_client_prepare_request"] + del _original_methods["llama_stack_client_prepare_request"] + OpenAI._prepare_request = _original_methods["openai_prepare_request"] + del _original_methods["openai_prepare_request"] + + +def get_api_recording_mode() -> APIRecordingMode: + return APIRecordingMode(os.environ.get("LLAMA_STACK_TEST_INFERENCE_MODE", "replay").lower()) + + +def setup_api_recording(): + """ + Returns a context manager that can be used to record or replay API requests (inference and tools). + This is to be used in tests to increase their reliability and reduce reliance on expensive, external services. + + Currently supports: + - Inference: OpenAI and Ollama clients + - Tools: Search providers (Tavily) + + Two environment variables are supported: + - LLAMA_STACK_TEST_INFERENCE_MODE: The mode to run in. Must be 'live', 'record', 'replay', or 'record-if-missing'. Default is 'replay'. + - 'live': Make all requests live without recording + - 'record': Record all requests (overwrites existing recordings) + - 'replay': Use only recorded responses (fails if recording not found) + - 'record-if-missing': Use recorded responses when available, record new ones when not found + - LLAMA_STACK_TEST_RECORDING_DIR: The directory to store the recordings in. Default is 'tests/integration/recordings'. + + The recordings are stored as JSON files. + """ + mode = get_api_recording_mode() + if mode == APIRecordingMode.LIVE: + return None + + storage_dir = os.environ.get("LLAMA_STACK_TEST_RECORDING_DIR", DEFAULT_STORAGE_DIR) + return api_recording(mode=mode, storage_dir=storage_dir) + + +def _normalize_response(data: dict[str, Any], request_hash: str) -> dict[str, Any]: """Normalize fields that change between recordings but don't affect functionality. This reduces noise in git diffs by making IDs deterministic and timestamps constant. @@ -106,184 +238,11 @@ def _normalize_response_data(data: dict[str, Any], request_hash: str) -> dict[st return data -def normalize_request(method: str, url: str, headers: dict[str, Any], body: dict[str, Any]) -> str: - """Create a normalized hash of the request for consistent matching. - - Includes test_id from context to ensure test isolation - identical requests - from different tests will have different hashes. - """ - # Extract just the endpoint path - from urllib.parse import urlparse - - parsed = urlparse(url) - - # Normalize file IDs in the body to ensure consistent hashing across test runs - normalized_body = _normalize_file_ids(body) - - normalized: dict[str, Any] = {"method": method.upper(), "endpoint": parsed.path, "body": normalized_body} - - # Include test_id for isolation, except for shared infrastructure endpoints - if parsed.path not in ("/api/tags", "/v1/models"): - # Server mode: test ID was synced from provider_data to _test_context - # by _sync_test_context_from_provider_data() at the start of the request. - # We read from _test_context because it's available in all contexts (including - # when making outgoing API calls), whereas PROVIDER_DATA_VAR is only set - # for incoming HTTP requests. - # - # Library client mode: test ID in same-process ContextVar - test_id = _test_context.get() - normalized["test_id"] = test_id - - # Create hash - sort_keys=True ensures deterministic ordering - normalized_json = json.dumps(normalized, sort_keys=True) - request_hash = hashlib.sha256(normalized_json.encode()).hexdigest() - - return request_hash - - -def normalize_tool_request(provider_name: str, tool_name: str, kwargs: dict[str, Any]) -> str: - """Create a normalized hash of the tool request for consistent matching.""" - normalized = {"provider": provider_name, "tool_name": tool_name, "kwargs": kwargs} - - # Create hash - sort_keys=True ensures deterministic ordering - normalized_json = json.dumps(normalized, sort_keys=True) - return hashlib.sha256(normalized_json.encode()).hexdigest() - - -@contextmanager -def set_test_context(test_id: str) -> Generator[None, None, None]: - """Set the test context for recording isolation. - - Usage: - with set_test_context("test_basic_completion"): - # Make API calls that will be recorded with this test_id - response = client.chat.completions.create(...) - """ - token = _test_context.set(test_id) - try: - yield - finally: - _test_context.reset(token) - - -def patch_httpx_for_test_id(): - """Patch client _prepare_request methods to inject test ID into provider data header. - - Patches both LlamaStackClient and OpenAI client to ensure test ID is transported - from client to server via HTTP headers in server mode. - - This is needed for server mode where the test ID must be transported from - client to server via HTTP headers. In library_client mode, this patch is a no-op - since everything runs in the same process. - - We use the _prepare_request hook that Stainless clients provide for mutating - requests after construction but before sending. - """ - from llama_stack_client import LlamaStackClient - - if "llama_stack_client_prepare_request" in _original_methods: - # Already patched - return - - # Save original methods - _original_methods["llama_stack_client_prepare_request"] = LlamaStackClient._prepare_request - - # Also patch OpenAI client if available (used in compat tests) - try: - from openai import OpenAI - - _original_methods["openai_prepare_request"] = OpenAI._prepare_request - except ImportError: - pass - - def patched_prepare_request(self, request): - # Call original first (it's a sync method that returns None) - # Determine which original to call based on client type - if "llama_stack_client" in self.__class__.__module__: - _original_methods["llama_stack_client_prepare_request"](self, request) - elif "openai_prepare_request" in _original_methods: - _original_methods["openai_prepare_request"](self, request) - - # Only inject test ID in server mode - stack_config_type = os.environ.get("LLAMA_STACK_TEST_STACK_CONFIG_TYPE", "library_client") - test_id = _test_context.get() - - if stack_config_type == "server" and test_id: - # Get existing provider data header or create new dict - provider_data_header = request.headers.get("X-LlamaStack-Provider-Data") - - if provider_data_header: - provider_data = json.loads(provider_data_header) - else: - provider_data = {} - - # Inject test ID - provider_data["__test_id"] = test_id - request.headers["X-LlamaStack-Provider-Data"] = json.dumps(provider_data) - - # Sync version returns None - return None - - # Apply patches - LlamaStackClient._prepare_request = patched_prepare_request - if "openai_prepare_request" in _original_methods: - from openai import OpenAI - - OpenAI._prepare_request = patched_prepare_request - - -def unpatch_httpx_for_test_id(): - """Remove client _prepare_request patches for test ID injection.""" - if "llama_stack_client_prepare_request" not in _original_methods: - return - - from llama_stack_client import LlamaStackClient - - LlamaStackClient._prepare_request = _original_methods["llama_stack_client_prepare_request"] - del _original_methods["llama_stack_client_prepare_request"] - - # Also restore OpenAI client if it was patched - if "openai_prepare_request" in _original_methods: - from openai import OpenAI - - OpenAI._prepare_request = _original_methods["openai_prepare_request"] - del _original_methods["openai_prepare_request"] - - -def get_api_recording_mode() -> APIRecordingMode: - return APIRecordingMode(os.environ.get("LLAMA_STACK_TEST_INFERENCE_MODE", "replay").lower()) - - -def setup_api_recording(): - """ - Returns a context manager that can be used to record or replay API requests (inference and tools). - This is to be used in tests to increase their reliability and reduce reliance on expensive, external services. - - Currently supports: - - Inference: OpenAI, Ollama, and LiteLLM clients - - Tools: Search providers (Tavily for now) - - Two environment variables are supported: - - LLAMA_STACK_TEST_INFERENCE_MODE: The mode to run in. Options: - - 'live': Make real API calls, no recording - - 'record': Record all API interactions (overwrites existing) - - 'replay': Use recorded responses only (default) - - 'record-if-missing': Replay when possible, record when recording doesn't exist - - LLAMA_STACK_TEST_RECORDING_DIR: The directory to store the recordings in. Default is 'tests/integration/recordings'. - - The recordings are stored as JSON files. - """ - mode = get_api_recording_mode() - if mode == APIRecordingMode.LIVE: - return None - - storage_dir = os.environ.get("LLAMA_STACK_TEST_RECORDING_DIR", DEFAULT_STORAGE_DIR) - return api_recording(mode=mode, storage_dir=storage_dir) - - -def _serialize_response(response: Any) -> Any: +def _serialize_response(response: Any, request_hash: str = "") -> Any: if hasattr(response, "model_dump"): data = response.model_dump(mode="json") + # Normalize fields to reduce noise + data = _normalize_response(data, request_hash) return { "__type__": f"{response.__class__.__module__}.{response.__class__.__qualname__}", "__data__": data, @@ -308,17 +267,22 @@ def _deserialize_response(data: dict[str, Any]) -> Any: return cls.model_validate(data["__data__"]) except (ImportError, AttributeError, TypeError, ValueError) as e: - logger.warning(f"Failed to deserialize object of type {data['__type__']}: {e}") - return data["__data__"] + logger.warning(f"Failed to deserialize object of type {data['__type__']} with model_validate: {e}") + try: + return cls.model_construct(**data["__data__"]) + except Exception as e: + logger.warning(f"Failed to deserialize object of type {data['__type__']} with model_construct: {e}") + return data["__data__"] return data class ResponseStorage: - """Handles storage/retrieval for API recordings (inference and tools).""" + """Handles SQLite index + JSON file storage/retrieval for inference recordings.""" def __init__(self, base_dir: Path): self.base_dir = base_dir + # Don't create responses_dir here - determine it per-test at runtime def _get_test_dir(self) -> Path: """Get the recordings directory in the test file's parent directory. @@ -327,7 +291,6 @@ class ResponseStorage: returns "tests/integration/inference/recordings/". """ test_id = _test_context.get() - if test_id: # Extract the directory path from the test nodeid # e.g., "tests/integration/inference/test_basic.py::test_foo[params]" @@ -342,21 +305,17 @@ class ResponseStorage: # Fallback for non-test contexts return self.base_dir / "recordings" - def _ensure_directories(self) -> Path: + def _ensure_directories(self): + """Ensure test-specific directories exist.""" test_dir = self._get_test_dir() test_dir.mkdir(parents=True, exist_ok=True) return test_dir def store_recording(self, request_hash: str, request: dict[str, Any], response: dict[str, Any]): - """Store a request/response pair both in memory cache and on disk.""" - global _memory_cache - - # Store in memory cache first - _memory_cache[request_hash] = {"request": request, "response": response} - + """Store a request/response pair.""" responses_dir = self._ensure_directories() - # Generate unique response filename using full hash + # Use FULL hash (not truncated) response_file = f"{request_hash}.json" # Serialize response body if needed @@ -364,32 +323,45 @@ class ResponseStorage: if "body" in serialized_response: if isinstance(serialized_response["body"], list): # Handle streaming responses (list of chunks) - serialized_response["body"] = [_serialize_response(chunk) for chunk in serialized_response["body"]] + serialized_response["body"] = [ + _serialize_response(chunk, request_hash) for chunk in serialized_response["body"] + ] else: # Handle single response - serialized_response["body"] = _serialize_response(serialized_response["body"]) + serialized_response["body"] = _serialize_response(serialized_response["body"], request_hash) - # If this is a model-list endpoint recording, include models digest in filename to distinguish variants + # For model-list endpoints, include digest in filename to distinguish different model sets endpoint = request.get("endpoint") - test_id = _test_context.get() if endpoint in ("/api/tags", "/v1/models"): - test_id = None digest = _model_identifiers_digest(endpoint, response) response_file = f"models-{request_hash}-{digest}.json" response_path = responses_dir / response_file - # Save response to JSON file + # Save response to JSON file with metadata with open(response_path, "w") as f: - json.dump({"test_id": test_id, "request": request, "response": serialized_response}, f, indent=2) + json.dump( + { + "test_id": _test_context.get(), + "request": request, + "response": serialized_response, + }, + f, + indent=2, + ) f.write("\n") f.flush() def find_recording(self, request_hash: str) -> dict[str, Any] | None: - """Find a recorded response by request hash.""" + """Find a recorded response by request hash. + + Uses fallback: first checks test-specific dir, then falls back to base recordings dir. + This handles cases where recordings happen during session setup (no test context) but + are requested during tests (with test context). + """ response_file = f"{request_hash}.json" - # Check test-specific directory first + # Try test-specific directory first test_dir = self._get_test_dir() response_path = test_dir / response_file @@ -529,23 +501,18 @@ async def _patched_tool_invoke_method( # If RECORD_IF_MISSING and no recording found, fall through to record if _current_mode in (APIRecordingMode.RECORD, APIRecordingMode.RECORD_IF_MISSING): - # Check in-memory cache first (collision detection) - global _memory_cache - if request_hash in _memory_cache: - # Return the cached response instead of making a new tool call - return _memory_cache[request_hash]["response"]["body"] - - # No cached response, make the tool call and record it + # Make the tool call and record it result = await original_method(self, tool_name, kwargs) request_data = { + "test_id": _test_context.get(), "provider": provider_name, "tool_name": tool_name, "kwargs": kwargs, } response_data = {"body": result, "is_streaming": False} - # Store the recording (both in memory and on disk) + # Store the recording _current_storage.store_recording(request_hash, request_data, response_data) return result @@ -557,40 +524,15 @@ async def _patched_tool_invoke_method( _test_context.reset(test_context_token) -def _sync_test_context_from_provider_data(): - """In server mode, sync test ID from provider_data to _test_context. - - This ensures that storage operations (which read from _test_context) work correctly - in server mode where the test ID arrives via HTTP header → provider_data. - - Returns a token to reset _test_context, or None if no sync was needed. - """ - stack_config_type = os.environ.get("LLAMA_STACK_TEST_STACK_CONFIG_TYPE", "library_client") - - if stack_config_type != "server": - return None - - try: - from llama_stack.core.request_headers import PROVIDER_DATA_VAR - - provider_data = PROVIDER_DATA_VAR.get() - - if provider_data and "__test_id" in provider_data: - test_id = provider_data["__test_id"] - return _test_context.set(test_id) - except ImportError: - pass - - return None - - async def _patched_inference_method(original_method, self, client_type, endpoint, *args, **kwargs): global _current_mode, _current_storage - if _current_mode == APIRecordingMode.LIVE or _current_storage is None: - # Normal operation - if client_type == "litellm": - return await original_method(*args, **kwargs) + mode = _current_mode + storage = _current_storage + + if mode == APIRecordingMode.LIVE or storage is None: + if endpoint == "/v1/models": + return original_method(self, *args, **kwargs) else: return await original_method(self, *args, **kwargs) @@ -609,30 +551,34 @@ async def _patched_inference_method(original_method, self, client_type, endpoint base_url = getattr(self, "host", "http://localhost:11434") if not base_url.startswith("http"): base_url = f"http://{base_url}" - elif client_type == "litellm": - # For LiteLLM, extract base URL from kwargs if available - base_url = kwargs.get("api_base", "https://api.openai.com") else: raise ValueError(f"Unknown client type: {client_type}") url = base_url.rstrip("/") + endpoint + # Special handling for Databricks URLs to avoid leaking workspace info + # e.g. https://adb-1234567890123456.7.cloud.databricks.com -> https://...cloud.databricks.com + if "cloud.databricks.com" in url: + url = "__databricks__" + url.split("cloud.databricks.com")[-1] method = "POST" headers = {} body = kwargs - request_hash = normalize_request(method, url, headers, body) + request_hash = normalize_inference_request(method, url, headers, body) - if _current_mode in (APIRecordingMode.REPLAY, APIRecordingMode.RECORD_IF_MISSING): - # Special handling for model-list endpoints: return union of all responses + # Try to find existing recording for REPLAY or RECORD_IF_MISSING modes + recording = None + if mode == APIRecordingMode.REPLAY or mode == APIRecordingMode.RECORD_IF_MISSING: + # Special handling for model-list endpoints: merge all recordings with this hash if endpoint in ("/api/tags", "/v1/models"): - records = _current_storage._model_list_responses(request_hash) + records = storage._model_list_responses(request_hash) recording = _combine_model_list_responses(endpoint, records) else: - recording = _current_storage.find_recording(request_hash) + recording = storage.find_recording(request_hash) + if recording: response_body = recording["response"]["body"] - if recording["response"].get("is_streaming", False) or recording["response"].get("is_paginated", False): + if recording["response"].get("is_streaming", False): async def replay_stream(): for chunk in response_body: @@ -641,41 +587,25 @@ async def _patched_inference_method(original_method, self, client_type, endpoint return replay_stream() else: return response_body - elif _current_mode == APIRecordingMode.REPLAY: + elif mode == APIRecordingMode.REPLAY: + # REPLAY mode requires recording to exist raise RuntimeError( f"No recorded response found for request hash: {request_hash}\n" f"Request: {method} {url} {body}\n" f"Model: {body.get('model', 'unknown')}\n" f"To record this response, run with LLAMA_STACK_TEST_INFERENCE_MODE=record" ) - # If RECORD_IF_MISSING and no recording found, fall through to record - if _current_mode in (APIRecordingMode.RECORD, APIRecordingMode.RECORD_IF_MISSING): - # Check in-memory cache first (collision detection) - global _memory_cache - if request_hash in _memory_cache: - # Return the cached response instead of making a new API call - cached_recording = _memory_cache[request_hash] - response_body = cached_recording["response"]["body"] - - if cached_recording["response"].get("is_streaming", False) or cached_recording["response"].get( - "is_paginated", False - ): - - async def replay_cached_stream(): - for chunk in response_body: - yield chunk - - return replay_cached_stream() - else: - return response_body - - # No cached response, make the API call and record it - if client_type == "litellm": - response = await original_method(*args, **kwargs) + if mode == APIRecordingMode.RECORD or (mode == APIRecordingMode.RECORD_IF_MISSING and not recording): + if endpoint == "/v1/models": + response = original_method(self, *args, **kwargs) else: response = await original_method(self, *args, **kwargs) + # we want to store the result of the iterator, not the iterator itself + if endpoint == "/v1/models": + response = [m async for m in response] + request_data = { "method": method, "url": url, @@ -688,20 +618,16 @@ async def _patched_inference_method(original_method, self, client_type, endpoint # Determine if this is a streaming request based on request parameters is_streaming = body.get("stream", False) - # Special case: /v1/models is a paginated endpoint that returns an async iterator - is_paginated = endpoint == "/v1/models" - - if is_streaming or is_paginated: - # For streaming/paginated responses, we need to collect all chunks immediately before yielding + if is_streaming: + # For streaming responses, we need to collect all chunks immediately before yielding # This ensures the recording is saved even if the generator isn't fully consumed chunks = [] async for chunk in response: chunks.append(chunk) - # Store the recording immediately (both in memory and on disk) - # For paginated endpoints, mark as paginated rather than streaming - response_data = {"body": chunks, "is_streaming": is_streaming, "is_paginated": is_paginated} - _current_storage.store_recording(request_hash, request_data, response_data) + # Store the recording immediately + response_data = {"body": chunks, "is_streaming": True} + storage.store_recording(request_hash, request_data, response_data) # Return a generator that replays the stored chunks async def replay_recorded_stream(): @@ -711,24 +637,20 @@ async def _patched_inference_method(original_method, self, client_type, endpoint return replay_recorded_stream() else: response_data = {"body": response, "is_streaming": False} - # Store the response (both in memory and on disk) - _current_storage.store_recording(request_hash, request_data, response_data) + storage.store_recording(request_hash, request_data, response_data) return response else: - raise AssertionError(f"Invalid mode: {_current_mode}") - + raise AssertionError(f"Invalid mode: {mode}") finally: - # Reset test context if we set it in server mode - if test_context_token is not None: + if test_context_token: _test_context.reset(test_context_token) -def patch_api_clients(): - """Install monkey patches for inference clients and tool runtime methods.""" +def patch_inference_clients(): + """Install monkey patches for OpenAI client methods, Ollama AsyncClient methods, and tool runtime methods.""" global _original_methods - import litellm from ollama import AsyncClient as OllamaAsyncClient from openai.resources.chat.completions import AsyncCompletions as AsyncChatCompletions from openai.resources.completions import AsyncCompletions @@ -736,9 +658,8 @@ def patch_api_clients(): from openai.resources.models import AsyncModels from llama_stack.providers.remote.tool_runtime.tavily_search.tavily_search import TavilySearchToolRuntimeImpl - from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin - # Store original methods for OpenAI, Ollama, LiteLLM clients, and tool runtimes + # Store original methods for OpenAI, Ollama clients, and tool runtimes _original_methods = { "chat_completions_create": AsyncChatCompletions.create, "completions_create": AsyncCompletions.create, @@ -750,9 +671,6 @@ def patch_api_clients(): "ollama_ps": OllamaAsyncClient.ps, "ollama_pull": OllamaAsyncClient.pull, "ollama_list": OllamaAsyncClient.list, - "litellm_acompletion": litellm.acompletion, - "litellm_atext_completion": litellm.atext_completion, - "litellm_openai_mixin_get_api_key": LiteLLMOpenAIMixin.get_api_key, "tavily_invoke_tool": TavilySearchToolRuntimeImpl.invoke_tool, } @@ -774,18 +692,10 @@ def patch_api_clients(): def patched_models_list(self, *args, **kwargs): async def _iter(): - result = await _patched_inference_method( + for item in await _patched_inference_method( _original_methods["models_list"], self, "openai", "/v1/models", *args, **kwargs - ) - # The result is either an async generator (streaming/paginated) or a list - # If it's an async generator, iterate through it - if hasattr(result, "__aiter__"): - async for item in result: - yield item - else: - # It's a list, yield each item - for item in result: - yield item + ): + yield item return _iter() @@ -834,33 +744,6 @@ def patch_api_clients(): OllamaAsyncClient.pull = patched_ollama_pull OllamaAsyncClient.list = patched_ollama_list - # Create patched methods for LiteLLM - async def patched_litellm_acompletion(*args, **kwargs): - return await _patched_inference_method( - _original_methods["litellm_acompletion"], None, "litellm", "/chat/completions", *args, **kwargs - ) - - async def patched_litellm_atext_completion(*args, **kwargs): - return await _patched_inference_method( - _original_methods["litellm_atext_completion"], None, "litellm", "/completions", *args, **kwargs - ) - - # Apply LiteLLM patches - litellm.acompletion = patched_litellm_acompletion - litellm.atext_completion = patched_litellm_atext_completion - - # Create patched method for LiteLLMOpenAIMixin.get_api_key - def patched_litellm_get_api_key(self): - global _current_mode - if _current_mode != APIRecordingMode.REPLAY: - return _original_methods["litellm_openai_mixin_get_api_key"](self) - else: - # For record/replay modes, return a fake API key to avoid exposing real credentials - return "fake-api-key-for-testing" - - # Apply LiteLLMOpenAIMixin patch - LiteLLMOpenAIMixin.get_api_key = patched_litellm_get_api_key - # Create patched methods for tool runtimes async def patched_tavily_invoke_tool(self, tool_name: str, kwargs: dict[str, Any]): return await _patched_tool_invoke_method( @@ -871,15 +754,14 @@ def patch_api_clients(): TavilySearchToolRuntimeImpl.invoke_tool = patched_tavily_invoke_tool -def unpatch_api_clients(): - """Remove monkey patches and restore original client methods and tool runtimes.""" - global _original_methods, _memory_cache +def unpatch_inference_clients(): + """Remove monkey patches and restore original OpenAI, Ollama client, and tool runtime methods.""" + global _original_methods if not _original_methods: return # Import here to avoid circular imports - import litellm from ollama import AsyncClient as OllamaAsyncClient from openai.resources.chat.completions import AsyncCompletions as AsyncChatCompletions from openai.resources.completions import AsyncCompletions @@ -887,7 +769,6 @@ def unpatch_api_clients(): from openai.resources.models import AsyncModels from llama_stack.providers.remote.tool_runtime.tavily_search.tavily_search import TavilySearchToolRuntimeImpl - from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin # Restore OpenAI client methods AsyncChatCompletions.create = _original_methods["chat_completions_create"] @@ -903,19 +784,11 @@ def unpatch_api_clients(): OllamaAsyncClient.pull = _original_methods["ollama_pull"] OllamaAsyncClient.list = _original_methods["ollama_list"] - # Restore LiteLLM methods - litellm.acompletion = _original_methods["litellm_acompletion"] - litellm.atext_completion = _original_methods["litellm_atext_completion"] - LiteLLMOpenAIMixin.get_api_key = _original_methods["litellm_openai_mixin_get_api_key"] - # Restore tool runtime methods TavilySearchToolRuntimeImpl.invoke_tool = _original_methods["tavily_invoke_tool"] _original_methods.clear() - # Clear memory cache to prevent memory leaks - _memory_cache.clear() - @contextmanager def api_recording(mode: str, storage_dir: str | Path | None = None) -> Generator[None, None, None]: @@ -933,14 +806,14 @@ def api_recording(mode: str, storage_dir: str | Path | None = None) -> Generator if storage_dir is None: raise ValueError("storage_dir is required for record, replay, and record-if-missing modes") _current_storage = ResponseStorage(Path(storage_dir)) - patch_api_clients() + patch_inference_clients() yield finally: # Restore previous state if mode in ["record", "replay", "record-if-missing"]: - unpatch_api_clients() + unpatch_inference_clients() _current_mode = prev_mode _current_storage = prev_storage