fake mode

# What does this PR do?


## Test Plan
This commit is contained in:
Eric Huang 2025-08-04 11:46:24 -07:00
parent 05cfa213b6
commit 12e46b7a4a
3 changed files with 484 additions and 8 deletions

View file

@ -0,0 +1,256 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
"""
Fake response generation for testing inference providers without making real API calls.
"""
import asyncio
import os
import time
from collections.abc import AsyncGenerator
from typing import Any
from openai.types.chat import ChatCompletion
from pydantic import BaseModel
class FakeConfig(BaseModel):
response_length: int = 100
latency_ms: int = 50
def parse_fake_config() -> FakeConfig:
"""Parse fake mode configuration from environment variable."""
mode_str = os.environ.get("LLAMA_STACK_TEST_INFERENCE_MODE", "live").lower()
config = {}
if ":" in mode_str:
parts = mode_str.split(":")
for part in parts[1:]:
if "=" in part:
key, value = part.split("=", 1)
config[key] = int(value)
return FakeConfig(**config)
def generate_fake_content(word_count: int) -> str:
"""Generate fake response content with specified word count."""
words = [
"This",
"is",
"a",
"synthetic",
"response",
"generated",
"for",
"testing",
"purposes",
"only",
"The",
"content",
"simulates",
"realistic",
"language",
"model",
"output",
"patterns",
"and",
"structures",
"It",
"includes",
"various",
"sentence",
"types",
"and",
"maintains",
"coherent",
"flow",
"throughout",
"These",
"responses",
"help",
"test",
"system",
"performance",
"without",
"requiring",
"real",
"model",
"calls",
]
return " ".join(words[i % len(words)] for i in range(word_count)) + "."
def generate_fake_chat_completion(body: dict[str, Any], config: FakeConfig) -> Any:
"""Generate fake OpenAI chat completion response."""
model = body.get("model", "gpt-3.5-turbo")
messages = body.get("messages", [])
# Calculate fake token counts based on input
prompt_tokens = 0
for msg in messages:
content = msg.get("content", "")
if isinstance(content, str):
prompt_tokens += len(content.split())
elif isinstance(content, list):
# Handle content arrays (images, etc.)
for item in content:
if isinstance(item, dict) and item.get("type") == "text":
prompt_tokens += len(item.get("text", "").split())
response_length = config.response_length
fake_content = generate_fake_content(response_length)
completion_tokens = len(fake_content.split())
response_data = {
"id": f"chatcmpl-fake-{int(time.time() * 1000)}",
"object": "chat.completion",
"created": int(time.time()),
"model": model,
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": fake_content,
"function_call": None,
"tool_calls": None,
},
"finish_reason": "stop",
"logprobs": None,
}
],
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
},
"system_fingerprint": None,
}
time.sleep(config.latency_ms / 1000.0)
return ChatCompletion.model_validate(response_data)
def generate_fake_completion(body: dict[str, Any], config: FakeConfig) -> dict[str, Any]:
"""Generate fake OpenAI completion response."""
raise NotImplementedError("Fake completions not implemented yet")
def generate_fake_embeddings(body: dict[str, Any], config: FakeConfig) -> dict[str, Any]:
"""Generate fake OpenAI embeddings response."""
raise NotImplementedError("Fake embeddings not implemented yet")
def generate_fake_models_list(config: FakeConfig) -> dict[str, Any]:
"""Generate fake OpenAI models list response."""
raise NotImplementedError("Fake models list not implemented yet")
async def generate_fake_stream(
response_data: Any, endpoint: str, config: FakeConfig
) -> AsyncGenerator[dict[str, Any], None]:
"""Convert fake response to streaming chunks."""
latency_seconds = config.latency_ms / 1000.0
if endpoint == "/v1/chat/completions":
if hasattr(response_data, "choices"):
content = response_data.choices[0].message.content
chunk_id = response_data.id
model = response_data.model
else:
content = response_data["choices"][0]["message"]["content"]
chunk_id = response_data["id"]
model = response_data["model"]
words = content.split()
yield {
"id": chunk_id,
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": model,
"choices": [
{
"index": 0,
"delta": {
"role": "assistant",
"content": "",
"function_call": None,
"tool_calls": None,
},
"finish_reason": None,
"logprobs": None,
}
],
"system_fingerprint": None,
}
await asyncio.sleep(latency_seconds)
for i, word in enumerate(words):
chunk_content = word + (" " if i < len(words) - 1 else "")
yield {
"id": chunk_id,
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": model,
"choices": [
{
"index": 0,
"delta": {
"content": chunk_content,
"function_call": None,
"tool_calls": None,
},
"finish_reason": None,
"logprobs": None,
}
],
"system_fingerprint": None,
}
await asyncio.sleep(latency_seconds)
yield {
"id": chunk_id,
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": model,
"choices": [
{
"index": 0,
"delta": {
"content": None,
"function_call": None,
"tool_calls": None,
},
"finish_reason": "stop",
"logprobs": None,
}
],
"system_fingerprint": None,
}
elif endpoint == "/v1/completions":
raise NotImplementedError("Fake streaming completions not implemented yet")
def generate_fake_response(endpoint: str, body: dict[str, Any], config: FakeConfig) -> Any:
"""Generate fake responses based on endpoint and request."""
if endpoint == "/v1/chat/completions":
return generate_fake_chat_completion(body, config)
elif endpoint == "/v1/completions":
return generate_fake_completion(body, config)
elif endpoint == "/v1/embeddings":
return generate_fake_embeddings(body, config)
elif endpoint == "/v1/models":
return generate_fake_models_list(config)
else:
raise ValueError(f"Unsupported endpoint for fake mode: {endpoint}")

View file

@ -17,6 +17,7 @@ from pathlib import Path
from typing import Any, Literal, cast
from llama_stack.log import get_logger
from llama_stack.testing.fake_responses import generate_fake_response, generate_fake_stream, parse_fake_config
logger = get_logger(__name__, category="testing")
@ -25,6 +26,7 @@ _current_mode: str | None = None
_current_storage: ResponseStorage | None = None
_original_methods: dict[str, Any] = {}
from openai.types.completion_choice import CompletionChoice
# update the "finish_reason" field, since its type definition is wrong (no None is accepted)
@ -36,6 +38,7 @@ class InferenceMode(StrEnum):
LIVE = "live"
RECORD = "record"
REPLAY = "replay"
FAKE = "fake"
def normalize_request(method: str, url: str, headers: dict[str, Any], body: dict[str, Any]) -> str:
@ -52,20 +55,31 @@ def normalize_request(method: str, url: str, headers: dict[str, Any], body: dict
def get_inference_mode() -> InferenceMode:
return InferenceMode(os.environ.get("LLAMA_STACK_TEST_INFERENCE_MODE", "live").lower())
mode_str = os.environ.get("LLAMA_STACK_TEST_INFERENCE_MODE", "live").lower()
# Parse mode and config (e.g., "fake:response_length=50:latency_ms=100")
if ":" in mode_str:
return InferenceMode(mode_str.split(":")[0])
return InferenceMode(mode_str)
def setup_inference_recording():
"""
Returns a context manager that can be used to record or replay inference requests. This is to be used in tests
Returns a context manager that can be used to record, replay, or fake inference requests. This is to be used in tests
to increase their reliability and reduce reliance on expensive, external services.
Currently, this is only supported for OpenAI and Ollama clients. These should cover the vast majority of use cases.
Calls to the /models endpoint are not currently trapped. We probably need to add support for this.
Two environment variables are required:
- LLAMA_STACK_TEST_INFERENCE_MODE: The mode to run in. Must be 'live', 'record', or 'replay'.
- LLAMA_STACK_TEST_RECORDING_DIR: The directory to store the recordings in.
Environment variables:
- LLAMA_STACK_TEST_INFERENCE_MODE: The mode to run in. Must be 'live', 'record', 'replay', or 'fake[:config]'.
For fake mode, configuration can be specified as 'fake:param1=value1:param2=value2'
- LLAMA_STACK_TEST_RECORDING_DIR: The directory to store the recordings in (required for 'record' and 'replay' modes).
Configuration for 'fake' mode:
- response_length: Number of words in fake responses (default: 100)
- latency_ms: Simulated latency in milliseconds (default: 50)
Example: LLAMA_STACK_TEST_INFERENCE_MODE=fake:response_length=50:latency_ms=100
The recordings are stored in a SQLite database and a JSON file for each request. The SQLite database is used to
quickly find the correct recording for a given request. The JSON files are used to store the request and response
@ -74,11 +88,16 @@ def setup_inference_recording():
mode = get_inference_mode()
if mode not in InferenceMode:
raise ValueError(f"Invalid LLAMA_STACK_TEST_INFERENCE_MODE: {mode}. Must be 'live', 'record', or 'replay'")
raise ValueError(
f"Invalid LLAMA_STACK_TEST_INFERENCE_MODE: {mode}. Must be 'live', 'record', 'replay', or 'fake'"
)
if mode == InferenceMode.LIVE:
return None
if mode == InferenceMode.FAKE:
return inference_recording(mode=mode, storage_dir=None)
if "LLAMA_STACK_TEST_RECORDING_DIR" not in os.environ:
raise ValueError("LLAMA_STACK_TEST_RECORDING_DIR must be set for recording or replaying")
storage_dir = os.environ["LLAMA_STACK_TEST_RECORDING_DIR"]
@ -220,7 +239,7 @@ class ResponseStorage:
async def _patched_inference_method(original_method, self, client_type, endpoint, *args, **kwargs):
global _current_mode, _current_storage
if _current_mode == InferenceMode.LIVE or _current_storage is None:
if _current_mode == InferenceMode.LIVE or (_current_mode != InferenceMode.FAKE and _current_storage is None):
# Normal operation
return await original_method(self, *args, **kwargs)
@ -266,6 +285,15 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
f"To record this response, run with LLAMA_STACK_INFERENCE_MODE=record"
)
elif _current_mode == InferenceMode.FAKE:
fake_config = parse_fake_config()
fake_response = generate_fake_response(endpoint, body, fake_config)
if body.get("stream", False):
return generate_fake_stream(fake_response, endpoint, fake_config)
else:
return fake_response
elif _current_mode == InferenceMode.RECORD:
response = await original_method(self, *args, **kwargs)
@ -440,12 +468,15 @@ def inference_recording(mode: str = "live", storage_dir: str | Path | None = Non
if mode in ["record", "replay"]:
_current_storage = ResponseStorage(storage_dir_path)
patch_inference_clients()
elif mode == "fake":
_current_storage = None
patch_inference_clients()
yield
finally:
# Restore previous state
if mode in ["record", "replay"]:
if mode in ["record", "replay", "fake"]:
unpatch_inference_clients()
_current_mode = prev_mode

View file

@ -0,0 +1,189 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from llama_stack.testing.fake_responses import FakeConfig, generate_fake_response, generate_fake_stream
class TestGenerateFakeResponse:
"""Test cases for generate_fake_response function."""
def test_chat_completions_basic(self):
"""Test basic chat completions generation."""
endpoint = "/v1/chat/completions"
body = {"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "Hello, how are you?"}]}
config = FakeConfig(response_length=10, latency_ms=50)
response = generate_fake_response(endpoint, body, config)
# Check response structure
if hasattr(response, "id"):
# OpenAI object format
assert response.id.startswith("chatcmpl-fake-")
assert response.object == "chat.completion"
assert response.model == "gpt-3.5-turbo"
assert len(response.choices) == 1
assert response.choices[0].message.role == "assistant"
assert response.choices[0].message.content is not None
assert len(response.choices[0].message.content.split()) == 10
assert response.usage.total_tokens > 0
else:
# Dict format fallback
assert response["id"].startswith("chatcmpl-fake-")
assert response["object"] == "chat.completion"
assert response["model"] == "gpt-3.5-turbo"
assert len(response["choices"]) == 1
assert response["choices"][0]["message"]["role"] == "assistant"
assert response["choices"][0]["message"]["content"] is not None
assert len(response["choices"][0]["message"]["content"].split()) == 10
assert response["usage"]["total_tokens"] > 0
def test_chat_completions_custom_model(self):
"""Test chat completions with custom model name."""
endpoint = "/v1/chat/completions"
body = {"model": "custom-model-name", "messages": [{"role": "user", "content": "Test message"}]}
config = FakeConfig(response_length=5, latency_ms=10)
response = generate_fake_response(endpoint, body, config)
# Check model name is preserved
if hasattr(response, "model"):
assert response.model == "custom-model-name"
else:
assert response["model"] == "custom-model-name"
def test_chat_completions_multiple_messages(self):
"""Test chat completions with multiple input messages."""
endpoint = "/v1/chat/completions"
body = {
"model": "gpt-4",
"messages": [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"},
{"role": "user", "content": "How are you doing today?"},
],
}
config = FakeConfig(response_length=15, latency_ms=25)
response = generate_fake_response(endpoint, body, config)
# Check token calculation includes all messages
if hasattr(response, "usage"):
assert response.usage.prompt_tokens > 0 # Should count all input messages
assert response.usage.completion_tokens == 15
else:
assert response["usage"]["prompt_tokens"] > 0
assert response["usage"]["completion_tokens"] == 15
def test_completions_not_implemented(self):
"""Test that completions endpoint raises NotImplementedError."""
endpoint = "/v1/completions"
body = {"model": "gpt-3.5-turbo-instruct", "prompt": "Test prompt"}
config = FakeConfig(response_length=10)
with pytest.raises(NotImplementedError, match="Fake completions not implemented yet"):
generate_fake_response(endpoint, body, config)
def test_embeddings_not_implemented(self):
"""Test that embeddings endpoint raises NotImplementedError."""
endpoint = "/v1/embeddings"
body = {"model": "text-embedding-ada-002", "input": "Test text"}
config = FakeConfig()
with pytest.raises(NotImplementedError, match="Fake embeddings not implemented yet"):
generate_fake_response(endpoint, body, config)
def test_models_not_implemented(self):
"""Test that models endpoint raises NotImplementedError."""
endpoint = "/v1/models"
body = {}
config = FakeConfig()
with pytest.raises(NotImplementedError, match="Fake models list not implemented yet"):
generate_fake_response(endpoint, body, config)
def test_unsupported_endpoint(self):
"""Test that unsupported endpoints raise ValueError."""
endpoint = "/v1/unknown"
body = {}
config = FakeConfig()
with pytest.raises(ValueError, match="Unsupported endpoint for fake mode: /v1/unknown"):
generate_fake_response(endpoint, body, config)
def test_content_with_arrays(self):
"""Test chat completions with content arrays (e.g., images)."""
endpoint = "/v1/chat/completions"
body = {
"model": "gpt-4-vision-preview",
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": "What's in this image?"},
{"type": "image_url", "image_url": {"url": "data:image/jpeg;base64,..."}},
],
}
],
}
config = FakeConfig(response_length=20)
response = generate_fake_response(endpoint, body, config)
# Should handle content arrays without errors
if hasattr(response, "usage"):
assert response.usage.prompt_tokens > 0
else:
assert response["usage"]["prompt_tokens"] > 0
class TestGenerateFakeStream:
"""Test cases for generate_fake_stream function."""
@pytest.mark.asyncio
async def test_chat_completions_streaming(self):
"""Test streaming chat completions generation."""
# First generate a response
response_data = generate_fake_response(
"/v1/chat/completions",
{"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "Hello"}]},
FakeConfig(response_length=5, latency_ms=1), # Very low latency for testing
)
# Then stream it
chunks = []
async for chunk in generate_fake_stream(response_data, "/v1/chat/completions", FakeConfig(latency_ms=1)):
chunks.append(chunk)
# Should have initial role chunk + content chunks + final chunk
assert len(chunks) >= 3
# First chunk should have role
first_chunk = chunks[0]
assert first_chunk["object"] == "chat.completion.chunk"
assert first_chunk["choices"][0]["delta"]["role"] == "assistant"
assert first_chunk["choices"][0]["delta"]["content"] == ""
# Middle chunks should have content
content_chunks = [c for c in chunks[1:-1] if c["choices"][0]["delta"].get("content")]
assert len(content_chunks) > 0
# Last chunk should have finish_reason
last_chunk = chunks[-1]
assert last_chunk["choices"][0]["finish_reason"] == "stop"
assert last_chunk["choices"][0]["delta"]["content"] is None
@pytest.mark.asyncio
async def test_completions_streaming_not_implemented(self):
"""Test that streaming completions raises NotImplementedError."""
response_data = {"id": "test", "choices": [{"text": "test content"}]}
stream = generate_fake_stream(response_data, "/v1/completions", FakeConfig())
with pytest.raises(NotImplementedError, match="Fake streaming completions not implemented yet"):
async for _ in stream:
pass