mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-14 04:02:38 +00:00
feat(tests): make inference_recorder into api_recorder (include tool_invoke)
This commit is contained in:
parent
b96640eca3
commit
9205731cd6
19 changed files with 849 additions and 666 deletions
4
.github/workflows/integration-tests.yml
vendored
4
.github/workflows/integration-tests.yml
vendored
|
|
@ -54,14 +54,14 @@ jobs:
|
||||||
# Define (setup, suite) pairs - they are always matched and cannot be independent
|
# Define (setup, suite) pairs - they are always matched and cannot be independent
|
||||||
# Weekly schedule (Sun 1 AM): vllm+base
|
# Weekly schedule (Sun 1 AM): vllm+base
|
||||||
# Input test-setup=ollama-vision: ollama-vision+vision
|
# Input test-setup=ollama-vision: ollama-vision+vision
|
||||||
# Default (including test-setup=ollama): both ollama+base and ollama-vision+vision
|
# Default (including test-setup=ollama): ollama+base, ollama-vision+vision, gpt+responses
|
||||||
config: >-
|
config: >-
|
||||||
${{
|
${{
|
||||||
github.event.schedule == '1 0 * * 0'
|
github.event.schedule == '1 0 * * 0'
|
||||||
&& fromJSON('[{"setup": "vllm", "suite": "base"}]')
|
&& fromJSON('[{"setup": "vllm", "suite": "base"}]')
|
||||||
|| github.event.inputs.test-setup == 'ollama-vision'
|
|| github.event.inputs.test-setup == 'ollama-vision'
|
||||||
&& fromJSON('[{"setup": "ollama-vision", "suite": "vision"}]')
|
&& fromJSON('[{"setup": "ollama-vision", "suite": "vision"}]')
|
||||||
|| fromJSON('[{"setup": "ollama", "suite": "base"}, {"setup": "ollama-vision", "suite": "vision"}]')
|
|| fromJSON('[{"setup": "ollama", "suite": "base"}, {"setup": "ollama-vision", "suite": "vision"}, {"setup": "gpt", "suite": "responses"}]')
|
||||||
}}
|
}}
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
|
|
|
||||||
|
|
@ -61,6 +61,9 @@ jobs:
|
||||||
|
|
||||||
- name: Run and record tests
|
- name: Run and record tests
|
||||||
uses: ./.github/actions/run-and-record-tests
|
uses: ./.github/actions/run-and-record-tests
|
||||||
|
env:
|
||||||
|
# Set OPENAI_API_KEY if using gpt setup
|
||||||
|
OPENAI_API_KEY: ${{ inputs.test-setup == 'gpt' && secrets.OPENAI_API_KEY || '' }}
|
||||||
with:
|
with:
|
||||||
stack-config: 'server:ci-tests' # recording must be done with server since more tests are run
|
stack-config: 'server:ci-tests' # recording must be done with server since more tests are run
|
||||||
setup: ${{ inputs.test-setup || 'ollama' }}
|
setup: ${{ inputs.test-setup || 'ollama' }}
|
||||||
|
|
|
||||||
|
|
@ -68,7 +68,9 @@ recordings/
|
||||||
Direct API calls with no recording or replay:
|
Direct API calls with no recording or replay:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
with inference_recording(mode=InferenceMode.LIVE):
|
from llama_stack.testing.api_recorder import api_recording, APIRecordingMode
|
||||||
|
|
||||||
|
with api_recording(mode=APIRecordingMode.LIVE):
|
||||||
response = await client.chat.completions.create(...)
|
response = await client.chat.completions.create(...)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
@ -79,7 +81,7 @@ Use for initial development and debugging against real APIs.
|
||||||
Captures API interactions while passing through real responses:
|
Captures API interactions while passing through real responses:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
with inference_recording(mode=InferenceMode.RECORD, storage_dir="./recordings"):
|
with api_recording(mode=APIRecordingMode.RECORD, storage_dir="./recordings"):
|
||||||
response = await client.chat.completions.create(...)
|
response = await client.chat.completions.create(...)
|
||||||
# Real API call made, response captured AND returned
|
# Real API call made, response captured AND returned
|
||||||
```
|
```
|
||||||
|
|
@ -96,7 +98,7 @@ The recording process:
|
||||||
Returns stored responses instead of making API calls:
|
Returns stored responses instead of making API calls:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
with inference_recording(mode=InferenceMode.REPLAY, storage_dir="./recordings"):
|
with api_recording(mode=APIRecordingMode.REPLAY, storage_dir="./recordings"):
|
||||||
response = await client.chat.completions.create(...)
|
response = await client.chat.completions.create(...)
|
||||||
# No API call made, cached response returned instantly
|
# No API call made, cached response returned instantly
|
||||||
```
|
```
|
||||||
|
|
|
||||||
|
|
@ -316,13 +316,13 @@ class Stack:
|
||||||
# asked for in the run config.
|
# asked for in the run config.
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
if "LLAMA_STACK_TEST_INFERENCE_MODE" in os.environ:
|
if "LLAMA_STACK_TEST_INFERENCE_MODE" in os.environ:
|
||||||
from llama_stack.testing.inference_recorder import setup_inference_recording
|
from llama_stack.testing.api_recorder import setup_api_recording
|
||||||
|
|
||||||
global TEST_RECORDING_CONTEXT
|
global TEST_RECORDING_CONTEXT
|
||||||
TEST_RECORDING_CONTEXT = setup_inference_recording()
|
TEST_RECORDING_CONTEXT = setup_api_recording()
|
||||||
if TEST_RECORDING_CONTEXT:
|
if TEST_RECORDING_CONTEXT:
|
||||||
TEST_RECORDING_CONTEXT.__enter__()
|
TEST_RECORDING_CONTEXT.__enter__()
|
||||||
logger.info(f"Inference recording enabled: mode={os.environ.get('LLAMA_STACK_TEST_INFERENCE_MODE')}")
|
logger.info(f"API recording enabled: mode={os.environ.get('LLAMA_STACK_TEST_INFERENCE_MODE')}")
|
||||||
|
|
||||||
dist_registry, _ = await create_dist_registry(self.run_config.metadata_store, self.run_config.image_name)
|
dist_registry, _ = await create_dist_registry(self.run_config.metadata_store, self.run_config.image_name)
|
||||||
policy = self.run_config.server.auth.access_policy if self.run_config.server.auth else []
|
policy = self.run_config.server.auth.access_policy if self.run_config.server.auth else []
|
||||||
|
|
@ -381,7 +381,7 @@ class Stack:
|
||||||
try:
|
try:
|
||||||
TEST_RECORDING_CONTEXT.__exit__(None, None, None)
|
TEST_RECORDING_CONTEXT.__exit__(None, None, None)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error during inference recording cleanup: {e}")
|
logger.error(f"Error during API recording cleanup: {e}")
|
||||||
|
|
||||||
global REGISTRY_REFRESH_TASK
|
global REGISTRY_REFRESH_TASK
|
||||||
if REGISTRY_REFRESH_TASK:
|
if REGISTRY_REFRESH_TASK:
|
||||||
|
|
|
||||||
|
|
@ -108,7 +108,7 @@ class OpenAIResponsesImpl:
|
||||||
# Use stored messages directly and convert only new input
|
# Use stored messages directly and convert only new input
|
||||||
message_adapter = TypeAdapter(list[OpenAIMessageParam])
|
message_adapter = TypeAdapter(list[OpenAIMessageParam])
|
||||||
messages = message_adapter.validate_python(previous_response.messages)
|
messages = message_adapter.validate_python(previous_response.messages)
|
||||||
new_messages = await convert_response_input_to_chat_messages(input)
|
new_messages = await convert_response_input_to_chat_messages(input, previous_messages=messages)
|
||||||
messages.extend(new_messages)
|
messages.extend(new_messages)
|
||||||
else:
|
else:
|
||||||
# Backward compatibility: reconstruct from inputs
|
# Backward compatibility: reconstruct from inputs
|
||||||
|
|
|
||||||
|
|
@ -103,9 +103,13 @@ async def convert_response_content_to_chat_content(
|
||||||
|
|
||||||
async def convert_response_input_to_chat_messages(
|
async def convert_response_input_to_chat_messages(
|
||||||
input: str | list[OpenAIResponseInput],
|
input: str | list[OpenAIResponseInput],
|
||||||
|
previous_messages: list[OpenAIMessageParam] | None = None,
|
||||||
) -> list[OpenAIMessageParam]:
|
) -> list[OpenAIMessageParam]:
|
||||||
"""
|
"""
|
||||||
Convert the input from an OpenAI Response API request into OpenAI Chat Completion messages.
|
Convert the input from an OpenAI Response API request into OpenAI Chat Completion messages.
|
||||||
|
|
||||||
|
:param input: The input to convert
|
||||||
|
:param previous_messages: Optional previous messages to check for function_call references
|
||||||
"""
|
"""
|
||||||
messages: list[OpenAIMessageParam] = []
|
messages: list[OpenAIMessageParam] = []
|
||||||
if isinstance(input, list):
|
if isinstance(input, list):
|
||||||
|
|
@ -169,16 +173,53 @@ async def convert_response_input_to_chat_messages(
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Llama Stack OpenAI Responses does not yet support message role '{input_item.role}' in this context"
|
f"Llama Stack OpenAI Responses does not yet support message role '{input_item.role}' in this context"
|
||||||
)
|
)
|
||||||
|
# Skip user messages that duplicate the last user message in previous_messages
|
||||||
|
# This handles cases where input includes context for function_call_outputs
|
||||||
|
if previous_messages and input_item.role == "user":
|
||||||
|
last_user_msg = None
|
||||||
|
for msg in reversed(previous_messages):
|
||||||
|
if isinstance(msg, OpenAIUserMessageParam):
|
||||||
|
last_user_msg = msg
|
||||||
|
break
|
||||||
|
if last_user_msg:
|
||||||
|
last_user_content = getattr(last_user_msg, "content", None)
|
||||||
|
if last_user_content == content:
|
||||||
|
continue # Skip duplicate user message
|
||||||
messages.append(message_type(content=content))
|
messages.append(message_type(content=content))
|
||||||
if len(tool_call_results):
|
if len(tool_call_results):
|
||||||
raise ValueError(
|
# Check if unpaired function_call_outputs reference function_calls from previous messages
|
||||||
f"Received function_call_output(s) with call_id(s) {tool_call_results.keys()}, but no corresponding function_call"
|
if previous_messages:
|
||||||
)
|
previous_call_ids = _extract_tool_call_ids(previous_messages)
|
||||||
|
for call_id in list(tool_call_results.keys()):
|
||||||
|
if call_id in previous_call_ids:
|
||||||
|
# Valid: this output references a call from previous messages
|
||||||
|
# Add the tool message
|
||||||
|
messages.append(tool_call_results[call_id])
|
||||||
|
del tool_call_results[call_id]
|
||||||
|
|
||||||
|
# If still have unpaired outputs, error
|
||||||
|
if len(tool_call_results):
|
||||||
|
raise ValueError(
|
||||||
|
f"Received function_call_output(s) with call_id(s) {tool_call_results.keys()}, but no corresponding function_call"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
messages.append(OpenAIUserMessageParam(content=input))
|
messages.append(OpenAIUserMessageParam(content=input))
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_tool_call_ids(messages: list[OpenAIMessageParam]) -> set[str]:
|
||||||
|
"""Extract all tool_call IDs from messages."""
|
||||||
|
call_ids = set()
|
||||||
|
for msg in messages:
|
||||||
|
if isinstance(msg, OpenAIAssistantMessageParam):
|
||||||
|
tool_calls = getattr(msg, "tool_calls", None)
|
||||||
|
if tool_calls:
|
||||||
|
for tool_call in tool_calls:
|
||||||
|
# tool_call is a Pydantic model, use attribute access
|
||||||
|
call_ids.add(tool_call.id)
|
||||||
|
return call_ids
|
||||||
|
|
||||||
|
|
||||||
async def convert_response_text_to_chat_response_format(
|
async def convert_response_text_to_chat_response_format(
|
||||||
text: OpenAIResponseText,
|
text: OpenAIResponseText,
|
||||||
) -> OpenAIResponseFormatParam:
|
) -> OpenAIResponseFormatParam:
|
||||||
|
|
|
||||||
|
|
@ -95,7 +95,9 @@ class LocalfsFilesImpl(Files):
|
||||||
raise RuntimeError("Files provider not initialized")
|
raise RuntimeError("Files provider not initialized")
|
||||||
|
|
||||||
if expires_after is not None:
|
if expires_after is not None:
|
||||||
raise NotImplementedError("File expiration is not supported by this provider")
|
logger.warning(
|
||||||
|
f"File expiration is not supported by this provider, ignoring expires_after: {expires_after}"
|
||||||
|
)
|
||||||
|
|
||||||
file_id = self._generate_file_id()
|
file_id = self._generate_file_id()
|
||||||
file_path = self._get_file_path(file_id)
|
file_path = self._get_file_path(file_id)
|
||||||
|
|
|
||||||
|
|
@ -15,19 +15,20 @@ from enum import StrEnum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Literal, cast
|
from typing import Any, Literal, cast
|
||||||
|
|
||||||
from openai import NOT_GIVEN, OpenAI
|
from openai import NOT_GIVEN
|
||||||
|
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
logger = get_logger(__name__, category="testing")
|
logger = get_logger(__name__, category="testing")
|
||||||
|
|
||||||
# Global state for the recording system
|
# Global state for the API recording system
|
||||||
# Note: Using module globals instead of ContextVars because the session-scoped
|
# Note: Using module globals instead of ContextVars because the session-scoped
|
||||||
# client initialization happens in one async context, but tests run in different
|
# client initialization happens in one async context, but tests run in different
|
||||||
# contexts, and we need the mode/storage to persist across all contexts.
|
# contexts, and we need the mode/storage to persist across all contexts.
|
||||||
_current_mode: str | None = None
|
_current_mode: str | None = None
|
||||||
_current_storage: ResponseStorage | None = None
|
_current_storage: ResponseStorage | None = None
|
||||||
_original_methods: dict[str, Any] = {}
|
_original_methods: dict[str, Any] = {}
|
||||||
|
_memory_cache: dict[str, dict[str, Any]] = {}
|
||||||
|
|
||||||
# Test context uses ContextVar since it changes per-test and needs async isolation
|
# Test context uses ContextVar since it changes per-test and needs async isolation
|
||||||
from contextvars import ContextVar
|
from contextvars import ContextVar
|
||||||
|
|
@ -44,158 +45,33 @@ REPO_ROOT = Path(__file__).parent.parent.parent
|
||||||
DEFAULT_STORAGE_DIR = REPO_ROOT / "tests/integration/common"
|
DEFAULT_STORAGE_DIR = REPO_ROOT / "tests/integration/common"
|
||||||
|
|
||||||
|
|
||||||
class InferenceMode(StrEnum):
|
class APIRecordingMode(StrEnum):
|
||||||
LIVE = "live"
|
LIVE = "live"
|
||||||
RECORD = "record"
|
RECORD = "record"
|
||||||
REPLAY = "replay"
|
REPLAY = "replay"
|
||||||
RECORD_IF_MISSING = "record-if-missing"
|
RECORD_IF_MISSING = "record-if-missing"
|
||||||
|
|
||||||
|
|
||||||
def normalize_request(method: str, url: str, headers: dict[str, Any], body: dict[str, Any]) -> str:
|
def _normalize_file_ids(obj: Any) -> Any:
|
||||||
"""Create a normalized hash of the request for consistent matching.
|
"""Recursively replace file IDs with a canonical placeholder for consistent hashing."""
|
||||||
|
import re
|
||||||
|
|
||||||
Includes test_id from context to ensure test isolation - identical requests
|
if isinstance(obj, dict):
|
||||||
from different tests will have different hashes.
|
result = {}
|
||||||
|
for k, v in obj.items():
|
||||||
Exception: Model list endpoints (/v1/models, /api/tags) exclude test_id since
|
# Normalize file IDs in attribute dictionaries
|
||||||
they are infrastructure/shared and need to work across session setup and tests.
|
if k == "document_id" and isinstance(v, str) and v.startswith("file-"):
|
||||||
"""
|
result[k] = "file-NORMALIZED"
|
||||||
# Extract just the endpoint path
|
|
||||||
from urllib.parse import urlparse
|
|
||||||
|
|
||||||
parsed = urlparse(url)
|
|
||||||
normalized: dict[str, Any] = {
|
|
||||||
"method": method.upper(),
|
|
||||||
"endpoint": parsed.path,
|
|
||||||
"body": body,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Include test_id for isolation, except for shared infrastructure endpoints
|
|
||||||
if parsed.path not in ("/api/tags", "/v1/models"):
|
|
||||||
normalized["test_id"] = _test_context.get()
|
|
||||||
|
|
||||||
# Create hash - sort_keys=True ensures deterministic ordering
|
|
||||||
normalized_json = json.dumps(normalized, sort_keys=True)
|
|
||||||
return hashlib.sha256(normalized_json.encode()).hexdigest()
|
|
||||||
|
|
||||||
|
|
||||||
def _sync_test_context_from_provider_data():
|
|
||||||
"""In server mode, sync test ID from provider_data to _test_context.
|
|
||||||
|
|
||||||
This ensures that storage operations (which read from _test_context) work correctly
|
|
||||||
in server mode where the test ID arrives via HTTP header → provider_data.
|
|
||||||
|
|
||||||
Returns a token to reset _test_context, or None if no sync was needed.
|
|
||||||
"""
|
|
||||||
stack_config_type = os.environ.get("LLAMA_STACK_TEST_STACK_CONFIG_TYPE", "library_client")
|
|
||||||
|
|
||||||
if stack_config_type != "server":
|
|
||||||
return None
|
|
||||||
|
|
||||||
try:
|
|
||||||
from llama_stack.core.request_headers import PROVIDER_DATA_VAR
|
|
||||||
|
|
||||||
provider_data = PROVIDER_DATA_VAR.get()
|
|
||||||
|
|
||||||
if provider_data and "__test_id" in provider_data:
|
|
||||||
test_id = provider_data["__test_id"]
|
|
||||||
return _test_context.set(test_id)
|
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def patch_httpx_for_test_id():
|
|
||||||
"""Patch client _prepare_request methods to inject test ID into provider data header.
|
|
||||||
|
|
||||||
This is needed for server mode where the test ID must be transported from
|
|
||||||
client to server via HTTP headers. In library_client mode, this patch is a no-op
|
|
||||||
since everything runs in the same process.
|
|
||||||
|
|
||||||
We use the _prepare_request hook that Stainless clients provide for mutating
|
|
||||||
requests after construction but before sending.
|
|
||||||
"""
|
|
||||||
from llama_stack_client import LlamaStackClient
|
|
||||||
|
|
||||||
if "llama_stack_client_prepare_request" in _original_methods:
|
|
||||||
return
|
|
||||||
|
|
||||||
_original_methods["llama_stack_client_prepare_request"] = LlamaStackClient._prepare_request
|
|
||||||
_original_methods["openai_prepare_request"] = OpenAI._prepare_request
|
|
||||||
|
|
||||||
def patched_prepare_request(self, request):
|
|
||||||
# Call original first (it's a sync method that returns None)
|
|
||||||
# Determine which original to call based on client type
|
|
||||||
if "llama_stack_client" in self.__class__.__module__:
|
|
||||||
_original_methods["llama_stack_client_prepare_request"](self, request)
|
|
||||||
_original_methods["openai_prepare_request"](self, request)
|
|
||||||
|
|
||||||
# Only inject test ID in server mode
|
|
||||||
stack_config_type = os.environ.get("LLAMA_STACK_TEST_STACK_CONFIG_TYPE", "library_client")
|
|
||||||
test_id = _test_context.get()
|
|
||||||
|
|
||||||
if stack_config_type == "server" and test_id:
|
|
||||||
provider_data_header = request.headers.get("X-LlamaStack-Provider-Data")
|
|
||||||
|
|
||||||
if provider_data_header:
|
|
||||||
provider_data = json.loads(provider_data_header)
|
|
||||||
else:
|
else:
|
||||||
provider_data = {}
|
result[k] = _normalize_file_ids(v)
|
||||||
|
return result
|
||||||
provider_data["__test_id"] = test_id
|
elif isinstance(obj, list):
|
||||||
request.headers["X-LlamaStack-Provider-Data"] = json.dumps(provider_data)
|
return [_normalize_file_ids(item) for item in obj]
|
||||||
|
elif isinstance(obj, str):
|
||||||
return None
|
# Replace file-<uuid> patterns in strings (like in text content)
|
||||||
|
return re.sub(r"file-[a-f0-9]{32}", "file-NORMALIZED", obj)
|
||||||
LlamaStackClient._prepare_request = patched_prepare_request
|
else:
|
||||||
OpenAI._prepare_request = patched_prepare_request
|
return obj
|
||||||
|
|
||||||
|
|
||||||
# currently, unpatch is never called
|
|
||||||
def unpatch_httpx_for_test_id():
|
|
||||||
"""Remove client _prepare_request patches for test ID injection."""
|
|
||||||
if "llama_stack_client_prepare_request" not in _original_methods:
|
|
||||||
return
|
|
||||||
|
|
||||||
from llama_stack_client import LlamaStackClient
|
|
||||||
|
|
||||||
LlamaStackClient._prepare_request = _original_methods["llama_stack_client_prepare_request"]
|
|
||||||
del _original_methods["llama_stack_client_prepare_request"]
|
|
||||||
|
|
||||||
# Also restore OpenAI client if it was patched
|
|
||||||
if "openai_prepare_request" in _original_methods:
|
|
||||||
OpenAI._prepare_request = _original_methods["openai_prepare_request"]
|
|
||||||
del _original_methods["openai_prepare_request"]
|
|
||||||
|
|
||||||
|
|
||||||
def get_inference_mode() -> InferenceMode:
|
|
||||||
return InferenceMode(os.environ.get("LLAMA_STACK_TEST_INFERENCE_MODE", "replay").lower())
|
|
||||||
|
|
||||||
|
|
||||||
def setup_inference_recording():
|
|
||||||
"""
|
|
||||||
Returns a context manager that can be used to record or replay inference requests. This is to be used in tests
|
|
||||||
to increase their reliability and reduce reliance on expensive, external services.
|
|
||||||
|
|
||||||
Currently, this is only supported for OpenAI and Ollama clients. These should cover the vast majority of use cases.
|
|
||||||
|
|
||||||
Two environment variables are supported:
|
|
||||||
- LLAMA_STACK_TEST_INFERENCE_MODE: The mode to run in. Must be 'live', 'record', 'replay', or 'record-if-missing'. Default is 'replay'.
|
|
||||||
- 'live': Make all requests live without recording
|
|
||||||
- 'record': Record all requests (overwrites existing recordings)
|
|
||||||
- 'replay': Use only recorded responses (fails if recording not found)
|
|
||||||
- 'record-if-missing': Use recorded responses when available, record new ones when not found
|
|
||||||
- LLAMA_STACK_TEST_RECORDING_DIR: The directory to store the recordings in. Default is 'tests/integration/recordings'.
|
|
||||||
|
|
||||||
The recordings are stored as JSON files.
|
|
||||||
"""
|
|
||||||
mode = get_inference_mode()
|
|
||||||
if mode == InferenceMode.LIVE:
|
|
||||||
return None
|
|
||||||
|
|
||||||
storage_dir = os.environ.get("LLAMA_STACK_TEST_RECORDING_DIR", DEFAULT_STORAGE_DIR)
|
|
||||||
return inference_recording(mode=mode, storage_dir=storage_dir)
|
|
||||||
|
|
||||||
|
|
||||||
def _normalize_response_data(data: dict[str, Any], request_hash: str) -> dict[str, Any]:
|
def _normalize_response_data(data: dict[str, Any], request_hash: str) -> dict[str, Any]:
|
||||||
|
|
@ -230,11 +106,184 @@ def _normalize_response_data(data: dict[str, Any], request_hash: str) -> dict[st
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
def _serialize_response(response: Any, request_hash: str = "") -> Any:
|
def normalize_request(method: str, url: str, headers: dict[str, Any], body: dict[str, Any]) -> str:
|
||||||
|
"""Create a normalized hash of the request for consistent matching.
|
||||||
|
|
||||||
|
Includes test_id from context to ensure test isolation - identical requests
|
||||||
|
from different tests will have different hashes.
|
||||||
|
"""
|
||||||
|
# Extract just the endpoint path
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
parsed = urlparse(url)
|
||||||
|
|
||||||
|
# Normalize file IDs in the body to ensure consistent hashing across test runs
|
||||||
|
normalized_body = _normalize_file_ids(body)
|
||||||
|
|
||||||
|
normalized: dict[str, Any] = {"method": method.upper(), "endpoint": parsed.path, "body": normalized_body}
|
||||||
|
|
||||||
|
# Include test_id for isolation, except for shared infrastructure endpoints
|
||||||
|
if parsed.path not in ("/api/tags", "/v1/models"):
|
||||||
|
# Server mode: test ID was synced from provider_data to _test_context
|
||||||
|
# by _sync_test_context_from_provider_data() at the start of the request.
|
||||||
|
# We read from _test_context because it's available in all contexts (including
|
||||||
|
# when making outgoing API calls), whereas PROVIDER_DATA_VAR is only set
|
||||||
|
# for incoming HTTP requests.
|
||||||
|
#
|
||||||
|
# Library client mode: test ID in same-process ContextVar
|
||||||
|
test_id = _test_context.get()
|
||||||
|
normalized["test_id"] = test_id
|
||||||
|
|
||||||
|
# Create hash - sort_keys=True ensures deterministic ordering
|
||||||
|
normalized_json = json.dumps(normalized, sort_keys=True)
|
||||||
|
request_hash = hashlib.sha256(normalized_json.encode()).hexdigest()
|
||||||
|
|
||||||
|
return request_hash
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_tool_request(provider_name: str, tool_name: str, kwargs: dict[str, Any]) -> str:
|
||||||
|
"""Create a normalized hash of the tool request for consistent matching."""
|
||||||
|
normalized = {"provider": provider_name, "tool_name": tool_name, "kwargs": kwargs}
|
||||||
|
|
||||||
|
# Create hash - sort_keys=True ensures deterministic ordering
|
||||||
|
normalized_json = json.dumps(normalized, sort_keys=True)
|
||||||
|
return hashlib.sha256(normalized_json.encode()).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def set_test_context(test_id: str) -> Generator[None, None, None]:
|
||||||
|
"""Set the test context for recording isolation.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
with set_test_context("test_basic_completion"):
|
||||||
|
# Make API calls that will be recorded with this test_id
|
||||||
|
response = client.chat.completions.create(...)
|
||||||
|
"""
|
||||||
|
token = _test_context.set(test_id)
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
_test_context.reset(token)
|
||||||
|
|
||||||
|
|
||||||
|
def patch_httpx_for_test_id():
|
||||||
|
"""Patch client _prepare_request methods to inject test ID into provider data header.
|
||||||
|
|
||||||
|
Patches both LlamaStackClient and OpenAI client to ensure test ID is transported
|
||||||
|
from client to server via HTTP headers in server mode.
|
||||||
|
|
||||||
|
This is needed for server mode where the test ID must be transported from
|
||||||
|
client to server via HTTP headers. In library_client mode, this patch is a no-op
|
||||||
|
since everything runs in the same process.
|
||||||
|
|
||||||
|
We use the _prepare_request hook that Stainless clients provide for mutating
|
||||||
|
requests after construction but before sending.
|
||||||
|
"""
|
||||||
|
from llama_stack_client import LlamaStackClient
|
||||||
|
|
||||||
|
if "llama_stack_client_prepare_request" in _original_methods:
|
||||||
|
# Already patched
|
||||||
|
return
|
||||||
|
|
||||||
|
# Save original methods
|
||||||
|
_original_methods["llama_stack_client_prepare_request"] = LlamaStackClient._prepare_request
|
||||||
|
|
||||||
|
# Also patch OpenAI client if available (used in compat tests)
|
||||||
|
try:
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
_original_methods["openai_prepare_request"] = OpenAI._prepare_request
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def patched_prepare_request(self, request):
|
||||||
|
# Call original first (it's a sync method that returns None)
|
||||||
|
# Determine which original to call based on client type
|
||||||
|
if "llama_stack_client" in self.__class__.__module__:
|
||||||
|
_original_methods["llama_stack_client_prepare_request"](self, request)
|
||||||
|
elif "openai_prepare_request" in _original_methods:
|
||||||
|
_original_methods["openai_prepare_request"](self, request)
|
||||||
|
|
||||||
|
# Only inject test ID in server mode
|
||||||
|
stack_config_type = os.environ.get("LLAMA_STACK_TEST_STACK_CONFIG_TYPE", "library_client")
|
||||||
|
test_id = _test_context.get()
|
||||||
|
|
||||||
|
if stack_config_type == "server" and test_id:
|
||||||
|
# Get existing provider data header or create new dict
|
||||||
|
provider_data_header = request.headers.get("X-LlamaStack-Provider-Data")
|
||||||
|
|
||||||
|
if provider_data_header:
|
||||||
|
provider_data = json.loads(provider_data_header)
|
||||||
|
else:
|
||||||
|
provider_data = {}
|
||||||
|
|
||||||
|
# Inject test ID
|
||||||
|
provider_data["__test_id"] = test_id
|
||||||
|
request.headers["X-LlamaStack-Provider-Data"] = json.dumps(provider_data)
|
||||||
|
|
||||||
|
# Sync version returns None
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Apply patches
|
||||||
|
LlamaStackClient._prepare_request = patched_prepare_request
|
||||||
|
if "openai_prepare_request" in _original_methods:
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
OpenAI._prepare_request = patched_prepare_request
|
||||||
|
|
||||||
|
|
||||||
|
def unpatch_httpx_for_test_id():
|
||||||
|
"""Remove client _prepare_request patches for test ID injection."""
|
||||||
|
if "llama_stack_client_prepare_request" not in _original_methods:
|
||||||
|
return
|
||||||
|
|
||||||
|
from llama_stack_client import LlamaStackClient
|
||||||
|
|
||||||
|
LlamaStackClient._prepare_request = _original_methods["llama_stack_client_prepare_request"]
|
||||||
|
del _original_methods["llama_stack_client_prepare_request"]
|
||||||
|
|
||||||
|
# Also restore OpenAI client if it was patched
|
||||||
|
if "openai_prepare_request" in _original_methods:
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
OpenAI._prepare_request = _original_methods["openai_prepare_request"]
|
||||||
|
del _original_methods["openai_prepare_request"]
|
||||||
|
|
||||||
|
|
||||||
|
def get_api_recording_mode() -> APIRecordingMode:
|
||||||
|
return APIRecordingMode(os.environ.get("LLAMA_STACK_TEST_INFERENCE_MODE", "replay").lower())
|
||||||
|
|
||||||
|
|
||||||
|
def setup_api_recording():
|
||||||
|
"""
|
||||||
|
Returns a context manager that can be used to record or replay API requests (inference and tools).
|
||||||
|
This is to be used in tests to increase their reliability and reduce reliance on expensive, external services.
|
||||||
|
|
||||||
|
Currently supports:
|
||||||
|
- Inference: OpenAI, Ollama, and LiteLLM clients
|
||||||
|
- Tools: Search providers (Tavily for now)
|
||||||
|
|
||||||
|
Two environment variables are supported:
|
||||||
|
- LLAMA_STACK_TEST_INFERENCE_MODE: The mode to run in. Options:
|
||||||
|
- 'live': Make real API calls, no recording
|
||||||
|
- 'record': Record all API interactions (overwrites existing)
|
||||||
|
- 'replay': Use recorded responses only (default)
|
||||||
|
- 'record-if-missing': Replay when possible, record when recording doesn't exist
|
||||||
|
- LLAMA_STACK_TEST_RECORDING_DIR: The directory to store the recordings in. Default is 'tests/integration/recordings'.
|
||||||
|
|
||||||
|
The recordings are stored as JSON files.
|
||||||
|
"""
|
||||||
|
mode = get_api_recording_mode()
|
||||||
|
if mode == APIRecordingMode.LIVE:
|
||||||
|
return None
|
||||||
|
|
||||||
|
storage_dir = os.environ.get("LLAMA_STACK_TEST_RECORDING_DIR", DEFAULT_STORAGE_DIR)
|
||||||
|
return api_recording(mode=mode, storage_dir=storage_dir)
|
||||||
|
|
||||||
|
|
||||||
|
def _serialize_response(response: Any) -> Any:
|
||||||
if hasattr(response, "model_dump"):
|
if hasattr(response, "model_dump"):
|
||||||
data = response.model_dump(mode="json")
|
data = response.model_dump(mode="json")
|
||||||
# Normalize fields to reduce noise
|
|
||||||
data = _normalize_response_data(data, request_hash)
|
|
||||||
return {
|
return {
|
||||||
"__type__": f"{response.__class__.__module__}.{response.__class__.__qualname__}",
|
"__type__": f"{response.__class__.__module__}.{response.__class__.__qualname__}",
|
||||||
"__data__": data,
|
"__data__": data,
|
||||||
|
|
@ -259,22 +308,17 @@ def _deserialize_response(data: dict[str, Any]) -> Any:
|
||||||
|
|
||||||
return cls.model_validate(data["__data__"])
|
return cls.model_validate(data["__data__"])
|
||||||
except (ImportError, AttributeError, TypeError, ValueError) as e:
|
except (ImportError, AttributeError, TypeError, ValueError) as e:
|
||||||
logger.warning(f"Failed to deserialize object of type {data['__type__']} with model_validate: {e}")
|
logger.warning(f"Failed to deserialize object of type {data['__type__']}: {e}")
|
||||||
try:
|
return data["__data__"]
|
||||||
return cls.model_construct(**data["__data__"])
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to deserialize object of type {data['__type__']} with model_construct: {e}")
|
|
||||||
return data["__data__"]
|
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
class ResponseStorage:
|
class ResponseStorage:
|
||||||
"""Handles SQLite index + JSON file storage/retrieval for inference recordings."""
|
"""Handles storage/retrieval for API recordings (inference and tools)."""
|
||||||
|
|
||||||
def __init__(self, base_dir: Path):
|
def __init__(self, base_dir: Path):
|
||||||
self.base_dir = base_dir
|
self.base_dir = base_dir
|
||||||
# Don't create responses_dir here - determine it per-test at runtime
|
|
||||||
|
|
||||||
def _get_test_dir(self) -> Path:
|
def _get_test_dir(self) -> Path:
|
||||||
"""Get the recordings directory in the test file's parent directory.
|
"""Get the recordings directory in the test file's parent directory.
|
||||||
|
|
@ -283,6 +327,7 @@ class ResponseStorage:
|
||||||
returns "tests/integration/inference/recordings/".
|
returns "tests/integration/inference/recordings/".
|
||||||
"""
|
"""
|
||||||
test_id = _test_context.get()
|
test_id = _test_context.get()
|
||||||
|
|
||||||
if test_id:
|
if test_id:
|
||||||
# Extract the directory path from the test nodeid
|
# Extract the directory path from the test nodeid
|
||||||
# e.g., "tests/integration/inference/test_basic.py::test_foo[params]"
|
# e.g., "tests/integration/inference/test_basic.py::test_foo[params]"
|
||||||
|
|
@ -297,17 +342,21 @@ class ResponseStorage:
|
||||||
# Fallback for non-test contexts
|
# Fallback for non-test contexts
|
||||||
return self.base_dir / "recordings"
|
return self.base_dir / "recordings"
|
||||||
|
|
||||||
def _ensure_directories(self):
|
def _ensure_directories(self) -> Path:
|
||||||
"""Ensure test-specific directories exist."""
|
|
||||||
test_dir = self._get_test_dir()
|
test_dir = self._get_test_dir()
|
||||||
test_dir.mkdir(parents=True, exist_ok=True)
|
test_dir.mkdir(parents=True, exist_ok=True)
|
||||||
return test_dir
|
return test_dir
|
||||||
|
|
||||||
def store_recording(self, request_hash: str, request: dict[str, Any], response: dict[str, Any]):
|
def store_recording(self, request_hash: str, request: dict[str, Any], response: dict[str, Any]):
|
||||||
"""Store a request/response pair."""
|
"""Store a request/response pair both in memory cache and on disk."""
|
||||||
|
global _memory_cache
|
||||||
|
|
||||||
|
# Store in memory cache first
|
||||||
|
_memory_cache[request_hash] = {"request": request, "response": response}
|
||||||
|
|
||||||
responses_dir = self._ensure_directories()
|
responses_dir = self._ensure_directories()
|
||||||
|
|
||||||
# Use FULL hash (not truncated)
|
# Generate unique response filename using full hash
|
||||||
response_file = f"{request_hash}.json"
|
response_file = f"{request_hash}.json"
|
||||||
|
|
||||||
# Serialize response body if needed
|
# Serialize response body if needed
|
||||||
|
|
@ -315,45 +364,32 @@ class ResponseStorage:
|
||||||
if "body" in serialized_response:
|
if "body" in serialized_response:
|
||||||
if isinstance(serialized_response["body"], list):
|
if isinstance(serialized_response["body"], list):
|
||||||
# Handle streaming responses (list of chunks)
|
# Handle streaming responses (list of chunks)
|
||||||
serialized_response["body"] = [
|
serialized_response["body"] = [_serialize_response(chunk) for chunk in serialized_response["body"]]
|
||||||
_serialize_response(chunk, request_hash) for chunk in serialized_response["body"]
|
|
||||||
]
|
|
||||||
else:
|
else:
|
||||||
# Handle single response
|
# Handle single response
|
||||||
serialized_response["body"] = _serialize_response(serialized_response["body"], request_hash)
|
serialized_response["body"] = _serialize_response(serialized_response["body"])
|
||||||
|
|
||||||
# For model-list endpoints, include digest in filename to distinguish different model sets
|
# If this is a model-list endpoint recording, include models digest in filename to distinguish variants
|
||||||
endpoint = request.get("endpoint")
|
endpoint = request.get("endpoint")
|
||||||
|
test_id = _test_context.get()
|
||||||
if endpoint in ("/api/tags", "/v1/models"):
|
if endpoint in ("/api/tags", "/v1/models"):
|
||||||
|
test_id = None
|
||||||
digest = _model_identifiers_digest(endpoint, response)
|
digest = _model_identifiers_digest(endpoint, response)
|
||||||
response_file = f"models-{request_hash}-{digest}.json"
|
response_file = f"models-{request_hash}-{digest}.json"
|
||||||
|
|
||||||
response_path = responses_dir / response_file
|
response_path = responses_dir / response_file
|
||||||
|
|
||||||
# Save response to JSON file with metadata
|
# Save response to JSON file
|
||||||
with open(response_path, "w") as f:
|
with open(response_path, "w") as f:
|
||||||
json.dump(
|
json.dump({"test_id": test_id, "request": request, "response": serialized_response}, f, indent=2)
|
||||||
{
|
|
||||||
"test_id": _test_context.get(),
|
|
||||||
"request": request,
|
|
||||||
"response": serialized_response,
|
|
||||||
},
|
|
||||||
f,
|
|
||||||
indent=2,
|
|
||||||
)
|
|
||||||
f.write("\n")
|
f.write("\n")
|
||||||
f.flush()
|
f.flush()
|
||||||
|
|
||||||
def find_recording(self, request_hash: str) -> dict[str, Any] | None:
|
def find_recording(self, request_hash: str) -> dict[str, Any] | None:
|
||||||
"""Find a recorded response by request hash.
|
"""Find a recorded response by request hash."""
|
||||||
|
|
||||||
Uses fallback: first checks test-specific dir, then falls back to base recordings dir.
|
|
||||||
This handles cases where recordings happen during session setup (no test context) but
|
|
||||||
are requested during tests (with test context).
|
|
||||||
"""
|
|
||||||
response_file = f"{request_hash}.json"
|
response_file = f"{request_hash}.json"
|
||||||
|
|
||||||
# Try test-specific directory first
|
# Check test-specific directory first
|
||||||
test_dir = self._get_test_dir()
|
test_dir = self._get_test_dir()
|
||||||
response_path = test_dir / response_file
|
response_path = test_dir / response_file
|
||||||
|
|
||||||
|
|
@ -464,15 +500,97 @@ def _combine_model_list_responses(endpoint: str, records: list[dict[str, Any]])
|
||||||
return {"request": canonical_req, "response": {"body": body, "is_streaming": False}}
|
return {"request": canonical_req, "response": {"body": body, "is_streaming": False}}
|
||||||
|
|
||||||
|
|
||||||
|
async def _patched_tool_invoke_method(
|
||||||
|
original_method, provider_name: str, self, tool_name: str, kwargs: dict[str, Any]
|
||||||
|
):
|
||||||
|
"""Patched version of tool runtime invoke_tool method for recording/replay."""
|
||||||
|
global _current_mode, _current_storage
|
||||||
|
|
||||||
|
if _current_mode == APIRecordingMode.LIVE or _current_storage is None:
|
||||||
|
# Normal operation
|
||||||
|
return await original_method(self, tool_name, kwargs)
|
||||||
|
|
||||||
|
# In server mode, sync test ID from provider_data to _test_context for storage operations
|
||||||
|
test_context_token = _sync_test_context_from_provider_data()
|
||||||
|
|
||||||
|
try:
|
||||||
|
request_hash = normalize_tool_request(provider_name, tool_name, kwargs)
|
||||||
|
|
||||||
|
if _current_mode in (APIRecordingMode.REPLAY, APIRecordingMode.RECORD_IF_MISSING):
|
||||||
|
recording = _current_storage.find_recording(request_hash)
|
||||||
|
if recording:
|
||||||
|
return recording["response"]["body"]
|
||||||
|
elif _current_mode == APIRecordingMode.REPLAY:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"No recorded tool result found for {provider_name}.{tool_name}\n"
|
||||||
|
f"Request: {kwargs}\n"
|
||||||
|
f"To record this response, run with LLAMA_STACK_TEST_INFERENCE_MODE=record"
|
||||||
|
)
|
||||||
|
# If RECORD_IF_MISSING and no recording found, fall through to record
|
||||||
|
|
||||||
|
if _current_mode in (APIRecordingMode.RECORD, APIRecordingMode.RECORD_IF_MISSING):
|
||||||
|
# Check in-memory cache first (collision detection)
|
||||||
|
global _memory_cache
|
||||||
|
if request_hash in _memory_cache:
|
||||||
|
# Return the cached response instead of making a new tool call
|
||||||
|
return _memory_cache[request_hash]["response"]["body"]
|
||||||
|
|
||||||
|
# No cached response, make the tool call and record it
|
||||||
|
result = await original_method(self, tool_name, kwargs)
|
||||||
|
|
||||||
|
request_data = {
|
||||||
|
"provider": provider_name,
|
||||||
|
"tool_name": tool_name,
|
||||||
|
"kwargs": kwargs,
|
||||||
|
}
|
||||||
|
response_data = {"body": result, "is_streaming": False}
|
||||||
|
|
||||||
|
# Store the recording (both in memory and on disk)
|
||||||
|
_current_storage.store_recording(request_hash, request_data, response_data)
|
||||||
|
return result
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise AssertionError(f"Invalid mode: {_current_mode}")
|
||||||
|
finally:
|
||||||
|
# Reset test context if we set it in server mode
|
||||||
|
if test_context_token is not None:
|
||||||
|
_test_context.reset(test_context_token)
|
||||||
|
|
||||||
|
|
||||||
|
def _sync_test_context_from_provider_data():
|
||||||
|
"""In server mode, sync test ID from provider_data to _test_context.
|
||||||
|
|
||||||
|
This ensures that storage operations (which read from _test_context) work correctly
|
||||||
|
in server mode where the test ID arrives via HTTP header → provider_data.
|
||||||
|
|
||||||
|
Returns a token to reset _test_context, or None if no sync was needed.
|
||||||
|
"""
|
||||||
|
stack_config_type = os.environ.get("LLAMA_STACK_TEST_STACK_CONFIG_TYPE", "library_client")
|
||||||
|
|
||||||
|
if stack_config_type != "server":
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
from llama_stack.core.request_headers import PROVIDER_DATA_VAR
|
||||||
|
|
||||||
|
provider_data = PROVIDER_DATA_VAR.get()
|
||||||
|
|
||||||
|
if provider_data and "__test_id" in provider_data:
|
||||||
|
test_id = provider_data["__test_id"]
|
||||||
|
return _test_context.set(test_id)
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
async def _patched_inference_method(original_method, self, client_type, endpoint, *args, **kwargs):
|
async def _patched_inference_method(original_method, self, client_type, endpoint, *args, **kwargs):
|
||||||
global _current_mode, _current_storage
|
global _current_mode, _current_storage
|
||||||
|
|
||||||
mode = _current_mode
|
if _current_mode == APIRecordingMode.LIVE or _current_storage is None:
|
||||||
storage = _current_storage
|
# Normal operation
|
||||||
|
if client_type == "litellm":
|
||||||
if mode == InferenceMode.LIVE or storage is None:
|
return await original_method(*args, **kwargs)
|
||||||
if endpoint == "/v1/models":
|
|
||||||
return original_method(self, *args, **kwargs)
|
|
||||||
else:
|
else:
|
||||||
return await original_method(self, *args, **kwargs)
|
return await original_method(self, *args, **kwargs)
|
||||||
|
|
||||||
|
|
@ -491,34 +609,30 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
|
||||||
base_url = getattr(self, "host", "http://localhost:11434")
|
base_url = getattr(self, "host", "http://localhost:11434")
|
||||||
if not base_url.startswith("http"):
|
if not base_url.startswith("http"):
|
||||||
base_url = f"http://{base_url}"
|
base_url = f"http://{base_url}"
|
||||||
|
elif client_type == "litellm":
|
||||||
|
# For LiteLLM, extract base URL from kwargs if available
|
||||||
|
base_url = kwargs.get("api_base", "https://api.openai.com")
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown client type: {client_type}")
|
raise ValueError(f"Unknown client type: {client_type}")
|
||||||
|
|
||||||
url = base_url.rstrip("/") + endpoint
|
url = base_url.rstrip("/") + endpoint
|
||||||
# Special handling for Databricks URLs to avoid leaking workspace info
|
|
||||||
# e.g. https://adb-1234567890123456.7.cloud.databricks.com -> https://...cloud.databricks.com
|
|
||||||
if "cloud.databricks.com" in url:
|
|
||||||
url = "__databricks__" + url.split("cloud.databricks.com")[-1]
|
|
||||||
method = "POST"
|
method = "POST"
|
||||||
headers = {}
|
headers = {}
|
||||||
body = kwargs
|
body = kwargs
|
||||||
|
|
||||||
request_hash = normalize_request(method, url, headers, body)
|
request_hash = normalize_request(method, url, headers, body)
|
||||||
|
|
||||||
# Try to find existing recording for REPLAY or RECORD_IF_MISSING modes
|
if _current_mode in (APIRecordingMode.REPLAY, APIRecordingMode.RECORD_IF_MISSING):
|
||||||
recording = None
|
# Special handling for model-list endpoints: return union of all responses
|
||||||
if mode == InferenceMode.REPLAY or mode == InferenceMode.RECORD_IF_MISSING:
|
|
||||||
# Special handling for model-list endpoints: merge all recordings with this hash
|
|
||||||
if endpoint in ("/api/tags", "/v1/models"):
|
if endpoint in ("/api/tags", "/v1/models"):
|
||||||
records = storage._model_list_responses(request_hash)
|
records = _current_storage._model_list_responses(request_hash)
|
||||||
recording = _combine_model_list_responses(endpoint, records)
|
recording = _combine_model_list_responses(endpoint, records)
|
||||||
else:
|
else:
|
||||||
recording = storage.find_recording(request_hash)
|
recording = _current_storage.find_recording(request_hash)
|
||||||
|
|
||||||
if recording:
|
if recording:
|
||||||
response_body = recording["response"]["body"]
|
response_body = recording["response"]["body"]
|
||||||
|
|
||||||
if recording["response"].get("is_streaming", False):
|
if recording["response"].get("is_streaming", False) or recording["response"].get("is_paginated", False):
|
||||||
|
|
||||||
async def replay_stream():
|
async def replay_stream():
|
||||||
for chunk in response_body:
|
for chunk in response_body:
|
||||||
|
|
@ -527,25 +641,41 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
|
||||||
return replay_stream()
|
return replay_stream()
|
||||||
else:
|
else:
|
||||||
return response_body
|
return response_body
|
||||||
elif mode == InferenceMode.REPLAY:
|
elif _current_mode == APIRecordingMode.REPLAY:
|
||||||
# REPLAY mode requires recording to exist
|
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"No recorded response found for request hash: {request_hash}\n"
|
f"No recorded response found for request hash: {request_hash}\n"
|
||||||
f"Request: {method} {url} {body}\n"
|
f"Request: {method} {url} {body}\n"
|
||||||
f"Model: {body.get('model', 'unknown')}\n"
|
f"Model: {body.get('model', 'unknown')}\n"
|
||||||
f"To record this response, run with LLAMA_STACK_TEST_INFERENCE_MODE=record"
|
f"To record this response, run with LLAMA_STACK_TEST_INFERENCE_MODE=record"
|
||||||
)
|
)
|
||||||
|
# If RECORD_IF_MISSING and no recording found, fall through to record
|
||||||
|
|
||||||
if mode == InferenceMode.RECORD or (mode == InferenceMode.RECORD_IF_MISSING and not recording):
|
if _current_mode in (APIRecordingMode.RECORD, APIRecordingMode.RECORD_IF_MISSING):
|
||||||
if endpoint == "/v1/models":
|
# Check in-memory cache first (collision detection)
|
||||||
response = original_method(self, *args, **kwargs)
|
global _memory_cache
|
||||||
|
if request_hash in _memory_cache:
|
||||||
|
# Return the cached response instead of making a new API call
|
||||||
|
cached_recording = _memory_cache[request_hash]
|
||||||
|
response_body = cached_recording["response"]["body"]
|
||||||
|
|
||||||
|
if cached_recording["response"].get("is_streaming", False) or cached_recording["response"].get(
|
||||||
|
"is_paginated", False
|
||||||
|
):
|
||||||
|
|
||||||
|
async def replay_cached_stream():
|
||||||
|
for chunk in response_body:
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
return replay_cached_stream()
|
||||||
|
else:
|
||||||
|
return response_body
|
||||||
|
|
||||||
|
# No cached response, make the API call and record it
|
||||||
|
if client_type == "litellm":
|
||||||
|
response = await original_method(*args, **kwargs)
|
||||||
else:
|
else:
|
||||||
response = await original_method(self, *args, **kwargs)
|
response = await original_method(self, *args, **kwargs)
|
||||||
|
|
||||||
# we want to store the result of the iterator, not the iterator itself
|
|
||||||
if endpoint == "/v1/models":
|
|
||||||
response = [m async for m in response]
|
|
||||||
|
|
||||||
request_data = {
|
request_data = {
|
||||||
"method": method,
|
"method": method,
|
||||||
"url": url,
|
"url": url,
|
||||||
|
|
@ -558,16 +688,20 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
|
||||||
# Determine if this is a streaming request based on request parameters
|
# Determine if this is a streaming request based on request parameters
|
||||||
is_streaming = body.get("stream", False)
|
is_streaming = body.get("stream", False)
|
||||||
|
|
||||||
if is_streaming:
|
# Special case: /v1/models is a paginated endpoint that returns an async iterator
|
||||||
# For streaming responses, we need to collect all chunks immediately before yielding
|
is_paginated = endpoint == "/v1/models"
|
||||||
|
|
||||||
|
if is_streaming or is_paginated:
|
||||||
|
# For streaming/paginated responses, we need to collect all chunks immediately before yielding
|
||||||
# This ensures the recording is saved even if the generator isn't fully consumed
|
# This ensures the recording is saved even if the generator isn't fully consumed
|
||||||
chunks = []
|
chunks = []
|
||||||
async for chunk in response:
|
async for chunk in response:
|
||||||
chunks.append(chunk)
|
chunks.append(chunk)
|
||||||
|
|
||||||
# Store the recording immediately
|
# Store the recording immediately (both in memory and on disk)
|
||||||
response_data = {"body": chunks, "is_streaming": True}
|
# For paginated endpoints, mark as paginated rather than streaming
|
||||||
storage.store_recording(request_hash, request_data, response_data)
|
response_data = {"body": chunks, "is_streaming": is_streaming, "is_paginated": is_paginated}
|
||||||
|
_current_storage.store_recording(request_hash, request_data, response_data)
|
||||||
|
|
||||||
# Return a generator that replays the stored chunks
|
# Return a generator that replays the stored chunks
|
||||||
async def replay_recorded_stream():
|
async def replay_recorded_stream():
|
||||||
|
|
@ -577,27 +711,34 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
|
||||||
return replay_recorded_stream()
|
return replay_recorded_stream()
|
||||||
else:
|
else:
|
||||||
response_data = {"body": response, "is_streaming": False}
|
response_data = {"body": response, "is_streaming": False}
|
||||||
storage.store_recording(request_hash, request_data, response_data)
|
# Store the response (both in memory and on disk)
|
||||||
|
_current_storage.store_recording(request_hash, request_data, response_data)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise AssertionError(f"Invalid mode: {mode}")
|
raise AssertionError(f"Invalid mode: {_current_mode}")
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
if test_context_token:
|
# Reset test context if we set it in server mode
|
||||||
|
if test_context_token is not None:
|
||||||
_test_context.reset(test_context_token)
|
_test_context.reset(test_context_token)
|
||||||
|
|
||||||
|
|
||||||
def patch_inference_clients():
|
def patch_api_clients():
|
||||||
"""Install monkey patches for OpenAI client methods and Ollama AsyncClient methods."""
|
"""Install monkey patches for inference clients and tool runtime methods."""
|
||||||
global _original_methods
|
global _original_methods
|
||||||
|
|
||||||
|
import litellm
|
||||||
from ollama import AsyncClient as OllamaAsyncClient
|
from ollama import AsyncClient as OllamaAsyncClient
|
||||||
from openai.resources.chat.completions import AsyncCompletions as AsyncChatCompletions
|
from openai.resources.chat.completions import AsyncCompletions as AsyncChatCompletions
|
||||||
from openai.resources.completions import AsyncCompletions
|
from openai.resources.completions import AsyncCompletions
|
||||||
from openai.resources.embeddings import AsyncEmbeddings
|
from openai.resources.embeddings import AsyncEmbeddings
|
||||||
from openai.resources.models import AsyncModels
|
from openai.resources.models import AsyncModels
|
||||||
|
|
||||||
# Store original methods for both OpenAI and Ollama clients
|
from llama_stack.providers.remote.tool_runtime.tavily_search.tavily_search import TavilySearchToolRuntimeImpl
|
||||||
|
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||||
|
|
||||||
|
# Store original methods for OpenAI, Ollama, LiteLLM clients, and tool runtimes
|
||||||
_original_methods = {
|
_original_methods = {
|
||||||
"chat_completions_create": AsyncChatCompletions.create,
|
"chat_completions_create": AsyncChatCompletions.create,
|
||||||
"completions_create": AsyncCompletions.create,
|
"completions_create": AsyncCompletions.create,
|
||||||
|
|
@ -609,6 +750,10 @@ def patch_inference_clients():
|
||||||
"ollama_ps": OllamaAsyncClient.ps,
|
"ollama_ps": OllamaAsyncClient.ps,
|
||||||
"ollama_pull": OllamaAsyncClient.pull,
|
"ollama_pull": OllamaAsyncClient.pull,
|
||||||
"ollama_list": OllamaAsyncClient.list,
|
"ollama_list": OllamaAsyncClient.list,
|
||||||
|
"litellm_acompletion": litellm.acompletion,
|
||||||
|
"litellm_atext_completion": litellm.atext_completion,
|
||||||
|
"litellm_openai_mixin_get_api_key": LiteLLMOpenAIMixin.get_api_key,
|
||||||
|
"tavily_invoke_tool": TavilySearchToolRuntimeImpl.invoke_tool,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Create patched methods for OpenAI client
|
# Create patched methods for OpenAI client
|
||||||
|
|
@ -629,10 +774,18 @@ def patch_inference_clients():
|
||||||
|
|
||||||
def patched_models_list(self, *args, **kwargs):
|
def patched_models_list(self, *args, **kwargs):
|
||||||
async def _iter():
|
async def _iter():
|
||||||
for item in await _patched_inference_method(
|
result = await _patched_inference_method(
|
||||||
_original_methods["models_list"], self, "openai", "/v1/models", *args, **kwargs
|
_original_methods["models_list"], self, "openai", "/v1/models", *args, **kwargs
|
||||||
):
|
)
|
||||||
yield item
|
# The result is either an async generator (streaming/paginated) or a list
|
||||||
|
# If it's an async generator, iterate through it
|
||||||
|
if hasattr(result, "__aiter__"):
|
||||||
|
async for item in result:
|
||||||
|
yield item
|
||||||
|
else:
|
||||||
|
# It's a list, yield each item
|
||||||
|
for item in result:
|
||||||
|
yield item
|
||||||
|
|
||||||
return _iter()
|
return _iter()
|
||||||
|
|
||||||
|
|
@ -681,21 +834,61 @@ def patch_inference_clients():
|
||||||
OllamaAsyncClient.pull = patched_ollama_pull
|
OllamaAsyncClient.pull = patched_ollama_pull
|
||||||
OllamaAsyncClient.list = patched_ollama_list
|
OllamaAsyncClient.list = patched_ollama_list
|
||||||
|
|
||||||
|
# Create patched methods for LiteLLM
|
||||||
|
async def patched_litellm_acompletion(*args, **kwargs):
|
||||||
|
return await _patched_inference_method(
|
||||||
|
_original_methods["litellm_acompletion"], None, "litellm", "/chat/completions", *args, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
def unpatch_inference_clients():
|
async def patched_litellm_atext_completion(*args, **kwargs):
|
||||||
"""Remove monkey patches and restore original OpenAI and Ollama client methods."""
|
return await _patched_inference_method(
|
||||||
global _original_methods
|
_original_methods["litellm_atext_completion"], None, "litellm", "/completions", *args, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply LiteLLM patches
|
||||||
|
litellm.acompletion = patched_litellm_acompletion
|
||||||
|
litellm.atext_completion = patched_litellm_atext_completion
|
||||||
|
|
||||||
|
# Create patched method for LiteLLMOpenAIMixin.get_api_key
|
||||||
|
def patched_litellm_get_api_key(self):
|
||||||
|
global _current_mode
|
||||||
|
if _current_mode != APIRecordingMode.REPLAY:
|
||||||
|
return _original_methods["litellm_openai_mixin_get_api_key"](self)
|
||||||
|
else:
|
||||||
|
# For record/replay modes, return a fake API key to avoid exposing real credentials
|
||||||
|
return "fake-api-key-for-testing"
|
||||||
|
|
||||||
|
# Apply LiteLLMOpenAIMixin patch
|
||||||
|
LiteLLMOpenAIMixin.get_api_key = patched_litellm_get_api_key
|
||||||
|
|
||||||
|
# Create patched methods for tool runtimes
|
||||||
|
async def patched_tavily_invoke_tool(self, tool_name: str, kwargs: dict[str, Any]):
|
||||||
|
return await _patched_tool_invoke_method(
|
||||||
|
_original_methods["tavily_invoke_tool"], "tavily", self, tool_name, kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply tool runtime patches
|
||||||
|
TavilySearchToolRuntimeImpl.invoke_tool = patched_tavily_invoke_tool
|
||||||
|
|
||||||
|
|
||||||
|
def unpatch_api_clients():
|
||||||
|
"""Remove monkey patches and restore original client methods and tool runtimes."""
|
||||||
|
global _original_methods, _memory_cache
|
||||||
|
|
||||||
if not _original_methods:
|
if not _original_methods:
|
||||||
return
|
return
|
||||||
|
|
||||||
# Import here to avoid circular imports
|
# Import here to avoid circular imports
|
||||||
|
import litellm
|
||||||
from ollama import AsyncClient as OllamaAsyncClient
|
from ollama import AsyncClient as OllamaAsyncClient
|
||||||
from openai.resources.chat.completions import AsyncCompletions as AsyncChatCompletions
|
from openai.resources.chat.completions import AsyncCompletions as AsyncChatCompletions
|
||||||
from openai.resources.completions import AsyncCompletions
|
from openai.resources.completions import AsyncCompletions
|
||||||
from openai.resources.embeddings import AsyncEmbeddings
|
from openai.resources.embeddings import AsyncEmbeddings
|
||||||
from openai.resources.models import AsyncModels
|
from openai.resources.models import AsyncModels
|
||||||
|
|
||||||
|
from llama_stack.providers.remote.tool_runtime.tavily_search.tavily_search import TavilySearchToolRuntimeImpl
|
||||||
|
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||||
|
|
||||||
# Restore OpenAI client methods
|
# Restore OpenAI client methods
|
||||||
AsyncChatCompletions.create = _original_methods["chat_completions_create"]
|
AsyncChatCompletions.create = _original_methods["chat_completions_create"]
|
||||||
AsyncCompletions.create = _original_methods["completions_create"]
|
AsyncCompletions.create = _original_methods["completions_create"]
|
||||||
|
|
@ -710,12 +903,23 @@ def unpatch_inference_clients():
|
||||||
OllamaAsyncClient.pull = _original_methods["ollama_pull"]
|
OllamaAsyncClient.pull = _original_methods["ollama_pull"]
|
||||||
OllamaAsyncClient.list = _original_methods["ollama_list"]
|
OllamaAsyncClient.list = _original_methods["ollama_list"]
|
||||||
|
|
||||||
|
# Restore LiteLLM methods
|
||||||
|
litellm.acompletion = _original_methods["litellm_acompletion"]
|
||||||
|
litellm.atext_completion = _original_methods["litellm_atext_completion"]
|
||||||
|
LiteLLMOpenAIMixin.get_api_key = _original_methods["litellm_openai_mixin_get_api_key"]
|
||||||
|
|
||||||
|
# Restore tool runtime methods
|
||||||
|
TavilySearchToolRuntimeImpl.invoke_tool = _original_methods["tavily_invoke_tool"]
|
||||||
|
|
||||||
_original_methods.clear()
|
_original_methods.clear()
|
||||||
|
|
||||||
|
# Clear memory cache to prevent memory leaks
|
||||||
|
_memory_cache.clear()
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def inference_recording(mode: str, storage_dir: str | Path | None = None) -> Generator[None, None, None]:
|
def api_recording(mode: str, storage_dir: str | Path | None = None) -> Generator[None, None, None]:
|
||||||
"""Context manager for inference recording/replaying."""
|
"""Context manager for API recording/replaying (inference and tools)."""
|
||||||
global _current_mode, _current_storage
|
global _current_mode, _current_storage
|
||||||
|
|
||||||
# Store previous state
|
# Store previous state
|
||||||
|
|
@ -729,14 +933,14 @@ def inference_recording(mode: str, storage_dir: str | Path | None = None) -> Gen
|
||||||
if storage_dir is None:
|
if storage_dir is None:
|
||||||
raise ValueError("storage_dir is required for record, replay, and record-if-missing modes")
|
raise ValueError("storage_dir is required for record, replay, and record-if-missing modes")
|
||||||
_current_storage = ResponseStorage(Path(storage_dir))
|
_current_storage = ResponseStorage(Path(storage_dir))
|
||||||
patch_inference_clients()
|
patch_api_clients()
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
# Restore previous state
|
# Restore previous state
|
||||||
if mode in ["record", "replay", "record-if-missing"]:
|
if mode in ["record", "replay", "record-if-missing"]:
|
||||||
unpatch_inference_clients()
|
unpatch_api_clients()
|
||||||
|
|
||||||
_current_mode = prev_mode
|
_current_mode = prev_mode
|
||||||
_current_storage = prev_storage
|
_current_storage = prev_storage
|
||||||
|
|
@ -29,7 +29,7 @@ Options:
|
||||||
--stack-config STRING Stack configuration to use (required)
|
--stack-config STRING Stack configuration to use (required)
|
||||||
--suite STRING Test suite to run (default: 'base')
|
--suite STRING Test suite to run (default: 'base')
|
||||||
--setup STRING Test setup (models, env) to use (e.g., 'ollama', 'ollama-vision', 'gpt', 'vllm')
|
--setup STRING Test setup (models, env) to use (e.g., 'ollama', 'ollama-vision', 'gpt', 'vllm')
|
||||||
--inference-mode STRING Inference mode: record or replay (default: replay)
|
--inference-mode STRING Inference mode: replay, record-if-missing or record (default: replay)
|
||||||
--subdirs STRING Comma-separated list of test subdirectories to run (overrides suite)
|
--subdirs STRING Comma-separated list of test subdirectories to run (overrides suite)
|
||||||
--pattern STRING Regex pattern to pass to pytest -k
|
--pattern STRING Regex pattern to pass to pytest -k
|
||||||
--help Show this help message
|
--help Show this help message
|
||||||
|
|
@ -102,7 +102,7 @@ if [[ -z "$STACK_CONFIG" ]]; then
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [[ -z "$TEST_SETUP" && -n "$TEST_SUBDIRS" ]]; then
|
if [[ -z "$TEST_SETUP" && -n "$TEST_SUBDIRS" ]]; then
|
||||||
echo "Error: --test-setup is required when --test-subdirs is provided"
|
echo "Error: --test-setup is required when --test-subdirs is not provided"
|
||||||
usage
|
usage
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
|
||||||
|
|
@ -159,7 +159,6 @@ def make_mcp_server(required_auth_token: str | None = None, tools: dict[str, Cal
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
|
|
||||||
import httpx
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from mcp.server.fastmcp import FastMCP
|
from mcp.server.fastmcp import FastMCP
|
||||||
from mcp.server.sse import SseServerTransport
|
from mcp.server.sse import SseServerTransport
|
||||||
|
|
@ -171,6 +170,11 @@ def make_mcp_server(required_auth_token: str | None = None, tools: dict[str, Cal
|
||||||
|
|
||||||
server = FastMCP("FastMCP Test Server", log_level="WARNING")
|
server = FastMCP("FastMCP Test Server", log_level="WARNING")
|
||||||
|
|
||||||
|
# Silence verbose MCP server logs
|
||||||
|
import logging # allow-direct-logging
|
||||||
|
|
||||||
|
logging.getLogger("mcp.server.lowlevel.server").setLevel(logging.WARNING)
|
||||||
|
|
||||||
tools = tools or default_tools()
|
tools = tools or default_tools()
|
||||||
|
|
||||||
# Register all tools with the server
|
# Register all tools with the server
|
||||||
|
|
@ -234,29 +238,25 @@ def make_mcp_server(required_auth_token: str | None = None, tools: dict[str, Cal
|
||||||
logger.debug(f"Starting MCP server thread on port {port}")
|
logger.debug(f"Starting MCP server thread on port {port}")
|
||||||
server_thread.start()
|
server_thread.start()
|
||||||
|
|
||||||
# Polling until the server is ready
|
# Wait for the server thread to be running
|
||||||
timeout = 10
|
# Note: We can't use a simple HTTP GET health check on /sse because it's an SSE endpoint
|
||||||
|
# that expects a long-lived connection, not a simple request/response
|
||||||
|
timeout = 2
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
server_url = f"http://localhost:{port}/sse"
|
server_url = f"http://localhost:{port}/sse"
|
||||||
logger.debug(f"Waiting for MCP server to be ready at {server_url}")
|
logger.debug(f"Waiting for MCP server thread to start on port {port}")
|
||||||
|
|
||||||
while time.time() - start_time < timeout:
|
while time.time() - start_time < timeout:
|
||||||
try:
|
if server_thread.is_alive():
|
||||||
response = httpx.get(server_url)
|
# Give the server a moment to bind to the port
|
||||||
if response.status_code in [200, 401]:
|
time.sleep(0.1)
|
||||||
logger.debug(f"MCP server is ready on port {port} (status: {response.status_code})")
|
logger.debug(f"MCP server is ready on port {port}")
|
||||||
break
|
break
|
||||||
except httpx.RequestError as e:
|
time.sleep(0.05)
|
||||||
logger.debug(f"Server not ready yet, retrying... ({e})")
|
|
||||||
pass
|
|
||||||
time.sleep(0.1)
|
|
||||||
else:
|
else:
|
||||||
# If we exit the loop due to timeout
|
# If we exit the loop due to timeout
|
||||||
logger.error(f"MCP server failed to start within {timeout} seconds on port {port}")
|
logger.error(f"MCP server thread failed to start within {timeout} seconds on port {port}")
|
||||||
logger.error(f"Thread alive: {server_thread.is_alive()}")
|
|
||||||
if server_thread.is_alive():
|
|
||||||
logger.error("Server thread is still running but not responding to HTTP requests")
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
yield {"server_url": server_url}
|
yield {"server_url": server_url}
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@
|
||||||
import inspect
|
import inspect
|
||||||
import itertools
|
import itertools
|
||||||
import os
|
import os
|
||||||
|
import tempfile
|
||||||
import textwrap
|
import textwrap
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
@ -14,6 +15,7 @@ import pytest
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
from llama_stack.testing.api_recorder import patch_httpx_for_test_id
|
||||||
|
|
||||||
from .suites import SETUP_DEFINITIONS, SUITE_DEFINITIONS
|
from .suites import SETUP_DEFINITIONS, SUITE_DEFINITIONS
|
||||||
|
|
||||||
|
|
@ -35,6 +37,10 @@ def pytest_sessionstart(session):
|
||||||
if "LLAMA_STACK_TEST_INFERENCE_MODE" not in os.environ:
|
if "LLAMA_STACK_TEST_INFERENCE_MODE" not in os.environ:
|
||||||
os.environ["LLAMA_STACK_TEST_INFERENCE_MODE"] = "replay"
|
os.environ["LLAMA_STACK_TEST_INFERENCE_MODE"] = "replay"
|
||||||
|
|
||||||
|
if "SQLITE_STORE_DIR" not in os.environ:
|
||||||
|
os.environ["SQLITE_STORE_DIR"] = tempfile.mkdtemp()
|
||||||
|
|
||||||
|
# Set test stack config type for api_recorder test isolation
|
||||||
stack_config = session.config.getoption("--stack-config", default=None)
|
stack_config = session.config.getoption("--stack-config", default=None)
|
||||||
if stack_config and stack_config.startswith("server:"):
|
if stack_config and stack_config.startswith("server:"):
|
||||||
os.environ["LLAMA_STACK_TEST_STACK_CONFIG_TYPE"] = "server"
|
os.environ["LLAMA_STACK_TEST_STACK_CONFIG_TYPE"] = "server"
|
||||||
|
|
@ -43,8 +49,6 @@ def pytest_sessionstart(session):
|
||||||
os.environ["LLAMA_STACK_TEST_STACK_CONFIG_TYPE"] = "library_client"
|
os.environ["LLAMA_STACK_TEST_STACK_CONFIG_TYPE"] = "library_client"
|
||||||
logger.info(f"Test stack config type: library_client (stack_config={stack_config})")
|
logger.info(f"Test stack config type: library_client (stack_config={stack_config})")
|
||||||
|
|
||||||
from llama_stack.testing.inference_recorder import patch_httpx_for_test_id
|
|
||||||
|
|
||||||
patch_httpx_for_test_id()
|
patch_httpx_for_test_id()
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -55,7 +59,7 @@ def _track_test_context(request):
|
||||||
This fixture runs for every test and stores the test's nodeid in a contextvar
|
This fixture runs for every test and stores the test's nodeid in a contextvar
|
||||||
that the recording system can access to determine which subdirectory to use.
|
that the recording system can access to determine which subdirectory to use.
|
||||||
"""
|
"""
|
||||||
from llama_stack.testing.inference_recorder import _test_context
|
from llama_stack.testing.api_recorder import _test_context
|
||||||
|
|
||||||
# Store the test nodeid (e.g., "tests/integration/responses/test_basic.py::test_foo[params]")
|
# Store the test nodeid (e.g., "tests/integration/responses/test_basic.py::test_foo[params]")
|
||||||
token = _test_context.set(request.node.nodeid)
|
token = _test_context.set(request.node.nodeid)
|
||||||
|
|
@ -121,9 +125,13 @@ def pytest_configure(config):
|
||||||
# Apply defaults if not provided explicitly
|
# Apply defaults if not provided explicitly
|
||||||
for dest, value in setup_obj.defaults.items():
|
for dest, value in setup_obj.defaults.items():
|
||||||
current = getattr(config.option, dest, None)
|
current = getattr(config.option, dest, None)
|
||||||
if not current:
|
if current is None:
|
||||||
setattr(config.option, dest, value)
|
setattr(config.option, dest, value)
|
||||||
|
|
||||||
|
# Apply global fallback for embedding_dimension if still not set
|
||||||
|
if getattr(config.option, "embedding_dimension", None) is None:
|
||||||
|
config.option.embedding_dimension = 384
|
||||||
|
|
||||||
|
|
||||||
def pytest_addoption(parser):
|
def pytest_addoption(parser):
|
||||||
parser.addoption(
|
parser.addoption(
|
||||||
|
|
@ -161,8 +169,8 @@ def pytest_addoption(parser):
|
||||||
parser.addoption(
|
parser.addoption(
|
||||||
"--embedding-dimension",
|
"--embedding-dimension",
|
||||||
type=int,
|
type=int,
|
||||||
default=384,
|
default=None,
|
||||||
help="Output dimensionality of the embedding model to use for testing. Default: 384",
|
help="Output dimensionality of the embedding model to use for testing. Default: 384 (or setup-specific)",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.addoption(
|
parser.addoption(
|
||||||
|
|
@ -236,7 +244,9 @@ def pytest_generate_tests(metafunc):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
params.append(fixture_name)
|
params.append(fixture_name)
|
||||||
val = metafunc.config.getoption(option)
|
# Use getattr on config.option to see values set by pytest_configure fallbacks
|
||||||
|
dest = option.lstrip("-").replace("-", "_")
|
||||||
|
val = getattr(metafunc.config.option, dest, None)
|
||||||
|
|
||||||
values = [v.strip() for v in str(val).split(",")] if val else [None]
|
values = [v.strip() for v in str(val).split(",")] if val else [None]
|
||||||
param_values[fixture_name] = values
|
param_values[fixture_name] = values
|
||||||
|
|
|
||||||
|
|
@ -183,6 +183,12 @@ def llama_stack_client(request):
|
||||||
# would be forced to use llama_stack_client, which is not what we want.
|
# would be forced to use llama_stack_client, which is not what we want.
|
||||||
print("\ninstantiating llama_stack_client")
|
print("\ninstantiating llama_stack_client")
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
|
# Patch httpx to inject test ID for server-mode test isolation
|
||||||
|
from llama_stack.testing.api_recorder import patch_httpx_for_test_id
|
||||||
|
|
||||||
|
patch_httpx_for_test_id()
|
||||||
|
|
||||||
client = instantiate_llama_stack_client(request.session)
|
client = instantiate_llama_stack_client(request.session)
|
||||||
print(f"llama_stack_client instantiated in {time.time() - start_time:.3f}s")
|
print(f"llama_stack_client instantiated in {time.time() - start_time:.3f}s")
|
||||||
return client
|
return client
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
|
||||||
def new_vector_store(openai_client, name):
|
def new_vector_store(openai_client, name, embedding_model, embedding_dimension):
|
||||||
"""Create a new vector store, cleaning up any existing one with the same name."""
|
"""Create a new vector store, cleaning up any existing one with the same name."""
|
||||||
# Ensure we don't reuse an existing vector store
|
# Ensure we don't reuse an existing vector store
|
||||||
vector_stores = openai_client.vector_stores.list()
|
vector_stores = openai_client.vector_stores.list()
|
||||||
|
|
@ -16,7 +16,21 @@ def new_vector_store(openai_client, name):
|
||||||
openai_client.vector_stores.delete(vector_store_id=vector_store.id)
|
openai_client.vector_stores.delete(vector_store_id=vector_store.id)
|
||||||
|
|
||||||
# Create a new vector store
|
# Create a new vector store
|
||||||
vector_store = openai_client.vector_stores.create(name=name)
|
# OpenAI SDK client uses extra_body for non-standard parameters
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
if isinstance(openai_client, OpenAI):
|
||||||
|
# OpenAI SDK client - use extra_body
|
||||||
|
vector_store = openai_client.vector_stores.create(
|
||||||
|
name=name,
|
||||||
|
extra_body={"embedding_model": embedding_model, "embedding_dimension": embedding_dimension},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# LlamaStack client - direct parameter
|
||||||
|
vector_store = openai_client.vector_stores.create(
|
||||||
|
name=name, embedding_model=embedding_model, embedding_dimension=embedding_dimension
|
||||||
|
)
|
||||||
|
|
||||||
return vector_store
|
return vector_store
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,7 @@ import pytest
|
||||||
from llama_stack_client import APIStatusError
|
from llama_stack_client import APIStatusError
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.xfail(reason="Shields are not yet implemented inside responses")
|
||||||
def test_shields_via_extra_body(compat_client, text_model_id):
|
def test_shields_via_extra_body(compat_client, text_model_id):
|
||||||
"""Test that shields parameter is received by the server and raises NotImplementedError."""
|
"""Test that shields parameter is received by the server and raises NotImplementedError."""
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -47,12 +47,14 @@ def test_response_text_format(compat_client, text_model_id, text_format):
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def vector_store_with_filtered_files(compat_client, text_model_id, tmp_path_factory):
|
def vector_store_with_filtered_files(compat_client, embedding_model_id, embedding_dimension, tmp_path_factory):
|
||||||
"""Create a vector store with multiple files that have different attributes for filtering tests."""
|
# """Create a vector store with multiple files that have different attributes for filtering tests."""
|
||||||
if isinstance(compat_client, LlamaStackAsLibraryClient):
|
if isinstance(compat_client, LlamaStackAsLibraryClient):
|
||||||
pytest.skip("Responses API file search is not yet supported in library client.")
|
pytest.skip("upload_file() is not yet supported in library client somehow?")
|
||||||
|
|
||||||
vector_store = new_vector_store(compat_client, "test_vector_store_with_filters")
|
vector_store = new_vector_store(
|
||||||
|
compat_client, "test_vector_store_with_filters", embedding_model_id, embedding_dimension
|
||||||
|
)
|
||||||
tmp_path = tmp_path_factory.mktemp("filter_test_files")
|
tmp_path = tmp_path_factory.mktemp("filter_test_files")
|
||||||
|
|
||||||
# Create multiple files with different attributes
|
# Create multiple files with different attributes
|
||||||
|
|
|
||||||
|
|
@ -46,11 +46,13 @@ def test_response_non_streaming_web_search(compat_client, text_model_id, case):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("case", file_search_test_cases)
|
@pytest.mark.parametrize("case", file_search_test_cases)
|
||||||
def test_response_non_streaming_file_search(compat_client, text_model_id, tmp_path, case):
|
def test_response_non_streaming_file_search(
|
||||||
|
compat_client, text_model_id, embedding_model_id, embedding_dimension, tmp_path, case
|
||||||
|
):
|
||||||
if isinstance(compat_client, LlamaStackAsLibraryClient):
|
if isinstance(compat_client, LlamaStackAsLibraryClient):
|
||||||
pytest.skip("Responses API file search is not yet supported in library client.")
|
pytest.skip("Responses API file search is not yet supported in library client.")
|
||||||
|
|
||||||
vector_store = new_vector_store(compat_client, "test_vector_store")
|
vector_store = new_vector_store(compat_client, "test_vector_store", embedding_model_id, embedding_dimension)
|
||||||
|
|
||||||
if case.file_content:
|
if case.file_content:
|
||||||
file_name = "test_response_non_streaming_file_search.txt"
|
file_name = "test_response_non_streaming_file_search.txt"
|
||||||
|
|
@ -101,11 +103,13 @@ def test_response_non_streaming_file_search(compat_client, text_model_id, tmp_pa
|
||||||
assert case.expected.lower() in response.output_text.lower().strip()
|
assert case.expected.lower() in response.output_text.lower().strip()
|
||||||
|
|
||||||
|
|
||||||
def test_response_non_streaming_file_search_empty_vector_store(compat_client, text_model_id):
|
def test_response_non_streaming_file_search_empty_vector_store(
|
||||||
|
compat_client, text_model_id, embedding_model_id, embedding_dimension
|
||||||
|
):
|
||||||
if isinstance(compat_client, LlamaStackAsLibraryClient):
|
if isinstance(compat_client, LlamaStackAsLibraryClient):
|
||||||
pytest.skip("Responses API file search is not yet supported in library client.")
|
pytest.skip("Responses API file search is not yet supported in library client.")
|
||||||
|
|
||||||
vector_store = new_vector_store(compat_client, "test_vector_store")
|
vector_store = new_vector_store(compat_client, "test_vector_store", embedding_model_id, embedding_dimension)
|
||||||
|
|
||||||
# Create the response request, which should query our vector store
|
# Create the response request, which should query our vector store
|
||||||
response = compat_client.responses.create(
|
response = compat_client.responses.create(
|
||||||
|
|
@ -127,12 +131,14 @@ def test_response_non_streaming_file_search_empty_vector_store(compat_client, te
|
||||||
assert response.output_text
|
assert response.output_text
|
||||||
|
|
||||||
|
|
||||||
def test_response_sequential_file_search(compat_client, text_model_id, tmp_path):
|
def test_response_sequential_file_search(
|
||||||
|
compat_client, text_model_id, embedding_model_id, embedding_dimension, tmp_path
|
||||||
|
):
|
||||||
"""Test file search with sequential responses using previous_response_id."""
|
"""Test file search with sequential responses using previous_response_id."""
|
||||||
if isinstance(compat_client, LlamaStackAsLibraryClient):
|
if isinstance(compat_client, LlamaStackAsLibraryClient):
|
||||||
pytest.skip("Responses API file search is not yet supported in library client.")
|
pytest.skip("Responses API file search is not yet supported in library client.")
|
||||||
|
|
||||||
vector_store = new_vector_store(compat_client, "test_vector_store")
|
vector_store = new_vector_store(compat_client, "test_vector_store", embedding_model_id, embedding_dimension)
|
||||||
|
|
||||||
# Create a test file with content
|
# Create a test file with content
|
||||||
file_content = "The Llama 4 Maverick model has 128 experts in its mixture of experts architecture."
|
file_content = "The Llama 4 Maverick model has 128 experts in its mixture of experts architecture."
|
||||||
|
|
|
||||||
|
|
@ -39,7 +39,7 @@ class Setup(BaseModel):
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
description: str
|
description: str
|
||||||
defaults: dict[str, str] = Field(default_factory=dict)
|
defaults: dict[str, str | int] = Field(default_factory=dict)
|
||||||
env: dict[str, str] = Field(default_factory=dict)
|
env: dict[str, str] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -88,6 +88,7 @@ SETUP_DEFINITIONS: dict[str, Setup] = {
|
||||||
defaults={
|
defaults={
|
||||||
"text_model": "openai/gpt-4o",
|
"text_model": "openai/gpt-4o",
|
||||||
"embedding_model": "openai/text-embedding-3-small",
|
"embedding_model": "openai/text-embedding-3-small",
|
||||||
|
"embedding_dimension": 1536,
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
"tgi": Setup(
|
"tgi": Setup(
|
||||||
|
|
|
||||||
273
tests/unit/distribution/test_api_recordings.py
Normal file
273
tests/unit/distribution/test_api_recordings.py
Normal file
|
|
@ -0,0 +1,273 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from openai import AsyncOpenAI
|
||||||
|
|
||||||
|
# Import the real Pydantic response types instead of using Mocks
|
||||||
|
from llama_stack.apis.inference import (
|
||||||
|
OpenAIAssistantMessageParam,
|
||||||
|
OpenAIChatCompletion,
|
||||||
|
OpenAIChoice,
|
||||||
|
OpenAIEmbeddingData,
|
||||||
|
OpenAIEmbeddingsResponse,
|
||||||
|
OpenAIEmbeddingUsage,
|
||||||
|
)
|
||||||
|
from llama_stack.testing.api_recorder import (
|
||||||
|
APIRecordingMode,
|
||||||
|
ResponseStorage,
|
||||||
|
api_recording,
|
||||||
|
normalize_request,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def temp_storage_dir():
|
||||||
|
"""Create a temporary directory for test recordings."""
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
yield Path(temp_dir)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def real_openai_chat_response():
|
||||||
|
"""Real OpenAI chat completion response using proper Pydantic objects."""
|
||||||
|
return OpenAIChatCompletion(
|
||||||
|
id="chatcmpl-test123",
|
||||||
|
choices=[
|
||||||
|
OpenAIChoice(
|
||||||
|
index=0,
|
||||||
|
message=OpenAIAssistantMessageParam(
|
||||||
|
role="assistant", content="Hello! I'm doing well, thank you for asking."
|
||||||
|
),
|
||||||
|
finish_reason="stop",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
created=1234567890,
|
||||||
|
model="llama3.2:3b",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def real_embeddings_response():
|
||||||
|
"""Real OpenAI embeddings response using proper Pydantic objects."""
|
||||||
|
return OpenAIEmbeddingsResponse(
|
||||||
|
object="list",
|
||||||
|
data=[
|
||||||
|
OpenAIEmbeddingData(object="embedding", embedding=[0.1, 0.2, 0.3], index=0),
|
||||||
|
OpenAIEmbeddingData(object="embedding", embedding=[0.4, 0.5, 0.6], index=1),
|
||||||
|
],
|
||||||
|
model="nomic-embed-text",
|
||||||
|
usage=OpenAIEmbeddingUsage(prompt_tokens=6, total_tokens=6),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestInferenceRecording:
|
||||||
|
"""Test the inference recording system."""
|
||||||
|
|
||||||
|
def test_request_normalization(self):
|
||||||
|
"""Test that request normalization produces consistent hashes."""
|
||||||
|
# Test basic normalization
|
||||||
|
hash1 = normalize_request(
|
||||||
|
"POST",
|
||||||
|
"http://localhost:11434/v1/chat/completions",
|
||||||
|
{},
|
||||||
|
{"model": "llama3.2:3b", "messages": [{"role": "user", "content": "Hello world"}], "temperature": 0.7},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Same request should produce same hash
|
||||||
|
hash2 = normalize_request(
|
||||||
|
"POST",
|
||||||
|
"http://localhost:11434/v1/chat/completions",
|
||||||
|
{},
|
||||||
|
{"model": "llama3.2:3b", "messages": [{"role": "user", "content": "Hello world"}], "temperature": 0.7},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert hash1 == hash2
|
||||||
|
|
||||||
|
# Different content should produce different hash
|
||||||
|
hash3 = normalize_request(
|
||||||
|
"POST",
|
||||||
|
"http://localhost:11434/v1/chat/completions",
|
||||||
|
{},
|
||||||
|
{
|
||||||
|
"model": "llama3.2:3b",
|
||||||
|
"messages": [{"role": "user", "content": "Different message"}],
|
||||||
|
"temperature": 0.7,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert hash1 != hash3
|
||||||
|
|
||||||
|
def test_request_normalization_edge_cases(self):
|
||||||
|
"""Test request normalization is precise about request content."""
|
||||||
|
# Test that different whitespace produces different hashes (no normalization)
|
||||||
|
hash1 = normalize_request(
|
||||||
|
"POST",
|
||||||
|
"http://test/v1/chat/completions",
|
||||||
|
{},
|
||||||
|
{"messages": [{"role": "user", "content": "Hello world\n\n"}]},
|
||||||
|
)
|
||||||
|
hash2 = normalize_request(
|
||||||
|
"POST", "http://test/v1/chat/completions", {}, {"messages": [{"role": "user", "content": "Hello world"}]}
|
||||||
|
)
|
||||||
|
assert hash1 != hash2 # Different whitespace should produce different hashes
|
||||||
|
|
||||||
|
# Test that different float precision produces different hashes (no rounding)
|
||||||
|
hash3 = normalize_request("POST", "http://test/v1/chat/completions", {}, {"temperature": 0.7000001})
|
||||||
|
hash4 = normalize_request("POST", "http://test/v1/chat/completions", {}, {"temperature": 0.7})
|
||||||
|
assert hash3 != hash4 # Different precision should produce different hashes
|
||||||
|
|
||||||
|
def test_response_storage(self, temp_storage_dir):
|
||||||
|
"""Test the ResponseStorage class."""
|
||||||
|
temp_storage_dir = temp_storage_dir / "test_response_storage"
|
||||||
|
storage = ResponseStorage(temp_storage_dir)
|
||||||
|
|
||||||
|
# Test storing and retrieving a recording
|
||||||
|
request_hash = "test_hash_123"
|
||||||
|
request_data = {
|
||||||
|
"method": "POST",
|
||||||
|
"url": "http://localhost:11434/v1/chat/completions",
|
||||||
|
"endpoint": "/v1/chat/completions",
|
||||||
|
"model": "llama3.2:3b",
|
||||||
|
}
|
||||||
|
response_data = {"body": {"content": "test response"}, "is_streaming": False}
|
||||||
|
|
||||||
|
storage.store_recording(request_hash, request_data, response_data)
|
||||||
|
|
||||||
|
# Verify file storage and retrieval
|
||||||
|
retrieved = storage.find_recording(request_hash)
|
||||||
|
assert retrieved is not None
|
||||||
|
assert retrieved["request"]["model"] == "llama3.2:3b"
|
||||||
|
assert retrieved["response"]["body"]["content"] == "test response"
|
||||||
|
|
||||||
|
async def test_recording_mode(self, temp_storage_dir, real_openai_chat_response):
|
||||||
|
"""Test that recording mode captures and stores responses."""
|
||||||
|
|
||||||
|
async def mock_create(*args, **kwargs):
|
||||||
|
return real_openai_chat_response
|
||||||
|
|
||||||
|
temp_storage_dir = temp_storage_dir / "test_recording_mode"
|
||||||
|
with patch("openai.resources.chat.completions.AsyncCompletions.create", side_effect=mock_create):
|
||||||
|
with api_recording(mode=APIRecordingMode.RECORD, storage_dir=str(temp_storage_dir)):
|
||||||
|
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||||
|
|
||||||
|
response = await client.chat.completions.create(
|
||||||
|
model="llama3.2:3b",
|
||||||
|
messages=[{"role": "user", "content": "Hello, how are you?"}],
|
||||||
|
temperature=0.7,
|
||||||
|
max_tokens=50,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the response was returned correctly
|
||||||
|
assert response.choices[0].message.content == "Hello! I'm doing well, thank you for asking."
|
||||||
|
|
||||||
|
# Verify recording was stored
|
||||||
|
storage = ResponseStorage(temp_storage_dir)
|
||||||
|
assert storage._get_test_dir().exists()
|
||||||
|
|
||||||
|
async def test_replay_mode(self, temp_storage_dir, real_openai_chat_response):
|
||||||
|
"""Test that replay mode returns stored responses without making real calls."""
|
||||||
|
|
||||||
|
async def mock_create(*args, **kwargs):
|
||||||
|
return real_openai_chat_response
|
||||||
|
|
||||||
|
temp_storage_dir = temp_storage_dir / "test_replay_mode"
|
||||||
|
# First, record a response
|
||||||
|
with patch("openai.resources.chat.completions.AsyncCompletions.create", side_effect=mock_create):
|
||||||
|
with api_recording(mode=APIRecordingMode.RECORD, storage_dir=str(temp_storage_dir)):
|
||||||
|
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||||
|
|
||||||
|
response = await client.chat.completions.create(
|
||||||
|
model="llama3.2:3b",
|
||||||
|
messages=[{"role": "user", "content": "Hello, how are you?"}],
|
||||||
|
temperature=0.7,
|
||||||
|
max_tokens=50,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Now test replay mode - should not call the original method
|
||||||
|
with patch("openai.resources.chat.completions.AsyncCompletions.create") as mock_create_patch:
|
||||||
|
with api_recording(mode=APIRecordingMode.REPLAY, storage_dir=str(temp_storage_dir)):
|
||||||
|
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||||
|
|
||||||
|
response = await client.chat.completions.create(
|
||||||
|
model="llama3.2:3b",
|
||||||
|
messages=[{"role": "user", "content": "Hello, how are you?"}],
|
||||||
|
temperature=0.7,
|
||||||
|
max_tokens=50,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify we got the recorded response
|
||||||
|
assert response.choices[0].message.content == "Hello! I'm doing well, thank you for asking."
|
||||||
|
|
||||||
|
# Verify the original method was NOT called
|
||||||
|
mock_create_patch.assert_not_called()
|
||||||
|
|
||||||
|
async def test_replay_missing_recording(self, temp_storage_dir):
|
||||||
|
"""Test that replay mode fails when no recording is found."""
|
||||||
|
temp_storage_dir = temp_storage_dir / "test_replay_missing_recording"
|
||||||
|
with patch("openai.resources.chat.completions.AsyncCompletions.create"):
|
||||||
|
with api_recording(mode=APIRecordingMode.REPLAY, storage_dir=str(temp_storage_dir)):
|
||||||
|
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError, match="No recorded response found"):
|
||||||
|
await client.chat.completions.create(
|
||||||
|
model="llama3.2:3b", messages=[{"role": "user", "content": "This was never recorded"}]
|
||||||
|
)
|
||||||
|
|
||||||
|
async def test_embeddings_recording(self, temp_storage_dir, real_embeddings_response):
|
||||||
|
"""Test recording and replay of embeddings calls."""
|
||||||
|
|
||||||
|
async def mock_create(*args, **kwargs):
|
||||||
|
return real_embeddings_response
|
||||||
|
|
||||||
|
temp_storage_dir = temp_storage_dir / "test_embeddings_recording"
|
||||||
|
# Record
|
||||||
|
with patch("openai.resources.embeddings.AsyncEmbeddings.create", side_effect=mock_create):
|
||||||
|
with api_recording(mode=APIRecordingMode.RECORD, storage_dir=str(temp_storage_dir)):
|
||||||
|
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||||
|
|
||||||
|
response = await client.embeddings.create(
|
||||||
|
model="nomic-embed-text", input=["Hello world", "Test embedding"]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(response.data) == 2
|
||||||
|
|
||||||
|
# Replay
|
||||||
|
with patch("openai.resources.embeddings.AsyncEmbeddings.create") as mock_create_patch:
|
||||||
|
with api_recording(mode=APIRecordingMode.REPLAY, storage_dir=str(temp_storage_dir)):
|
||||||
|
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||||
|
|
||||||
|
response = await client.embeddings.create(
|
||||||
|
model="nomic-embed-text", input=["Hello world", "Test embedding"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify we got the recorded response
|
||||||
|
assert len(response.data) == 2
|
||||||
|
assert response.data[0].embedding == [0.1, 0.2, 0.3]
|
||||||
|
|
||||||
|
# Verify original method was not called
|
||||||
|
mock_create_patch.assert_not_called()
|
||||||
|
|
||||||
|
async def test_live_mode(self, real_openai_chat_response):
|
||||||
|
"""Test that live mode passes through to original methods."""
|
||||||
|
|
||||||
|
async def mock_create(*args, **kwargs):
|
||||||
|
return real_openai_chat_response
|
||||||
|
|
||||||
|
with patch("openai.resources.chat.completions.AsyncCompletions.create", side_effect=mock_create):
|
||||||
|
with api_recording(mode=APIRecordingMode.LIVE, storage_dir="foo"):
|
||||||
|
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||||
|
|
||||||
|
response = await client.chat.completions.create(
|
||||||
|
model="llama3.2:3b", messages=[{"role": "user", "content": "Hello"}]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the response was returned
|
||||||
|
assert response.choices[0].message.content == "Hello! I'm doing well, thank you for asking."
|
||||||
|
|
@ -1,382 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
import tempfile
|
|
||||||
from pathlib import Path
|
|
||||||
from unittest.mock import AsyncMock, Mock, patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from openai import NOT_GIVEN, AsyncOpenAI
|
|
||||||
from openai.types.model import Model as OpenAIModel
|
|
||||||
|
|
||||||
# Import the real Pydantic response types instead of using Mocks
|
|
||||||
from llama_stack.apis.inference import (
|
|
||||||
OpenAIAssistantMessageParam,
|
|
||||||
OpenAIChatCompletion,
|
|
||||||
OpenAIChoice,
|
|
||||||
OpenAICompletion,
|
|
||||||
OpenAIEmbeddingData,
|
|
||||||
OpenAIEmbeddingsResponse,
|
|
||||||
OpenAIEmbeddingUsage,
|
|
||||||
)
|
|
||||||
from llama_stack.testing.inference_recorder import (
|
|
||||||
InferenceMode,
|
|
||||||
ResponseStorage,
|
|
||||||
inference_recording,
|
|
||||||
normalize_request,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def temp_storage_dir():
|
|
||||||
"""Create a temporary directory for test recordings."""
|
|
||||||
with tempfile.TemporaryDirectory() as temp_dir:
|
|
||||||
yield Path(temp_dir)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def real_openai_chat_response():
|
|
||||||
"""Real OpenAI chat completion response using proper Pydantic objects."""
|
|
||||||
return OpenAIChatCompletion(
|
|
||||||
id="chatcmpl-test123",
|
|
||||||
choices=[
|
|
||||||
OpenAIChoice(
|
|
||||||
index=0,
|
|
||||||
message=OpenAIAssistantMessageParam(
|
|
||||||
role="assistant", content="Hello! I'm doing well, thank you for asking."
|
|
||||||
),
|
|
||||||
finish_reason="stop",
|
|
||||||
)
|
|
||||||
],
|
|
||||||
created=1234567890,
|
|
||||||
model="llama3.2:3b",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def real_embeddings_response():
|
|
||||||
"""Real OpenAI embeddings response using proper Pydantic objects."""
|
|
||||||
return OpenAIEmbeddingsResponse(
|
|
||||||
object="list",
|
|
||||||
data=[
|
|
||||||
OpenAIEmbeddingData(object="embedding", embedding=[0.1, 0.2, 0.3], index=0),
|
|
||||||
OpenAIEmbeddingData(object="embedding", embedding=[0.4, 0.5, 0.6], index=1),
|
|
||||||
],
|
|
||||||
model="nomic-embed-text",
|
|
||||||
usage=OpenAIEmbeddingUsage(prompt_tokens=6, total_tokens=6),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestInferenceRecording:
|
|
||||||
"""Test the inference recording system."""
|
|
||||||
|
|
||||||
def test_request_normalization(self):
|
|
||||||
"""Test that request normalization produces consistent hashes."""
|
|
||||||
# Test basic normalization
|
|
||||||
hash1 = normalize_request(
|
|
||||||
"POST",
|
|
||||||
"http://localhost:11434/v1/chat/completions",
|
|
||||||
{},
|
|
||||||
{"model": "llama3.2:3b", "messages": [{"role": "user", "content": "Hello world"}], "temperature": 0.7},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Same request should produce same hash
|
|
||||||
hash2 = normalize_request(
|
|
||||||
"POST",
|
|
||||||
"http://localhost:11434/v1/chat/completions",
|
|
||||||
{},
|
|
||||||
{"model": "llama3.2:3b", "messages": [{"role": "user", "content": "Hello world"}], "temperature": 0.7},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert hash1 == hash2
|
|
||||||
|
|
||||||
# Different content should produce different hash
|
|
||||||
hash3 = normalize_request(
|
|
||||||
"POST",
|
|
||||||
"http://localhost:11434/v1/chat/completions",
|
|
||||||
{},
|
|
||||||
{
|
|
||||||
"model": "llama3.2:3b",
|
|
||||||
"messages": [{"role": "user", "content": "Different message"}],
|
|
||||||
"temperature": 0.7,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert hash1 != hash3
|
|
||||||
|
|
||||||
def test_request_normalization_edge_cases(self):
|
|
||||||
"""Test request normalization is precise about request content."""
|
|
||||||
# Test that different whitespace produces different hashes (no normalization)
|
|
||||||
hash1 = normalize_request(
|
|
||||||
"POST",
|
|
||||||
"http://test/v1/chat/completions",
|
|
||||||
{},
|
|
||||||
{"messages": [{"role": "user", "content": "Hello world\n\n"}]},
|
|
||||||
)
|
|
||||||
hash2 = normalize_request(
|
|
||||||
"POST", "http://test/v1/chat/completions", {}, {"messages": [{"role": "user", "content": "Hello world"}]}
|
|
||||||
)
|
|
||||||
assert hash1 != hash2 # Different whitespace should produce different hashes
|
|
||||||
|
|
||||||
# Test that different float precision produces different hashes (no rounding)
|
|
||||||
hash3 = normalize_request("POST", "http://test/v1/chat/completions", {}, {"temperature": 0.7000001})
|
|
||||||
hash4 = normalize_request("POST", "http://test/v1/chat/completions", {}, {"temperature": 0.7})
|
|
||||||
assert hash3 != hash4 # Different precision should produce different hashes
|
|
||||||
|
|
||||||
def test_response_storage(self, temp_storage_dir):
|
|
||||||
"""Test the ResponseStorage class."""
|
|
||||||
temp_storage_dir = temp_storage_dir / "test_response_storage"
|
|
||||||
storage = ResponseStorage(temp_storage_dir)
|
|
||||||
|
|
||||||
# Test storing and retrieving a recording
|
|
||||||
request_hash = "test_hash_123"
|
|
||||||
request_data = {
|
|
||||||
"method": "POST",
|
|
||||||
"url": "http://localhost:11434/v1/chat/completions",
|
|
||||||
"endpoint": "/v1/chat/completions",
|
|
||||||
"model": "llama3.2:3b",
|
|
||||||
}
|
|
||||||
response_data = {"body": {"content": "test response"}, "is_streaming": False}
|
|
||||||
|
|
||||||
storage.store_recording(request_hash, request_data, response_data)
|
|
||||||
|
|
||||||
# Verify file storage and retrieval
|
|
||||||
retrieved = storage.find_recording(request_hash)
|
|
||||||
assert retrieved is not None
|
|
||||||
assert retrieved["request"]["model"] == "llama3.2:3b"
|
|
||||||
assert retrieved["response"]["body"]["content"] == "test response"
|
|
||||||
|
|
||||||
async def test_recording_mode(self, temp_storage_dir, real_openai_chat_response):
|
|
||||||
"""Test that recording mode captures and stores responses."""
|
|
||||||
temp_storage_dir = temp_storage_dir / "test_recording_mode"
|
|
||||||
with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)):
|
|
||||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
|
||||||
client.chat.completions._post = AsyncMock(return_value=real_openai_chat_response)
|
|
||||||
|
|
||||||
response = await client.chat.completions.create(
|
|
||||||
model="llama3.2:3b",
|
|
||||||
messages=[{"role": "user", "content": "Hello, how are you?"}],
|
|
||||||
temperature=0.7,
|
|
||||||
max_tokens=50,
|
|
||||||
user=NOT_GIVEN,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify the response was returned correctly
|
|
||||||
assert response.choices[0].message.content == "Hello! I'm doing well, thank you for asking."
|
|
||||||
client.chat.completions._post.assert_called_once()
|
|
||||||
|
|
||||||
# Verify recording was stored
|
|
||||||
storage = ResponseStorage(temp_storage_dir)
|
|
||||||
dir = storage._get_test_dir()
|
|
||||||
assert dir.exists()
|
|
||||||
|
|
||||||
async def test_replay_mode(self, temp_storage_dir, real_openai_chat_response):
|
|
||||||
"""Test that replay mode returns stored responses without making real calls."""
|
|
||||||
temp_storage_dir = temp_storage_dir / "test_replay_mode"
|
|
||||||
# First, record a response
|
|
||||||
with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)):
|
|
||||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
|
||||||
client.chat.completions._post = AsyncMock(return_value=real_openai_chat_response)
|
|
||||||
|
|
||||||
response = await client.chat.completions.create(
|
|
||||||
model="llama3.2:3b",
|
|
||||||
messages=[{"role": "user", "content": "Hello, how are you?"}],
|
|
||||||
temperature=0.7,
|
|
||||||
max_tokens=50,
|
|
||||||
user=NOT_GIVEN,
|
|
||||||
)
|
|
||||||
client.chat.completions._post.assert_called_once()
|
|
||||||
|
|
||||||
# Now test replay mode - should not call the original method
|
|
||||||
with inference_recording(mode=InferenceMode.REPLAY, storage_dir=str(temp_storage_dir)):
|
|
||||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
|
||||||
client.chat.completions._post = AsyncMock(return_value=real_openai_chat_response)
|
|
||||||
|
|
||||||
response = await client.chat.completions.create(
|
|
||||||
model="llama3.2:3b",
|
|
||||||
messages=[{"role": "user", "content": "Hello, how are you?"}],
|
|
||||||
temperature=0.7,
|
|
||||||
max_tokens=50,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify we got the recorded response
|
|
||||||
assert response.choices[0].message.content == "Hello! I'm doing well, thank you for asking."
|
|
||||||
|
|
||||||
# Verify the original method was NOT called
|
|
||||||
client.chat.completions._post.assert_not_called()
|
|
||||||
|
|
||||||
async def test_replay_mode_models(self, temp_storage_dir):
|
|
||||||
"""Test that replay mode returns stored responses without making real model listing calls."""
|
|
||||||
|
|
||||||
async def _async_iterator(models):
|
|
||||||
for model in models:
|
|
||||||
yield model
|
|
||||||
|
|
||||||
models = [
|
|
||||||
OpenAIModel(id="foo", created=1, object="model", owned_by="test"),
|
|
||||||
OpenAIModel(id="bar", created=2, object="model", owned_by="test"),
|
|
||||||
]
|
|
||||||
|
|
||||||
expected_ids = {m.id for m in models}
|
|
||||||
|
|
||||||
temp_storage_dir = temp_storage_dir / "test_replay_mode_models"
|
|
||||||
|
|
||||||
# baseline - mock works without recording
|
|
||||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
|
||||||
client.models._get_api_list = Mock(return_value=_async_iterator(models))
|
|
||||||
assert {m.id async for m in client.models.list()} == expected_ids
|
|
||||||
client.models._get_api_list.assert_called_once()
|
|
||||||
|
|
||||||
# record the call
|
|
||||||
with inference_recording(mode=InferenceMode.RECORD, storage_dir=temp_storage_dir):
|
|
||||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
|
||||||
client.models._get_api_list = Mock(return_value=_async_iterator(models))
|
|
||||||
assert {m.id async for m in client.models.list()} == expected_ids
|
|
||||||
client.models._get_api_list.assert_called_once()
|
|
||||||
|
|
||||||
# replay the call
|
|
||||||
with inference_recording(mode=InferenceMode.REPLAY, storage_dir=temp_storage_dir):
|
|
||||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
|
||||||
client.models._get_api_list = Mock(return_value=_async_iterator(models))
|
|
||||||
assert {m.id async for m in client.models.list()} == expected_ids
|
|
||||||
client.models._get_api_list.assert_not_called()
|
|
||||||
|
|
||||||
async def test_replay_missing_recording(self, temp_storage_dir):
|
|
||||||
"""Test that replay mode fails when no recording is found."""
|
|
||||||
temp_storage_dir = temp_storage_dir / "test_replay_missing_recording"
|
|
||||||
with patch("openai.resources.chat.completions.AsyncCompletions.create"):
|
|
||||||
with inference_recording(mode=InferenceMode.REPLAY, storage_dir=str(temp_storage_dir)):
|
|
||||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
|
||||||
|
|
||||||
with pytest.raises(RuntimeError, match="No recorded response found"):
|
|
||||||
await client.chat.completions.create(
|
|
||||||
model="llama3.2:3b", messages=[{"role": "user", "content": "This was never recorded"}]
|
|
||||||
)
|
|
||||||
|
|
||||||
async def test_embeddings_recording(self, temp_storage_dir, real_embeddings_response):
|
|
||||||
"""Test recording and replay of embeddings calls."""
|
|
||||||
|
|
||||||
# baseline - mock works without recording
|
|
||||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
|
||||||
client.embeddings._post = AsyncMock(return_value=real_embeddings_response)
|
|
||||||
response = await client.embeddings.create(
|
|
||||||
model=real_embeddings_response.model,
|
|
||||||
input=["Hello world", "Test embedding"],
|
|
||||||
encoding_format=NOT_GIVEN,
|
|
||||||
)
|
|
||||||
assert len(response.data) == 2
|
|
||||||
assert response.data[0].embedding == [0.1, 0.2, 0.3]
|
|
||||||
client.embeddings._post.assert_called_once()
|
|
||||||
|
|
||||||
temp_storage_dir = temp_storage_dir / "test_embeddings_recording"
|
|
||||||
# Record
|
|
||||||
with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)):
|
|
||||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
|
||||||
client.embeddings._post = AsyncMock(return_value=real_embeddings_response)
|
|
||||||
|
|
||||||
response = await client.embeddings.create(
|
|
||||||
model=real_embeddings_response.model,
|
|
||||||
input=["Hello world", "Test embedding"],
|
|
||||||
encoding_format=NOT_GIVEN,
|
|
||||||
dimensions=NOT_GIVEN,
|
|
||||||
user=NOT_GIVEN,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert len(response.data) == 2
|
|
||||||
|
|
||||||
# Replay
|
|
||||||
with inference_recording(mode=InferenceMode.REPLAY, storage_dir=str(temp_storage_dir)):
|
|
||||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
|
||||||
client.embeddings._post = AsyncMock(return_value=real_embeddings_response)
|
|
||||||
|
|
||||||
response = await client.embeddings.create(
|
|
||||||
model=real_embeddings_response.model,
|
|
||||||
input=["Hello world", "Test embedding"],
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify we got the recorded response
|
|
||||||
assert len(response.data) == 2
|
|
||||||
assert response.data[0].embedding == [0.1, 0.2, 0.3]
|
|
||||||
|
|
||||||
# Verify original method was not called
|
|
||||||
client.embeddings._post.assert_not_called()
|
|
||||||
|
|
||||||
async def test_completions_recording(self, temp_storage_dir):
|
|
||||||
real_completions_response = OpenAICompletion(
|
|
||||||
id="test_completion",
|
|
||||||
object="text_completion",
|
|
||||||
created=1234567890,
|
|
||||||
model="llama3.2:3b",
|
|
||||||
choices=[
|
|
||||||
{
|
|
||||||
"text": "Hello! I'm doing well, thank you for asking.",
|
|
||||||
"index": 0,
|
|
||||||
"logprobs": None,
|
|
||||||
"finish_reason": "stop",
|
|
||||||
}
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
temp_storage_dir = temp_storage_dir / "test_completions_recording"
|
|
||||||
|
|
||||||
# baseline - mock works without recording
|
|
||||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
|
||||||
client.completions._post = AsyncMock(return_value=real_completions_response)
|
|
||||||
response = await client.completions.create(
|
|
||||||
model=real_completions_response.model,
|
|
||||||
prompt="Hello, how are you?",
|
|
||||||
temperature=0.7,
|
|
||||||
max_tokens=50,
|
|
||||||
user=NOT_GIVEN,
|
|
||||||
)
|
|
||||||
assert response.choices[0].text == real_completions_response.choices[0].text
|
|
||||||
client.completions._post.assert_called_once()
|
|
||||||
|
|
||||||
# Record
|
|
||||||
with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)):
|
|
||||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
|
||||||
client.completions._post = AsyncMock(return_value=real_completions_response)
|
|
||||||
|
|
||||||
response = await client.completions.create(
|
|
||||||
model=real_completions_response.model,
|
|
||||||
prompt="Hello, how are you?",
|
|
||||||
temperature=0.7,
|
|
||||||
max_tokens=50,
|
|
||||||
user=NOT_GIVEN,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.choices[0].text == real_completions_response.choices[0].text
|
|
||||||
client.completions._post.assert_called_once()
|
|
||||||
|
|
||||||
# Replay
|
|
||||||
with inference_recording(mode=InferenceMode.REPLAY, storage_dir=str(temp_storage_dir)):
|
|
||||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
|
||||||
client.completions._post = AsyncMock(return_value=real_completions_response)
|
|
||||||
response = await client.completions.create(
|
|
||||||
model=real_completions_response.model,
|
|
||||||
prompt="Hello, how are you?",
|
|
||||||
temperature=0.7,
|
|
||||||
max_tokens=50,
|
|
||||||
)
|
|
||||||
assert response.choices[0].text == real_completions_response.choices[0].text
|
|
||||||
client.completions._post.assert_not_called()
|
|
||||||
|
|
||||||
async def test_live_mode(self, real_openai_chat_response):
|
|
||||||
"""Test that live mode passes through to original methods."""
|
|
||||||
|
|
||||||
async def mock_create(*args, **kwargs):
|
|
||||||
return real_openai_chat_response
|
|
||||||
|
|
||||||
with patch("openai.resources.chat.completions.AsyncCompletions.create", side_effect=mock_create):
|
|
||||||
with inference_recording(mode=InferenceMode.LIVE, storage_dir="foo"):
|
|
||||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
|
||||||
|
|
||||||
response = await client.chat.completions.create(
|
|
||||||
model="llama3.2:3b", messages=[{"role": "user", "content": "Hello"}]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify the response was returned
|
|
||||||
assert response.choices[0].message.content == "Hello! I'm doing well, thank you for asking."
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue