fake mode

# What does this PR do?


## Test Plan
This commit is contained in:
Eric Huang 2025-08-04 11:46:24 -07:00
parent 05cfa213b6
commit 12e46b7a4a
3 changed files with 484 additions and 8 deletions

View file

@ -0,0 +1,256 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
"""
Fake response generation for testing inference providers without making real API calls.
"""
import asyncio
import os
import time
from collections.abc import AsyncGenerator
from typing import Any
from openai.types.chat import ChatCompletion
from pydantic import BaseModel
class FakeConfig(BaseModel):
response_length: int = 100
latency_ms: int = 50
def parse_fake_config() -> FakeConfig:
"""Parse fake mode configuration from environment variable."""
mode_str = os.environ.get("LLAMA_STACK_TEST_INFERENCE_MODE", "live").lower()
config = {}
if ":" in mode_str:
parts = mode_str.split(":")
for part in parts[1:]:
if "=" in part:
key, value = part.split("=", 1)
config[key] = int(value)
return FakeConfig(**config)
def generate_fake_content(word_count: int) -> str:
"""Generate fake response content with specified word count."""
words = [
"This",
"is",
"a",
"synthetic",
"response",
"generated",
"for",
"testing",
"purposes",
"only",
"The",
"content",
"simulates",
"realistic",
"language",
"model",
"output",
"patterns",
"and",
"structures",
"It",
"includes",
"various",
"sentence",
"types",
"and",
"maintains",
"coherent",
"flow",
"throughout",
"These",
"responses",
"help",
"test",
"system",
"performance",
"without",
"requiring",
"real",
"model",
"calls",
]
return " ".join(words[i % len(words)] for i in range(word_count)) + "."
def generate_fake_chat_completion(body: dict[str, Any], config: FakeConfig) -> Any:
"""Generate fake OpenAI chat completion response."""
model = body.get("model", "gpt-3.5-turbo")
messages = body.get("messages", [])
# Calculate fake token counts based on input
prompt_tokens = 0
for msg in messages:
content = msg.get("content", "")
if isinstance(content, str):
prompt_tokens += len(content.split())
elif isinstance(content, list):
# Handle content arrays (images, etc.)
for item in content:
if isinstance(item, dict) and item.get("type") == "text":
prompt_tokens += len(item.get("text", "").split())
response_length = config.response_length
fake_content = generate_fake_content(response_length)
completion_tokens = len(fake_content.split())
response_data = {
"id": f"chatcmpl-fake-{int(time.time() * 1000)}",
"object": "chat.completion",
"created": int(time.time()),
"model": model,
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": fake_content,
"function_call": None,
"tool_calls": None,
},
"finish_reason": "stop",
"logprobs": None,
}
],
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
},
"system_fingerprint": None,
}
time.sleep(config.latency_ms / 1000.0)
return ChatCompletion.model_validate(response_data)
def generate_fake_completion(body: dict[str, Any], config: FakeConfig) -> dict[str, Any]:
"""Generate fake OpenAI completion response."""
raise NotImplementedError("Fake completions not implemented yet")
def generate_fake_embeddings(body: dict[str, Any], config: FakeConfig) -> dict[str, Any]:
"""Generate fake OpenAI embeddings response."""
raise NotImplementedError("Fake embeddings not implemented yet")
def generate_fake_models_list(config: FakeConfig) -> dict[str, Any]:
"""Generate fake OpenAI models list response."""
raise NotImplementedError("Fake models list not implemented yet")
async def generate_fake_stream(
response_data: Any, endpoint: str, config: FakeConfig
) -> AsyncGenerator[dict[str, Any], None]:
"""Convert fake response to streaming chunks."""
latency_seconds = config.latency_ms / 1000.0
if endpoint == "/v1/chat/completions":
if hasattr(response_data, "choices"):
content = response_data.choices[0].message.content
chunk_id = response_data.id
model = response_data.model
else:
content = response_data["choices"][0]["message"]["content"]
chunk_id = response_data["id"]
model = response_data["model"]
words = content.split()
yield {
"id": chunk_id,
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": model,
"choices": [
{
"index": 0,
"delta": {
"role": "assistant",
"content": "",
"function_call": None,
"tool_calls": None,
},
"finish_reason": None,
"logprobs": None,
}
],
"system_fingerprint": None,
}
await asyncio.sleep(latency_seconds)
for i, word in enumerate(words):
chunk_content = word + (" " if i < len(words) - 1 else "")
yield {
"id": chunk_id,
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": model,
"choices": [
{
"index": 0,
"delta": {
"content": chunk_content,
"function_call": None,
"tool_calls": None,
},
"finish_reason": None,
"logprobs": None,
}
],
"system_fingerprint": None,
}
await asyncio.sleep(latency_seconds)
yield {
"id": chunk_id,
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": model,
"choices": [
{
"index": 0,
"delta": {
"content": None,
"function_call": None,
"tool_calls": None,
},
"finish_reason": "stop",
"logprobs": None,
}
],
"system_fingerprint": None,
}
elif endpoint == "/v1/completions":
raise NotImplementedError("Fake streaming completions not implemented yet")
def generate_fake_response(endpoint: str, body: dict[str, Any], config: FakeConfig) -> Any:
"""Generate fake responses based on endpoint and request."""
if endpoint == "/v1/chat/completions":
return generate_fake_chat_completion(body, config)
elif endpoint == "/v1/completions":
return generate_fake_completion(body, config)
elif endpoint == "/v1/embeddings":
return generate_fake_embeddings(body, config)
elif endpoint == "/v1/models":
return generate_fake_models_list(config)
else:
raise ValueError(f"Unsupported endpoint for fake mode: {endpoint}")

View file

@ -17,6 +17,7 @@ from pathlib import Path
from typing import Any, Literal, cast
from llama_stack.log import get_logger
from llama_stack.testing.fake_responses import generate_fake_response, generate_fake_stream, parse_fake_config
logger = get_logger(__name__, category="testing")
@ -25,6 +26,7 @@ _current_mode: str | None = None
_current_storage: ResponseStorage | None = None
_original_methods: dict[str, Any] = {}
from openai.types.completion_choice import CompletionChoice
# update the "finish_reason" field, since its type definition is wrong (no None is accepted)
@ -36,6 +38,7 @@ class InferenceMode(StrEnum):
LIVE = "live"
RECORD = "record"
REPLAY = "replay"
FAKE = "fake"
def normalize_request(method: str, url: str, headers: dict[str, Any], body: dict[str, Any]) -> str:
@ -52,20 +55,31 @@ def normalize_request(method: str, url: str, headers: dict[str, Any], body: dict
def get_inference_mode() -> InferenceMode:
return InferenceMode(os.environ.get("LLAMA_STACK_TEST_INFERENCE_MODE", "live").lower())
mode_str = os.environ.get("LLAMA_STACK_TEST_INFERENCE_MODE", "live").lower()
# Parse mode and config (e.g., "fake:response_length=50:latency_ms=100")
if ":" in mode_str:
return InferenceMode(mode_str.split(":")[0])
return InferenceMode(mode_str)
def setup_inference_recording():
"""
Returns a context manager that can be used to record or replay inference requests. This is to be used in tests
Returns a context manager that can be used to record, replay, or fake inference requests. This is to be used in tests
to increase their reliability and reduce reliance on expensive, external services.
Currently, this is only supported for OpenAI and Ollama clients. These should cover the vast majority of use cases.
Calls to the /models endpoint are not currently trapped. We probably need to add support for this.
Two environment variables are required:
- LLAMA_STACK_TEST_INFERENCE_MODE: The mode to run in. Must be 'live', 'record', or 'replay'.
- LLAMA_STACK_TEST_RECORDING_DIR: The directory to store the recordings in.
Environment variables:
- LLAMA_STACK_TEST_INFERENCE_MODE: The mode to run in. Must be 'live', 'record', 'replay', or 'fake[:config]'.
For fake mode, configuration can be specified as 'fake:param1=value1:param2=value2'
- LLAMA_STACK_TEST_RECORDING_DIR: The directory to store the recordings in (required for 'record' and 'replay' modes).
Configuration for 'fake' mode:
- response_length: Number of words in fake responses (default: 100)
- latency_ms: Simulated latency in milliseconds (default: 50)
Example: LLAMA_STACK_TEST_INFERENCE_MODE=fake:response_length=50:latency_ms=100
The recordings are stored in a SQLite database and a JSON file for each request. The SQLite database is used to
quickly find the correct recording for a given request. The JSON files are used to store the request and response
@ -74,11 +88,16 @@ def setup_inference_recording():
mode = get_inference_mode()
if mode not in InferenceMode:
raise ValueError(f"Invalid LLAMA_STACK_TEST_INFERENCE_MODE: {mode}. Must be 'live', 'record', or 'replay'")
raise ValueError(
f"Invalid LLAMA_STACK_TEST_INFERENCE_MODE: {mode}. Must be 'live', 'record', 'replay', or 'fake'"
)
if mode == InferenceMode.LIVE:
return None
if mode == InferenceMode.FAKE:
return inference_recording(mode=mode, storage_dir=None)
if "LLAMA_STACK_TEST_RECORDING_DIR" not in os.environ:
raise ValueError("LLAMA_STACK_TEST_RECORDING_DIR must be set for recording or replaying")
storage_dir = os.environ["LLAMA_STACK_TEST_RECORDING_DIR"]
@ -220,7 +239,7 @@ class ResponseStorage:
async def _patched_inference_method(original_method, self, client_type, endpoint, *args, **kwargs):
global _current_mode, _current_storage
if _current_mode == InferenceMode.LIVE or _current_storage is None:
if _current_mode == InferenceMode.LIVE or (_current_mode != InferenceMode.FAKE and _current_storage is None):
# Normal operation
return await original_method(self, *args, **kwargs)
@ -266,6 +285,15 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
f"To record this response, run with LLAMA_STACK_INFERENCE_MODE=record"
)
elif _current_mode == InferenceMode.FAKE:
fake_config = parse_fake_config()
fake_response = generate_fake_response(endpoint, body, fake_config)
if body.get("stream", False):
return generate_fake_stream(fake_response, endpoint, fake_config)
else:
return fake_response
elif _current_mode == InferenceMode.RECORD:
response = await original_method(self, *args, **kwargs)
@ -440,12 +468,15 @@ def inference_recording(mode: str = "live", storage_dir: str | Path | None = Non
if mode in ["record", "replay"]:
_current_storage = ResponseStorage(storage_dir_path)
patch_inference_clients()
elif mode == "fake":
_current_storage = None
patch_inference_clients()
yield
finally:
# Restore previous state
if mode in ["record", "replay"]:
if mode in ["record", "replay", "fake"]:
unpatch_inference_clients()
_current_mode = prev_mode