feat(tests): make inference_recorder into api_recorder (include tool_invoke)

This commit is contained in:
Ashwin Bharambe 2025-09-09 17:14:35 -07:00
parent dd1f946b3e
commit 8da9ff352c
45 changed files with 25144 additions and 41 deletions

View file

@ -321,10 +321,10 @@ async def construct_stack(
run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None
) -> dict[Api, Any]:
if "LLAMA_STACK_TEST_INFERENCE_MODE" in os.environ:
from llama_stack.testing.inference_recorder import setup_inference_recording
from llama_stack.testing.api_recorder import setup_api_recording
global TEST_RECORDING_CONTEXT
TEST_RECORDING_CONTEXT = setup_inference_recording()
TEST_RECORDING_CONTEXT = setup_api_recording()
if TEST_RECORDING_CONTEXT:
TEST_RECORDING_CONTEXT.__enter__()
logger.info(f"Inference recording enabled: mode={os.environ.get('LLAMA_STACK_TEST_INFERENCE_MODE')}")

View file

@ -19,7 +19,7 @@ from llama_stack.log import get_logger
logger = get_logger(__name__, category="testing")
# Global state for the recording system
# Global state for the API recording system
_current_mode: str | None = None
_current_storage: ResponseStorage | None = None
_original_methods: dict[str, Any] = {}
@ -34,7 +34,7 @@ REPO_ROOT = Path(__file__).parent.parent.parent
DEFAULT_STORAGE_DIR = REPO_ROOT / "tests/integration/recordings"
class InferenceMode(StrEnum):
class APIRecordingMode(StrEnum):
LIVE = "live"
RECORD = "record"
REPLAY = "replay"
@ -53,16 +53,27 @@ def normalize_request(method: str, url: str, headers: dict[str, Any], body: dict
return hashlib.sha256(normalized_json.encode()).hexdigest()
def get_inference_mode() -> InferenceMode:
return InferenceMode(os.environ.get("LLAMA_STACK_TEST_INFERENCE_MODE", "replay").lower())
def normalize_tool_request(provider_name: str, tool_name: str, kwargs: dict[str, Any]) -> str:
"""Create a normalized hash of the tool request for consistent matching."""
normalized = {"provider": provider_name, "tool_name": tool_name, "kwargs": kwargs}
# Create hash - sort_keys=True ensures deterministic ordering
normalized_json = json.dumps(normalized, sort_keys=True)
return hashlib.sha256(normalized_json.encode()).hexdigest()
def setup_inference_recording():
def get_api_recording_mode() -> APIRecordingMode:
return APIRecordingMode(os.environ.get("LLAMA_STACK_TEST_INFERENCE_MODE", "replay").lower())
def setup_api_recording():
"""
Returns a context manager that can be used to record or replay inference requests. This is to be used in tests
to increase their reliability and reduce reliance on expensive, external services.
Returns a context manager that can be used to record or replay API requests (inference and tools).
This is to be used in tests to increase their reliability and reduce reliance on expensive, external services.
Currently, this is only supported for OpenAI and Ollama clients. These should cover the vast majority of use cases.
Currently supports:
- Inference: OpenAI, Ollama, and LiteLLM clients
- Tools: Search providers (Tavily for now)
Two environment variables are supported:
- LLAMA_STACK_TEST_INFERENCE_MODE: The mode to run in. Must be 'live', 'record', or 'replay'. Default is 'replay'.
@ -70,12 +81,12 @@ def setup_inference_recording():
The recordings are stored as JSON files.
"""
mode = get_inference_mode()
if mode == InferenceMode.LIVE:
mode = get_api_recording_mode()
if mode == APIRecordingMode.LIVE:
return None
storage_dir = os.environ.get("LLAMA_STACK_TEST_RECORDING_DIR", DEFAULT_STORAGE_DIR)
return inference_recording(mode=mode, storage_dir=storage_dir)
return api_recording(mode=mode, storage_dir=storage_dir)
def _serialize_response(response: Any) -> Any:
@ -112,7 +123,7 @@ def _deserialize_response(data: dict[str, Any]) -> Any:
class ResponseStorage:
"""Handles SQLite index + JSON file storage/retrieval for inference recordings."""
"""Handles storage/retrieval for API recordings (inference and tools)."""
def __init__(self, test_dir: Path):
self.test_dir = test_dir
@ -243,12 +254,55 @@ def _combine_model_list_responses(endpoint: str, records: list[dict[str, Any]])
return {"request": canonical_req, "response": {"body": body, "is_streaming": False}}
async def _patched_tool_invoke_method(
original_method, provider_name: str, self, tool_name: str, kwargs: dict[str, Any]
):
"""Patched version of tool runtime invoke_tool method for recording/replay."""
global _current_mode, _current_storage
if _current_mode == APIRecordingMode.LIVE or _current_storage is None:
# Normal operation
return await original_method(self, tool_name, kwargs)
request_hash = normalize_tool_request(provider_name, tool_name, kwargs)
if _current_mode == APIRecordingMode.REPLAY:
recording = _current_storage.find_recording(request_hash)
if recording:
return recording["response"]["body"]
else:
raise RuntimeError(
f"No recorded tool result found for {provider_name}.{tool_name}\n"
f"Request: {kwargs}\n"
f"To record this response, run with LLAMA_STACK_TEST_INFERENCE_MODE=record"
)
elif _current_mode == APIRecordingMode.RECORD:
result = await original_method(self, tool_name, kwargs)
request_data = {
"provider": provider_name,
"tool_name": tool_name,
"kwargs": kwargs,
}
response_data = {"body": result, "is_streaming": False}
_current_storage.store_recording(request_hash, request_data, response_data)
return result
else:
raise AssertionError(f"Invalid mode: {_current_mode}")
async def _patched_inference_method(original_method, self, client_type, endpoint, *args, **kwargs):
global _current_mode, _current_storage
if _current_mode == InferenceMode.LIVE or _current_storage is None:
if _current_mode == APIRecordingMode.LIVE or _current_storage is None:
# Normal operation
return await original_method(self, *args, **kwargs)
if client_type == "litellm":
return await original_method(*args, **kwargs)
else:
return await original_method(self, *args, **kwargs)
# Get base URL based on client type
if client_type == "openai":
@ -258,6 +312,9 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
base_url = getattr(self, "host", "http://localhost:11434")
if not base_url.startswith("http"):
base_url = f"http://{base_url}"
elif client_type == "litellm":
# For LiteLLM, extract base URL from kwargs if available
base_url = kwargs.get("api_base", "https://api.openai.com")
else:
raise ValueError(f"Unknown client type: {client_type}")
@ -268,7 +325,7 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
request_hash = normalize_request(method, url, headers, body)
if _current_mode == InferenceMode.REPLAY:
if _current_mode == APIRecordingMode.REPLAY:
# Special handling for model-list endpoints: return union of all responses
if endpoint in ("/api/tags", "/v1/models"):
records = _current_storage._model_list_responses(request_hash[:12])
@ -295,8 +352,11 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
f"To record this response, run with LLAMA_STACK_TEST_INFERENCE_MODE=record"
)
elif _current_mode == InferenceMode.RECORD:
response = await original_method(self, *args, **kwargs)
elif _current_mode == APIRecordingMode.RECORD:
if client_type == "litellm":
response = await original_method(*args, **kwargs)
else:
response = await original_method(self, *args, **kwargs)
request_data = {
"method": method,
@ -336,17 +396,21 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
raise AssertionError(f"Invalid mode: {_current_mode}")
def patch_inference_clients():
"""Install monkey patches for OpenAI client methods and Ollama AsyncClient methods."""
def patch_api_clients():
"""Install monkey patches for inference clients and tool runtime methods."""
global _original_methods
import litellm
from ollama import AsyncClient as OllamaAsyncClient
from openai.resources.chat.completions import AsyncCompletions as AsyncChatCompletions
from openai.resources.completions import AsyncCompletions
from openai.resources.embeddings import AsyncEmbeddings
from openai.resources.models import AsyncModels
# Store original methods for both OpenAI and Ollama clients
from llama_stack.providers.remote.tool_runtime.tavily_search.tavily_search import TavilySearchToolRuntimeImpl
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
# Store original methods for OpenAI, Ollama, LiteLLM clients, and tool runtimes
_original_methods = {
"chat_completions_create": AsyncChatCompletions.create,
"completions_create": AsyncCompletions.create,
@ -358,6 +422,10 @@ def patch_inference_clients():
"ollama_ps": OllamaAsyncClient.ps,
"ollama_pull": OllamaAsyncClient.pull,
"ollama_list": OllamaAsyncClient.list,
"litellm_acompletion": litellm.acompletion,
"litellm_atext_completion": litellm.atext_completion,
"litellm_openai_mixin_get_api_key": LiteLLMOpenAIMixin.get_api_key,
"tavily_invoke_tool": TavilySearchToolRuntimeImpl.invoke_tool,
}
# Create patched methods for OpenAI client
@ -426,21 +494,61 @@ def patch_inference_clients():
OllamaAsyncClient.pull = patched_ollama_pull
OllamaAsyncClient.list = patched_ollama_list
# Create patched methods for LiteLLM
async def patched_litellm_acompletion(*args, **kwargs):
return await _patched_inference_method(
_original_methods["litellm_acompletion"], None, "litellm", "/chat/completions", *args, **kwargs
)
def unpatch_inference_clients():
"""Remove monkey patches and restore original OpenAI and Ollama client methods."""
async def patched_litellm_atext_completion(*args, **kwargs):
return await _patched_inference_method(
_original_methods["litellm_atext_completion"], None, "litellm", "/completions", *args, **kwargs
)
# Apply LiteLLM patches
litellm.acompletion = patched_litellm_acompletion
litellm.atext_completion = patched_litellm_atext_completion
# Create patched method for LiteLLMOpenAIMixin.get_api_key
def patched_litellm_get_api_key(self):
global _current_mode
if _current_mode != APIRecordingMode.REPLAY:
return _original_methods["litellm_openai_mixin_get_api_key"](self)
else:
# For record/replay modes, return a fake API key to avoid exposing real credentials
return "fake-api-key-for-testing"
# Apply LiteLLMOpenAIMixin patch
LiteLLMOpenAIMixin.get_api_key = patched_litellm_get_api_key
# Create patched methods for tool runtimes
async def patched_tavily_invoke_tool(self, tool_name: str, kwargs: dict[str, Any]):
return await _patched_tool_invoke_method(
_original_methods["tavily_invoke_tool"], "tavily", self, tool_name, kwargs
)
# Apply tool runtime patches
TavilySearchToolRuntimeImpl.invoke_tool = patched_tavily_invoke_tool
def unpatch_api_clients():
"""Remove monkey patches and restore original client methods and tool runtimes."""
global _original_methods
if not _original_methods:
return
# Import here to avoid circular imports
import litellm
from ollama import AsyncClient as OllamaAsyncClient
from openai.resources.chat.completions import AsyncCompletions as AsyncChatCompletions
from openai.resources.completions import AsyncCompletions
from openai.resources.embeddings import AsyncEmbeddings
from openai.resources.models import AsyncModels
from llama_stack.providers.remote.tool_runtime.tavily_search.tavily_search import TavilySearchToolRuntimeImpl
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
# Restore OpenAI client methods
AsyncChatCompletions.create = _original_methods["chat_completions_create"]
AsyncCompletions.create = _original_methods["completions_create"]
@ -455,12 +563,20 @@ def unpatch_inference_clients():
OllamaAsyncClient.pull = _original_methods["ollama_pull"]
OllamaAsyncClient.list = _original_methods["ollama_list"]
# Restore LiteLLM methods
litellm.acompletion = _original_methods["litellm_acompletion"]
litellm.atext_completion = _original_methods["litellm_atext_completion"]
LiteLLMOpenAIMixin.get_api_key = _original_methods["litellm_openai_mixin_get_api_key"]
# Restore tool runtime methods
TavilySearchToolRuntimeImpl.invoke_tool = _original_methods["tavily_invoke_tool"]
_original_methods.clear()
@contextmanager
def inference_recording(mode: str, storage_dir: str | Path | None = None) -> Generator[None, None, None]:
"""Context manager for inference recording/replaying."""
def api_recording(mode: str, storage_dir: str | Path | None = None) -> Generator[None, None, None]:
"""Context manager for API recording/replaying (inference and tools)."""
global _current_mode, _current_storage
# Store previous state
@ -474,14 +590,14 @@ def inference_recording(mode: str, storage_dir: str | Path | None = None) -> Gen
if storage_dir is None:
raise ValueError("storage_dir is required for record and replay modes")
_current_storage = ResponseStorage(Path(storage_dir))
patch_inference_clients()
patch_api_clients()
yield
finally:
# Restore previous state
if mode in ["record", "replay"]:
unpatch_inference_clients()
unpatch_api_clients()
_current_mode = prev_mode
_current_storage = prev_storage