# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. from __future__ import annotations # for forward references import hashlib import json import os from collections.abc import Generator from contextlib import contextmanager from enum import StrEnum from pathlib import Path from typing import Any, Literal, cast from openai import NOT_GIVEN from llama_stack.log import get_logger logger = get_logger(__name__, category="testing") # Global state for the API recording system # Note: Using module globals instead of ContextVars because the session-scoped # client initialization happens in one async context, but tests run in different # contexts, and we need the mode/storage to persist across all contexts. _current_mode: str | None = None _current_storage: ResponseStorage | None = None _original_methods: dict[str, Any] = {} _memory_cache: dict[str, dict[str, Any]] = {} # Test context uses ContextVar since it changes per-test and needs async isolation from contextvars import ContextVar _test_context: ContextVar[str | None] = ContextVar("_test_context", default=None) from openai.types.completion_choice import CompletionChoice # update the "finish_reason" field, since its type definition is wrong (no None is accepted) CompletionChoice.model_fields["finish_reason"].annotation = Literal["stop", "length", "content_filter"] | None CompletionChoice.model_rebuild() REPO_ROOT = Path(__file__).parent.parent.parent DEFAULT_STORAGE_DIR = REPO_ROOT / "tests/integration/common" class APIRecordingMode(StrEnum): LIVE = "live" RECORD = "record" REPLAY = "replay" RECORD_IF_MISSING = "record-if-missing" def _normalize_file_ids(obj: Any) -> Any: """Recursively replace file IDs with a canonical placeholder for consistent hashing.""" import re if isinstance(obj, dict): result = {} for k, v in obj.items(): # Normalize file IDs in attribute dictionaries if k == "document_id" and isinstance(v, str) and v.startswith("file-"): result[k] = "file-NORMALIZED" else: result[k] = _normalize_file_ids(v) return result elif isinstance(obj, list): return [_normalize_file_ids(item) for item in obj] elif isinstance(obj, str): # Replace file- patterns in strings (like in text content) return re.sub(r"file-[a-f0-9]{32}", "file-NORMALIZED", obj) else: return obj def _normalize_response_data(data: dict[str, Any], request_hash: str) -> dict[str, Any]: """Normalize fields that change between recordings but don't affect functionality. This reduces noise in git diffs by making IDs deterministic and timestamps constant. """ # Only normalize ID for completion/chat responses, not for model objects # Model objects have "object": "model" and the ID is the actual model identifier if "id" in data and data.get("object") != "model": data["id"] = f"rec-{request_hash[:12]}" # Normalize timestamp to epoch (0) (for OpenAI-style responses) # But not for model objects where created timestamp might be meaningful if "created" in data and data.get("object") != "model": data["created"] = 0 # Normalize Ollama-specific timestamp fields if "created_at" in data: data["created_at"] = "1970-01-01T00:00:00.000000Z" # Normalize Ollama-specific duration fields (these vary based on system load) if "total_duration" in data and data["total_duration"] is not None: data["total_duration"] = 0 if "load_duration" in data and data["load_duration"] is not None: data["load_duration"] = 0 if "prompt_eval_duration" in data and data["prompt_eval_duration"] is not None: data["prompt_eval_duration"] = 0 if "eval_duration" in data and data["eval_duration"] is not None: data["eval_duration"] = 0 return data def normalize_request(method: str, url: str, headers: dict[str, Any], body: dict[str, Any]) -> str: """Create a normalized hash of the request for consistent matching. Includes test_id from context to ensure test isolation - identical requests from different tests will have different hashes. """ # Extract just the endpoint path from urllib.parse import urlparse parsed = urlparse(url) # Normalize file IDs in the body to ensure consistent hashing across test runs normalized_body = _normalize_file_ids(body) normalized: dict[str, Any] = {"method": method.upper(), "endpoint": parsed.path, "body": normalized_body} # Include test_id for isolation, except for shared infrastructure endpoints if parsed.path not in ("/api/tags", "/v1/models"): # Server mode: test ID was synced from provider_data to _test_context # by _sync_test_context_from_provider_data() at the start of the request. # We read from _test_context because it's available in all contexts (including # when making outgoing API calls), whereas PROVIDER_DATA_VAR is only set # for incoming HTTP requests. # # Library client mode: test ID in same-process ContextVar test_id = _test_context.get() normalized["test_id"] = test_id # Create hash - sort_keys=True ensures deterministic ordering normalized_json = json.dumps(normalized, sort_keys=True) request_hash = hashlib.sha256(normalized_json.encode()).hexdigest() return request_hash def normalize_tool_request(provider_name: str, tool_name: str, kwargs: dict[str, Any]) -> str: """Create a normalized hash of the tool request for consistent matching.""" normalized = {"provider": provider_name, "tool_name": tool_name, "kwargs": kwargs} # Create hash - sort_keys=True ensures deterministic ordering normalized_json = json.dumps(normalized, sort_keys=True) return hashlib.sha256(normalized_json.encode()).hexdigest() @contextmanager def set_test_context(test_id: str) -> Generator[None, None, None]: """Set the test context for recording isolation. Usage: with set_test_context("test_basic_completion"): # Make API calls that will be recorded with this test_id response = client.chat.completions.create(...) """ token = _test_context.set(test_id) try: yield finally: _test_context.reset(token) def patch_httpx_for_test_id(): """Patch client _prepare_request methods to inject test ID into provider data header. Patches both LlamaStackClient and OpenAI client to ensure test ID is transported from client to server via HTTP headers in server mode. This is needed for server mode where the test ID must be transported from client to server via HTTP headers. In library_client mode, this patch is a no-op since everything runs in the same process. We use the _prepare_request hook that Stainless clients provide for mutating requests after construction but before sending. """ from llama_stack_client import LlamaStackClient if "llama_stack_client_prepare_request" in _original_methods: # Already patched return # Save original methods _original_methods["llama_stack_client_prepare_request"] = LlamaStackClient._prepare_request # Also patch OpenAI client if available (used in compat tests) try: from openai import OpenAI _original_methods["openai_prepare_request"] = OpenAI._prepare_request except ImportError: pass def patched_prepare_request(self, request): # Call original first (it's a sync method that returns None) # Determine which original to call based on client type if "llama_stack_client" in self.__class__.__module__: _original_methods["llama_stack_client_prepare_request"](self, request) elif "openai_prepare_request" in _original_methods: _original_methods["openai_prepare_request"](self, request) # Only inject test ID in server mode stack_config_type = os.environ.get("LLAMA_STACK_TEST_STACK_CONFIG_TYPE", "library_client") test_id = _test_context.get() if stack_config_type == "server" and test_id: # Get existing provider data header or create new dict provider_data_header = request.headers.get("X-LlamaStack-Provider-Data") if provider_data_header: provider_data = json.loads(provider_data_header) else: provider_data = {} # Inject test ID provider_data["__test_id"] = test_id request.headers["X-LlamaStack-Provider-Data"] = json.dumps(provider_data) # Sync version returns None return None # Apply patches LlamaStackClient._prepare_request = patched_prepare_request if "openai_prepare_request" in _original_methods: from openai import OpenAI OpenAI._prepare_request = patched_prepare_request def unpatch_httpx_for_test_id(): """Remove client _prepare_request patches for test ID injection.""" if "llama_stack_client_prepare_request" not in _original_methods: return from llama_stack_client import LlamaStackClient LlamaStackClient._prepare_request = _original_methods["llama_stack_client_prepare_request"] del _original_methods["llama_stack_client_prepare_request"] # Also restore OpenAI client if it was patched if "openai_prepare_request" in _original_methods: from openai import OpenAI OpenAI._prepare_request = _original_methods["openai_prepare_request"] del _original_methods["openai_prepare_request"] def get_api_recording_mode() -> APIRecordingMode: return APIRecordingMode(os.environ.get("LLAMA_STACK_TEST_INFERENCE_MODE", "replay").lower()) def setup_api_recording(): """ Returns a context manager that can be used to record or replay API requests (inference and tools). This is to be used in tests to increase their reliability and reduce reliance on expensive, external services. Currently supports: - Inference: OpenAI, Ollama, and LiteLLM clients - Tools: Search providers (Tavily for now) Two environment variables are supported: - LLAMA_STACK_TEST_INFERENCE_MODE: The mode to run in. Options: - 'live': Make real API calls, no recording - 'record': Record all API interactions (overwrites existing) - 'replay': Use recorded responses only (default) - 'record-if-missing': Replay when possible, record when recording doesn't exist - LLAMA_STACK_TEST_RECORDING_DIR: The directory to store the recordings in. Default is 'tests/integration/recordings'. The recordings are stored as JSON files. """ mode = get_api_recording_mode() if mode == APIRecordingMode.LIVE: return None storage_dir = os.environ.get("LLAMA_STACK_TEST_RECORDING_DIR", DEFAULT_STORAGE_DIR) return api_recording(mode=mode, storage_dir=storage_dir) def _serialize_response(response: Any) -> Any: if hasattr(response, "model_dump"): data = response.model_dump(mode="json") return { "__type__": f"{response.__class__.__module__}.{response.__class__.__qualname__}", "__data__": data, } elif hasattr(response, "__dict__"): return dict(response.__dict__) else: return response def _deserialize_response(data: dict[str, Any]) -> Any: # Check if this is a serialized Pydantic model with type information if isinstance(data, dict) and "__type__" in data and "__data__" in data: try: # Import the original class and reconstruct the object module_path, class_name = data["__type__"].rsplit(".", 1) module = __import__(module_path, fromlist=[class_name]) cls = getattr(module, class_name) if not hasattr(cls, "model_validate"): raise ValueError(f"Pydantic class {cls} does not support model_validate?") return cls.model_validate(data["__data__"]) except (ImportError, AttributeError, TypeError, ValueError) as e: logger.warning(f"Failed to deserialize object of type {data['__type__']}: {e}") return data["__data__"] return data class ResponseStorage: """Handles storage/retrieval for API recordings (inference and tools).""" def __init__(self, base_dir: Path): self.base_dir = base_dir def _get_test_dir(self) -> Path: """Get the recordings directory in the test file's parent directory. For test at "tests/integration/inference/test_foo.py::test_bar", returns "tests/integration/inference/recordings/". """ test_id = _test_context.get() if test_id: # Extract the directory path from the test nodeid # e.g., "tests/integration/inference/test_basic.py::test_foo[params]" # -> get "tests/integration/inference" test_file = test_id.split("::")[0] # Remove test function part test_dir = Path(test_file).parent # Get parent directory # Put recordings in a "recordings" subdirectory of the test's parent dir # e.g., "tests/integration/inference" -> "tests/integration/inference/recordings" return test_dir / "recordings" else: # Fallback for non-test contexts return self.base_dir / "recordings" def _ensure_directories(self) -> Path: test_dir = self._get_test_dir() test_dir.mkdir(parents=True, exist_ok=True) return test_dir def store_recording(self, request_hash: str, request: dict[str, Any], response: dict[str, Any]): """Store a request/response pair both in memory cache and on disk.""" global _memory_cache # Store in memory cache first _memory_cache[request_hash] = {"request": request, "response": response} responses_dir = self._ensure_directories() # Generate unique response filename using full hash response_file = f"{request_hash}.json" # Serialize response body if needed serialized_response = dict(response) if "body" in serialized_response: if isinstance(serialized_response["body"], list): # Handle streaming responses (list of chunks) serialized_response["body"] = [_serialize_response(chunk) for chunk in serialized_response["body"]] else: # Handle single response serialized_response["body"] = _serialize_response(serialized_response["body"]) # If this is a model-list endpoint recording, include models digest in filename to distinguish variants endpoint = request.get("endpoint") test_id = _test_context.get() if endpoint in ("/api/tags", "/v1/models"): test_id = None digest = _model_identifiers_digest(endpoint, response) response_file = f"models-{request_hash}-{digest}.json" response_path = responses_dir / response_file # Save response to JSON file with open(response_path, "w") as f: json.dump({"test_id": test_id, "request": request, "response": serialized_response}, f, indent=2) f.write("\n") f.flush() def find_recording(self, request_hash: str) -> dict[str, Any] | None: """Find a recorded response by request hash.""" response_file = f"{request_hash}.json" # Check test-specific directory first test_dir = self._get_test_dir() response_path = test_dir / response_file if response_path.exists(): return _recording_from_file(response_path) # Fallback to base recordings directory (for session-level recordings) fallback_dir = self.base_dir / "recordings" fallback_path = fallback_dir / response_file if fallback_path.exists(): return _recording_from_file(fallback_path) return None def _model_list_responses(self, request_hash: str) -> list[dict[str, Any]]: """Find all model-list recordings with the given hash (different digests).""" results: list[dict[str, Any]] = [] # Check test-specific directory first test_dir = self._get_test_dir() if test_dir.exists(): for path in test_dir.glob(f"models-{request_hash}-*.json"): data = _recording_from_file(path) results.append(data) # Also check fallback directory fallback_dir = self.base_dir / "recordings" if fallback_dir.exists(): for path in fallback_dir.glob(f"models-{request_hash}-*.json"): data = _recording_from_file(path) results.append(data) return results def _recording_from_file(response_path) -> dict[str, Any]: with open(response_path) as f: data = json.load(f) # Deserialize response body if needed if "response" in data and "body" in data["response"]: if isinstance(data["response"]["body"], list): # Handle streaming responses data["response"]["body"] = [_deserialize_response(chunk) for chunk in data["response"]["body"]] else: # Handle single response data["response"]["body"] = _deserialize_response(data["response"]["body"]) return cast(dict[str, Any], data) def _model_identifiers_digest(endpoint: str, response: dict[str, Any]) -> str: """Generate a digest from model identifiers for distinguishing different model sets.""" def _extract_model_identifiers(): """Extract a stable set of identifiers for model-list endpoints. Supported endpoints: - '/api/tags' (Ollama): response body has 'models': [ { name/model/digest/id/... }, ... ] - '/v1/models' (OpenAI): response body is: [ { id: ... }, ... ] Returns a list of unique identifiers or None if structure doesn't match. """ if "models" in response["body"]: # ollama items = response["body"]["models"] else: # openai items = response["body"] idents = [m.model if endpoint == "/api/tags" else m.id for m in items] return sorted(set(idents)) identifiers = _extract_model_identifiers() return hashlib.sha256(("|".join(identifiers)).encode("utf-8")).hexdigest()[:8] def _combine_model_list_responses(endpoint: str, records: list[dict[str, Any]]) -> dict[str, Any] | None: """Return a single, unioned recording for supported model-list endpoints. Merges multiple recordings with different model sets (from different servers) into a single response containing all models. """ if not records: return None seen: dict[str, dict[str, Any]] = {} for rec in records: body = rec["response"]["body"] if endpoint == "/v1/models": for m in body: key = m.id seen[key] = m elif endpoint == "/api/tags": for m in body.models: key = m.model seen[key] = m ordered = [seen[k] for k in sorted(seen.keys())] canonical = records[0] canonical_req = canonical.get("request", {}) if isinstance(canonical_req, dict): canonical_req["endpoint"] = endpoint body = ordered if endpoint == "/api/tags": from ollama import ListResponse body = ListResponse(models=ordered) return {"request": canonical_req, "response": {"body": body, "is_streaming": False}} async def _patched_tool_invoke_method( original_method, provider_name: str, self, tool_name: str, kwargs: dict[str, Any] ): """Patched version of tool runtime invoke_tool method for recording/replay.""" global _current_mode, _current_storage if _current_mode == APIRecordingMode.LIVE or _current_storage is None: # Normal operation return await original_method(self, tool_name, kwargs) # In server mode, sync test ID from provider_data to _test_context for storage operations test_context_token = _sync_test_context_from_provider_data() try: request_hash = normalize_tool_request(provider_name, tool_name, kwargs) if _current_mode in (APIRecordingMode.REPLAY, APIRecordingMode.RECORD_IF_MISSING): recording = _current_storage.find_recording(request_hash) if recording: return recording["response"]["body"] elif _current_mode == APIRecordingMode.REPLAY: raise RuntimeError( f"No recorded tool result found for {provider_name}.{tool_name}\n" f"Request: {kwargs}\n" f"To record this response, run with LLAMA_STACK_TEST_INFERENCE_MODE=record" ) # If RECORD_IF_MISSING and no recording found, fall through to record if _current_mode in (APIRecordingMode.RECORD, APIRecordingMode.RECORD_IF_MISSING): # Check in-memory cache first (collision detection) global _memory_cache if request_hash in _memory_cache: # Return the cached response instead of making a new tool call return _memory_cache[request_hash]["response"]["body"] # No cached response, make the tool call and record it result = await original_method(self, tool_name, kwargs) request_data = { "provider": provider_name, "tool_name": tool_name, "kwargs": kwargs, } response_data = {"body": result, "is_streaming": False} # Store the recording (both in memory and on disk) _current_storage.store_recording(request_hash, request_data, response_data) return result else: raise AssertionError(f"Invalid mode: {_current_mode}") finally: # Reset test context if we set it in server mode if test_context_token is not None: _test_context.reset(test_context_token) def _sync_test_context_from_provider_data(): """In server mode, sync test ID from provider_data to _test_context. This ensures that storage operations (which read from _test_context) work correctly in server mode where the test ID arrives via HTTP header → provider_data. Returns a token to reset _test_context, or None if no sync was needed. """ stack_config_type = os.environ.get("LLAMA_STACK_TEST_STACK_CONFIG_TYPE", "library_client") if stack_config_type != "server": return None try: from llama_stack.core.request_headers import PROVIDER_DATA_VAR provider_data = PROVIDER_DATA_VAR.get() if provider_data and "__test_id" in provider_data: test_id = provider_data["__test_id"] return _test_context.set(test_id) except ImportError: pass return None async def _patched_inference_method(original_method, self, client_type, endpoint, *args, **kwargs): global _current_mode, _current_storage if _current_mode == APIRecordingMode.LIVE or _current_storage is None: # Normal operation if client_type == "litellm": return await original_method(*args, **kwargs) else: return await original_method(self, *args, **kwargs) # In server mode, sync test ID from provider_data to _test_context for storage operations test_context_token = _sync_test_context_from_provider_data() try: # Get base URL based on client type if client_type == "openai": base_url = str(self._client.base_url) # the OpenAI client methods may pass NOT_GIVEN for unset parameters; filter these out kwargs = {k: v for k, v in kwargs.items() if v is not NOT_GIVEN} elif client_type == "ollama": # Get base URL from the client (Ollama client uses host attribute) base_url = getattr(self, "host", "http://localhost:11434") if not base_url.startswith("http"): base_url = f"http://{base_url}" elif client_type == "litellm": # For LiteLLM, extract base URL from kwargs if available base_url = kwargs.get("api_base", "https://api.openai.com") else: raise ValueError(f"Unknown client type: {client_type}") url = base_url.rstrip("/") + endpoint method = "POST" headers = {} body = kwargs request_hash = normalize_request(method, url, headers, body) if _current_mode in (APIRecordingMode.REPLAY, APIRecordingMode.RECORD_IF_MISSING): # Special handling for model-list endpoints: return union of all responses if endpoint in ("/api/tags", "/v1/models"): records = _current_storage._model_list_responses(request_hash) recording = _combine_model_list_responses(endpoint, records) else: recording = _current_storage.find_recording(request_hash) if recording: response_body = recording["response"]["body"] if recording["response"].get("is_streaming", False) or recording["response"].get("is_paginated", False): async def replay_stream(): for chunk in response_body: yield chunk return replay_stream() else: return response_body elif _current_mode == APIRecordingMode.REPLAY: raise RuntimeError( f"No recorded response found for request hash: {request_hash}\n" f"Request: {method} {url} {body}\n" f"Model: {body.get('model', 'unknown')}\n" f"To record this response, run with LLAMA_STACK_TEST_INFERENCE_MODE=record" ) # If RECORD_IF_MISSING and no recording found, fall through to record if _current_mode in (APIRecordingMode.RECORD, APIRecordingMode.RECORD_IF_MISSING): # Check in-memory cache first (collision detection) global _memory_cache if request_hash in _memory_cache: # Return the cached response instead of making a new API call cached_recording = _memory_cache[request_hash] response_body = cached_recording["response"]["body"] if cached_recording["response"].get("is_streaming", False) or cached_recording["response"].get( "is_paginated", False ): async def replay_cached_stream(): for chunk in response_body: yield chunk return replay_cached_stream() else: return response_body # No cached response, make the API call and record it if client_type == "litellm": response = await original_method(*args, **kwargs) else: response = await original_method(self, *args, **kwargs) request_data = { "method": method, "url": url, "headers": headers, "body": body, "endpoint": endpoint, "model": body.get("model", ""), } # Determine if this is a streaming request based on request parameters is_streaming = body.get("stream", False) # Special case: /v1/models is a paginated endpoint that returns an async iterator is_paginated = endpoint == "/v1/models" if is_streaming or is_paginated: # For streaming/paginated responses, we need to collect all chunks immediately before yielding # This ensures the recording is saved even if the generator isn't fully consumed chunks = [] async for chunk in response: chunks.append(chunk) # Store the recording immediately (both in memory and on disk) # For paginated endpoints, mark as paginated rather than streaming response_data = {"body": chunks, "is_streaming": is_streaming, "is_paginated": is_paginated} _current_storage.store_recording(request_hash, request_data, response_data) # Return a generator that replays the stored chunks async def replay_recorded_stream(): for chunk in chunks: yield chunk return replay_recorded_stream() else: response_data = {"body": response, "is_streaming": False} # Store the response (both in memory and on disk) _current_storage.store_recording(request_hash, request_data, response_data) return response else: raise AssertionError(f"Invalid mode: {_current_mode}") finally: # Reset test context if we set it in server mode if test_context_token is not None: _test_context.reset(test_context_token) def patch_api_clients(): """Install monkey patches for inference clients and tool runtime methods.""" global _original_methods import litellm from ollama import AsyncClient as OllamaAsyncClient from openai.resources.chat.completions import AsyncCompletions as AsyncChatCompletions from openai.resources.completions import AsyncCompletions from openai.resources.embeddings import AsyncEmbeddings from openai.resources.models import AsyncModels from llama_stack.providers.remote.tool_runtime.tavily_search.tavily_search import TavilySearchToolRuntimeImpl from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin # Store original methods for OpenAI, Ollama, LiteLLM clients, and tool runtimes _original_methods = { "chat_completions_create": AsyncChatCompletions.create, "completions_create": AsyncCompletions.create, "embeddings_create": AsyncEmbeddings.create, "models_list": AsyncModels.list, "ollama_generate": OllamaAsyncClient.generate, "ollama_chat": OllamaAsyncClient.chat, "ollama_embed": OllamaAsyncClient.embed, "ollama_ps": OllamaAsyncClient.ps, "ollama_pull": OllamaAsyncClient.pull, "ollama_list": OllamaAsyncClient.list, "litellm_acompletion": litellm.acompletion, "litellm_atext_completion": litellm.atext_completion, "litellm_openai_mixin_get_api_key": LiteLLMOpenAIMixin.get_api_key, "tavily_invoke_tool": TavilySearchToolRuntimeImpl.invoke_tool, } # Create patched methods for OpenAI client async def patched_chat_completions_create(self, *args, **kwargs): return await _patched_inference_method( _original_methods["chat_completions_create"], self, "openai", "/v1/chat/completions", *args, **kwargs ) async def patched_completions_create(self, *args, **kwargs): return await _patched_inference_method( _original_methods["completions_create"], self, "openai", "/v1/completions", *args, **kwargs ) async def patched_embeddings_create(self, *args, **kwargs): return await _patched_inference_method( _original_methods["embeddings_create"], self, "openai", "/v1/embeddings", *args, **kwargs ) def patched_models_list(self, *args, **kwargs): async def _iter(): result = await _patched_inference_method( _original_methods["models_list"], self, "openai", "/v1/models", *args, **kwargs ) # The result is either an async generator (streaming/paginated) or a list # If it's an async generator, iterate through it if hasattr(result, "__aiter__"): async for item in result: yield item else: # It's a list, yield each item for item in result: yield item return _iter() # Apply OpenAI patches AsyncChatCompletions.create = patched_chat_completions_create AsyncCompletions.create = patched_completions_create AsyncEmbeddings.create = patched_embeddings_create AsyncModels.list = patched_models_list # Create patched methods for Ollama client async def patched_ollama_generate(self, *args, **kwargs): return await _patched_inference_method( _original_methods["ollama_generate"], self, "ollama", "/api/generate", *args, **kwargs ) async def patched_ollama_chat(self, *args, **kwargs): return await _patched_inference_method( _original_methods["ollama_chat"], self, "ollama", "/api/chat", *args, **kwargs ) async def patched_ollama_embed(self, *args, **kwargs): return await _patched_inference_method( _original_methods["ollama_embed"], self, "ollama", "/api/embeddings", *args, **kwargs ) async def patched_ollama_ps(self, *args, **kwargs): return await _patched_inference_method( _original_methods["ollama_ps"], self, "ollama", "/api/ps", *args, **kwargs ) async def patched_ollama_pull(self, *args, **kwargs): return await _patched_inference_method( _original_methods["ollama_pull"], self, "ollama", "/api/pull", *args, **kwargs ) async def patched_ollama_list(self, *args, **kwargs): return await _patched_inference_method( _original_methods["ollama_list"], self, "ollama", "/api/tags", *args, **kwargs ) # Apply Ollama patches OllamaAsyncClient.generate = patched_ollama_generate OllamaAsyncClient.chat = patched_ollama_chat OllamaAsyncClient.embed = patched_ollama_embed OllamaAsyncClient.ps = patched_ollama_ps OllamaAsyncClient.pull = patched_ollama_pull OllamaAsyncClient.list = patched_ollama_list # Create patched methods for LiteLLM async def patched_litellm_acompletion(*args, **kwargs): return await _patched_inference_method( _original_methods["litellm_acompletion"], None, "litellm", "/chat/completions", *args, **kwargs ) async def patched_litellm_atext_completion(*args, **kwargs): return await _patched_inference_method( _original_methods["litellm_atext_completion"], None, "litellm", "/completions", *args, **kwargs ) # Apply LiteLLM patches litellm.acompletion = patched_litellm_acompletion litellm.atext_completion = patched_litellm_atext_completion # Create patched method for LiteLLMOpenAIMixin.get_api_key def patched_litellm_get_api_key(self): global _current_mode if _current_mode != APIRecordingMode.REPLAY: return _original_methods["litellm_openai_mixin_get_api_key"](self) else: # For record/replay modes, return a fake API key to avoid exposing real credentials return "fake-api-key-for-testing" # Apply LiteLLMOpenAIMixin patch LiteLLMOpenAIMixin.get_api_key = patched_litellm_get_api_key # Create patched methods for tool runtimes async def patched_tavily_invoke_tool(self, tool_name: str, kwargs: dict[str, Any]): return await _patched_tool_invoke_method( _original_methods["tavily_invoke_tool"], "tavily", self, tool_name, kwargs ) # Apply tool runtime patches TavilySearchToolRuntimeImpl.invoke_tool = patched_tavily_invoke_tool def unpatch_api_clients(): """Remove monkey patches and restore original client methods and tool runtimes.""" global _original_methods, _memory_cache if not _original_methods: return # Import here to avoid circular imports import litellm from ollama import AsyncClient as OllamaAsyncClient from openai.resources.chat.completions import AsyncCompletions as AsyncChatCompletions from openai.resources.completions import AsyncCompletions from openai.resources.embeddings import AsyncEmbeddings from openai.resources.models import AsyncModels from llama_stack.providers.remote.tool_runtime.tavily_search.tavily_search import TavilySearchToolRuntimeImpl from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin # Restore OpenAI client methods AsyncChatCompletions.create = _original_methods["chat_completions_create"] AsyncCompletions.create = _original_methods["completions_create"] AsyncEmbeddings.create = _original_methods["embeddings_create"] AsyncModels.list = _original_methods["models_list"] # Restore Ollama client methods if they were patched OllamaAsyncClient.generate = _original_methods["ollama_generate"] OllamaAsyncClient.chat = _original_methods["ollama_chat"] OllamaAsyncClient.embed = _original_methods["ollama_embed"] OllamaAsyncClient.ps = _original_methods["ollama_ps"] OllamaAsyncClient.pull = _original_methods["ollama_pull"] OllamaAsyncClient.list = _original_methods["ollama_list"] # Restore LiteLLM methods litellm.acompletion = _original_methods["litellm_acompletion"] litellm.atext_completion = _original_methods["litellm_atext_completion"] LiteLLMOpenAIMixin.get_api_key = _original_methods["litellm_openai_mixin_get_api_key"] # Restore tool runtime methods TavilySearchToolRuntimeImpl.invoke_tool = _original_methods["tavily_invoke_tool"] _original_methods.clear() # Clear memory cache to prevent memory leaks _memory_cache.clear() @contextmanager def api_recording(mode: str, storage_dir: str | Path | None = None) -> Generator[None, None, None]: """Context manager for API recording/replaying (inference and tools).""" global _current_mode, _current_storage # Store previous state prev_mode = _current_mode prev_storage = _current_storage try: _current_mode = mode if mode in ["record", "replay", "record-if-missing"]: if storage_dir is None: raise ValueError("storage_dir is required for record, replay, and record-if-missing modes") _current_storage = ResponseStorage(Path(storage_dir)) patch_api_clients() yield finally: # Restore previous state if mode in ["record", "replay", "record-if-missing"]: unpatch_api_clients() _current_mode = prev_mode _current_storage = prev_storage