fake mode

# What does this PR do?


## Test Plan
This commit is contained in:
Eric Huang 2025-08-04 11:46:24 -07:00
parent 05cfa213b6
commit 12e46b7a4a
3 changed files with 484 additions and 8 deletions

View file

@ -17,6 +17,7 @@ from pathlib import Path
from typing import Any, Literal, cast
from llama_stack.log import get_logger
from llama_stack.testing.fake_responses import generate_fake_response, generate_fake_stream, parse_fake_config
logger = get_logger(__name__, category="testing")
@ -25,6 +26,7 @@ _current_mode: str | None = None
_current_storage: ResponseStorage | None = None
_original_methods: dict[str, Any] = {}
from openai.types.completion_choice import CompletionChoice
# update the "finish_reason" field, since its type definition is wrong (no None is accepted)
@ -36,6 +38,7 @@ class InferenceMode(StrEnum):
LIVE = "live"
RECORD = "record"
REPLAY = "replay"
FAKE = "fake"
def normalize_request(method: str, url: str, headers: dict[str, Any], body: dict[str, Any]) -> str:
@ -52,20 +55,31 @@ def normalize_request(method: str, url: str, headers: dict[str, Any], body: dict
def get_inference_mode() -> InferenceMode:
return InferenceMode(os.environ.get("LLAMA_STACK_TEST_INFERENCE_MODE", "live").lower())
mode_str = os.environ.get("LLAMA_STACK_TEST_INFERENCE_MODE", "live").lower()
# Parse mode and config (e.g., "fake:response_length=50:latency_ms=100")
if ":" in mode_str:
return InferenceMode(mode_str.split(":")[0])
return InferenceMode(mode_str)
def setup_inference_recording():
"""
Returns a context manager that can be used to record or replay inference requests. This is to be used in tests
Returns a context manager that can be used to record, replay, or fake inference requests. This is to be used in tests
to increase their reliability and reduce reliance on expensive, external services.
Currently, this is only supported for OpenAI and Ollama clients. These should cover the vast majority of use cases.
Calls to the /models endpoint are not currently trapped. We probably need to add support for this.
Two environment variables are required:
- LLAMA_STACK_TEST_INFERENCE_MODE: The mode to run in. Must be 'live', 'record', or 'replay'.
- LLAMA_STACK_TEST_RECORDING_DIR: The directory to store the recordings in.
Environment variables:
- LLAMA_STACK_TEST_INFERENCE_MODE: The mode to run in. Must be 'live', 'record', 'replay', or 'fake[:config]'.
For fake mode, configuration can be specified as 'fake:param1=value1:param2=value2'
- LLAMA_STACK_TEST_RECORDING_DIR: The directory to store the recordings in (required for 'record' and 'replay' modes).
Configuration for 'fake' mode:
- response_length: Number of words in fake responses (default: 100)
- latency_ms: Simulated latency in milliseconds (default: 50)
Example: LLAMA_STACK_TEST_INFERENCE_MODE=fake:response_length=50:latency_ms=100
The recordings are stored in a SQLite database and a JSON file for each request. The SQLite database is used to
quickly find the correct recording for a given request. The JSON files are used to store the request and response
@ -74,11 +88,16 @@ def setup_inference_recording():
mode = get_inference_mode()
if mode not in InferenceMode:
raise ValueError(f"Invalid LLAMA_STACK_TEST_INFERENCE_MODE: {mode}. Must be 'live', 'record', or 'replay'")
raise ValueError(
f"Invalid LLAMA_STACK_TEST_INFERENCE_MODE: {mode}. Must be 'live', 'record', 'replay', or 'fake'"
)
if mode == InferenceMode.LIVE:
return None
if mode == InferenceMode.FAKE:
return inference_recording(mode=mode, storage_dir=None)
if "LLAMA_STACK_TEST_RECORDING_DIR" not in os.environ:
raise ValueError("LLAMA_STACK_TEST_RECORDING_DIR must be set for recording or replaying")
storage_dir = os.environ["LLAMA_STACK_TEST_RECORDING_DIR"]
@ -220,7 +239,7 @@ class ResponseStorage:
async def _patched_inference_method(original_method, self, client_type, endpoint, *args, **kwargs):
global _current_mode, _current_storage
if _current_mode == InferenceMode.LIVE or _current_storage is None:
if _current_mode == InferenceMode.LIVE or (_current_mode != InferenceMode.FAKE and _current_storage is None):
# Normal operation
return await original_method(self, *args, **kwargs)
@ -266,6 +285,15 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
f"To record this response, run with LLAMA_STACK_INFERENCE_MODE=record"
)
elif _current_mode == InferenceMode.FAKE:
fake_config = parse_fake_config()
fake_response = generate_fake_response(endpoint, body, fake_config)
if body.get("stream", False):
return generate_fake_stream(fake_response, endpoint, fake_config)
else:
return fake_response
elif _current_mode == InferenceMode.RECORD:
response = await original_method(self, *args, **kwargs)
@ -440,12 +468,15 @@ def inference_recording(mode: str = "live", storage_dir: str | Path | None = Non
if mode in ["record", "replay"]:
_current_storage = ResponseStorage(storage_dir_path)
patch_inference_clients()
elif mode == "fake":
_current_storage = None
patch_inference_clients()
yield
finally:
# Restore previous state
if mode in ["record", "replay"]:
if mode in ["record", "replay", "fake"]:
unpatch_inference_clients()
_current_mode = prev_mode