mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-15 14:08:00 +00:00
fake mode
# What does this PR do? ## Test Plan
This commit is contained in:
parent
05cfa213b6
commit
12e46b7a4a
3 changed files with 484 additions and 8 deletions
256
llama_stack/testing/fake_responses.py
Normal file
256
llama_stack/testing/fake_responses.py
Normal file
|
@ -0,0 +1,256 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
"""
|
||||||
|
Fake response generation for testing inference providers without making real API calls.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
|
from typing import Any
|
||||||
|
from openai.types.chat import ChatCompletion
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class FakeConfig(BaseModel):
|
||||||
|
response_length: int = 100
|
||||||
|
latency_ms: int = 50
|
||||||
|
|
||||||
|
|
||||||
|
def parse_fake_config() -> FakeConfig:
|
||||||
|
"""Parse fake mode configuration from environment variable."""
|
||||||
|
mode_str = os.environ.get("LLAMA_STACK_TEST_INFERENCE_MODE", "live").lower()
|
||||||
|
config = {}
|
||||||
|
|
||||||
|
if ":" in mode_str:
|
||||||
|
parts = mode_str.split(":")
|
||||||
|
for part in parts[1:]:
|
||||||
|
if "=" in part:
|
||||||
|
key, value = part.split("=", 1)
|
||||||
|
config[key] = int(value)
|
||||||
|
return FakeConfig(**config)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_fake_content(word_count: int) -> str:
|
||||||
|
"""Generate fake response content with specified word count."""
|
||||||
|
words = [
|
||||||
|
"This",
|
||||||
|
"is",
|
||||||
|
"a",
|
||||||
|
"synthetic",
|
||||||
|
"response",
|
||||||
|
"generated",
|
||||||
|
"for",
|
||||||
|
"testing",
|
||||||
|
"purposes",
|
||||||
|
"only",
|
||||||
|
"The",
|
||||||
|
"content",
|
||||||
|
"simulates",
|
||||||
|
"realistic",
|
||||||
|
"language",
|
||||||
|
"model",
|
||||||
|
"output",
|
||||||
|
"patterns",
|
||||||
|
"and",
|
||||||
|
"structures",
|
||||||
|
"It",
|
||||||
|
"includes",
|
||||||
|
"various",
|
||||||
|
"sentence",
|
||||||
|
"types",
|
||||||
|
"and",
|
||||||
|
"maintains",
|
||||||
|
"coherent",
|
||||||
|
"flow",
|
||||||
|
"throughout",
|
||||||
|
"These",
|
||||||
|
"responses",
|
||||||
|
"help",
|
||||||
|
"test",
|
||||||
|
"system",
|
||||||
|
"performance",
|
||||||
|
"without",
|
||||||
|
"requiring",
|
||||||
|
"real",
|
||||||
|
"model",
|
||||||
|
"calls",
|
||||||
|
]
|
||||||
|
|
||||||
|
return " ".join(words[i % len(words)] for i in range(word_count)) + "."
|
||||||
|
|
||||||
|
|
||||||
|
def generate_fake_chat_completion(body: dict[str, Any], config: FakeConfig) -> Any:
|
||||||
|
"""Generate fake OpenAI chat completion response."""
|
||||||
|
model = body.get("model", "gpt-3.5-turbo")
|
||||||
|
messages = body.get("messages", [])
|
||||||
|
|
||||||
|
# Calculate fake token counts based on input
|
||||||
|
prompt_tokens = 0
|
||||||
|
for msg in messages:
|
||||||
|
content = msg.get("content", "")
|
||||||
|
if isinstance(content, str):
|
||||||
|
prompt_tokens += len(content.split())
|
||||||
|
elif isinstance(content, list):
|
||||||
|
# Handle content arrays (images, etc.)
|
||||||
|
for item in content:
|
||||||
|
if isinstance(item, dict) and item.get("type") == "text":
|
||||||
|
prompt_tokens += len(item.get("text", "").split())
|
||||||
|
|
||||||
|
response_length = config.response_length
|
||||||
|
fake_content = generate_fake_content(response_length)
|
||||||
|
completion_tokens = len(fake_content.split())
|
||||||
|
|
||||||
|
response_data = {
|
||||||
|
"id": f"chatcmpl-fake-{int(time.time() * 1000)}",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"created": int(time.time()),
|
||||||
|
"model": model,
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": fake_content,
|
||||||
|
"function_call": None,
|
||||||
|
"tool_calls": None,
|
||||||
|
},
|
||||||
|
"finish_reason": "stop",
|
||||||
|
"logprobs": None,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": prompt_tokens,
|
||||||
|
"completion_tokens": completion_tokens,
|
||||||
|
"total_tokens": prompt_tokens + completion_tokens,
|
||||||
|
},
|
||||||
|
"system_fingerprint": None,
|
||||||
|
}
|
||||||
|
time.sleep(config.latency_ms / 1000.0)
|
||||||
|
|
||||||
|
return ChatCompletion.model_validate(response_data)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_fake_completion(body: dict[str, Any], config: FakeConfig) -> dict[str, Any]:
|
||||||
|
"""Generate fake OpenAI completion response."""
|
||||||
|
raise NotImplementedError("Fake completions not implemented yet")
|
||||||
|
|
||||||
|
|
||||||
|
def generate_fake_embeddings(body: dict[str, Any], config: FakeConfig) -> dict[str, Any]:
|
||||||
|
"""Generate fake OpenAI embeddings response."""
|
||||||
|
raise NotImplementedError("Fake embeddings not implemented yet")
|
||||||
|
|
||||||
|
|
||||||
|
def generate_fake_models_list(config: FakeConfig) -> dict[str, Any]:
|
||||||
|
"""Generate fake OpenAI models list response."""
|
||||||
|
raise NotImplementedError("Fake models list not implemented yet")
|
||||||
|
|
||||||
|
|
||||||
|
async def generate_fake_stream(
|
||||||
|
response_data: Any, endpoint: str, config: FakeConfig
|
||||||
|
) -> AsyncGenerator[dict[str, Any], None]:
|
||||||
|
"""Convert fake response to streaming chunks."""
|
||||||
|
latency_seconds = config.latency_ms / 1000.0
|
||||||
|
|
||||||
|
if endpoint == "/v1/chat/completions":
|
||||||
|
if hasattr(response_data, "choices"):
|
||||||
|
content = response_data.choices[0].message.content
|
||||||
|
chunk_id = response_data.id
|
||||||
|
model = response_data.model
|
||||||
|
else:
|
||||||
|
content = response_data["choices"][0]["message"]["content"]
|
||||||
|
chunk_id = response_data["id"]
|
||||||
|
model = response_data["model"]
|
||||||
|
|
||||||
|
words = content.split()
|
||||||
|
|
||||||
|
yield {
|
||||||
|
"id": chunk_id,
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"created": int(time.time()),
|
||||||
|
"model": model,
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"delta": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "",
|
||||||
|
"function_call": None,
|
||||||
|
"tool_calls": None,
|
||||||
|
},
|
||||||
|
"finish_reason": None,
|
||||||
|
"logprobs": None,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"system_fingerprint": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
await asyncio.sleep(latency_seconds)
|
||||||
|
|
||||||
|
for i, word in enumerate(words):
|
||||||
|
chunk_content = word + (" " if i < len(words) - 1 else "")
|
||||||
|
|
||||||
|
yield {
|
||||||
|
"id": chunk_id,
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"created": int(time.time()),
|
||||||
|
"model": model,
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"delta": {
|
||||||
|
"content": chunk_content,
|
||||||
|
"function_call": None,
|
||||||
|
"tool_calls": None,
|
||||||
|
},
|
||||||
|
"finish_reason": None,
|
||||||
|
"logprobs": None,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"system_fingerprint": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
await asyncio.sleep(latency_seconds)
|
||||||
|
|
||||||
|
yield {
|
||||||
|
"id": chunk_id,
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"created": int(time.time()),
|
||||||
|
"model": model,
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"delta": {
|
||||||
|
"content": None,
|
||||||
|
"function_call": None,
|
||||||
|
"tool_calls": None,
|
||||||
|
},
|
||||||
|
"finish_reason": "stop",
|
||||||
|
"logprobs": None,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"system_fingerprint": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
elif endpoint == "/v1/completions":
|
||||||
|
raise NotImplementedError("Fake streaming completions not implemented yet")
|
||||||
|
|
||||||
|
|
||||||
|
def generate_fake_response(endpoint: str, body: dict[str, Any], config: FakeConfig) -> Any:
|
||||||
|
"""Generate fake responses based on endpoint and request."""
|
||||||
|
if endpoint == "/v1/chat/completions":
|
||||||
|
return generate_fake_chat_completion(body, config)
|
||||||
|
elif endpoint == "/v1/completions":
|
||||||
|
return generate_fake_completion(body, config)
|
||||||
|
elif endpoint == "/v1/embeddings":
|
||||||
|
return generate_fake_embeddings(body, config)
|
||||||
|
elif endpoint == "/v1/models":
|
||||||
|
return generate_fake_models_list(config)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported endpoint for fake mode: {endpoint}")
|
|
@ -17,6 +17,7 @@ from pathlib import Path
|
||||||
from typing import Any, Literal, cast
|
from typing import Any, Literal, cast
|
||||||
|
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
from llama_stack.testing.fake_responses import generate_fake_response, generate_fake_stream, parse_fake_config
|
||||||
|
|
||||||
logger = get_logger(__name__, category="testing")
|
logger = get_logger(__name__, category="testing")
|
||||||
|
|
||||||
|
@ -25,6 +26,7 @@ _current_mode: str | None = None
|
||||||
_current_storage: ResponseStorage | None = None
|
_current_storage: ResponseStorage | None = None
|
||||||
_original_methods: dict[str, Any] = {}
|
_original_methods: dict[str, Any] = {}
|
||||||
|
|
||||||
|
|
||||||
from openai.types.completion_choice import CompletionChoice
|
from openai.types.completion_choice import CompletionChoice
|
||||||
|
|
||||||
# update the "finish_reason" field, since its type definition is wrong (no None is accepted)
|
# update the "finish_reason" field, since its type definition is wrong (no None is accepted)
|
||||||
|
@ -36,6 +38,7 @@ class InferenceMode(StrEnum):
|
||||||
LIVE = "live"
|
LIVE = "live"
|
||||||
RECORD = "record"
|
RECORD = "record"
|
||||||
REPLAY = "replay"
|
REPLAY = "replay"
|
||||||
|
FAKE = "fake"
|
||||||
|
|
||||||
|
|
||||||
def normalize_request(method: str, url: str, headers: dict[str, Any], body: dict[str, Any]) -> str:
|
def normalize_request(method: str, url: str, headers: dict[str, Any], body: dict[str, Any]) -> str:
|
||||||
|
@ -52,20 +55,31 @@ def normalize_request(method: str, url: str, headers: dict[str, Any], body: dict
|
||||||
|
|
||||||
|
|
||||||
def get_inference_mode() -> InferenceMode:
|
def get_inference_mode() -> InferenceMode:
|
||||||
return InferenceMode(os.environ.get("LLAMA_STACK_TEST_INFERENCE_MODE", "live").lower())
|
mode_str = os.environ.get("LLAMA_STACK_TEST_INFERENCE_MODE", "live").lower()
|
||||||
|
# Parse mode and config (e.g., "fake:response_length=50:latency_ms=100")
|
||||||
|
if ":" in mode_str:
|
||||||
|
return InferenceMode(mode_str.split(":")[0])
|
||||||
|
return InferenceMode(mode_str)
|
||||||
|
|
||||||
|
|
||||||
def setup_inference_recording():
|
def setup_inference_recording():
|
||||||
"""
|
"""
|
||||||
Returns a context manager that can be used to record or replay inference requests. This is to be used in tests
|
Returns a context manager that can be used to record, replay, or fake inference requests. This is to be used in tests
|
||||||
to increase their reliability and reduce reliance on expensive, external services.
|
to increase their reliability and reduce reliance on expensive, external services.
|
||||||
|
|
||||||
Currently, this is only supported for OpenAI and Ollama clients. These should cover the vast majority of use cases.
|
Currently, this is only supported for OpenAI and Ollama clients. These should cover the vast majority of use cases.
|
||||||
Calls to the /models endpoint are not currently trapped. We probably need to add support for this.
|
Calls to the /models endpoint are not currently trapped. We probably need to add support for this.
|
||||||
|
|
||||||
Two environment variables are required:
|
Environment variables:
|
||||||
- LLAMA_STACK_TEST_INFERENCE_MODE: The mode to run in. Must be 'live', 'record', or 'replay'.
|
- LLAMA_STACK_TEST_INFERENCE_MODE: The mode to run in. Must be 'live', 'record', 'replay', or 'fake[:config]'.
|
||||||
- LLAMA_STACK_TEST_RECORDING_DIR: The directory to store the recordings in.
|
For fake mode, configuration can be specified as 'fake:param1=value1:param2=value2'
|
||||||
|
- LLAMA_STACK_TEST_RECORDING_DIR: The directory to store the recordings in (required for 'record' and 'replay' modes).
|
||||||
|
|
||||||
|
Configuration for 'fake' mode:
|
||||||
|
- response_length: Number of words in fake responses (default: 100)
|
||||||
|
- latency_ms: Simulated latency in milliseconds (default: 50)
|
||||||
|
|
||||||
|
Example: LLAMA_STACK_TEST_INFERENCE_MODE=fake:response_length=50:latency_ms=100
|
||||||
|
|
||||||
The recordings are stored in a SQLite database and a JSON file for each request. The SQLite database is used to
|
The recordings are stored in a SQLite database and a JSON file for each request. The SQLite database is used to
|
||||||
quickly find the correct recording for a given request. The JSON files are used to store the request and response
|
quickly find the correct recording for a given request. The JSON files are used to store the request and response
|
||||||
|
@ -74,11 +88,16 @@ def setup_inference_recording():
|
||||||
mode = get_inference_mode()
|
mode = get_inference_mode()
|
||||||
|
|
||||||
if mode not in InferenceMode:
|
if mode not in InferenceMode:
|
||||||
raise ValueError(f"Invalid LLAMA_STACK_TEST_INFERENCE_MODE: {mode}. Must be 'live', 'record', or 'replay'")
|
raise ValueError(
|
||||||
|
f"Invalid LLAMA_STACK_TEST_INFERENCE_MODE: {mode}. Must be 'live', 'record', 'replay', or 'fake'"
|
||||||
|
)
|
||||||
|
|
||||||
if mode == InferenceMode.LIVE:
|
if mode == InferenceMode.LIVE:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
if mode == InferenceMode.FAKE:
|
||||||
|
return inference_recording(mode=mode, storage_dir=None)
|
||||||
|
|
||||||
if "LLAMA_STACK_TEST_RECORDING_DIR" not in os.environ:
|
if "LLAMA_STACK_TEST_RECORDING_DIR" not in os.environ:
|
||||||
raise ValueError("LLAMA_STACK_TEST_RECORDING_DIR must be set for recording or replaying")
|
raise ValueError("LLAMA_STACK_TEST_RECORDING_DIR must be set for recording or replaying")
|
||||||
storage_dir = os.environ["LLAMA_STACK_TEST_RECORDING_DIR"]
|
storage_dir = os.environ["LLAMA_STACK_TEST_RECORDING_DIR"]
|
||||||
|
@ -220,7 +239,7 @@ class ResponseStorage:
|
||||||
async def _patched_inference_method(original_method, self, client_type, endpoint, *args, **kwargs):
|
async def _patched_inference_method(original_method, self, client_type, endpoint, *args, **kwargs):
|
||||||
global _current_mode, _current_storage
|
global _current_mode, _current_storage
|
||||||
|
|
||||||
if _current_mode == InferenceMode.LIVE or _current_storage is None:
|
if _current_mode == InferenceMode.LIVE or (_current_mode != InferenceMode.FAKE and _current_storage is None):
|
||||||
# Normal operation
|
# Normal operation
|
||||||
return await original_method(self, *args, **kwargs)
|
return await original_method(self, *args, **kwargs)
|
||||||
|
|
||||||
|
@ -266,6 +285,15 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
|
||||||
f"To record this response, run with LLAMA_STACK_INFERENCE_MODE=record"
|
f"To record this response, run with LLAMA_STACK_INFERENCE_MODE=record"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
elif _current_mode == InferenceMode.FAKE:
|
||||||
|
fake_config = parse_fake_config()
|
||||||
|
fake_response = generate_fake_response(endpoint, body, fake_config)
|
||||||
|
|
||||||
|
if body.get("stream", False):
|
||||||
|
return generate_fake_stream(fake_response, endpoint, fake_config)
|
||||||
|
else:
|
||||||
|
return fake_response
|
||||||
|
|
||||||
elif _current_mode == InferenceMode.RECORD:
|
elif _current_mode == InferenceMode.RECORD:
|
||||||
response = await original_method(self, *args, **kwargs)
|
response = await original_method(self, *args, **kwargs)
|
||||||
|
|
||||||
|
@ -440,12 +468,15 @@ def inference_recording(mode: str = "live", storage_dir: str | Path | None = Non
|
||||||
if mode in ["record", "replay"]:
|
if mode in ["record", "replay"]:
|
||||||
_current_storage = ResponseStorage(storage_dir_path)
|
_current_storage = ResponseStorage(storage_dir_path)
|
||||||
patch_inference_clients()
|
patch_inference_clients()
|
||||||
|
elif mode == "fake":
|
||||||
|
_current_storage = None
|
||||||
|
patch_inference_clients()
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
# Restore previous state
|
# Restore previous state
|
||||||
if mode in ["record", "replay"]:
|
if mode in ["record", "replay", "fake"]:
|
||||||
unpatch_inference_clients()
|
unpatch_inference_clients()
|
||||||
|
|
||||||
_current_mode = prev_mode
|
_current_mode = prev_mode
|
||||||
|
|
189
tests/unit/testing/test_fake_responses.py
Normal file
189
tests/unit/testing/test_fake_responses.py
Normal file
|
@ -0,0 +1,189 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from llama_stack.testing.fake_responses import FakeConfig, generate_fake_response, generate_fake_stream
|
||||||
|
|
||||||
|
|
||||||
|
class TestGenerateFakeResponse:
|
||||||
|
"""Test cases for generate_fake_response function."""
|
||||||
|
|
||||||
|
def test_chat_completions_basic(self):
|
||||||
|
"""Test basic chat completions generation."""
|
||||||
|
endpoint = "/v1/chat/completions"
|
||||||
|
body = {"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "Hello, how are you?"}]}
|
||||||
|
config = FakeConfig(response_length=10, latency_ms=50)
|
||||||
|
|
||||||
|
response = generate_fake_response(endpoint, body, config)
|
||||||
|
|
||||||
|
# Check response structure
|
||||||
|
if hasattr(response, "id"):
|
||||||
|
# OpenAI object format
|
||||||
|
assert response.id.startswith("chatcmpl-fake-")
|
||||||
|
assert response.object == "chat.completion"
|
||||||
|
assert response.model == "gpt-3.5-turbo"
|
||||||
|
assert len(response.choices) == 1
|
||||||
|
assert response.choices[0].message.role == "assistant"
|
||||||
|
assert response.choices[0].message.content is not None
|
||||||
|
assert len(response.choices[0].message.content.split()) == 10
|
||||||
|
assert response.usage.total_tokens > 0
|
||||||
|
else:
|
||||||
|
# Dict format fallback
|
||||||
|
assert response["id"].startswith("chatcmpl-fake-")
|
||||||
|
assert response["object"] == "chat.completion"
|
||||||
|
assert response["model"] == "gpt-3.5-turbo"
|
||||||
|
assert len(response["choices"]) == 1
|
||||||
|
assert response["choices"][0]["message"]["role"] == "assistant"
|
||||||
|
assert response["choices"][0]["message"]["content"] is not None
|
||||||
|
assert len(response["choices"][0]["message"]["content"].split()) == 10
|
||||||
|
assert response["usage"]["total_tokens"] > 0
|
||||||
|
|
||||||
|
def test_chat_completions_custom_model(self):
|
||||||
|
"""Test chat completions with custom model name."""
|
||||||
|
endpoint = "/v1/chat/completions"
|
||||||
|
body = {"model": "custom-model-name", "messages": [{"role": "user", "content": "Test message"}]}
|
||||||
|
config = FakeConfig(response_length=5, latency_ms=10)
|
||||||
|
|
||||||
|
response = generate_fake_response(endpoint, body, config)
|
||||||
|
|
||||||
|
# Check model name is preserved
|
||||||
|
if hasattr(response, "model"):
|
||||||
|
assert response.model == "custom-model-name"
|
||||||
|
else:
|
||||||
|
assert response["model"] == "custom-model-name"
|
||||||
|
|
||||||
|
def test_chat_completions_multiple_messages(self):
|
||||||
|
"""Test chat completions with multiple input messages."""
|
||||||
|
endpoint = "/v1/chat/completions"
|
||||||
|
body = {
|
||||||
|
"model": "gpt-4",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "Hello"},
|
||||||
|
{"role": "assistant", "content": "Hi there!"},
|
||||||
|
{"role": "user", "content": "How are you doing today?"},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
config = FakeConfig(response_length=15, latency_ms=25)
|
||||||
|
|
||||||
|
response = generate_fake_response(endpoint, body, config)
|
||||||
|
|
||||||
|
# Check token calculation includes all messages
|
||||||
|
if hasattr(response, "usage"):
|
||||||
|
assert response.usage.prompt_tokens > 0 # Should count all input messages
|
||||||
|
assert response.usage.completion_tokens == 15
|
||||||
|
else:
|
||||||
|
assert response["usage"]["prompt_tokens"] > 0
|
||||||
|
assert response["usage"]["completion_tokens"] == 15
|
||||||
|
|
||||||
|
def test_completions_not_implemented(self):
|
||||||
|
"""Test that completions endpoint raises NotImplementedError."""
|
||||||
|
endpoint = "/v1/completions"
|
||||||
|
body = {"model": "gpt-3.5-turbo-instruct", "prompt": "Test prompt"}
|
||||||
|
config = FakeConfig(response_length=10)
|
||||||
|
|
||||||
|
with pytest.raises(NotImplementedError, match="Fake completions not implemented yet"):
|
||||||
|
generate_fake_response(endpoint, body, config)
|
||||||
|
|
||||||
|
def test_embeddings_not_implemented(self):
|
||||||
|
"""Test that embeddings endpoint raises NotImplementedError."""
|
||||||
|
endpoint = "/v1/embeddings"
|
||||||
|
body = {"model": "text-embedding-ada-002", "input": "Test text"}
|
||||||
|
config = FakeConfig()
|
||||||
|
|
||||||
|
with pytest.raises(NotImplementedError, match="Fake embeddings not implemented yet"):
|
||||||
|
generate_fake_response(endpoint, body, config)
|
||||||
|
|
||||||
|
def test_models_not_implemented(self):
|
||||||
|
"""Test that models endpoint raises NotImplementedError."""
|
||||||
|
endpoint = "/v1/models"
|
||||||
|
body = {}
|
||||||
|
config = FakeConfig()
|
||||||
|
|
||||||
|
with pytest.raises(NotImplementedError, match="Fake models list not implemented yet"):
|
||||||
|
generate_fake_response(endpoint, body, config)
|
||||||
|
|
||||||
|
def test_unsupported_endpoint(self):
|
||||||
|
"""Test that unsupported endpoints raise ValueError."""
|
||||||
|
endpoint = "/v1/unknown"
|
||||||
|
body = {}
|
||||||
|
config = FakeConfig()
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Unsupported endpoint for fake mode: /v1/unknown"):
|
||||||
|
generate_fake_response(endpoint, body, config)
|
||||||
|
|
||||||
|
def test_content_with_arrays(self):
|
||||||
|
"""Test chat completions with content arrays (e.g., images)."""
|
||||||
|
endpoint = "/v1/chat/completions"
|
||||||
|
body = {
|
||||||
|
"model": "gpt-4-vision-preview",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "What's in this image?"},
|
||||||
|
{"type": "image_url", "image_url": {"url": "data:image/jpeg;base64,..."}},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
config = FakeConfig(response_length=20)
|
||||||
|
|
||||||
|
response = generate_fake_response(endpoint, body, config)
|
||||||
|
|
||||||
|
# Should handle content arrays without errors
|
||||||
|
if hasattr(response, "usage"):
|
||||||
|
assert response.usage.prompt_tokens > 0
|
||||||
|
else:
|
||||||
|
assert response["usage"]["prompt_tokens"] > 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestGenerateFakeStream:
|
||||||
|
"""Test cases for generate_fake_stream function."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_chat_completions_streaming(self):
|
||||||
|
"""Test streaming chat completions generation."""
|
||||||
|
# First generate a response
|
||||||
|
response_data = generate_fake_response(
|
||||||
|
"/v1/chat/completions",
|
||||||
|
{"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "Hello"}]},
|
||||||
|
FakeConfig(response_length=5, latency_ms=1), # Very low latency for testing
|
||||||
|
)
|
||||||
|
|
||||||
|
# Then stream it
|
||||||
|
chunks = []
|
||||||
|
async for chunk in generate_fake_stream(response_data, "/v1/chat/completions", FakeConfig(latency_ms=1)):
|
||||||
|
chunks.append(chunk)
|
||||||
|
|
||||||
|
# Should have initial role chunk + content chunks + final chunk
|
||||||
|
assert len(chunks) >= 3
|
||||||
|
|
||||||
|
# First chunk should have role
|
||||||
|
first_chunk = chunks[0]
|
||||||
|
assert first_chunk["object"] == "chat.completion.chunk"
|
||||||
|
assert first_chunk["choices"][0]["delta"]["role"] == "assistant"
|
||||||
|
assert first_chunk["choices"][0]["delta"]["content"] == ""
|
||||||
|
|
||||||
|
# Middle chunks should have content
|
||||||
|
content_chunks = [c for c in chunks[1:-1] if c["choices"][0]["delta"].get("content")]
|
||||||
|
assert len(content_chunks) > 0
|
||||||
|
|
||||||
|
# Last chunk should have finish_reason
|
||||||
|
last_chunk = chunks[-1]
|
||||||
|
assert last_chunk["choices"][0]["finish_reason"] == "stop"
|
||||||
|
assert last_chunk["choices"][0]["delta"]["content"] is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_completions_streaming_not_implemented(self):
|
||||||
|
"""Test that streaming completions raises NotImplementedError."""
|
||||||
|
response_data = {"id": "test", "choices": [{"text": "test content"}]}
|
||||||
|
|
||||||
|
stream = generate_fake_stream(response_data, "/v1/completions", FakeConfig())
|
||||||
|
|
||||||
|
with pytest.raises(NotImplementedError, match="Fake streaming completions not implemented yet"):
|
||||||
|
async for _ in stream:
|
||||||
|
pass
|
Loading…
Add table
Add a link
Reference in a new issue