mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-23 04:12:25 +00:00
fake mode
# What does this PR do? ## Test Plan
This commit is contained in:
parent
05cfa213b6
commit
12e46b7a4a
3 changed files with 484 additions and 8 deletions
|
|
@ -17,6 +17,7 @@ from pathlib import Path
|
|||
from typing import Any, Literal, cast
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.testing.fake_responses import generate_fake_response, generate_fake_stream, parse_fake_config
|
||||
|
||||
logger = get_logger(__name__, category="testing")
|
||||
|
||||
|
|
@ -25,6 +26,7 @@ _current_mode: str | None = None
|
|||
_current_storage: ResponseStorage | None = None
|
||||
_original_methods: dict[str, Any] = {}
|
||||
|
||||
|
||||
from openai.types.completion_choice import CompletionChoice
|
||||
|
||||
# update the "finish_reason" field, since its type definition is wrong (no None is accepted)
|
||||
|
|
@ -36,6 +38,7 @@ class InferenceMode(StrEnum):
|
|||
LIVE = "live"
|
||||
RECORD = "record"
|
||||
REPLAY = "replay"
|
||||
FAKE = "fake"
|
||||
|
||||
|
||||
def normalize_request(method: str, url: str, headers: dict[str, Any], body: dict[str, Any]) -> str:
|
||||
|
|
@ -52,20 +55,31 @@ def normalize_request(method: str, url: str, headers: dict[str, Any], body: dict
|
|||
|
||||
|
||||
def get_inference_mode() -> InferenceMode:
|
||||
return InferenceMode(os.environ.get("LLAMA_STACK_TEST_INFERENCE_MODE", "live").lower())
|
||||
mode_str = os.environ.get("LLAMA_STACK_TEST_INFERENCE_MODE", "live").lower()
|
||||
# Parse mode and config (e.g., "fake:response_length=50:latency_ms=100")
|
||||
if ":" in mode_str:
|
||||
return InferenceMode(mode_str.split(":")[0])
|
||||
return InferenceMode(mode_str)
|
||||
|
||||
|
||||
def setup_inference_recording():
|
||||
"""
|
||||
Returns a context manager that can be used to record or replay inference requests. This is to be used in tests
|
||||
Returns a context manager that can be used to record, replay, or fake inference requests. This is to be used in tests
|
||||
to increase their reliability and reduce reliance on expensive, external services.
|
||||
|
||||
Currently, this is only supported for OpenAI and Ollama clients. These should cover the vast majority of use cases.
|
||||
Calls to the /models endpoint are not currently trapped. We probably need to add support for this.
|
||||
|
||||
Two environment variables are required:
|
||||
- LLAMA_STACK_TEST_INFERENCE_MODE: The mode to run in. Must be 'live', 'record', or 'replay'.
|
||||
- LLAMA_STACK_TEST_RECORDING_DIR: The directory to store the recordings in.
|
||||
Environment variables:
|
||||
- LLAMA_STACK_TEST_INFERENCE_MODE: The mode to run in. Must be 'live', 'record', 'replay', or 'fake[:config]'.
|
||||
For fake mode, configuration can be specified as 'fake:param1=value1:param2=value2'
|
||||
- LLAMA_STACK_TEST_RECORDING_DIR: The directory to store the recordings in (required for 'record' and 'replay' modes).
|
||||
|
||||
Configuration for 'fake' mode:
|
||||
- response_length: Number of words in fake responses (default: 100)
|
||||
- latency_ms: Simulated latency in milliseconds (default: 50)
|
||||
|
||||
Example: LLAMA_STACK_TEST_INFERENCE_MODE=fake:response_length=50:latency_ms=100
|
||||
|
||||
The recordings are stored in a SQLite database and a JSON file for each request. The SQLite database is used to
|
||||
quickly find the correct recording for a given request. The JSON files are used to store the request and response
|
||||
|
|
@ -74,11 +88,16 @@ def setup_inference_recording():
|
|||
mode = get_inference_mode()
|
||||
|
||||
if mode not in InferenceMode:
|
||||
raise ValueError(f"Invalid LLAMA_STACK_TEST_INFERENCE_MODE: {mode}. Must be 'live', 'record', or 'replay'")
|
||||
raise ValueError(
|
||||
f"Invalid LLAMA_STACK_TEST_INFERENCE_MODE: {mode}. Must be 'live', 'record', 'replay', or 'fake'"
|
||||
)
|
||||
|
||||
if mode == InferenceMode.LIVE:
|
||||
return None
|
||||
|
||||
if mode == InferenceMode.FAKE:
|
||||
return inference_recording(mode=mode, storage_dir=None)
|
||||
|
||||
if "LLAMA_STACK_TEST_RECORDING_DIR" not in os.environ:
|
||||
raise ValueError("LLAMA_STACK_TEST_RECORDING_DIR must be set for recording or replaying")
|
||||
storage_dir = os.environ["LLAMA_STACK_TEST_RECORDING_DIR"]
|
||||
|
|
@ -220,7 +239,7 @@ class ResponseStorage:
|
|||
async def _patched_inference_method(original_method, self, client_type, endpoint, *args, **kwargs):
|
||||
global _current_mode, _current_storage
|
||||
|
||||
if _current_mode == InferenceMode.LIVE or _current_storage is None:
|
||||
if _current_mode == InferenceMode.LIVE or (_current_mode != InferenceMode.FAKE and _current_storage is None):
|
||||
# Normal operation
|
||||
return await original_method(self, *args, **kwargs)
|
||||
|
||||
|
|
@ -266,6 +285,15 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
|
|||
f"To record this response, run with LLAMA_STACK_INFERENCE_MODE=record"
|
||||
)
|
||||
|
||||
elif _current_mode == InferenceMode.FAKE:
|
||||
fake_config = parse_fake_config()
|
||||
fake_response = generate_fake_response(endpoint, body, fake_config)
|
||||
|
||||
if body.get("stream", False):
|
||||
return generate_fake_stream(fake_response, endpoint, fake_config)
|
||||
else:
|
||||
return fake_response
|
||||
|
||||
elif _current_mode == InferenceMode.RECORD:
|
||||
response = await original_method(self, *args, **kwargs)
|
||||
|
||||
|
|
@ -440,12 +468,15 @@ def inference_recording(mode: str = "live", storage_dir: str | Path | None = Non
|
|||
if mode in ["record", "replay"]:
|
||||
_current_storage = ResponseStorage(storage_dir_path)
|
||||
patch_inference_clients()
|
||||
elif mode == "fake":
|
||||
_current_storage = None
|
||||
patch_inference_clients()
|
||||
|
||||
yield
|
||||
|
||||
finally:
|
||||
# Restore previous state
|
||||
if mode in ["record", "replay"]:
|
||||
if mode in ["record", "replay", "fake"]:
|
||||
unpatch_inference_clients()
|
||||
|
||||
_current_mode = prev_mode
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue