diff --git a/llama_stack/distribution/stack.py b/llama_stack/distribution/stack.py index 63345e04a..1dbcbb7fa 100644 --- a/llama_stack/distribution/stack.py +++ b/llama_stack/distribution/stack.py @@ -308,14 +308,14 @@ def add_internal_implementations(impls: dict[Api, Any], run_config: StackRunConf async def construct_stack( run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None ) -> dict[Api, Any]: - inference_mode = os.environ.get("LLAMA_STACK_TEST_INFERENCE_MODE", "live").lower() - if inference_mode in ["record", "replay"]: - global TEST_RECORDING_CONTEXT + if "LLAMA_STACK_TEST_INFERENCE_MODE" in os.environ: from llama_stack.testing.inference_recorder import setup_inference_recording + global TEST_RECORDING_CONTEXT TEST_RECORDING_CONTEXT = setup_inference_recording() - TEST_RECORDING_CONTEXT.__enter__() - logger.info(f"Inference recording enabled: mode={inference_mode}") + if TEST_RECORDING_CONTEXT: + TEST_RECORDING_CONTEXT.__enter__() + logger.info(f"Inference recording enabled: mode={os.environ.get('LLAMA_STACK_TEST_INFERENCE_MODE')}") dist_registry, _ = await create_dist_registry(run_config.metadata_store, run_config.image_name) policy = run_config.server.auth.access_policy if run_config.server.auth else [] diff --git a/llama_stack/testing/inference_recorder.py b/llama_stack/testing/inference_recorder.py index f3e79605d..abfefa0ce 100644 --- a/llama_stack/testing/inference_recorder.py +++ b/llama_stack/testing/inference_recorder.py @@ -12,6 +12,7 @@ import os import sqlite3 from collections.abc import Generator from contextlib import contextmanager +from enum import StrEnum from pathlib import Path from typing import Any, Literal, cast @@ -31,6 +32,12 @@ CompletionChoice.model_fields["finish_reason"].annotation = Literal["stop", "len CompletionChoice.model_rebuild() +class InferenceMode(StrEnum): + LIVE = "live" + RECORD = "record" + REPLAY = "replay" + + def normalize_request(method: str, url: str, headers: dict[str, Any], body: dict[str, Any]) -> str: """Create a normalized hash of the request for consistent matching.""" # Extract just the endpoint path @@ -44,23 +51,33 @@ def normalize_request(method: str, url: str, headers: dict[str, Any], body: dict return hashlib.sha256(normalized_json.encode()).hexdigest() -def get_inference_mode() -> str: - return os.environ.get("LLAMA_STACK_TEST_INFERENCE_MODE", "live").lower() +def get_inference_mode() -> InferenceMode: + return InferenceMode(os.environ.get("LLAMA_STACK_TEST_INFERENCE_MODE", "live").lower()) def setup_inference_recording(): + """ + Returns a context manager that can be used to record or replay inference requests. This is to be used in tests + to increase their reliability and reduce reliance on expensive, external services. + + Currently, this is only supported for OpenAI and Ollama clients. These should cover the vast majority of use cases. + Calls to the /models endpoint are not currently trapped. We probably need to add support for this. + + Two environment variables are required: + - LLAMA_STACK_TEST_INFERENCE_MODE: The mode to run in. Must be 'live', 'record', or 'replay'. + - LLAMA_STACK_TEST_RECORDING_DIR: The directory to store the recordings in. + + The recordings are stored in a SQLite database and a JSON file for each request. The SQLite database is used to + quickly find the correct recording for a given request. The JSON files are used to store the request and response + bodies. + """ mode = get_inference_mode() - if mode not in ["live", "record", "replay"]: + if mode not in InferenceMode: raise ValueError(f"Invalid LLAMA_STACK_TEST_INFERENCE_MODE: {mode}. Must be 'live', 'record', or 'replay'") - if mode == "live": - # Return a no-op context manager for live mode - @contextmanager - def live_mode(): - yield - - return live_mode() + if mode == InferenceMode.LIVE: + return None if "LLAMA_STACK_TEST_RECORDING_DIR" not in os.environ: raise ValueError("LLAMA_STACK_TEST_RECORDING_DIR must be set for recording or replaying") @@ -203,7 +220,7 @@ class ResponseStorage: async def _patched_inference_method(original_method, self, client_type, method_name=None, *args, **kwargs): global _current_mode, _current_storage - if _current_mode == "live" or _current_storage is None: + if _current_mode == InferenceMode.LIVE or _current_storage is None: # Normal operation return await original_method(self, *args, **kwargs) @@ -261,7 +278,7 @@ async def _patched_inference_method(original_method, self, client_type, method_n request_hash = normalize_request(method, url, headers, body) - if _current_mode == "replay": + if _current_mode == InferenceMode.REPLAY: recording = _current_storage.find_recording(request_hash) if recording: response_body = recording["response"]["body"] @@ -283,7 +300,7 @@ async def _patched_inference_method(original_method, self, client_type, method_n f"To record this response, run with LLAMA_STACK_INFERENCE_MODE=record" ) - elif _current_mode == "record": + elif _current_mode == InferenceMode.RECORD: response = await original_method(self, *args, **kwargs) request_data = { diff --git a/llama_stack/ui/package.json b/llama_stack/ui/package.json index 4ca94a64e..742c6f7c7 100644 --- a/llama_stack/ui/package.json +++ b/llama_stack/ui/package.json @@ -20,7 +20,7 @@ "@radix-ui/react-tooltip": "^1.2.6", "class-variance-authority": "^0.7.1", "clsx": "^2.1.1", - "llama-stack-client": ""0.2.16", + "llama-stack-client": "0.2.16", "lucide-react": "^0.510.0", "next": "15.3.3", "next-auth": "^4.24.11", diff --git a/tests/integration/test_inference_recordings.py b/tests/unit/distribution/test_inference_recordings.py similarity index 94% rename from tests/integration/test_inference_recordings.py rename to tests/unit/distribution/test_inference_recordings.py index fe3e61858..1dbd14540 100644 --- a/tests/integration/test_inference_recordings.py +++ b/tests/unit/distribution/test_inference_recordings.py @@ -22,6 +22,7 @@ from llama_stack.apis.inference import ( OpenAIEmbeddingUsage, ) from llama_stack.testing.inference_recorder import ( + InferenceMode, ResponseStorage, inference_recording, normalize_request, @@ -169,7 +170,7 @@ class TestInferenceRecording: temp_storage_dir = temp_storage_dir / "test_recording_mode" with patch("openai.resources.chat.completions.AsyncCompletions.create", side_effect=mock_create): - with inference_recording(mode="record", storage_dir=str(temp_storage_dir)): + with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)): client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test") response = await client.chat.completions.create( @@ -198,7 +199,7 @@ class TestInferenceRecording: temp_storage_dir = temp_storage_dir / "test_replay_mode" # First, record a response with patch("openai.resources.chat.completions.AsyncCompletions.create", side_effect=mock_create): - with inference_recording(mode="record", storage_dir=str(temp_storage_dir)): + with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)): client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test") response = await client.chat.completions.create( @@ -210,7 +211,7 @@ class TestInferenceRecording: # Now test replay mode - should not call the original method with patch("openai.resources.chat.completions.AsyncCompletions.create") as mock_create_patch: - with inference_recording(mode="replay", storage_dir=str(temp_storage_dir)): + with inference_recording(mode=InferenceMode.REPLAY, storage_dir=str(temp_storage_dir)): client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test") response = await client.chat.completions.create( @@ -230,7 +231,7 @@ class TestInferenceRecording: """Test that replay mode fails when no recording is found.""" temp_storage_dir = temp_storage_dir / "test_replay_missing_recording" with patch("openai.resources.chat.completions.AsyncCompletions.create"): - with inference_recording(mode="replay", storage_dir=str(temp_storage_dir)): + with inference_recording(mode=InferenceMode.REPLAY, storage_dir=str(temp_storage_dir)): client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test") with pytest.raises(RuntimeError, match="No recorded response found"): @@ -247,7 +248,7 @@ class TestInferenceRecording: temp_storage_dir = temp_storage_dir / "test_embeddings_recording" # Record with patch("openai.resources.embeddings.AsyncEmbeddings.create", side_effect=mock_create): - with inference_recording(mode="record", storage_dir=str(temp_storage_dir)): + with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)): client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test") response = await client.embeddings.create( @@ -258,7 +259,7 @@ class TestInferenceRecording: # Replay with patch("openai.resources.embeddings.AsyncEmbeddings.create") as mock_create_patch: - with inference_recording(mode="replay", storage_dir=str(temp_storage_dir)): + with inference_recording(mode=InferenceMode.REPLAY, storage_dir=str(temp_storage_dir)): client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test") response = await client.embeddings.create( @@ -279,7 +280,7 @@ class TestInferenceRecording: return real_openai_chat_response with patch("openai.resources.chat.completions.AsyncCompletions.create", side_effect=mock_create): - with inference_recording(mode="live"): + with inference_recording(mode=InferenceMode.LIVE): client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test") response = await client.chat.completions.create(