Merge branch 'main' into chroma

This commit is contained in:
Bwook (Byoungwook) Kim 2025-08-18 16:11:36 +09:00 committed by GitHub
commit c66ebae9b6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
207 changed files with 15490 additions and 7927 deletions

View file

@ -1,9 +1,115 @@
# Llama Stack Tests
There are two obvious types of tests:
Llama Stack has multiple layers of testing done to ensure continuous functionality and prevent regressions to the codebase.
| Type | Location | Purpose |
|------|----------|---------|
| **Unit** | [`tests/unit/`](unit/README.md) | Fast, isolated component testing |
| **Integration** | [`tests/integration/`](integration/README.md) | End-to-end workflows with record-replay |
| Testing Type | Details |
|--------------|---------|
| Unit | [unit/README.md](unit/README.md) |
| Integration | [integration/README.md](integration/README.md) |
| Verification | [verifications/README.md](verifications/README.md) |
Both have their place. For unit tests, it is important to create minimal mocks and instead rely more on "fakes". Mocks are too brittle. In either case, tests must be very fast and reliable.
### Record-replay for integration tests
Testing AI applications end-to-end creates some challenges:
- **API costs** accumulate quickly during development and CI
- **Non-deterministic responses** make tests unreliable
- **Multiple providers** require testing the same logic across different APIs
Our solution: **Record real API responses once, replay them for fast, deterministic tests.** This is better than mocking because AI APIs have complex response structures and streaming behavior. Mocks can miss edge cases that real APIs exhibit. A single test can exercise underlying APIs in multiple complex ways making it really hard to mock.
This gives you:
- Cost control - No repeated API calls during development
- Speed - Instant test execution with cached responses
- Reliability - Consistent results regardless of external service state
- Provider coverage - Same tests work across OpenAI, Anthropic, local models, etc.
### Testing Quick Start
You can run the unit tests with:
```bash
uv run --group unit pytest -sv tests/unit/
```
For running integration tests, you must provide a few things:
- A stack config. This is a pointer to a stack. You have a few ways to point to a stack:
- **`server:<config>`** - automatically start a server with the given config (e.g., `server:starter`). This provides one-step testing by auto-starting the server if the port is available, or reusing an existing server if already running.
- **`server:<config>:<port>`** - same as above but with a custom port (e.g., `server:starter:8322`)
- a URL which points to a Llama Stack distribution server
- a distribution name (e.g., `starter`) or a path to a `run.yaml` file
- a comma-separated list of api=provider pairs, e.g. `inference=fireworks,safety=llama-guard,agents=meta-reference`. This is most useful for testing a single API surface.
- Whether you are using replay or live mode for inference. This is specified with the LLAMA_STACK_TEST_INFERENCE_MODE environment variable. The default mode currently is "live" -- that is certainly surprising, but we will fix this soon.
- Any API keys you need to use should be set in the environment, or can be passed in with the --env option.
You can run the integration tests in replay mode with:
```bash
# Run all tests with existing recordings
LLAMA_STACK_TEST_INFERENCE_MODE=replay \
LLAMA_STACK_TEST_RECORDING_DIR=tests/integration/recordings \
uv run --group test \
pytest -sv tests/integration/ --stack-config=starter
```
If you don't specify LLAMA_STACK_TEST_INFERENCE_MODE, by default it will be in "live" mode -- that is, it will make real API calls.
```bash
# Test against live APIs
FIREWORKS_API_KEY=your_key pytest -sv tests/integration/inference --stack-config=starter
```
### Re-recording tests
#### Local Re-recording (Manual Setup Required)
If you want to re-record tests locally, you can do so with:
```bash
LLAMA_STACK_TEST_INFERENCE_MODE=record \
LLAMA_STACK_TEST_RECORDING_DIR=tests/integration/recordings \
uv run --group test \
pytest -sv tests/integration/ --stack-config=starter -k "<appropriate test name>"
```
This will record new API responses and overwrite the existing recordings.
```{warning}
You must be careful when re-recording. CI workflows assume a specific setup for running the replay-mode tests. You must re-record the tests in the same way as the CI workflows. This means
- you need Ollama running and serving some specific models.
- you are using the `starter` distribution.
```
#### Remote Re-recording (Recommended)
**For easier re-recording without local setup**, use the automated recording workflow:
```bash
# Record tests for specific test subdirectories
./scripts/github/schedule-record-workflow.sh --test-subdirs "agents,inference"
# Record with vision tests enabled
./scripts/github/schedule-record-workflow.sh --test-subdirs "inference" --run-vision-tests
# Record with specific provider
./scripts/github/schedule-record-workflow.sh --test-subdirs "agents" --test-provider vllm
```
This script:
- 🚀 **Runs in GitHub Actions** - no local Ollama setup required
- 🔍 **Auto-detects your branch** and associated PR
- 🍴 **Works from forks** - handles repository context automatically
- ✅ **Commits recordings back** to your branch
**Prerequisites:**
- GitHub CLI: `brew install gh && gh auth login`
- jq: `brew install jq`
- Your branch pushed to a remote
**Supported providers:** `vllm`, `ollama`
### Next Steps
- [Integration Testing Guide](integration/README.md) - Detailed usage and configuration
- [Unit Testing Guide](unit/README.md) - Fast component testing

View file

@ -16,13 +16,10 @@ MCP_TOOLGROUP_ID = "mcp::localmcp"
def default_tools():
"""Default tools for backward compatibility."""
from mcp import types
from mcp.server.fastmcp import Context
async def greet_everyone(
url: str, ctx: Context
) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
return [types.TextContent(type="text", text="Hello, world!")]
async def greet_everyone(url: str, ctx: Context) -> str:
return "Hello, world!"
async def get_boiling_point(liquid_name: str, celsius: bool = True) -> int:
"""
@ -45,7 +42,6 @@ def default_tools():
def dependency_tools():
"""Tools with natural dependencies for multi-turn testing."""
from mcp import types
from mcp.server.fastmcp import Context
async def get_user_id(username: str, ctx: Context) -> str:
@ -106,7 +102,7 @@ def dependency_tools():
else:
access = "no"
return [types.TextContent(type="text", text=access)]
return access
async def get_experiment_id(experiment_name: str, ctx: Context) -> str:
"""
@ -245,7 +241,6 @@ def make_mcp_server(required_auth_token: str | None = None, tools: dict[str, Cal
try:
yield {"server_url": server_url}
finally:
print("Telling SSE server to exit")
server_instance.should_exit = True
time.sleep(0.5)
@ -269,4 +264,3 @@ def make_mcp_server(required_auth_token: str | None = None, tools: dict[str, Cal
AppStatus.should_exit = False
AppStatus.should_exit_event = None
print("SSE server exited")

View file

@ -3,7 +3,7 @@ name = "llama-stack-api-weather"
version = "0.1.0"
description = "Weather API for Llama Stack"
readme = "README.md"
requires-python = ">=3.10"
requires-python = ">=3.12"
dependencies = ["llama-stack", "pydantic"]
[build-system]

View file

@ -3,7 +3,7 @@ name = "llama-stack-provider-kaze"
version = "0.1.0"
description = "Kaze weather provider for Llama Stack"
readme = "README.md"
requires-python = ">=3.10"
requires-python = ">=3.12"
dependencies = ["llama-stack", "pydantic", "aiohttp"]
[build-system]

View file

@ -1,6 +1,20 @@
# Llama Stack Integration Tests
# Integration Testing Guide
We use `pytest` for parameterizing and running tests. You can see all options with:
Integration tests verify complete workflows across different providers using Llama Stack's record-replay system.
## Quick Start
```bash
# Run all integration tests with existing recordings
LLAMA_STACK_TEST_INFERENCE_MODE=replay \
LLAMA_STACK_TEST_RECORDING_DIR=tests/integration/recordings \
uv run --group test \
pytest -sv tests/integration/ --stack-config=starter
```
## Configuration Options
You can see all options with:
```bash
cd tests/integration
@ -10,11 +24,11 @@ pytest --help
Here are the most important options:
- `--stack-config`: specify the stack config to use. You have four ways to point to a stack:
- **`server:<config>`** - automatically start a server with the given config (e.g., `server:fireworks`). This provides one-step testing by auto-starting the server if the port is available, or reusing an existing server if already running.
- **`server:<config>:<port>`** - same as above but with a custom port (e.g., `server:together:8322`)
- **`server:<config>`** - automatically start a server with the given config (e.g., `server:starter`). This provides one-step testing by auto-starting the server if the port is available, or reusing an existing server if already running.
- **`server:<config>:<port>`** - same as above but with a custom port (e.g., `server:starter:8322`)
- a URL which points to a Llama Stack distribution server
- a template (e.g., `starter`) or a path to a `run.yaml` file
- a comma-separated list of api=provider pairs, e.g. `inference=fireworks,safety=llama-guard,agents=meta-reference`. This is most useful for testing a single API surface.
- a distribution name (e.g., `starter`) or a path to a `run.yaml` file
- a comma-separated list of api=provider pairs, e.g. `inference=ollama,safety=llama-guard,agents=meta-reference`. This is most useful for testing a single API surface.
- `--env`: set environment variables, e.g. --env KEY=value. this is a utility option to set environment variables required by various providers.
Model parameters can be influenced by the following options:
@ -32,85 +46,139 @@ if no model is specified.
### Testing against a Server
Run all text inference tests by auto-starting a server with the `fireworks` config:
Run all text inference tests by auto-starting a server with the `starter` config:
```bash
pytest -s -v tests/integration/inference/test_text_inference.py \
--stack-config=server:fireworks \
--text-model=meta-llama/Llama-3.1-8B-Instruct
OLLAMA_URL=http://localhost:11434 \
pytest -s -v tests/integration/inference/test_text_inference.py \
--stack-config=server:starter \
--text-model=ollama/llama3.2:3b-instruct-fp16 \
--embedding-model=sentence-transformers/all-MiniLM-L6-v2
```
Run tests with auto-server startup on a custom port:
```bash
pytest -s -v tests/integration/inference/ \
--stack-config=server:together:8322 \
--text-model=meta-llama/Llama-3.1-8B-Instruct
```
Run multiple test suites with auto-server (eliminates manual server management):
```bash
# Auto-start server and run all integration tests
export FIREWORKS_API_KEY=<your_key>
pytest -s -v tests/integration/inference/ tests/integration/safety/ tests/integration/agents/ \
--stack-config=server:fireworks \
--text-model=meta-llama/Llama-3.1-8B-Instruct
OLLAMA_URL=http://localhost:11434 \
pytest -s -v tests/integration/inference/ \
--stack-config=server:starter:8322 \
--text-model=ollama/llama3.2:3b-instruct-fp16 \
--embedding-model=sentence-transformers/all-MiniLM-L6-v2
```
### Testing with Library Client
Run all text inference tests with the `starter` distribution using the `together` provider:
The library client constructs the Stack "in-process" instead of using a server. This is useful during the iterative development process since you don't need to constantly start and stop servers.
You can do this by simply using `--stack-config=starter` instead of `--stack-config=server:starter`.
### Using ad-hoc distributions
Sometimes, you may want to make up a distribution on the fly. This is useful for testing a single provider or a single API or a small combination of providers. You can do so by specifying a comma-separated list of api=provider pairs to the `--stack-config` option, e.g. `inference=remote::ollama,safety=inline::llama-guard,agents=inline::meta-reference`.
```bash
ENABLE_TOGETHER=together pytest -s -v tests/integration/inference/test_text_inference.py \
--stack-config=starter \
--text-model=meta-llama/Llama-3.1-8B-Instruct
```
Run all text inference tests with the `starter` distribution using the `together` provider and `meta-llama/Llama-3.1-8B-Instruct`:
```bash
ENABLE_TOGETHER=together pytest -s -v tests/integration/inference/test_text_inference.py \
--stack-config=starter \
--text-model=meta-llama/Llama-3.1-8B-Instruct
```
Running all inference tests for a number of models using the `together` provider:
```bash
TEXT_MODELS=meta-llama/Llama-3.1-8B-Instruct,meta-llama/Llama-3.1-70B-Instruct
VISION_MODELS=meta-llama/Llama-3.2-11B-Vision-Instruct
EMBEDDING_MODELS=all-MiniLM-L6-v2
ENABLE_TOGETHER=together
export TOGETHER_API_KEY=<together_api_key>
pytest -s -v tests/integration/inference/ \
--stack-config=together \
--stack-config=inference=remote::ollama,safety=inline::llama-guard,agents=inline::meta-reference \
--text-model=$TEXT_MODELS \
--vision-model=$VISION_MODELS \
--embedding-model=$EMBEDDING_MODELS
```
Same thing but instead of using the distribution, use an adhoc stack with just one provider (`fireworks` for inference):
Another example: Running Vector IO tests for embedding models:
```bash
export FIREWORKS_API_KEY=<fireworks_api_key>
pytest -s -v tests/integration/inference/ \
--stack-config=inference=fireworks \
--text-model=$TEXT_MODELS \
--vision-model=$VISION_MODELS \
--embedding-model=$EMBEDDING_MODELS
```
Running Vector IO tests for a number of embedding models:
```bash
EMBEDDING_MODELS=all-MiniLM-L6-v2
pytest -s -v tests/integration/vector_io/ \
--stack-config=inference=sentence-transformers,vector_io=sqlite-vec \
--embedding-model=$EMBEDDING_MODELS
--stack-config=inference=inline::sentence-transformers,vector_io=inline::sqlite-vec \
--embedding-model=sentence-transformers/all-MiniLM-L6-v2
```
## Recording Modes
The testing system supports three modes controlled by environment variables:
### LIVE Mode (Default)
Tests make real API calls:
```bash
LLAMA_STACK_TEST_INFERENCE_MODE=live pytest tests/integration/
```
### RECORD Mode
Captures API interactions for later replay:
```bash
LLAMA_STACK_TEST_INFERENCE_MODE=record \
LLAMA_STACK_TEST_RECORDING_DIR=tests/integration/recordings \
pytest tests/integration/inference/test_new_feature.py
```
### REPLAY Mode
Uses cached responses instead of making API calls:
```bash
LLAMA_STACK_TEST_INFERENCE_MODE=replay \
LLAMA_STACK_TEST_RECORDING_DIR=tests/integration/recordings \
pytest tests/integration/
```
Note that right now you must specify the recording directory. This is because different tests use different recording directories and we don't (yet) have a fool-proof way to map a test to a recording directory. We are working on this.
## Managing Recordings
### Viewing Recordings
```bash
# See what's recorded
sqlite3 recordings/index.sqlite "SELECT endpoint, model, timestamp FROM recordings;"
# Inspect specific response
cat recordings/responses/abc123.json | jq '.'
```
### Re-recording Tests
#### Remote Re-recording (Recommended)
Use the automated workflow script for easier re-recording:
```bash
./scripts/github/schedule-record-workflow.sh --test-subdirs "inference,agents"
```
See the [main testing guide](../README.md#remote-re-recording-recommended) for full details.
#### Local Re-recording
```bash
# Re-record specific tests
LLAMA_STACK_TEST_INFERENCE_MODE=record \
LLAMA_STACK_TEST_RECORDING_DIR=tests/integration/recordings \
pytest -s -v --stack-config=server:starter tests/integration/inference/test_modified.py
```
Note that when re-recording tests, you must use a Stack pointing to a server (i.e., `server:starter`). This subtlety exists because the set of tests run in server are a superset of the set of tests run in the library client.
## Writing Tests
### Basic Test Pattern
```python
def test_basic_completion(llama_stack_client, text_model_id):
response = llama_stack_client.inference.completion(
model_id=text_model_id,
content=CompletionMessage(role="user", content="Hello"),
)
# Test structure, not AI output quality
assert response.completion_message is not None
assert isinstance(response.completion_message.content, str)
assert len(response.completion_message.content) > 0
```
### Provider-Specific Tests
```python
def test_asymmetric_embeddings(llama_stack_client, embedding_model_id):
if embedding_model_id not in MODELS_SUPPORTING_TASK_TYPE:
pytest.skip(f"Model {embedding_model_id} doesn't support task types")
query_response = llama_stack_client.inference.embeddings(
model_id=embedding_model_id,
contents=["What is machine learning?"],
task_type="query",
)
assert query_response.embeddings is not None
```

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,122 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
"""Shared pytest fixtures for batch tests."""
import json
import time
import warnings
from contextlib import contextmanager
from io import BytesIO
import pytest
from llama_stack.apis.files import OpenAIFilePurpose
class BatchHelper:
"""Helper class for creating and managing batch input files."""
def __init__(self, client):
"""Initialize with either a batch_client or openai_client."""
self.client = client
@contextmanager
def create_file(self, content: str | list[dict], filename_prefix="batch_input"):
"""Context manager for creating and cleaning up batch input files.
Args:
content: Either a list of batch request dictionaries or raw string content
filename_prefix: Prefix for the generated filename (or full filename if content is string)
Yields:
The uploaded file object
"""
if isinstance(content, str):
# Handle raw string content (e.g., malformed JSONL, empty files)
file_content = content.encode("utf-8")
else:
# Handle list of batch request dictionaries
jsonl_content = "\n".join(json.dumps(req) for req in content)
file_content = jsonl_content.encode("utf-8")
filename = filename_prefix if filename_prefix.endswith(".jsonl") else f"{filename_prefix}.jsonl"
with BytesIO(file_content) as file_buffer:
file_buffer.name = filename
uploaded_file = self.client.files.create(file=file_buffer, purpose=OpenAIFilePurpose.BATCH)
try:
yield uploaded_file
finally:
try:
self.client.files.delete(uploaded_file.id)
except Exception:
warnings.warn(
f"Failed to cleanup file {uploaded_file.id}: {uploaded_file.filename}",
stacklevel=2,
)
def wait_for(
self,
batch_id: str,
max_wait_time: int = 60,
sleep_interval: int | None = None,
expected_statuses: set[str] | None = None,
timeout_action: str = "fail",
):
"""Wait for a batch to reach a terminal status.
Args:
batch_id: The batch ID to monitor
max_wait_time: Maximum time to wait in seconds (default: 60 seconds)
sleep_interval: Time to sleep between checks in seconds (default: 1/10th of max_wait_time, min 1s, max 15s)
expected_statuses: Set of expected terminal statuses (default: {"completed"})
timeout_action: Action on timeout - "fail" (pytest.fail) or "skip" (pytest.skip)
Returns:
The final batch object
Raises:
pytest.Failed: If batch reaches an unexpected status or timeout_action is "fail"
pytest.Skipped: If timeout_action is "skip" on timeout or unexpected status
"""
if sleep_interval is None:
# Default to 1/10th of max_wait_time, with min 1s and max 15s
sleep_interval = max(1, min(15, max_wait_time // 10))
if expected_statuses is None:
expected_statuses = {"completed"}
terminal_statuses = {"completed", "failed", "cancelled", "expired"}
unexpected_statuses = terminal_statuses - expected_statuses
start_time = time.time()
while time.time() - start_time < max_wait_time:
current_batch = self.client.batches.retrieve(batch_id)
if current_batch.status in expected_statuses:
return current_batch
elif current_batch.status in unexpected_statuses:
error_msg = f"Batch reached unexpected status: {current_batch.status}"
if timeout_action == "skip":
pytest.skip(error_msg)
else:
pytest.fail(error_msg)
time.sleep(sleep_interval)
timeout_msg = f"Batch did not reach expected status {expected_statuses} within {max_wait_time} seconds"
if timeout_action == "skip":
pytest.skip(timeout_msg)
else:
pytest.fail(timeout_msg)
@pytest.fixture
def batch_helper(openai_client):
"""Fixture that provides a BatchHelper instance for OpenAI client."""
return BatchHelper(openai_client)

View file

@ -0,0 +1,270 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
"""
Integration tests for the Llama Stack batch processing functionality.
This module contains comprehensive integration tests for the batch processing API,
using the OpenAI-compatible client interface for consistency.
Test Categories:
1. Core Batch Operations:
- test_batch_creation_and_retrieval: Comprehensive batch creation, structure validation, and retrieval
- test_batch_listing: Basic batch listing functionality
- test_batch_immediate_cancellation: Batch cancellation workflow
# TODO: cancel during processing
2. End-to-End Processing:
- test_batch_e2e_chat_completions: Full chat completions workflow with output and error validation
Note: Error conditions and edge cases are primarily tested in test_batches_errors.py
for better organization and separation of concerns.
CLEANUP WARNING: These tests currently create batches that are not automatically
cleaned up after test completion. This may lead to resource accumulation over
multiple test runs. Only test_batch_immediate_cancellation properly cancels its batch.
The test_batch_e2e_chat_completions test does clean up its output and error files.
"""
import json
class TestBatchesIntegration:
"""Integration tests for the batches API."""
def test_batch_creation_and_retrieval(self, openai_client, batch_helper, text_model_id):
"""Test comprehensive batch creation and retrieval scenarios."""
test_metadata = {
"test_type": "comprehensive",
"purpose": "creation_and_retrieval_test",
"version": "1.0",
"tags": "test,batch",
}
batch_requests = [
{
"custom_id": "request-1",
"method": "POST",
"url": "/v1/chat/completions",
"body": {
"model": text_model_id,
"messages": [{"role": "user", "content": "Hello"}],
"max_tokens": 10,
},
}
]
with batch_helper.create_file(batch_requests, "batch_creation_test") as uploaded_file:
batch = openai_client.batches.create(
input_file_id=uploaded_file.id,
endpoint="/v1/chat/completions",
completion_window="24h",
metadata=test_metadata,
)
assert batch.endpoint == "/v1/chat/completions"
assert batch.input_file_id == uploaded_file.id
assert batch.completion_window == "24h"
assert batch.metadata == test_metadata
retrieved_batch = openai_client.batches.retrieve(batch.id)
assert retrieved_batch.id == batch.id
assert retrieved_batch.object == batch.object
assert retrieved_batch.endpoint == batch.endpoint
assert retrieved_batch.input_file_id == batch.input_file_id
assert retrieved_batch.completion_window == batch.completion_window
assert retrieved_batch.metadata == batch.metadata
def test_batch_listing(self, openai_client, batch_helper, text_model_id):
"""
Test batch listing.
This test creates multiple batches and verifies that they can be listed.
It also deletes the input files before execution, which means the batches
will appear as failed due to missing input files. This is expected and
a good thing, because it means no inference is performed.
"""
batch_ids = []
for i in range(2):
batch_requests = [
{
"custom_id": f"request-{i}",
"method": "POST",
"url": "/v1/chat/completions",
"body": {
"model": text_model_id,
"messages": [{"role": "user", "content": f"Hello {i}"}],
"max_tokens": 10,
},
}
]
with batch_helper.create_file(batch_requests, f"batch_input_{i}") as uploaded_file:
batch = openai_client.batches.create(
input_file_id=uploaded_file.id,
endpoint="/v1/chat/completions",
completion_window="24h",
)
batch_ids.append(batch.id)
batch_list = openai_client.batches.list()
assert isinstance(batch_list.data, list)
listed_batch_ids = {b.id for b in batch_list.data}
for batch_id in batch_ids:
assert batch_id in listed_batch_ids
def test_batch_immediate_cancellation(self, openai_client, batch_helper, text_model_id):
"""Test immediate batch cancellation."""
batch_requests = [
{
"custom_id": "request-1",
"method": "POST",
"url": "/v1/chat/completions",
"body": {
"model": text_model_id,
"messages": [{"role": "user", "content": "Hello"}],
"max_tokens": 10,
},
}
]
with batch_helper.create_file(batch_requests) as uploaded_file:
batch = openai_client.batches.create(
input_file_id=uploaded_file.id,
endpoint="/v1/chat/completions",
completion_window="24h",
)
# hopefully cancel the batch before it completes
cancelling_batch = openai_client.batches.cancel(batch.id)
assert cancelling_batch.status in ["cancelling", "cancelled"]
assert isinstance(cancelling_batch.cancelling_at, int), (
f"cancelling_at should be int, got {type(cancelling_batch.cancelling_at)}"
)
final_batch = batch_helper.wait_for(
batch.id,
max_wait_time=3 * 60, # often takes 10-11 minutes, give it 3 min
expected_statuses={"cancelled"},
timeout_action="skip",
)
assert final_batch.status == "cancelled"
assert isinstance(final_batch.cancelled_at, int), (
f"cancelled_at should be int, got {type(final_batch.cancelled_at)}"
)
def test_batch_e2e_chat_completions(self, openai_client, batch_helper, text_model_id):
"""Test end-to-end batch processing for chat completions with both successful and failed operations."""
batch_requests = [
{
"custom_id": "success-1",
"method": "POST",
"url": "/v1/chat/completions",
"body": {
"model": text_model_id,
"messages": [{"role": "user", "content": "Say hello"}],
"max_tokens": 20,
},
},
{
"custom_id": "error-1",
"method": "POST",
"url": "/v1/chat/completions",
"body": {
"model": text_model_id,
"messages": [{"rolez": "user", "contentz": "This should fail"}], # Invalid keys to trigger error
# note: ollama does not validate max_tokens values or the "role" key, so they won't trigger an error
},
},
]
with batch_helper.create_file(batch_requests) as uploaded_file:
batch = openai_client.batches.create(
input_file_id=uploaded_file.id,
endpoint="/v1/chat/completions",
completion_window="24h",
metadata={"test": "e2e_success_and_errors_test"},
)
final_batch = batch_helper.wait_for(
batch.id,
max_wait_time=3 * 60, # often takes 2-3 minutes
expected_statuses={"completed"},
timeout_action="skip",
)
# Expecting a completed batch with both successful and failed requests
# Batch(id='batch_xxx',
# completion_window='24h',
# created_at=...,
# endpoint='/v1/chat/completions',
# input_file_id='file-xxx',
# object='batch',
# status='completed',
# output_file_id='file-xxx',
# error_file_id='file-xxx',
# request_counts=BatchRequestCounts(completed=1, failed=1, total=2))
assert final_batch.status == "completed"
assert final_batch.request_counts is not None
assert final_batch.request_counts.total == 2
assert final_batch.request_counts.completed == 1
assert final_batch.request_counts.failed == 1
assert final_batch.output_file_id is not None, "Output file should exist for successful requests"
output_content = openai_client.files.content(final_batch.output_file_id)
if isinstance(output_content, str):
output_text = output_content
else:
output_text = output_content.content.decode("utf-8")
output_lines = output_text.strip().split("\n")
for line in output_lines:
result = json.loads(line)
assert "id" in result
assert "custom_id" in result
assert result["custom_id"] == "success-1"
assert "response" in result
assert result["response"]["status_code"] == 200
assert "body" in result["response"]
assert "choices" in result["response"]["body"]
assert final_batch.error_file_id is not None, "Error file should exist for failed requests"
error_content = openai_client.files.content(final_batch.error_file_id)
if isinstance(error_content, str):
error_text = error_content
else:
error_text = error_content.content.decode("utf-8")
error_lines = error_text.strip().split("\n")
for line in error_lines:
result = json.loads(line)
assert "id" in result
assert "custom_id" in result
assert result["custom_id"] == "error-1"
assert "error" in result
error = result["error"]
assert error is not None
assert "code" in error or "message" in error, "Error should have code or message"
deleted_output_file = openai_client.files.delete(final_batch.output_file_id)
assert deleted_output_file.deleted, f"Output file {final_batch.output_file_id} was not deleted successfully"
deleted_error_file = openai_client.files.delete(final_batch.error_file_id)
assert deleted_error_file.deleted, f"Error file {final_batch.error_file_id} was not deleted successfully"

View file

@ -0,0 +1,693 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
"""
Error handling and edge case tests for the Llama Stack batch processing functionality.
This module focuses exclusively on testing error conditions, validation failures,
and edge cases for batch operations to ensure robust error handling and graceful
degradation.
Test Categories:
1. File and Input Validation:
- test_batch_nonexistent_file_id: Handling invalid file IDs
- test_batch_malformed_jsonl: Processing malformed JSONL input files
- test_file_malformed_batch_file: Handling malformed files at upload time
- test_batch_missing_required_fields: Validation of required request fields
2. API Endpoint and Model Validation:
- test_batch_invalid_endpoint: Invalid endpoint handling during creation
- test_batch_error_handling_invalid_model: Error handling with nonexistent models
- test_batch_endpoint_mismatch: Validation of endpoint/URL consistency
3. Batch Lifecycle Error Handling:
- test_batch_retrieve_nonexistent: Retrieving non-existent batches
- test_batch_cancel_nonexistent: Cancelling non-existent batches
- test_batch_cancel_completed: Attempting to cancel completed batches
4. Parameter and Configuration Validation:
- test_batch_invalid_completion_window: Invalid completion window values
- test_batch_invalid_metadata_types: Invalid metadata type validation
- test_batch_missing_required_body_fields: Validation of required fields in request body
5. Feature Restriction and Compatibility:
- test_batch_streaming_not_supported: Streaming request rejection
- test_batch_mixed_streaming_requests: Mixed streaming/non-streaming validation
Note: Core functionality and OpenAI compatibility tests are located in
test_batches_integration.py for better organization and separation of concerns.
CLEANUP WARNING: These tests create batches to test error conditions but do not
automatically clean them up after test completion. While most error tests create
batches that fail quickly, some may create valid batches that consume resources.
"""
import pytest
from openai import BadRequestError, ConflictError, NotFoundError
class TestBatchesErrorHandling:
"""Error handling and edge case tests for the batches API using OpenAI client."""
def test_batch_nonexistent_file_id(self, openai_client, batch_helper):
"""Test batch creation with nonexistent input file ID."""
batch = openai_client.batches.create(
input_file_id="file-nonexistent-xyz",
endpoint="/v1/chat/completions",
completion_window="24h",
)
final_batch = batch_helper.wait_for(batch.id, expected_statuses={"failed"})
# Expecting -
# Batch(...,
# status='failed',
# errors=Errors(data=[
# BatchError(
# code='invalid_request',
# line=None,
# message='Cannot find file ..., or organization ... does not have access to it.',
# param='file_id')
# ], object='list'),
# failed_at=1754566971,
# ...)
assert final_batch.status == "failed"
assert final_batch.errors is not None
assert len(final_batch.errors.data) == 1
error = final_batch.errors.data[0]
assert error.code == "invalid_request"
assert "cannot find file" in error.message.lower()
def test_batch_invalid_endpoint(self, openai_client, batch_helper, text_model_id):
"""Test batch creation with invalid endpoint."""
batch_requests = [
{
"custom_id": "invalid-endpoint",
"method": "POST",
"url": "/v1/chat/completions",
"body": {
"model": text_model_id,
"messages": [{"role": "user", "content": "Hello"}],
"max_tokens": 10,
},
}
]
with batch_helper.create_file(batch_requests) as uploaded_file:
with pytest.raises(BadRequestError) as exc_info:
openai_client.batches.create(
input_file_id=uploaded_file.id,
endpoint="/v1/invalid/endpoint",
completion_window="24h",
)
# Expected -
# Error code: 400 - {
# 'error': {
# 'message': "Invalid value: '/v1/invalid/endpoint'. Supported values are: '/v1/chat/completions', '/v1/completions', '/v1/embeddings', and '/v1/responses'.",
# 'type': 'invalid_request_error',
# 'param': 'endpoint',
# 'code': 'invalid_value'
# }
# }
error_msg = str(exc_info.value).lower()
assert exc_info.value.status_code == 400
assert "invalid value" in error_msg
assert "/v1/invalid/endpoint" in error_msg
assert "supported values" in error_msg
assert "endpoint" in error_msg
assert "invalid_value" in error_msg
def test_batch_malformed_jsonl(self, openai_client, batch_helper):
"""
Test batch with malformed JSONL input.
The /v1/files endpoint requires valid JSONL format, so we provide a well formed line
before a malformed line to ensure we get to the /v1/batches validation stage.
"""
with batch_helper.create_file(
"""{"custom_id": "valid", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "test"}}
{invalid json here""",
"malformed_batch_input.jsonl",
) as uploaded_file:
batch = openai_client.batches.create(
input_file_id=uploaded_file.id,
endpoint="/v1/chat/completions",
completion_window="24h",
)
final_batch = batch_helper.wait_for(batch.id, expected_statuses={"failed"})
# Expecting -
# Batch(...,
# status='failed',
# errors=Errors(data=[
# ...,
# BatchError(code='invalid_json_line',
# line=2,
# message='This line is not parseable as valid JSON.',
# param=None)
# ], object='list'),
# ...)
assert final_batch.status == "failed"
assert final_batch.errors is not None
assert len(final_batch.errors.data) > 0
error = final_batch.errors.data[-1] # get last error because first may be about the "test" model
assert error.code == "invalid_json_line"
assert error.line == 2
assert "not" in error.message.lower()
assert "valid json" in error.message.lower()
@pytest.mark.xfail(reason="Not all file providers validate content")
@pytest.mark.parametrize("batch_requests", ["", "{malformed json"], ids=["empty", "malformed"])
def test_file_malformed_batch_file(self, openai_client, batch_helper, batch_requests):
"""Test file upload with malformed content."""
with pytest.raises(BadRequestError) as exc_info:
with batch_helper.create_file(batch_requests, "malformed_batch_input_file.jsonl"):
# /v1/files rejects the file, we don't get to batch creation
pass
error_msg = str(exc_info.value).lower()
assert exc_info.value.status_code == 400
assert "invalid file format" in error_msg
assert "jsonl" in error_msg
def test_batch_retrieve_nonexistent(self, openai_client):
"""Test retrieving nonexistent batch."""
with pytest.raises(NotFoundError) as exc_info:
openai_client.batches.retrieve("batch-nonexistent-xyz")
error_msg = str(exc_info.value).lower()
assert exc_info.value.status_code == 404
assert "no batch found" in error_msg or "not found" in error_msg
def test_batch_cancel_nonexistent(self, openai_client):
"""Test cancelling nonexistent batch."""
with pytest.raises(NotFoundError) as exc_info:
openai_client.batches.cancel("batch-nonexistent-xyz")
error_msg = str(exc_info.value).lower()
assert exc_info.value.status_code == 404
assert "no batch found" in error_msg or "not found" in error_msg
def test_batch_cancel_completed(self, openai_client, batch_helper, text_model_id):
"""Test cancelling already completed batch."""
batch_requests = [
{
"custom_id": "cancel-completed",
"method": "POST",
"url": "/v1/chat/completions",
"body": {
"model": text_model_id,
"messages": [{"role": "user", "content": "Quick test"}],
"max_tokens": 5,
},
}
]
with batch_helper.create_file(batch_requests, "cancel_test_batch_input") as uploaded_file:
batch = openai_client.batches.create(
input_file_id=uploaded_file.id,
endpoint="/v1/chat/completions",
completion_window="24h",
)
final_batch = batch_helper.wait_for(
batch.id,
max_wait_time=3 * 60, # often take 10-11 min, give it 3 min
expected_statuses={"completed"},
timeout_action="skip",
)
deleted_file = openai_client.files.delete(final_batch.output_file_id)
assert deleted_file.deleted, f"File {final_batch.output_file_id} was not deleted successfully"
with pytest.raises(ConflictError) as exc_info:
openai_client.batches.cancel(batch.id)
# Expecting -
# Error code: 409 - {
# 'error': {
# 'message': "Cannot cancel a batch with status 'completed'.",
# 'type': 'invalid_request_error',
# 'param': None,
# 'code': None
# }
# }
#
# NOTE: Same for "failed", cancelling "cancelled" batches is allowed
error_msg = str(exc_info.value).lower()
assert exc_info.value.status_code == 409
assert "cannot cancel" in error_msg
def test_batch_missing_required_fields(self, openai_client, batch_helper, text_model_id):
"""Test batch with requests missing required fields."""
batch_requests = [
{
# Missing custom_id
"method": "POST",
"url": "/v1/chat/completions",
"body": {
"model": text_model_id,
"messages": [{"role": "user", "content": "No custom_id"}],
"max_tokens": 10,
},
},
{
"custom_id": "no-method",
"url": "/v1/chat/completions",
"body": {
"model": text_model_id,
"messages": [{"role": "user", "content": "No method"}],
"max_tokens": 10,
},
},
{
"custom_id": "no-url",
"method": "POST",
"body": {
"model": text_model_id,
"messages": [{"role": "user", "content": "No URL"}],
"max_tokens": 10,
},
},
{
"custom_id": "no-body",
"method": "POST",
"url": "/v1/chat/completions",
},
]
with batch_helper.create_file(batch_requests, "missing_fields_batch_input") as uploaded_file:
batch = openai_client.batches.create(
input_file_id=uploaded_file.id,
endpoint="/v1/chat/completions",
completion_window="24h",
)
final_batch = batch_helper.wait_for(batch.id, expected_statuses={"failed"})
# Expecting -
# Batch(...,
# status='failed',
# errors=Errors(
# data=[
# BatchError(
# code='missing_required_parameter',
# line=1,
# message="Missing required parameter: 'custom_id'.",
# param='custom_id'
# ),
# BatchError(
# code='missing_required_parameter',
# line=2,
# message="Missing required parameter: 'method'.",
# param='method'
# ),
# BatchError(
# code='missing_required_parameter',
# line=3,
# message="Missing required parameter: 'url'.",
# param='url'
# ),
# BatchError(
# code='missing_required_parameter',
# line=4,
# message="Missing required parameter: 'body'.",
# param='body'
# )
# ], object='list'),
# failed_at=1754566945,
# ...)
# )
assert final_batch.status == "failed"
assert final_batch.errors is not None
assert len(final_batch.errors.data) == 4
no_custom_id_error = final_batch.errors.data[0]
assert no_custom_id_error.code == "missing_required_parameter"
assert no_custom_id_error.line == 1
assert "missing" in no_custom_id_error.message.lower()
assert "custom_id" in no_custom_id_error.message.lower()
no_method_error = final_batch.errors.data[1]
assert no_method_error.code == "missing_required_parameter"
assert no_method_error.line == 2
assert "missing" in no_method_error.message.lower()
assert "method" in no_method_error.message.lower()
no_url_error = final_batch.errors.data[2]
assert no_url_error.code == "missing_required_parameter"
assert no_url_error.line == 3
assert "missing" in no_url_error.message.lower()
assert "url" in no_url_error.message.lower()
no_body_error = final_batch.errors.data[3]
assert no_body_error.code == "missing_required_parameter"
assert no_body_error.line == 4
assert "missing" in no_body_error.message.lower()
assert "body" in no_body_error.message.lower()
def test_batch_invalid_completion_window(self, openai_client, batch_helper, text_model_id):
"""Test batch creation with invalid completion window."""
batch_requests = [
{
"custom_id": "invalid-completion-window",
"method": "POST",
"url": "/v1/chat/completions",
"body": {
"model": text_model_id,
"messages": [{"role": "user", "content": "Hello"}],
"max_tokens": 10,
},
}
]
with batch_helper.create_file(batch_requests) as uploaded_file:
for window in ["1h", "48h", "invalid", ""]:
with pytest.raises(BadRequestError) as exc_info:
openai_client.batches.create(
input_file_id=uploaded_file.id,
endpoint="/v1/chat/completions",
completion_window=window,
)
assert exc_info.value.status_code == 400
error_msg = str(exc_info.value).lower()
assert "error" in error_msg
assert "completion_window" in error_msg
def test_batch_streaming_not_supported(self, openai_client, batch_helper, text_model_id):
"""Test that streaming responses are not supported in batches."""
batch_requests = [
{
"custom_id": "streaming-test",
"method": "POST",
"url": "/v1/chat/completions",
"body": {
"model": text_model_id,
"messages": [{"role": "user", "content": "Hello"}],
"max_tokens": 10,
"stream": True, # Not supported
},
}
]
with batch_helper.create_file(batch_requests, "streaming_batch_input") as uploaded_file:
batch = openai_client.batches.create(
input_file_id=uploaded_file.id,
endpoint="/v1/chat/completions",
completion_window="24h",
)
final_batch = batch_helper.wait_for(batch.id, expected_statuses={"failed"})
# Expecting -
# Batch(...,
# status='failed',
# errors=Errors(data=[
# BatchError(code='streaming_unsupported',
# line=1,
# message='Chat Completions: Streaming is not supported in the Batch API.',
# param='body.stream')
# ], object='list'),
# failed_at=1754566965,
# ...)
assert final_batch.status == "failed"
assert final_batch.errors is not None
assert len(final_batch.errors.data) == 1
error = final_batch.errors.data[0]
assert error.code == "streaming_unsupported"
assert error.line == 1
assert "streaming" in error.message.lower()
assert "not supported" in error.message.lower()
assert error.param == "body.stream"
assert final_batch.failed_at is not None
def test_batch_mixed_streaming_requests(self, openai_client, batch_helper, text_model_id):
"""
Test batch with mixed streaming and non-streaming requests.
This is distinct from test_batch_streaming_not_supported, which tests a single
streaming request, to ensure an otherwise valid batch fails when a single
streaming request is included.
"""
batch_requests = [
{
"custom_id": "valid-non-streaming-request",
"method": "POST",
"url": "/v1/chat/completions",
"body": {
"model": text_model_id,
"messages": [{"role": "user", "content": "Hello without streaming"}],
"max_tokens": 10,
},
},
{
"custom_id": "streaming-request",
"method": "POST",
"url": "/v1/chat/completions",
"body": {
"model": text_model_id,
"messages": [{"role": "user", "content": "Hello with streaming"}],
"max_tokens": 10,
"stream": True, # Not supported
},
},
]
with batch_helper.create_file(batch_requests, "mixed_streaming_batch_input") as uploaded_file:
batch = openai_client.batches.create(
input_file_id=uploaded_file.id,
endpoint="/v1/chat/completions",
completion_window="24h",
)
final_batch = batch_helper.wait_for(batch.id, expected_statuses={"failed"})
# Expecting -
# Batch(...,
# status='failed',
# errors=Errors(data=[
# BatchError(
# code='streaming_unsupported',
# line=2,
# message='Chat Completions: Streaming is not supported in the Batch API.',
# param='body.stream')
# ], object='list'),
# failed_at=1754574442,
# ...)
assert final_batch.status == "failed"
assert final_batch.errors is not None
assert len(final_batch.errors.data) == 1
error = final_batch.errors.data[0]
assert error.code == "streaming_unsupported"
assert error.line == 2
assert "streaming" in error.message.lower()
assert "not supported" in error.message.lower()
assert error.param == "body.stream"
assert final_batch.failed_at is not None
def test_batch_endpoint_mismatch(self, openai_client, batch_helper, text_model_id):
"""Test batch creation with mismatched endpoint and request URL."""
batch_requests = [
{
"custom_id": "endpoint-mismatch",
"method": "POST",
"url": "/v1/embeddings", # Different from batch endpoint
"body": {
"model": text_model_id,
"messages": [{"role": "user", "content": "Hello"}],
},
}
]
with batch_helper.create_file(batch_requests, "endpoint_mismatch_batch_input") as uploaded_file:
batch = openai_client.batches.create(
input_file_id=uploaded_file.id,
endpoint="/v1/chat/completions", # Different from request URL
completion_window="24h",
)
final_batch = batch_helper.wait_for(batch.id, expected_statuses={"failed"})
# Expecting -
# Batch(...,
# status='failed',
# errors=Errors(data=[
# BatchError(
# code='invalid_url',
# line=1,
# message='The URL provided for this request does not match the batch endpoint.',
# param='url')
# ], object='list'),
# failed_at=1754566972,
# ...)
assert final_batch.status == "failed"
assert final_batch.errors is not None
assert len(final_batch.errors.data) == 1
error = final_batch.errors.data[0]
assert error.line == 1
assert error.code == "invalid_url"
assert "does not match" in error.message.lower()
assert "endpoint" in error.message.lower()
assert final_batch.failed_at is not None
def test_batch_error_handling_invalid_model(self, openai_client, batch_helper):
"""Test batch error handling with invalid model."""
batch_requests = [
{
"custom_id": "invalid-model",
"method": "POST",
"url": "/v1/chat/completions",
"body": {
"model": "nonexistent-model-xyz",
"messages": [{"role": "user", "content": "Hello"}],
"max_tokens": 10,
},
}
]
with batch_helper.create_file(batch_requests) as uploaded_file:
batch = openai_client.batches.create(
input_file_id=uploaded_file.id,
endpoint="/v1/chat/completions",
completion_window="24h",
)
final_batch = batch_helper.wait_for(batch.id, expected_statuses={"failed"})
# Expecting -
# Batch(...,
# status='failed',
# errors=Errors(data=[
# BatchError(code='model_not_found',
# line=1,
# message="The provided model 'nonexistent-model-xyz' is not supported by the Batch API.",
# param='body.model')
# ], object='list'),
# failed_at=1754566978,
# ...)
assert final_batch.status == "failed"
assert final_batch.errors is not None
assert len(final_batch.errors.data) == 1
error = final_batch.errors.data[0]
assert error.line == 1
assert error.code == "model_not_found"
assert "not supported" in error.message.lower()
assert error.param == "body.model"
assert final_batch.failed_at is not None
def test_batch_missing_required_body_fields(self, openai_client, batch_helper, text_model_id):
"""Test batch with requests missing required fields in body (model and messages)."""
batch_requests = [
{
"custom_id": "missing-model",
"method": "POST",
"url": "/v1/chat/completions",
"body": {
# Missing model field
"messages": [{"role": "user", "content": "Hello without model"}],
"max_tokens": 10,
},
},
{
"custom_id": "missing-messages",
"method": "POST",
"url": "/v1/chat/completions",
"body": {
"model": text_model_id,
# Missing messages field
"max_tokens": 10,
},
},
]
with batch_helper.create_file(batch_requests, "missing_body_fields_batch_input") as uploaded_file:
batch = openai_client.batches.create(
input_file_id=uploaded_file.id,
endpoint="/v1/chat/completions",
completion_window="24h",
)
final_batch = batch_helper.wait_for(batch.id, expected_statuses={"failed"})
# Expecting -
# Batch(...,
# status='failed',
# errors=Errors(data=[
# BatchError(
# code='invalid_request',
# line=1,
# message='Model parameter is required.',
# param='body.model'),
# BatchError(
# code='invalid_request',
# line=2,
# message='Messages parameter is required.',
# param='body.messages')
# ], object='list'),
# ...)
assert final_batch.status == "failed"
assert final_batch.errors is not None
assert len(final_batch.errors.data) == 2
model_error = final_batch.errors.data[0]
assert model_error.line == 1
assert "model" in model_error.message.lower()
assert model_error.param == "body.model"
messages_error = final_batch.errors.data[1]
assert messages_error.line == 2
assert "messages" in messages_error.message.lower()
assert messages_error.param == "body.messages"
assert final_batch.failed_at is not None
def test_batch_invalid_metadata_types(self, openai_client, batch_helper, text_model_id):
"""Test batch creation with invalid metadata types (like lists)."""
batch_requests = [
{
"custom_id": "invalid-metadata-type",
"method": "POST",
"url": "/v1/chat/completions",
"body": {
"model": text_model_id,
"messages": [{"role": "user", "content": "Hello"}],
"max_tokens": 10,
},
}
]
with batch_helper.create_file(batch_requests) as uploaded_file:
with pytest.raises(Exception) as exc_info:
openai_client.batches.create(
input_file_id=uploaded_file.id,
endpoint="/v1/chat/completions",
completion_window="24h",
metadata={
"tags": ["tag1", "tag2"], # Invalid type, should be a string
},
)
# Expecting -
# Error code: 400 - {'error':
# {'message': "Invalid type for 'metadata.tags': expected a string,
# but got an array instead.",
# 'type': 'invalid_request_error', 'param': 'metadata.tags',
# 'code': 'invalid_type'}}
error_msg = str(exc_info.value).lower()
assert "400" in error_msg
assert "tags" in error_msg
assert "string" in error_msg

View file

@ -270,7 +270,7 @@ def openai_client(client_with_models):
@pytest.fixture(params=["openai_client", "client_with_models"])
def compat_client(request, client_with_models):
if isinstance(client_with_models, LlamaStackAsLibraryClient):
if request.param == "openai_client" and isinstance(client_with_models, LlamaStackAsLibraryClient):
# OpenAI client expects a server, so unless we also rewrite OpenAI client's requests
# to go via the Stack library client (which itself rewrites requests to be served inline),
# we cannot do this.

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
import os
import re
from pathlib import Path
import pytest
@ -48,19 +47,6 @@ def _load_all_verification_configs():
return {"providers": all_provider_configs}
def case_id_generator(case):
"""Generate a test ID from the case's 'case_id' field, or use a default."""
case_id = case.get("case_id")
if isinstance(case_id, str | int):
return re.sub(r"\\W|^(?=\\d)", "_", str(case_id))
return None
# Helper to get the base test name from the request object
def get_base_test_name(request):
return request.node.originalname
# --- End Helper Functions ---

View file

@ -1,16 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pathlib import Path
import yaml
def load_test_cases(name: str):
fixture_dir = Path(__file__).parent / "test_cases"
yaml_path = fixture_dir / f"{name}.yaml"
with open(yaml_path) as f:
return yaml.safe_load(f)

View file

@ -0,0 +1,262 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
import pytest
from pydantic import BaseModel
class ResponsesTestCase(BaseModel):
# Input can be a simple string or complex message structure
input: str | list[dict[str, Any]]
expected: str
# Tools as flexible dict structure (gets validated at runtime by the API)
tools: list[dict[str, Any]] | None = None
# Multi-turn conversations with input/output pairs
turns: list[tuple[str | list[dict[str, Any]], str]] | None = None
# File search specific fields
file_content: str | None = None
file_path: str | None = None
# Streaming flag
stream: bool | None = None
# Basic response test cases
basic_test_cases = [
pytest.param(
ResponsesTestCase(
input="Which planet do humans live on?",
expected="earth",
),
id="earth",
),
pytest.param(
ResponsesTestCase(
input="Which planet has rings around it with a name starting with letter S?",
expected="saturn",
),
id="saturn",
),
pytest.param(
ResponsesTestCase(
input=[
{
"role": "user",
"content": [
{
"type": "input_text",
"text": "what teams are playing in this image?",
}
],
},
{
"role": "user",
"content": [
{
"type": "input_image",
"image_url": "https://upload.wikimedia.org/wikipedia/commons/3/3b/LeBron_James_Layup_%28Cleveland_vs_Brooklyn_2018%29.jpg",
}
],
},
],
expected="brooklyn nets",
),
id="image_input",
),
]
# Multi-turn test cases
multi_turn_test_cases = [
pytest.param(
ResponsesTestCase(
input="", # Not used for multi-turn
expected="", # Not used for multi-turn
turns=[
("Which planet do humans live on?", "earth"),
("What is the name of the planet from your previous response?", "earth"),
],
),
id="earth",
),
]
# Web search test cases
web_search_test_cases = [
pytest.param(
ResponsesTestCase(
input="How many experts does the Llama 4 Maverick model have?",
tools=[{"type": "web_search", "search_context_size": "low"}],
expected="128",
),
id="llama_experts",
),
]
# File search test cases
file_search_test_cases = [
pytest.param(
ResponsesTestCase(
input="How many experts does the Llama 4 Maverick model have?",
tools=[{"type": "file_search"}],
expected="128",
file_content="Llama 4 Maverick has 128 experts",
),
id="llama_experts",
),
pytest.param(
ResponsesTestCase(
input="How many experts does the Llama 4 Maverick model have?",
tools=[{"type": "file_search"}],
expected="128",
file_path="pdfs/llama_stack_and_models.pdf",
),
id="llama_experts_pdf",
),
]
# MCP tool test cases
mcp_tool_test_cases = [
pytest.param(
ResponsesTestCase(
input="What is the boiling point of myawesomeliquid in Celsius?",
tools=[{"type": "mcp", "server_label": "localmcp", "server_url": "<FILLED_BY_TEST_RUNNER>"}],
expected="Hello, world!",
),
id="boiling_point_tool",
),
]
# Custom tool test cases
custom_tool_test_cases = [
pytest.param(
ResponsesTestCase(
input="What's the weather like in San Francisco?",
tools=[
{
"type": "function",
"name": "get_weather",
"description": "Get current temperature for a given location.",
"parameters": {
"additionalProperties": False,
"properties": {
"location": {
"description": "City and country e.g. Bogotá, Colombia",
"type": "string",
}
},
"required": ["location"],
"type": "object",
},
}
],
expected="", # No specific expected output for custom tools
),
id="sf_weather",
),
]
# Image test cases
image_test_cases = [
pytest.param(
ResponsesTestCase(
input=[
{
"role": "user",
"content": [
{
"type": "input_text",
"text": "Identify the type of animal in this image.",
},
{
"type": "input_image",
"image_url": "https://upload.wikimedia.org/wikipedia/commons/f/f7/Llamas%2C_Vernagt-Stausee%2C_Italy.jpg",
},
],
},
],
expected="llama",
),
id="llama_image",
),
]
# Multi-turn image test cases
multi_turn_image_test_cases = [
pytest.param(
ResponsesTestCase(
input="", # Not used for multi-turn
expected="", # Not used for multi-turn
turns=[
(
[
{
"role": "user",
"content": [
{
"type": "input_text",
"text": "What type of animal is in this image? Please respond with a single word that starts with the letter 'L'.",
},
{
"type": "input_image",
"image_url": "https://upload.wikimedia.org/wikipedia/commons/f/f7/Llamas%2C_Vernagt-Stausee%2C_Italy.jpg",
},
],
},
],
"llama",
),
(
"What country do you find this animal primarily in? What continent?",
"peru",
),
],
),
id="llama_image_understanding",
),
]
# Multi-turn tool execution test cases
multi_turn_tool_execution_test_cases = [
pytest.param(
ResponsesTestCase(
input="I need to check if user 'alice' can access the file 'document.txt'. First, get alice's user ID, then check if that user ID can access the file 'document.txt'. Do this as a series of steps, where each step is a separate message. Return only one tool call per step. Summarize the final result with a single 'yes' or 'no' response.",
tools=[{"type": "mcp", "server_label": "localmcp", "server_url": "<FILLED_BY_TEST_RUNNER>"}],
expected="yes",
),
id="user_file_access_check",
),
pytest.param(
ResponsesTestCase(
input="I need to get the results for the 'boiling_point' experiment. First, get the experiment ID for 'boiling_point', then use that ID to get the experiment results. Tell me the boiling point in Celsius.",
tools=[{"type": "mcp", "server_label": "localmcp", "server_url": "<FILLED_BY_TEST_RUNNER>"}],
expected="100°C",
),
id="experiment_results_lookup",
),
]
# Multi-turn tool execution streaming test cases
multi_turn_tool_execution_streaming_test_cases = [
pytest.param(
ResponsesTestCase(
input="Help me with this security check: First, get the user ID for 'charlie', then get the permissions for that user ID, and finally check if that user can access 'secret_file.txt'. Stream your progress as you work through each step. Return only one tool call per step. Summarize the final result with a single 'yes' or 'no' response.",
tools=[{"type": "mcp", "server_label": "localmcp", "server_url": "<FILLED_BY_TEST_RUNNER>"}],
expected="no",
stream=True,
),
id="user_permissions_workflow",
),
pytest.param(
ResponsesTestCase(
input="I need a complete analysis: First, get the experiment ID for 'chemical_reaction', then get the results for that experiment, and tell me if the yield was above 80%. Return only one tool call per step. Please stream your analysis process.",
tools=[{"type": "mcp", "server_label": "localmcp", "server_url": "<FILLED_BY_TEST_RUNNER>"}],
expected="85%",
stream=True,
),
id="experiment_analysis_streaming",
),
]

View file

@ -1,397 +0,0 @@
test_chat_basic:
test_name: test_chat_basic
test_params:
case:
- case_id: "earth"
input:
messages:
- content: Which planet do humans live on?
role: user
output: Earth
- case_id: "saturn"
input:
messages:
- content: Which planet has rings around it with a name starting with letter
S?
role: user
output: Saturn
test_chat_input_validation:
test_name: test_chat_input_validation
test_params:
case:
- case_id: "messages_missing"
input:
messages: []
output:
error:
status_code: 400
- case_id: "messages_role_invalid"
input:
messages:
- content: Which planet do humans live on?
role: fake_role
output:
error:
status_code: 400
- case_id: "tool_choice_invalid"
input:
messages:
- content: Which planet do humans live on?
role: user
tool_choice: invalid
output:
error:
status_code: 400
- case_id: "tool_choice_no_tools"
input:
messages:
- content: Which planet do humans live on?
role: user
tool_choice: required
output:
error:
status_code: 400
- case_id: "tools_type_invalid"
input:
messages:
- content: Which planet do humans live on?
role: user
tools:
- type: invalid
output:
error:
status_code: 400
test_chat_image:
test_name: test_chat_image
test_params:
case:
- input:
messages:
- content:
- text: What is in this image?
type: text
- image_url:
url: https://upload.wikimedia.org/wikipedia/commons/f/f7/Llamas%2C_Vernagt-Stausee%2C_Italy.jpg
type: image_url
role: user
output: llama
test_chat_structured_output:
test_name: test_chat_structured_output
test_params:
case:
- case_id: "calendar"
input:
messages:
- content: Extract the event information.
role: system
- content: Alice and Bob are going to a science fair on Friday.
role: user
response_format:
json_schema:
name: calendar_event
schema:
properties:
date:
title: Date
type: string
name:
title: Name
type: string
participants:
items:
type: string
title: Participants
type: array
required:
- name
- date
- participants
title: CalendarEvent
type: object
type: json_schema
output: valid_calendar_event
- case_id: "math"
input:
messages:
- content: You are a helpful math tutor. Guide the user through the solution
step by step.
role: system
- content: how can I solve 8x + 7 = -23
role: user
response_format:
json_schema:
name: math_reasoning
schema:
$defs:
Step:
properties:
explanation:
title: Explanation
type: string
output:
title: Output
type: string
required:
- explanation
- output
title: Step
type: object
properties:
final_answer:
title: Final Answer
type: string
steps:
items:
$ref: '#/$defs/Step'
title: Steps
type: array
required:
- steps
- final_answer
title: MathReasoning
type: object
type: json_schema
output: valid_math_reasoning
test_tool_calling:
test_name: test_tool_calling
test_params:
case:
- input:
messages:
- content: You are a helpful assistant that can use tools to get information.
role: system
- content: What's the weather like in San Francisco?
role: user
tools:
- function:
description: Get current temperature for a given location.
name: get_weather
parameters:
additionalProperties: false
properties:
location:
description: "City and country e.g. Bogot\xE1, Colombia"
type: string
required:
- location
type: object
type: function
output: get_weather_tool_call
test_chat_multi_turn_tool_calling:
test_name: test_chat_multi_turn_tool_calling
test_params:
case:
- case_id: "text_then_weather_tool"
input:
messages:
- - role: user
content: "What's the name of the Sun in latin?"
- - role: user
content: "What's the weather like in San Francisco?"
tools:
- function:
description: Get the current weather
name: get_weather
parameters:
type: object
properties:
location:
description: "The city and state (both required), e.g. San Francisco, CA."
type: string
required: ["location"]
type: function
tool_responses:
- response: "{'response': '70 degrees and foggy'}"
expected:
- num_tool_calls: 0
answer: ["sol"]
- num_tool_calls: 1
tool_name: get_weather
tool_arguments:
location: "San Francisco, CA"
- num_tool_calls: 0
answer: ["foggy", "70 degrees"]
- case_id: "weather_tool_then_text"
input:
messages:
- - role: user
content: "What's the weather like in San Francisco?"
tools:
- function:
description: Get the current weather
name: get_weather
parameters:
type: object
properties:
location:
description: "The city and state (both required), e.g. San Francisco, CA."
type: string
required: ["location"]
type: function
tool_responses:
- response: "{'response': '70 degrees and foggy'}"
expected:
- num_tool_calls: 1
tool_name: get_weather
tool_arguments:
location: "San Francisco, CA"
- num_tool_calls: 0
answer: ["foggy", "70 degrees"]
- case_id: "add_product_tool"
input:
messages:
- - role: user
content: "Please add a new product with name 'Widget', price 19.99, in stock, and tags ['new', 'sale'] and give me the product id."
tools:
- function:
description: Add a new product
name: addProduct
parameters:
type: object
properties:
name:
description: "Name of the product"
type: string
price:
description: "Price of the product"
type: number
inStock:
description: "Availability status of the product."
type: boolean
tags:
description: "List of product tags"
type: array
items:
type: string
required: ["name", "price", "inStock"]
type: function
tool_responses:
- response: "{'response': 'Successfully added product with id: 123'}"
expected:
- num_tool_calls: 1
tool_name: addProduct
tool_arguments:
name: "Widget"
price: 19.99
inStock: true
tags:
- "new"
- "sale"
- num_tool_calls: 0
answer: ["123", "product id: 123"]
- case_id: "get_then_create_event_tool"
input:
messages:
- - role: system
content: "Todays date is 2025-03-01."
- role: user
content: "Do i have any meetings on March 3rd at 10 am? Yes or no?"
- - role: user
content: "Alright then, Create an event named 'Team Building', scheduled for that time same time, in the 'Main Conference Room' and add Alice, Bob, Charlie to it. Give me the created event id."
tools:
- function:
description: Create a new event
name: create_event
parameters:
type: object
properties:
name:
description: "Name of the event"
type: string
date:
description: "Date of the event in ISO format"
type: string
time:
description: "Event Time (HH:MM)"
type: string
location:
description: "Location of the event"
type: string
participants:
description: "List of participant names"
type: array
items:
type: string
required: ["name", "date", "time", "location", "participants"]
type: function
- function:
description: Get an event by date and time
name: get_event
parameters:
type: object
properties:
date:
description: "Date of the event in ISO format"
type: string
time:
description: "Event Time (HH:MM)"
type: string
required: ["date", "time"]
type: function
tool_responses:
- response: "{'response': 'No events found for 2025-03-03 at 10:00'}"
- response: "{'response': 'Successfully created new event with id: e_123'}"
expected:
- num_tool_calls: 1
tool_name: get_event
tool_arguments:
date: "2025-03-03"
time: "10:00"
- num_tool_calls: 0
answer: ["no", "no events found", "no meetings"]
- num_tool_calls: 1
tool_name: create_event
tool_arguments:
name: "Team Building"
date: "2025-03-03"
time: "10:00"
location: "Main Conference Room"
participants:
- "Alice"
- "Bob"
- "Charlie"
- num_tool_calls: 0
answer: ["e_123", "event id: e_123"]
- case_id: "compare_monthly_expense_tool"
input:
messages:
- - role: system
content: "Todays date is 2025-03-01."
- role: user
content: "what was my monthly expense in Jan of this year?"
- - role: user
content: "Was it less than Feb of last year? Only answer with yes or no."
tools:
- function:
description: Get monthly expense summary
name: getMonthlyExpenseSummary
parameters:
type: object
properties:
month:
description: "Month of the year (1-12)"
type: integer
year:
description: "Year"
type: integer
required: ["month", "year"]
type: function
tool_responses:
- response: "{'response': 'Total expenses for January 2025: $1000'}"
- response: "{'response': 'Total expenses for February 2024: $2000'}"
expected:
- num_tool_calls: 1
tool_name: getMonthlyExpenseSummary
tool_arguments:
month: 1
year: 2025
- num_tool_calls: 0
answer: ["1000", "$1,000", "1,000"]
- num_tool_calls: 1
tool_name: getMonthlyExpenseSummary
tool_arguments:
month: 2
year: 2024
- num_tool_calls: 0
answer: ["yes"]

View file

@ -1,166 +0,0 @@
test_response_basic:
test_name: test_response_basic
test_params:
case:
- case_id: "earth"
input: "Which planet do humans live on?"
output: "earth"
- case_id: "saturn"
input: "Which planet has rings around it with a name starting with letter S?"
output: "saturn"
- case_id: "image_input"
input:
- role: user
content:
- type: input_text
text: "what teams are playing in this image?"
- role: user
content:
- type: input_image
image_url: "https://upload.wikimedia.org/wikipedia/commons/3/3b/LeBron_James_Layup_%28Cleveland_vs_Brooklyn_2018%29.jpg"
output: "brooklyn nets"
test_response_multi_turn:
test_name: test_response_multi_turn
test_params:
case:
- case_id: "earth"
turns:
- input: "Which planet do humans live on?"
output: "earth"
- input: "What is the name of the planet from your previous response?"
output: "earth"
test_response_web_search:
test_name: test_response_web_search
test_params:
case:
- case_id: "llama_experts"
input: "How many experts does the Llama 4 Maverick model have?"
tools:
- type: web_search
search_context_size: "low"
output: "128"
test_response_file_search:
test_name: test_response_file_search
test_params:
case:
- case_id: "llama_experts"
input: "How many experts does the Llama 4 Maverick model have?"
tools:
- type: file_search
# vector_store_ids param for file_search tool gets added by the test runner
file_content: "Llama 4 Maverick has 128 experts"
output: "128"
- case_id: "llama_experts_pdf"
input: "How many experts does the Llama 4 Maverick model have?"
tools:
- type: file_search
# vector_store_ids param for file_search toolgets added by the test runner
file_path: "pdfs/llama_stack_and_models.pdf"
output: "128"
test_response_mcp_tool:
test_name: test_response_mcp_tool
test_params:
case:
- case_id: "boiling_point_tool"
input: "What is the boiling point of myawesomeliquid in Celsius?"
tools:
- type: mcp
server_label: "localmcp"
server_url: "<FILLED_BY_TEST_RUNNER>"
output: "Hello, world!"
test_response_custom_tool:
test_name: test_response_custom_tool
test_params:
case:
- case_id: "sf_weather"
input: "What's the weather like in San Francisco?"
tools:
- type: function
name: get_weather
description: Get current temperature for a given location.
parameters:
additionalProperties: false
properties:
location:
description: "City and country e.g. Bogot\xE1, Colombia"
type: string
required:
- location
type: object
test_response_image:
test_name: test_response_image
test_params:
case:
- case_id: "llama_image"
input:
- role: user
content:
- type: input_text
text: "Identify the type of animal in this image."
- type: input_image
image_url: "https://upload.wikimedia.org/wikipedia/commons/f/f7/Llamas%2C_Vernagt-Stausee%2C_Italy.jpg"
output: "llama"
# the models are really poor at tool calling after seeing images :/
test_response_multi_turn_image:
test_name: test_response_multi_turn_image
test_params:
case:
- case_id: "llama_image_understanding"
turns:
- input:
- role: user
content:
- type: input_text
text: "What type of animal is in this image? Please respond with a single word that starts with the letter 'L'."
- type: input_image
image_url: "https://upload.wikimedia.org/wikipedia/commons/f/f7/Llamas%2C_Vernagt-Stausee%2C_Italy.jpg"
output: "llama"
- input: "What country do you find this animal primarily in? What continent?"
output: "peru"
test_response_multi_turn_tool_execution:
test_name: test_response_multi_turn_tool_execution
test_params:
case:
- case_id: "user_file_access_check"
input: "I need to check if user 'alice' can access the file 'document.txt'. First, get alice's user ID, then check if that user ID can access the file 'document.txt'. Do this as a series of steps, where each step is a separate message. Return only one tool call per step. Summarize the final result with a single 'yes' or 'no' response."
tools:
- type: mcp
server_label: "localmcp"
server_url: "<FILLED_BY_TEST_RUNNER>"
output: "yes"
- case_id: "experiment_results_lookup"
input: "I need to get the results for the 'boiling_point' experiment. First, get the experiment ID for 'boiling_point', then use that ID to get the experiment results. Tell me what you found."
tools:
- type: mcp
server_label: "localmcp"
server_url: "<FILLED_BY_TEST_RUNNER>"
output: "100°C"
test_response_multi_turn_tool_execution_streaming:
test_name: test_response_multi_turn_tool_execution_streaming
test_params:
case:
- case_id: "user_permissions_workflow"
input: "Help me with this security check: First, get the user ID for 'charlie', then get the permissions for that user ID, and finally check if that user can access 'secret_file.txt'. Stream your progress as you work through each step."
tools:
- type: mcp
server_label: "localmcp"
server_url: "<FILLED_BY_TEST_RUNNER>"
stream: true
output: "no"
- case_id: "experiment_analysis_streaming"
input: "I need a complete analysis: First, get the experiment ID for 'chemical_reaction', then get the results for that experiment, and tell me if the yield was above 80%. Please stream your analysis process."
tools:
- type: mcp
server_label: "localmcp"
server_url: "<FILLED_BY_TEST_RUNNER>"
stream: true
output: "85%"

View file

@ -0,0 +1,64 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import time
def new_vector_store(openai_client, name):
"""Create a new vector store, cleaning up any existing one with the same name."""
# Ensure we don't reuse an existing vector store
vector_stores = openai_client.vector_stores.list()
for vector_store in vector_stores:
if vector_store.name == name:
openai_client.vector_stores.delete(vector_store_id=vector_store.id)
# Create a new vector store
vector_store = openai_client.vector_stores.create(name=name)
return vector_store
def upload_file(openai_client, name, file_path):
"""Upload a file, cleaning up any existing file with the same name."""
# Ensure we don't reuse an existing file
files = openai_client.files.list()
for file in files:
if file.filename == name:
openai_client.files.delete(file_id=file.id)
# Upload a text file with our document content
return openai_client.files.create(file=open(file_path, "rb"), purpose="assistants")
def wait_for_file_attachment(compat_client, vector_store_id, file_id):
"""Wait for a file to be attached to a vector store."""
file_attach_response = compat_client.vector_stores.files.retrieve(
vector_store_id=vector_store_id,
file_id=file_id,
)
while file_attach_response.status == "in_progress":
time.sleep(0.1)
file_attach_response = compat_client.vector_stores.files.retrieve(
vector_store_id=vector_store_id,
file_id=file_id,
)
assert file_attach_response.status == "completed", f"Expected file to be attached, got {file_attach_response}"
assert not file_attach_response.last_error
return file_attach_response
def setup_mcp_tools(tools, mcp_server_info):
"""Replace placeholder MCP server URLs with actual server info."""
# Create a deep copy to avoid modifying the original test case
import copy
tools_copy = copy.deepcopy(tools)
for tool in tools_copy:
if tool["type"] == "mcp" and tool["server_url"] == "<FILLED_BY_TEST_RUNNER>":
tool["server_url"] = mcp_server_info["server_url"]
return tools_copy

View file

@ -0,0 +1,145 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
class StreamingValidator:
"""Helper class for validating streaming response events."""
def __init__(self, chunks: list[Any]):
self.chunks = chunks
self.event_types = [chunk.type for chunk in chunks]
def assert_basic_event_sequence(self):
"""Verify basic created -> completed event sequence."""
assert len(self.chunks) >= 2, f"Expected at least 2 chunks (created + completed), got {len(self.chunks)}"
assert self.chunks[0].type == "response.created", (
f"First chunk should be response.created, got {self.chunks[0].type}"
)
assert self.chunks[-1].type == "response.completed", (
f"Last chunk should be response.completed, got {self.chunks[-1].type}"
)
# Verify event order
created_index = self.event_types.index("response.created")
completed_index = self.event_types.index("response.completed")
assert created_index < completed_index, "response.created should come before response.completed"
def assert_response_consistency(self):
"""Verify response ID consistency across events."""
response_ids = set()
for chunk in self.chunks:
if hasattr(chunk, "response_id"):
response_ids.add(chunk.response_id)
elif hasattr(chunk, "response") and hasattr(chunk.response, "id"):
response_ids.add(chunk.response.id)
assert len(response_ids) == 1, f"All events should reference the same response_id, found: {response_ids}"
def assert_has_incremental_content(self):
"""Verify that content is delivered incrementally via delta events."""
delta_events = [
i for i, event_type in enumerate(self.event_types) if event_type == "response.output_text.delta"
]
assert len(delta_events) > 0, "Expected delta events for true incremental streaming, but found none"
# Verify delta events have content
non_empty_deltas = 0
delta_content_total = ""
for delta_idx in delta_events:
chunk = self.chunks[delta_idx]
if hasattr(chunk, "delta") and chunk.delta:
delta_content_total += chunk.delta
non_empty_deltas += 1
assert non_empty_deltas > 0, "Delta events found but none contain content"
assert len(delta_content_total) > 0, "Delta events found but total delta content is empty"
return delta_content_total
def assert_content_quality(self, expected_content: str):
"""Verify the final response contains expected content."""
final_chunk = self.chunks[-1]
if hasattr(final_chunk, "response"):
output_text = final_chunk.response.output_text.lower().strip()
assert len(output_text) > 0, "Response should have content"
assert expected_content.lower() in output_text, f"Expected '{expected_content}' in response"
def assert_has_tool_calls(self):
"""Verify tool call streaming events are present."""
# Check for tool call events
delta_events = [
chunk
for chunk in self.chunks
if chunk.type in ["response.function_call_arguments.delta", "response.mcp_call.arguments.delta"]
]
done_events = [
chunk
for chunk in self.chunks
if chunk.type in ["response.function_call_arguments.done", "response.mcp_call.arguments.done"]
]
assert len(delta_events) > 0, f"Expected tool call delta events, got chunk types: {self.event_types}"
assert len(done_events) > 0, f"Expected tool call done events, got chunk types: {self.event_types}"
# Verify output item events
item_added_events = [chunk for chunk in self.chunks if chunk.type == "response.output_item.added"]
item_done_events = [chunk for chunk in self.chunks if chunk.type == "response.output_item.done"]
assert len(item_added_events) > 0, (
f"Expected response.output_item.added events, got chunk types: {self.event_types}"
)
assert len(item_done_events) > 0, (
f"Expected response.output_item.done events, got chunk types: {self.event_types}"
)
def assert_has_mcp_events(self):
"""Verify MCP-specific streaming events are present."""
# Tool execution progress events
mcp_in_progress_events = [chunk for chunk in self.chunks if chunk.type == "response.mcp_call.in_progress"]
mcp_completed_events = [chunk for chunk in self.chunks if chunk.type == "response.mcp_call.completed"]
assert len(mcp_in_progress_events) > 0, (
f"Expected response.mcp_call.in_progress events, got chunk types: {self.event_types}"
)
assert len(mcp_completed_events) > 0, (
f"Expected response.mcp_call.completed events, got chunk types: {self.event_types}"
)
# MCP list tools events
mcp_list_tools_in_progress_events = [
chunk for chunk in self.chunks if chunk.type == "response.mcp_list_tools.in_progress"
]
mcp_list_tools_completed_events = [
chunk for chunk in self.chunks if chunk.type == "response.mcp_list_tools.completed"
]
assert len(mcp_list_tools_in_progress_events) > 0, (
f"Expected response.mcp_list_tools.in_progress events, got chunk types: {self.event_types}"
)
assert len(mcp_list_tools_completed_events) > 0, (
f"Expected response.mcp_list_tools.completed events, got chunk types: {self.event_types}"
)
def assert_rich_streaming(self, min_chunks: int = 10):
"""Verify we have substantial streaming activity."""
assert len(self.chunks) > min_chunks, (
f"Expected rich streaming with many events, got only {len(self.chunks)} chunks"
)
def validate_event_structure(self):
"""Validate the structure of various event types."""
for chunk in self.chunks:
if chunk.type == "response.created":
assert chunk.response.status == "in_progress"
elif chunk.type == "response.completed":
assert chunk.response.status == "completed"
elif hasattr(chunk, "item_id"):
assert chunk.item_id, "Events with item_id should have non-empty item_id"
elif hasattr(chunk, "sequence_number"):
assert isinstance(chunk.sequence_number, int), "sequence_number should be an integer"

View file

@ -0,0 +1,188 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import time
import pytest
from fixtures.test_cases import basic_test_cases, image_test_cases, multi_turn_image_test_cases, multi_turn_test_cases
from streaming_assertions import StreamingValidator
@pytest.mark.parametrize("case", basic_test_cases)
def test_response_non_streaming_basic(compat_client, text_model_id, case):
response = compat_client.responses.create(
model=text_model_id,
input=case.input,
stream=False,
)
output_text = response.output_text.lower().strip()
assert len(output_text) > 0
assert case.expected.lower() in output_text
retrieved_response = compat_client.responses.retrieve(response_id=response.id)
assert retrieved_response.output_text == response.output_text
next_response = compat_client.responses.create(
model=text_model_id,
input="Repeat your previous response in all caps.",
previous_response_id=response.id,
)
next_output_text = next_response.output_text.strip()
assert case.expected.upper() in next_output_text
@pytest.mark.parametrize("case", basic_test_cases)
def test_response_streaming_basic(compat_client, text_model_id, case):
response = compat_client.responses.create(
model=text_model_id,
input=case.input,
stream=True,
)
# Track events and timing to verify proper streaming
events = []
event_times = []
response_id = ""
start_time = time.time()
for chunk in response:
current_time = time.time()
event_times.append(current_time - start_time)
events.append(chunk)
if chunk.type == "response.created":
# Verify response.created is emitted first and immediately
assert len(events) == 1, "response.created should be the first event"
assert event_times[0] < 0.1, "response.created should be emitted immediately"
assert chunk.response.status == "in_progress"
response_id = chunk.response.id
elif chunk.type == "response.completed":
# Verify response.completed comes after response.created
assert len(events) >= 2, "response.completed should come after response.created"
assert chunk.response.status == "completed"
assert chunk.response.id == response_id, "Response ID should be consistent"
# Verify content quality
output_text = chunk.response.output_text.lower().strip()
assert len(output_text) > 0, "Response should have content"
assert case.expected.lower() in output_text, f"Expected '{case.expected}' in response"
# Use validator for common checks
validator = StreamingValidator(events)
validator.assert_basic_event_sequence()
validator.assert_response_consistency()
# Verify stored response matches streamed response
retrieved_response = compat_client.responses.retrieve(response_id=response_id)
final_event = events[-1]
assert retrieved_response.output_text == final_event.response.output_text
@pytest.mark.parametrize("case", basic_test_cases)
def test_response_streaming_incremental_content(compat_client, text_model_id, case):
"""Test that streaming actually delivers content incrementally, not just at the end."""
response = compat_client.responses.create(
model=text_model_id,
input=case.input,
stream=True,
)
# Track all events and their content to verify incremental streaming
events = []
content_snapshots = []
event_times = []
start_time = time.time()
for chunk in response:
current_time = time.time()
event_times.append(current_time - start_time)
events.append(chunk)
# Track content at each event based on event type
if chunk.type == "response.output_text.delta":
# For delta events, track the delta content
content_snapshots.append(chunk.delta)
elif hasattr(chunk, "response") and hasattr(chunk.response, "output_text"):
# For response.created/completed events, track the full output_text
content_snapshots.append(chunk.response.output_text)
else:
content_snapshots.append("")
validator = StreamingValidator(events)
validator.assert_basic_event_sequence()
# Check if we have incremental content updates
event_types = [event.type for event in events]
created_index = event_types.index("response.created")
completed_index = event_types.index("response.completed")
# The key test: verify content progression
created_content = content_snapshots[created_index]
completed_content = content_snapshots[completed_index]
# Verify that response.created has empty or minimal content
assert len(created_content) == 0, f"response.created should have empty content, got: {repr(created_content[:100])}"
# Verify that response.completed has the full content
assert len(completed_content) > 0, "response.completed should have content"
assert case.expected.lower() in completed_content.lower(), f"Expected '{case.expected}' in final content"
# Use validator for incremental content checks
delta_content_total = validator.assert_has_incremental_content()
# Verify that the accumulated delta content matches the final content
assert delta_content_total.strip() == completed_content.strip(), (
f"Delta content '{delta_content_total}' should match final content '{completed_content}'"
)
# Verify timing: delta events should come between created and completed
delta_events = [i for i, event_type in enumerate(event_types) if event_type == "response.output_text.delta"]
for delta_idx in delta_events:
assert created_index < delta_idx < completed_index, (
f"Delta event at index {delta_idx} should be between created ({created_index}) and completed ({completed_index})"
)
@pytest.mark.parametrize("case", multi_turn_test_cases)
def test_response_non_streaming_multi_turn(compat_client, text_model_id, case):
previous_response_id = None
for turn_input, turn_expected in case.turns:
response = compat_client.responses.create(
model=text_model_id,
input=turn_input,
previous_response_id=previous_response_id,
)
previous_response_id = response.id
output_text = response.output_text.lower()
assert turn_expected.lower() in output_text
@pytest.mark.parametrize("case", image_test_cases)
def test_response_non_streaming_image(compat_client, text_model_id, case):
response = compat_client.responses.create(
model=text_model_id,
input=case.input,
stream=False,
)
output_text = response.output_text.lower()
assert case.expected.lower() in output_text
@pytest.mark.parametrize("case", multi_turn_image_test_cases)
def test_response_non_streaming_multi_turn_image(compat_client, text_model_id, case):
previous_response_id = None
for turn_input, turn_expected in case.turns:
response = compat_client.responses.create(
model=text_model_id,
input=turn_input,
previous_response_id=previous_response_id,
)
previous_response_id = response.id
output_text = response.output_text.lower()
assert turn_expected.lower() in output_text

View file

@ -0,0 +1,318 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import time
import pytest
from llama_stack import LlamaStackAsLibraryClient
from .helpers import new_vector_store, upload_file
@pytest.mark.parametrize(
"text_format",
# Not testing json_object because most providers don't actually support it.
[
{"type": "text"},
{
"type": "json_schema",
"name": "capitals",
"description": "A schema for the capital of each country",
"schema": {"type": "object", "properties": {"capital": {"type": "string"}}},
"strict": True,
},
],
)
def test_response_text_format(compat_client, text_model_id, text_format):
if isinstance(compat_client, LlamaStackAsLibraryClient):
pytest.skip("Responses API text format is not yet supported in library client.")
stream = False
response = compat_client.responses.create(
model=text_model_id,
input="What is the capital of France?",
stream=stream,
text={"format": text_format},
)
# by_alias=True is needed because otherwise Pydantic renames our "schema" field
assert response.text.format.model_dump(exclude_none=True, by_alias=True) == text_format
assert "paris" in response.output_text.lower()
if text_format["type"] == "json_schema":
assert "paris" in json.loads(response.output_text)["capital"].lower()
@pytest.fixture
def vector_store_with_filtered_files(compat_client, text_model_id, tmp_path_factory):
"""Create a vector store with multiple files that have different attributes for filtering tests."""
if isinstance(compat_client, LlamaStackAsLibraryClient):
pytest.skip("Responses API file search is not yet supported in library client.")
vector_store = new_vector_store(compat_client, "test_vector_store_with_filters")
tmp_path = tmp_path_factory.mktemp("filter_test_files")
# Create multiple files with different attributes
files_data = [
{
"name": "us_marketing_q1.txt",
"content": "US promotional campaigns for Q1 2023. Revenue increased by 15% in the US region.",
"attributes": {
"region": "us",
"category": "marketing",
"date": 1672531200, # Jan 1, 2023
},
},
{
"name": "us_engineering_q2.txt",
"content": "US technical updates for Q2 2023. New features deployed in the US region.",
"attributes": {
"region": "us",
"category": "engineering",
"date": 1680307200, # Apr 1, 2023
},
},
{
"name": "eu_marketing_q1.txt",
"content": "European advertising campaign results for Q1 2023. Strong growth in EU markets.",
"attributes": {
"region": "eu",
"category": "marketing",
"date": 1672531200, # Jan 1, 2023
},
},
{
"name": "asia_sales_q3.txt",
"content": "Asia Pacific revenue figures for Q3 2023. Record breaking quarter in Asia.",
"attributes": {
"region": "asia",
"category": "sales",
"date": 1688169600, # Jul 1, 2023
},
},
]
file_ids = []
for file_data in files_data:
# Create file
file_path = tmp_path / file_data["name"]
file_path.write_text(file_data["content"])
# Upload file
file_response = upload_file(compat_client, file_data["name"], str(file_path))
file_ids.append(file_response.id)
# Attach file to vector store with attributes
file_attach_response = compat_client.vector_stores.files.create(
vector_store_id=vector_store.id,
file_id=file_response.id,
attributes=file_data["attributes"],
)
# Wait for attachment
while file_attach_response.status == "in_progress":
time.sleep(0.1)
file_attach_response = compat_client.vector_stores.files.retrieve(
vector_store_id=vector_store.id,
file_id=file_response.id,
)
assert file_attach_response.status == "completed"
yield vector_store
# Cleanup: delete vector store and files
try:
compat_client.vector_stores.delete(vector_store_id=vector_store.id)
for file_id in file_ids:
try:
compat_client.files.delete(file_id=file_id)
except Exception:
pass # File might already be deleted
except Exception:
pass # Best effort cleanup
def test_response_file_search_filter_by_region(compat_client, text_model_id, vector_store_with_filtered_files):
"""Test file search with region equality filter."""
tools = [
{
"type": "file_search",
"vector_store_ids": [vector_store_with_filtered_files.id],
"filters": {"type": "eq", "key": "region", "value": "us"},
}
]
response = compat_client.responses.create(
model=text_model_id,
input="What are the updates from the US region?",
tools=tools,
stream=False,
include=["file_search_call.results"],
)
# Verify file search was called with US filter
assert len(response.output) > 1
assert response.output[0].type == "file_search_call"
assert response.output[0].status == "completed"
assert response.output[0].results
# Should only return US files (not EU or Asia files)
for result in response.output[0].results:
assert "us" in result.text.lower() or "US" in result.text
# Ensure non-US regions are NOT returned
assert "european" not in result.text.lower()
assert "asia" not in result.text.lower()
def test_response_file_search_filter_by_category(compat_client, text_model_id, vector_store_with_filtered_files):
"""Test file search with category equality filter."""
tools = [
{
"type": "file_search",
"vector_store_ids": [vector_store_with_filtered_files.id],
"filters": {"type": "eq", "key": "category", "value": "marketing"},
}
]
response = compat_client.responses.create(
model=text_model_id,
input="Show me all marketing reports",
tools=tools,
stream=False,
include=["file_search_call.results"],
)
assert response.output[0].type == "file_search_call"
assert response.output[0].status == "completed"
assert response.output[0].results
# Should only return marketing files (not engineering or sales)
for result in response.output[0].results:
# Marketing files should have promotional/advertising content
assert "promotional" in result.text.lower() or "advertising" in result.text.lower()
# Ensure non-marketing categories are NOT returned
assert "technical" not in result.text.lower()
assert "revenue figures" not in result.text.lower()
def test_response_file_search_filter_by_date_range(compat_client, text_model_id, vector_store_with_filtered_files):
"""Test file search with date range filter using compound AND."""
tools = [
{
"type": "file_search",
"vector_store_ids": [vector_store_with_filtered_files.id],
"filters": {
"type": "and",
"filters": [
{
"type": "gte",
"key": "date",
"value": 1672531200, # Jan 1, 2023
},
{
"type": "lt",
"key": "date",
"value": 1680307200, # Apr 1, 2023
},
],
},
}
]
response = compat_client.responses.create(
model=text_model_id,
input="What happened in Q1 2023?",
tools=tools,
stream=False,
include=["file_search_call.results"],
)
assert response.output[0].type == "file_search_call"
assert response.output[0].status == "completed"
assert response.output[0].results
# Should only return Q1 files (not Q2 or Q3)
for result in response.output[0].results:
assert "q1" in result.text.lower()
# Ensure non-Q1 quarters are NOT returned
assert "q2" not in result.text.lower()
assert "q3" not in result.text.lower()
def test_response_file_search_filter_compound_and(compat_client, text_model_id, vector_store_with_filtered_files):
"""Test file search with compound AND filter (region AND category)."""
tools = [
{
"type": "file_search",
"vector_store_ids": [vector_store_with_filtered_files.id],
"filters": {
"type": "and",
"filters": [
{"type": "eq", "key": "region", "value": "us"},
{"type": "eq", "key": "category", "value": "engineering"},
],
},
}
]
response = compat_client.responses.create(
model=text_model_id,
input="What are the engineering updates from the US?",
tools=tools,
stream=False,
include=["file_search_call.results"],
)
assert response.output[0].type == "file_search_call"
assert response.output[0].status == "completed"
assert response.output[0].results
# Should only return US engineering files
assert len(response.output[0].results) >= 1
for result in response.output[0].results:
assert "us" in result.text.lower() and "technical" in result.text.lower()
# Ensure it's not from other regions or categories
assert "european" not in result.text.lower() and "asia" not in result.text.lower()
assert "promotional" not in result.text.lower() and "revenue" not in result.text.lower()
def test_response_file_search_filter_compound_or(compat_client, text_model_id, vector_store_with_filtered_files):
"""Test file search with compound OR filter (marketing OR sales)."""
tools = [
{
"type": "file_search",
"vector_store_ids": [vector_store_with_filtered_files.id],
"filters": {
"type": "or",
"filters": [
{"type": "eq", "key": "category", "value": "marketing"},
{"type": "eq", "key": "category", "value": "sales"},
],
},
}
]
response = compat_client.responses.create(
model=text_model_id,
input="Show me marketing and sales documents",
tools=tools,
stream=False,
include=["file_search_call.results"],
)
assert response.output[0].type == "file_search_call"
assert response.output[0].status == "completed"
assert response.output[0].results
# Should return marketing and sales files, but NOT engineering
categories_found = set()
for result in response.output[0].results:
text_lower = result.text.lower()
if "promotional" in text_lower or "advertising" in text_lower:
categories_found.add("marketing")
if "revenue figures" in text_lower:
categories_found.add("sales")
# Ensure engineering files are NOT returned
assert "technical" not in text_lower, f"Engineering file should not be returned, but got: {result.text}"
# Verify we got at least one of the expected categories
assert len(categories_found) > 0, "Should have found at least one marketing or sales file"
assert categories_found.issubset({"marketing", "sales"}), f"Found unexpected categories: {categories_found}"

View file

@ -1,922 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import os
import time
import httpx
import openai
import pytest
from llama_stack import LlamaStackAsLibraryClient
from llama_stack.core.datatypes import AuthenticationRequiredError
from tests.common.mcp import dependency_tools, make_mcp_server
from .fixtures.fixtures import case_id_generator
from .fixtures.load import load_test_cases
responses_test_cases = load_test_cases("responses")
def _new_vector_store(openai_client, name):
# Ensure we don't reuse an existing vector store
vector_stores = openai_client.vector_stores.list()
for vector_store in vector_stores:
if vector_store.name == name:
openai_client.vector_stores.delete(vector_store_id=vector_store.id)
# Create a new vector store
vector_store = openai_client.vector_stores.create(
name=name,
)
return vector_store
def _upload_file(openai_client, name, file_path):
# Ensure we don't reuse an existing file
files = openai_client.files.list()
for file in files:
if file.filename == name:
openai_client.files.delete(file_id=file.id)
# Upload a text file with our document content
return openai_client.files.create(file=open(file_path, "rb"), purpose="assistants")
@pytest.mark.parametrize(
"case",
responses_test_cases["test_response_basic"]["test_params"]["case"],
ids=case_id_generator,
)
def test_response_non_streaming_basic(request, compat_client, text_model_id, case):
response = compat_client.responses.create(
model=text_model_id,
input=case["input"],
stream=False,
)
output_text = response.output_text.lower().strip()
assert len(output_text) > 0
assert case["output"].lower() in output_text
retrieved_response = compat_client.responses.retrieve(response_id=response.id)
assert retrieved_response.output_text == response.output_text
next_response = compat_client.responses.create(
model=text_model_id,
input="Repeat your previous response in all caps.",
previous_response_id=response.id,
)
next_output_text = next_response.output_text.strip()
assert case["output"].upper() in next_output_text
@pytest.mark.parametrize(
"case",
responses_test_cases["test_response_basic"]["test_params"]["case"],
ids=case_id_generator,
)
def test_response_streaming_basic(request, compat_client, text_model_id, case):
import time
response = compat_client.responses.create(
model=text_model_id,
input=case["input"],
stream=True,
)
# Track events and timing to verify proper streaming
events = []
event_times = []
response_id = ""
start_time = time.time()
for chunk in response:
current_time = time.time()
event_times.append(current_time - start_time)
events.append(chunk)
if chunk.type == "response.created":
# Verify response.created is emitted first and immediately
assert len(events) == 1, "response.created should be the first event"
assert event_times[0] < 0.1, "response.created should be emitted immediately"
assert chunk.response.status == "in_progress"
response_id = chunk.response.id
elif chunk.type == "response.completed":
# Verify response.completed comes after response.created
assert len(events) >= 2, "response.completed should come after response.created"
assert chunk.response.status == "completed"
assert chunk.response.id == response_id, "Response ID should be consistent"
# Verify content quality
output_text = chunk.response.output_text.lower().strip()
assert len(output_text) > 0, "Response should have content"
assert case["output"].lower() in output_text, f"Expected '{case['output']}' in response"
# Verify we got both required events
event_types = [event.type for event in events]
assert "response.created" in event_types, "Missing response.created event"
assert "response.completed" in event_types, "Missing response.completed event"
# Verify event order
created_index = event_types.index("response.created")
completed_index = event_types.index("response.completed")
assert created_index < completed_index, "response.created should come before response.completed"
# Verify stored response matches streamed response
retrieved_response = compat_client.responses.retrieve(response_id=response_id)
final_event = events[-1]
assert retrieved_response.output_text == final_event.response.output_text
@pytest.mark.parametrize(
"case",
responses_test_cases["test_response_basic"]["test_params"]["case"],
ids=case_id_generator,
)
def test_response_streaming_incremental_content(request, compat_client, text_model_id, case):
"""Test that streaming actually delivers content incrementally, not just at the end."""
import time
response = compat_client.responses.create(
model=text_model_id,
input=case["input"],
stream=True,
)
# Track all events and their content to verify incremental streaming
events = []
content_snapshots = []
event_times = []
start_time = time.time()
for chunk in response:
current_time = time.time()
event_times.append(current_time - start_time)
events.append(chunk)
# Track content at each event based on event type
if chunk.type == "response.output_text.delta":
# For delta events, track the delta content
content_snapshots.append(chunk.delta)
elif hasattr(chunk, "response") and hasattr(chunk.response, "output_text"):
# For response.created/completed events, track the full output_text
content_snapshots.append(chunk.response.output_text)
else:
content_snapshots.append("")
# Verify we have the expected events
event_types = [event.type for event in events]
assert "response.created" in event_types, "Missing response.created event"
assert "response.completed" in event_types, "Missing response.completed event"
# Check if we have incremental content updates
created_index = event_types.index("response.created")
completed_index = event_types.index("response.completed")
# The key test: verify content progression
created_content = content_snapshots[created_index]
completed_content = content_snapshots[completed_index]
# Verify that response.created has empty or minimal content
assert len(created_content) == 0, f"response.created should have empty content, got: {repr(created_content[:100])}"
# Verify that response.completed has the full content
assert len(completed_content) > 0, "response.completed should have content"
assert case["output"].lower() in completed_content.lower(), f"Expected '{case['output']}' in final content"
# Check for true incremental streaming by looking for delta events
delta_events = [i for i, event_type in enumerate(event_types) if event_type == "response.output_text.delta"]
# Assert that we have delta events (true incremental streaming)
assert len(delta_events) > 0, "Expected delta events for true incremental streaming, but found none"
# Verify delta events have content and accumulate to final content
delta_content_total = ""
non_empty_deltas = 0
for delta_idx in delta_events:
delta_content = content_snapshots[delta_idx]
if delta_content:
delta_content_total += delta_content
non_empty_deltas += 1
# Assert that we have meaningful delta content
assert non_empty_deltas > 0, "Delta events found but none contain content"
assert len(delta_content_total) > 0, "Delta events found but total delta content is empty"
# Verify that the accumulated delta content matches the final content
assert delta_content_total.strip() == completed_content.strip(), (
f"Delta content '{delta_content_total}' should match final content '{completed_content}'"
)
# Verify timing: delta events should come between created and completed
for delta_idx in delta_events:
assert created_index < delta_idx < completed_index, (
f"Delta event at index {delta_idx} should be between created ({created_index}) and completed ({completed_index})"
)
@pytest.mark.parametrize(
"case",
responses_test_cases["test_response_multi_turn"]["test_params"]["case"],
ids=case_id_generator,
)
def test_response_non_streaming_multi_turn(request, compat_client, text_model_id, case):
previous_response_id = None
for turn in case["turns"]:
response = compat_client.responses.create(
model=text_model_id,
input=turn["input"],
previous_response_id=previous_response_id,
tools=turn["tools"] if "tools" in turn else None,
)
previous_response_id = response.id
output_text = response.output_text.lower()
assert turn["output"].lower() in output_text
@pytest.mark.parametrize(
"case",
responses_test_cases["test_response_web_search"]["test_params"]["case"],
ids=case_id_generator,
)
def test_response_non_streaming_web_search(request, compat_client, text_model_id, case):
response = compat_client.responses.create(
model=text_model_id,
input=case["input"],
tools=case["tools"],
stream=False,
)
assert len(response.output) > 1
assert response.output[0].type == "web_search_call"
assert response.output[0].status == "completed"
assert response.output[1].type == "message"
assert response.output[1].status == "completed"
assert response.output[1].role == "assistant"
assert len(response.output[1].content) > 0
assert case["output"].lower() in response.output_text.lower().strip()
@pytest.mark.parametrize(
"case",
responses_test_cases["test_response_file_search"]["test_params"]["case"],
ids=case_id_generator,
)
def test_response_non_streaming_file_search(request, compat_client, text_model_id, tmp_path, case):
if isinstance(compat_client, LlamaStackAsLibraryClient):
pytest.skip("Responses API file search is not yet supported in library client.")
vector_store = _new_vector_store(compat_client, "test_vector_store")
if "file_content" in case:
file_name = "test_response_non_streaming_file_search.txt"
file_path = tmp_path / file_name
file_path.write_text(case["file_content"])
elif "file_path" in case:
file_path = os.path.join(os.path.dirname(__file__), "fixtures", case["file_path"])
file_name = os.path.basename(file_path)
else:
raise ValueError(f"No file content or path provided for case {case['case_id']}")
file_response = _upload_file(compat_client, file_name, file_path)
# Attach our file to the vector store
file_attach_response = compat_client.vector_stores.files.create(
vector_store_id=vector_store.id,
file_id=file_response.id,
)
# Wait for the file to be attached
while file_attach_response.status == "in_progress":
time.sleep(0.1)
file_attach_response = compat_client.vector_stores.files.retrieve(
vector_store_id=vector_store.id,
file_id=file_response.id,
)
assert file_attach_response.status == "completed", f"Expected file to be attached, got {file_attach_response}"
assert not file_attach_response.last_error
# Update our tools with the right vector store id
tools = case["tools"]
for tool in tools:
if tool["type"] == "file_search":
tool["vector_store_ids"] = [vector_store.id]
# Create the response request, which should query our vector store
response = compat_client.responses.create(
model=text_model_id,
input=case["input"],
tools=tools,
stream=False,
include=["file_search_call.results"],
)
# Verify the file_search_tool was called
assert len(response.output) > 1
assert response.output[0].type == "file_search_call"
assert response.output[0].status == "completed"
assert response.output[0].queries # ensure it's some non-empty list
assert response.output[0].results
assert case["output"].lower() in response.output[0].results[0].text.lower()
assert response.output[0].results[0].score > 0
# Verify the output_text generated by the response
assert case["output"].lower() in response.output_text.lower().strip()
def test_response_non_streaming_file_search_empty_vector_store(request, compat_client, text_model_id):
if isinstance(compat_client, LlamaStackAsLibraryClient):
pytest.skip("Responses API file search is not yet supported in library client.")
vector_store = _new_vector_store(compat_client, "test_vector_store")
# Create the response request, which should query our vector store
response = compat_client.responses.create(
model=text_model_id,
input="How many experts does the Llama 4 Maverick model have?",
tools=[{"type": "file_search", "vector_store_ids": [vector_store.id]}],
stream=False,
include=["file_search_call.results"],
)
# Verify the file_search_tool was called
assert len(response.output) > 1
assert response.output[0].type == "file_search_call"
assert response.output[0].status == "completed"
assert response.output[0].queries # ensure it's some non-empty list
assert not response.output[0].results # ensure we don't get any results
# Verify some output_text was generated by the response
assert response.output_text
@pytest.mark.parametrize(
"case",
responses_test_cases["test_response_mcp_tool"]["test_params"]["case"],
ids=case_id_generator,
)
def test_response_non_streaming_mcp_tool(request, compat_client, text_model_id, case):
with make_mcp_server() as mcp_server_info:
tools = case["tools"]
for tool in tools:
if tool["type"] == "mcp":
tool["server_url"] = mcp_server_info["server_url"]
response = compat_client.responses.create(
model=text_model_id,
input=case["input"],
tools=tools,
stream=False,
)
assert len(response.output) >= 3
list_tools = response.output[0]
assert list_tools.type == "mcp_list_tools"
assert list_tools.server_label == "localmcp"
assert len(list_tools.tools) == 2
assert {t.name for t in list_tools.tools} == {"get_boiling_point", "greet_everyone"}
call = response.output[1]
assert call.type == "mcp_call"
assert call.name == "get_boiling_point"
assert json.loads(call.arguments) == {"liquid_name": "myawesomeliquid", "celsius": True}
assert call.error is None
assert "-100" in call.output
# sometimes the model will call the tool again, so we need to get the last message
message = response.output[-1]
text_content = message.content[0].text
assert "boiling point" in text_content.lower()
with make_mcp_server(required_auth_token="test-token") as mcp_server_info:
tools = case["tools"]
for tool in tools:
if tool["type"] == "mcp":
tool["server_url"] = mcp_server_info["server_url"]
exc_type = (
AuthenticationRequiredError
if isinstance(compat_client, LlamaStackAsLibraryClient)
else (httpx.HTTPStatusError, openai.AuthenticationError)
)
with pytest.raises(exc_type):
compat_client.responses.create(
model=text_model_id,
input=case["input"],
tools=tools,
stream=False,
)
for tool in tools:
if tool["type"] == "mcp":
tool["server_url"] = mcp_server_info["server_url"]
tool["headers"] = {"Authorization": "Bearer test-token"}
response = compat_client.responses.create(
model=text_model_id,
input=case["input"],
tools=tools,
stream=False,
)
assert len(response.output) >= 3
@pytest.mark.parametrize(
"case",
responses_test_cases["test_response_custom_tool"]["test_params"]["case"],
ids=case_id_generator,
)
def test_response_non_streaming_custom_tool(request, compat_client, text_model_id, case):
response = compat_client.responses.create(
model=text_model_id,
input=case["input"],
tools=case["tools"],
stream=False,
)
assert len(response.output) == 1
assert response.output[0].type == "function_call"
assert response.output[0].status == "completed"
assert response.output[0].name == "get_weather"
@pytest.mark.parametrize(
"case",
responses_test_cases["test_response_image"]["test_params"]["case"],
ids=case_id_generator,
)
def test_response_non_streaming_image(request, compat_client, text_model_id, case):
response = compat_client.responses.create(
model=text_model_id,
input=case["input"],
stream=False,
)
output_text = response.output_text.lower()
assert case["output"].lower() in output_text
@pytest.mark.parametrize(
"case",
responses_test_cases["test_response_multi_turn_image"]["test_params"]["case"],
ids=case_id_generator,
)
def test_response_non_streaming_multi_turn_image(request, compat_client, text_model_id, case):
previous_response_id = None
for turn in case["turns"]:
response = compat_client.responses.create(
model=text_model_id,
input=turn["input"],
previous_response_id=previous_response_id,
tools=turn["tools"] if "tools" in turn else None,
)
previous_response_id = response.id
output_text = response.output_text.lower()
assert turn["output"].lower() in output_text
@pytest.mark.parametrize(
"case",
responses_test_cases["test_response_multi_turn_tool_execution"]["test_params"]["case"],
ids=case_id_generator,
)
def test_response_non_streaming_multi_turn_tool_execution(request, compat_client, text_model_id, case):
"""Test multi-turn tool execution where multiple MCP tool calls are performed in sequence."""
with make_mcp_server(tools=dependency_tools()) as mcp_server_info:
tools = case["tools"]
# Replace the placeholder URL with the actual server URL
for tool in tools:
if tool["type"] == "mcp" and tool["server_url"] == "<FILLED_BY_TEST_RUNNER>":
tool["server_url"] = mcp_server_info["server_url"]
response = compat_client.responses.create(
input=case["input"],
model=text_model_id,
tools=tools,
)
# Verify we have MCP tool calls in the output
mcp_list_tools = [output for output in response.output if output.type == "mcp_list_tools"]
mcp_calls = [output for output in response.output if output.type == "mcp_call"]
message_outputs = [output for output in response.output if output.type == "message"]
# Should have exactly 1 MCP list tools message (at the beginning)
assert len(mcp_list_tools) == 1, f"Expected exactly 1 mcp_list_tools, got {len(mcp_list_tools)}"
assert mcp_list_tools[0].server_label == "localmcp"
assert len(mcp_list_tools[0].tools) == 5 # Updated for dependency tools
expected_tool_names = {
"get_user_id",
"get_user_permissions",
"check_file_access",
"get_experiment_id",
"get_experiment_results",
}
assert {t.name for t in mcp_list_tools[0].tools} == expected_tool_names
assert len(mcp_calls) >= 1, f"Expected at least 1 mcp_call, got {len(mcp_calls)}"
for mcp_call in mcp_calls:
assert mcp_call.error is None, f"MCP call should not have errors, got: {mcp_call.error}"
assert len(message_outputs) >= 1, f"Expected at least 1 message output, got {len(message_outputs)}"
final_message = message_outputs[-1]
assert final_message.role == "assistant", f"Final message should be from assistant, got {final_message.role}"
assert final_message.status == "completed", f"Final message should be completed, got {final_message.status}"
assert len(final_message.content) > 0, "Final message should have content"
expected_output = case["output"]
assert expected_output.lower() in response.output_text.lower(), (
f"Expected '{expected_output}' to appear in response: {response.output_text}"
)
@pytest.mark.parametrize(
"case",
responses_test_cases["test_response_multi_turn_tool_execution_streaming"]["test_params"]["case"],
ids=case_id_generator,
)
async def test_response_streaming_multi_turn_tool_execution(request, compat_client, text_model_id, case):
"""Test streaming multi-turn tool execution where multiple MCP tool calls are performed in sequence."""
with make_mcp_server(tools=dependency_tools()) as mcp_server_info:
tools = case["tools"]
# Replace the placeholder URL with the actual server URL
for tool in tools:
if tool["type"] == "mcp" and tool["server_url"] == "<FILLED_BY_TEST_RUNNER>":
tool["server_url"] = mcp_server_info["server_url"]
stream = compat_client.responses.create(
input=case["input"],
model=text_model_id,
tools=tools,
stream=True,
)
chunks = []
for chunk in stream:
chunks.append(chunk)
# Should have at least response.created and response.completed
assert len(chunks) >= 2, f"Expected at least 2 chunks (created + completed), got {len(chunks)}"
# First chunk should be response.created
assert chunks[0].type == "response.created", f"First chunk should be response.created, got {chunks[0].type}"
# Last chunk should be response.completed
assert chunks[-1].type == "response.completed", (
f"Last chunk should be response.completed, got {chunks[-1].type}"
)
# Get the final response from the last chunk
final_chunk = chunks[-1]
if hasattr(final_chunk, "response"):
final_response = final_chunk.response
# Verify multi-turn MCP tool execution results
mcp_list_tools = [output for output in final_response.output if output.type == "mcp_list_tools"]
mcp_calls = [output for output in final_response.output if output.type == "mcp_call"]
message_outputs = [output for output in final_response.output if output.type == "message"]
# Should have exactly 1 MCP list tools message (at the beginning)
assert len(mcp_list_tools) == 1, f"Expected exactly 1 mcp_list_tools, got {len(mcp_list_tools)}"
assert mcp_list_tools[0].server_label == "localmcp"
assert len(mcp_list_tools[0].tools) == 5 # Updated for dependency tools
expected_tool_names = {
"get_user_id",
"get_user_permissions",
"check_file_access",
"get_experiment_id",
"get_experiment_results",
}
assert {t.name for t in mcp_list_tools[0].tools} == expected_tool_names
# Should have at least 1 MCP call (the model should call at least one tool)
assert len(mcp_calls) >= 1, f"Expected at least 1 mcp_call, got {len(mcp_calls)}"
# All MCP calls should be completed (verifies our tool execution works)
for mcp_call in mcp_calls:
assert mcp_call.error is None, f"MCP call should not have errors, got: {mcp_call.error}"
# Should have at least one final message response
assert len(message_outputs) >= 1, f"Expected at least 1 message output, got {len(message_outputs)}"
# Final message should be from assistant and completed
final_message = message_outputs[-1]
assert final_message.role == "assistant", (
f"Final message should be from assistant, got {final_message.role}"
)
assert final_message.status == "completed", f"Final message should be completed, got {final_message.status}"
assert len(final_message.content) > 0, "Final message should have content"
# Check that the expected output appears in the response
expected_output = case["output"]
assert expected_output.lower() in final_response.output_text.lower(), (
f"Expected '{expected_output}' to appear in response: {final_response.output_text}"
)
@pytest.mark.parametrize(
"text_format",
# Not testing json_object because most providers don't actually support it.
[
{"type": "text"},
{
"type": "json_schema",
"name": "capitals",
"description": "A schema for the capital of each country",
"schema": {"type": "object", "properties": {"capital": {"type": "string"}}},
"strict": True,
},
],
)
def test_response_text_format(request, compat_client, text_model_id, text_format):
if isinstance(compat_client, LlamaStackAsLibraryClient):
pytest.skip("Responses API text format is not yet supported in library client.")
stream = False
response = compat_client.responses.create(
model=text_model_id,
input="What is the capital of France?",
stream=stream,
text={"format": text_format},
)
# by_alias=True is needed because otherwise Pydantic renames our "schema" field
assert response.text.format.model_dump(exclude_none=True, by_alias=True) == text_format
assert "paris" in response.output_text.lower()
if text_format["type"] == "json_schema":
assert "paris" in json.loads(response.output_text)["capital"].lower()
@pytest.fixture
def vector_store_with_filtered_files(request, compat_client, text_model_id, tmp_path_factory):
"""Create a vector store with multiple files that have different attributes for filtering tests."""
if isinstance(compat_client, LlamaStackAsLibraryClient):
pytest.skip("Responses API file search is not yet supported in library client.")
vector_store = _new_vector_store(compat_client, "test_vector_store_with_filters")
tmp_path = tmp_path_factory.mktemp("filter_test_files")
# Create multiple files with different attributes
files_data = [
{
"name": "us_marketing_q1.txt",
"content": "US promotional campaigns for Q1 2023. Revenue increased by 15% in the US region.",
"attributes": {
"region": "us",
"category": "marketing",
"date": 1672531200, # Jan 1, 2023
},
},
{
"name": "us_engineering_q2.txt",
"content": "US technical updates for Q2 2023. New features deployed in the US region.",
"attributes": {
"region": "us",
"category": "engineering",
"date": 1680307200, # Apr 1, 2023
},
},
{
"name": "eu_marketing_q1.txt",
"content": "European advertising campaign results for Q1 2023. Strong growth in EU markets.",
"attributes": {
"region": "eu",
"category": "marketing",
"date": 1672531200, # Jan 1, 2023
},
},
{
"name": "asia_sales_q3.txt",
"content": "Asia Pacific revenue figures for Q3 2023. Record breaking quarter in Asia.",
"attributes": {
"region": "asia",
"category": "sales",
"date": 1688169600, # Jul 1, 2023
},
},
]
file_ids = []
for file_data in files_data:
# Create file
file_path = tmp_path / file_data["name"]
file_path.write_text(file_data["content"])
# Upload file
file_response = _upload_file(compat_client, file_data["name"], str(file_path))
file_ids.append(file_response.id)
# Attach file to vector store with attributes
file_attach_response = compat_client.vector_stores.files.create(
vector_store_id=vector_store.id, file_id=file_response.id, attributes=file_data["attributes"]
)
# Wait for attachment
while file_attach_response.status == "in_progress":
time.sleep(0.1)
file_attach_response = compat_client.vector_stores.files.retrieve(
vector_store_id=vector_store.id,
file_id=file_response.id,
)
assert file_attach_response.status == "completed"
yield vector_store
# Cleanup: delete vector store and files
try:
compat_client.vector_stores.delete(vector_store_id=vector_store.id)
for file_id in file_ids:
try:
compat_client.files.delete(file_id=file_id)
except Exception:
pass # File might already be deleted
except Exception:
pass # Best effort cleanup
def test_response_file_search_filter_by_region(compat_client, text_model_id, vector_store_with_filtered_files):
"""Test file search with region equality filter."""
tools = [
{
"type": "file_search",
"vector_store_ids": [vector_store_with_filtered_files.id],
"filters": {"type": "eq", "key": "region", "value": "us"},
}
]
response = compat_client.responses.create(
model=text_model_id,
input="What are the updates from the US region?",
tools=tools,
stream=False,
include=["file_search_call.results"],
)
# Verify file search was called with US filter
assert len(response.output) > 1
assert response.output[0].type == "file_search_call"
assert response.output[0].status == "completed"
assert response.output[0].results
# Should only return US files (not EU or Asia files)
for result in response.output[0].results:
assert "us" in result.text.lower() or "US" in result.text
# Ensure non-US regions are NOT returned
assert "european" not in result.text.lower()
assert "asia" not in result.text.lower()
def test_response_file_search_filter_by_category(compat_client, text_model_id, vector_store_with_filtered_files):
"""Test file search with category equality filter."""
tools = [
{
"type": "file_search",
"vector_store_ids": [vector_store_with_filtered_files.id],
"filters": {"type": "eq", "key": "category", "value": "marketing"},
}
]
response = compat_client.responses.create(
model=text_model_id,
input="Show me all marketing reports",
tools=tools,
stream=False,
include=["file_search_call.results"],
)
assert response.output[0].type == "file_search_call"
assert response.output[0].status == "completed"
assert response.output[0].results
# Should only return marketing files (not engineering or sales)
for result in response.output[0].results:
# Marketing files should have promotional/advertising content
assert "promotional" in result.text.lower() or "advertising" in result.text.lower()
# Ensure non-marketing categories are NOT returned
assert "technical" not in result.text.lower()
assert "revenue figures" not in result.text.lower()
def test_response_file_search_filter_by_date_range(compat_client, text_model_id, vector_store_with_filtered_files):
"""Test file search with date range filter using compound AND."""
tools = [
{
"type": "file_search",
"vector_store_ids": [vector_store_with_filtered_files.id],
"filters": {
"type": "and",
"filters": [
{
"type": "gte",
"key": "date",
"value": 1672531200, # Jan 1, 2023
},
{
"type": "lt",
"key": "date",
"value": 1680307200, # Apr 1, 2023
},
],
},
}
]
response = compat_client.responses.create(
model=text_model_id,
input="What happened in Q1 2023?",
tools=tools,
stream=False,
include=["file_search_call.results"],
)
assert response.output[0].type == "file_search_call"
assert response.output[0].status == "completed"
assert response.output[0].results
# Should only return Q1 files (not Q2 or Q3)
for result in response.output[0].results:
assert "q1" in result.text.lower()
# Ensure non-Q1 quarters are NOT returned
assert "q2" not in result.text.lower()
assert "q3" not in result.text.lower()
def test_response_file_search_filter_compound_and(compat_client, text_model_id, vector_store_with_filtered_files):
"""Test file search with compound AND filter (region AND category)."""
tools = [
{
"type": "file_search",
"vector_store_ids": [vector_store_with_filtered_files.id],
"filters": {
"type": "and",
"filters": [
{"type": "eq", "key": "region", "value": "us"},
{"type": "eq", "key": "category", "value": "engineering"},
],
},
}
]
response = compat_client.responses.create(
model=text_model_id,
input="What are the engineering updates from the US?",
tools=tools,
stream=False,
include=["file_search_call.results"],
)
assert response.output[0].type == "file_search_call"
assert response.output[0].status == "completed"
assert response.output[0].results
# Should only return US engineering files
assert len(response.output[0].results) >= 1
for result in response.output[0].results:
assert "us" in result.text.lower() and "technical" in result.text.lower()
# Ensure it's not from other regions or categories
assert "european" not in result.text.lower() and "asia" not in result.text.lower()
assert "promotional" not in result.text.lower() and "revenue" not in result.text.lower()
def test_response_file_search_filter_compound_or(compat_client, text_model_id, vector_store_with_filtered_files):
"""Test file search with compound OR filter (marketing OR sales)."""
tools = [
{
"type": "file_search",
"vector_store_ids": [vector_store_with_filtered_files.id],
"filters": {
"type": "or",
"filters": [
{"type": "eq", "key": "category", "value": "marketing"},
{"type": "eq", "key": "category", "value": "sales"},
],
},
}
]
response = compat_client.responses.create(
model=text_model_id,
input="Show me marketing and sales documents",
tools=tools,
stream=False,
include=["file_search_call.results"],
)
assert response.output[0].type == "file_search_call"
assert response.output[0].status == "completed"
assert response.output[0].results
# Should return marketing and sales files, but NOT engineering
categories_found = set()
for result in response.output[0].results:
text_lower = result.text.lower()
if "promotional" in text_lower or "advertising" in text_lower:
categories_found.add("marketing")
if "revenue figures" in text_lower:
categories_found.add("sales")
# Ensure engineering files are NOT returned
assert "technical" not in text_lower, f"Engineering file should not be returned, but got: {result.text}"
# Verify we got at least one of the expected categories
assert len(categories_found) > 0, "Should have found at least one marketing or sales file"
assert categories_found.issubset({"marketing", "sales"}), f"Found unexpected categories: {categories_found}"

View file

@ -0,0 +1,335 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import os
import httpx
import openai
import pytest
from fixtures.test_cases import (
custom_tool_test_cases,
file_search_test_cases,
mcp_tool_test_cases,
multi_turn_tool_execution_streaming_test_cases,
multi_turn_tool_execution_test_cases,
web_search_test_cases,
)
from helpers import new_vector_store, setup_mcp_tools, upload_file, wait_for_file_attachment
from streaming_assertions import StreamingValidator
from llama_stack import LlamaStackAsLibraryClient
from llama_stack.core.datatypes import AuthenticationRequiredError
from tests.common.mcp import dependency_tools, make_mcp_server
@pytest.mark.parametrize("case", web_search_test_cases)
def test_response_non_streaming_web_search(compat_client, text_model_id, case):
response = compat_client.responses.create(
model=text_model_id,
input=case.input,
tools=case.tools,
stream=False,
)
assert len(response.output) > 1
assert response.output[0].type == "web_search_call"
assert response.output[0].status == "completed"
assert response.output[1].type == "message"
assert response.output[1].status == "completed"
assert response.output[1].role == "assistant"
assert len(response.output[1].content) > 0
assert case.expected.lower() in response.output_text.lower().strip()
@pytest.mark.parametrize("case", file_search_test_cases)
def test_response_non_streaming_file_search(compat_client, text_model_id, tmp_path, case):
if isinstance(compat_client, LlamaStackAsLibraryClient):
pytest.skip("Responses API file search is not yet supported in library client.")
vector_store = new_vector_store(compat_client, "test_vector_store")
if case.file_content:
file_name = "test_response_non_streaming_file_search.txt"
file_path = tmp_path / file_name
file_path.write_text(case.file_content)
elif case.file_path:
file_path = os.path.join(os.path.dirname(__file__), "fixtures", case.file_path)
file_name = os.path.basename(file_path)
else:
raise ValueError("No file content or path provided for case")
file_response = upload_file(compat_client, file_name, file_path)
# Attach our file to the vector store
compat_client.vector_stores.files.create(
vector_store_id=vector_store.id,
file_id=file_response.id,
)
# Wait for the file to be attached
wait_for_file_attachment(compat_client, vector_store.id, file_response.id)
# Update our tools with the right vector store id
tools = case.tools
for tool in tools:
if tool["type"] == "file_search":
tool["vector_store_ids"] = [vector_store.id]
# Create the response request, which should query our vector store
response = compat_client.responses.create(
model=text_model_id,
input=case.input,
tools=tools,
stream=False,
include=["file_search_call.results"],
)
# Verify the file_search_tool was called
assert len(response.output) > 1
assert response.output[0].type == "file_search_call"
assert response.output[0].status == "completed"
assert response.output[0].queries # ensure it's some non-empty list
assert response.output[0].results
assert case.expected.lower() in response.output[0].results[0].text.lower()
assert response.output[0].results[0].score > 0
# Verify the output_text generated by the response
assert case.expected.lower() in response.output_text.lower().strip()
def test_response_non_streaming_file_search_empty_vector_store(compat_client, text_model_id):
if isinstance(compat_client, LlamaStackAsLibraryClient):
pytest.skip("Responses API file search is not yet supported in library client.")
vector_store = new_vector_store(compat_client, "test_vector_store")
# Create the response request, which should query our vector store
response = compat_client.responses.create(
model=text_model_id,
input="How many experts does the Llama 4 Maverick model have?",
tools=[{"type": "file_search", "vector_store_ids": [vector_store.id]}],
stream=False,
include=["file_search_call.results"],
)
# Verify the file_search_tool was called
assert len(response.output) > 1
assert response.output[0].type == "file_search_call"
assert response.output[0].status == "completed"
assert response.output[0].queries # ensure it's some non-empty list
assert not response.output[0].results # ensure we don't get any results
# Verify some output_text was generated by the response
assert response.output_text
@pytest.mark.parametrize("case", mcp_tool_test_cases)
def test_response_non_streaming_mcp_tool(compat_client, text_model_id, case):
if not isinstance(compat_client, LlamaStackAsLibraryClient):
pytest.skip("in-process MCP server is only supported in library client")
with make_mcp_server() as mcp_server_info:
tools = setup_mcp_tools(case.tools, mcp_server_info)
response = compat_client.responses.create(
model=text_model_id,
input=case.input,
tools=tools,
stream=False,
)
assert len(response.output) >= 3
list_tools = response.output[0]
assert list_tools.type == "mcp_list_tools"
assert list_tools.server_label == "localmcp"
assert len(list_tools.tools) == 2
assert {t.name for t in list_tools.tools} == {
"get_boiling_point",
"greet_everyone",
}
call = response.output[1]
assert call.type == "mcp_call"
assert call.name == "get_boiling_point"
assert json.loads(call.arguments) == {
"liquid_name": "myawesomeliquid",
"celsius": True,
}
assert call.error is None
assert "-100" in call.output
# sometimes the model will call the tool again, so we need to get the last message
message = response.output[-1]
text_content = message.content[0].text
assert "boiling point" in text_content.lower()
with make_mcp_server(required_auth_token="test-token") as mcp_server_info:
tools = setup_mcp_tools(case.tools, mcp_server_info)
exc_type = (
AuthenticationRequiredError
if isinstance(compat_client, LlamaStackAsLibraryClient)
else (httpx.HTTPStatusError, openai.AuthenticationError)
)
with pytest.raises(exc_type):
compat_client.responses.create(
model=text_model_id,
input=case.input,
tools=tools,
stream=False,
)
for tool in tools:
if tool["type"] == "mcp":
tool["headers"] = {"Authorization": "Bearer test-token"}
response = compat_client.responses.create(
model=text_model_id,
input=case.input,
tools=tools,
stream=False,
)
assert len(response.output) >= 3
@pytest.mark.parametrize("case", custom_tool_test_cases)
def test_response_non_streaming_custom_tool(compat_client, text_model_id, case):
response = compat_client.responses.create(
model=text_model_id,
input=case.input,
tools=case.tools,
stream=False,
)
assert len(response.output) == 1
assert response.output[0].type == "function_call"
assert response.output[0].status == "completed"
assert response.output[0].name == "get_weather"
@pytest.mark.parametrize("case", multi_turn_tool_execution_test_cases)
def test_response_non_streaming_multi_turn_tool_execution(compat_client, text_model_id, case):
"""Test multi-turn tool execution where multiple MCP tool calls are performed in sequence."""
if not isinstance(compat_client, LlamaStackAsLibraryClient):
pytest.skip("in-process MCP server is only supported in library client")
with make_mcp_server(tools=dependency_tools()) as mcp_server_info:
tools = setup_mcp_tools(case.tools, mcp_server_info)
response = compat_client.responses.create(
input=case.input,
model=text_model_id,
tools=tools,
)
# Verify we have MCP tool calls in the output
mcp_list_tools = [output for output in response.output if output.type == "mcp_list_tools"]
mcp_calls = [output for output in response.output if output.type == "mcp_call"]
message_outputs = [output for output in response.output if output.type == "message"]
# Should have exactly 1 MCP list tools message (at the beginning)
assert len(mcp_list_tools) == 1, f"Expected exactly 1 mcp_list_tools, got {len(mcp_list_tools)}"
assert mcp_list_tools[0].server_label == "localmcp"
assert len(mcp_list_tools[0].tools) == 5 # Updated for dependency tools
expected_tool_names = {
"get_user_id",
"get_user_permissions",
"check_file_access",
"get_experiment_id",
"get_experiment_results",
}
assert {t.name for t in mcp_list_tools[0].tools} == expected_tool_names
assert len(mcp_calls) >= 1, f"Expected at least 1 mcp_call, got {len(mcp_calls)}"
for mcp_call in mcp_calls:
assert mcp_call.error is None, f"MCP call should not have errors, got: {mcp_call.error}"
assert len(message_outputs) >= 1, f"Expected at least 1 message output, got {len(message_outputs)}"
final_message = message_outputs[-1]
assert final_message.role == "assistant", f"Final message should be from assistant, got {final_message.role}"
assert final_message.status == "completed", f"Final message should be completed, got {final_message.status}"
assert len(final_message.content) > 0, "Final message should have content"
expected_output = case.expected
assert expected_output.lower() in response.output_text.lower(), (
f"Expected '{expected_output}' to appear in response: {response.output_text}"
)
@pytest.mark.parametrize("case", multi_turn_tool_execution_streaming_test_cases)
def test_response_streaming_multi_turn_tool_execution(compat_client, text_model_id, case):
"""Test streaming multi-turn tool execution where multiple MCP tool calls are performed in sequence."""
if not isinstance(compat_client, LlamaStackAsLibraryClient):
pytest.skip("in-process MCP server is only supported in library client")
with make_mcp_server(tools=dependency_tools()) as mcp_server_info:
tools = setup_mcp_tools(case.tools, mcp_server_info)
stream = compat_client.responses.create(
input=case.input,
model=text_model_id,
tools=tools,
stream=True,
)
chunks = []
for chunk in stream:
chunks.append(chunk)
# Use validator for common streaming checks
validator = StreamingValidator(chunks)
validator.assert_basic_event_sequence()
validator.assert_response_consistency()
validator.assert_has_tool_calls()
validator.assert_has_mcp_events()
validator.assert_rich_streaming()
# Get the final response from the last chunk
final_chunk = chunks[-1]
if hasattr(final_chunk, "response"):
final_response = final_chunk.response
# Verify multi-turn MCP tool execution results
mcp_list_tools = [output for output in final_response.output if output.type == "mcp_list_tools"]
mcp_calls = [output for output in final_response.output if output.type == "mcp_call"]
message_outputs = [output for output in final_response.output if output.type == "message"]
# Should have exactly 1 MCP list tools message (at the beginning)
assert len(mcp_list_tools) == 1, f"Expected exactly 1 mcp_list_tools, got {len(mcp_list_tools)}"
assert mcp_list_tools[0].server_label == "localmcp"
assert len(mcp_list_tools[0].tools) == 5 # Updated for dependency tools
expected_tool_names = {
"get_user_id",
"get_user_permissions",
"check_file_access",
"get_experiment_id",
"get_experiment_results",
}
assert {t.name for t in mcp_list_tools[0].tools} == expected_tool_names
# Should have at least 1 MCP call (the model should call at least one tool)
assert len(mcp_calls) >= 1, f"Expected at least 1 mcp_call, got {len(mcp_calls)}"
# All MCP calls should be completed (verifies our tool execution works)
for mcp_call in mcp_calls:
assert mcp_call.error is None, f"MCP call should not have errors, got: {mcp_call.error}"
# Should have at least one final message response
assert len(message_outputs) >= 1, f"Expected at least 1 message output, got {len(message_outputs)}"
# Final message should be from assistant and completed
final_message = message_outputs[-1]
assert final_message.role == "assistant", (
f"Final message should be from assistant, got {final_message.role}"
)
assert final_message.status == "completed", f"Final message should be completed, got {final_message.status}"
assert len(final_message.content) > 0, "Final message should have content"
# Check that the expected output appears in the response
expected_output = case.expected
assert expected_output.lower() in final_response.output_text.lower(), (
f"Expected '{expected_output}' to appear in response: {final_response.output_text}"
)

View file

@ -0,0 +1,56 @@
{
"request": {
"method": "POST",
"url": "http://0.0.0.0:11434/v1/v1/chat/completions",
"headers": {},
"body": {
"model": "llama3.2:3b-instruct-fp16",
"messages": [
{
"role": "user",
"content": "Quick test"
}
],
"max_tokens": 5
},
"endpoint": "/v1/chat/completions",
"model": "llama3.2:3b-instruct-fp16"
},
"response": {
"body": {
"__type__": "openai.types.chat.chat_completion.ChatCompletion",
"__data__": {
"id": "chatcmpl-651",
"choices": [
{
"finish_reason": "length",
"index": 0,
"logprobs": null,
"message": {
"content": "I'm ready to help",
"refusal": null,
"role": "assistant",
"annotations": null,
"audio": null,
"function_call": null,
"tool_calls": null
}
}
],
"created": 1755294941,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": {
"completion_tokens": 5,
"prompt_tokens": 27,
"total_tokens": 32,
"completion_tokens_details": null,
"prompt_tokens_details": null
}
}
},
"is_streaming": false
}
}

View file

@ -0,0 +1,56 @@
{
"request": {
"method": "POST",
"url": "http://0.0.0.0:11434/v1/v1/chat/completions",
"headers": {},
"body": {
"model": "llama3.2:3b-instruct-fp16",
"messages": [
{
"role": "user",
"content": "Say hello"
}
],
"max_tokens": 20
},
"endpoint": "/v1/chat/completions",
"model": "llama3.2:3b-instruct-fp16"
},
"response": {
"body": {
"__type__": "openai.types.chat.chat_completion.ChatCompletion",
"__data__": {
"id": "chatcmpl-987",
"choices": [
{
"finish_reason": "length",
"index": 0,
"logprobs": null,
"message": {
"content": "Hello! It's nice to meet you. Is there something I can help you with or would you",
"refusal": null,
"role": "assistant",
"annotations": null,
"audio": null,
"function_call": null,
"tool_calls": null
}
}
],
"created": 1755294921,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": {
"completion_tokens": 20,
"prompt_tokens": 27,
"total_tokens": 47,
"completion_tokens_details": null,
"prompt_tokens_details": null
}
}
},
"is_streaming": false
}
}

View file

@ -14,7 +14,7 @@
"models": [
{
"model": "nomic-embed-text:latest",
"modified_at": "2025-08-05T14:04:07.946926-07:00",
"modified_at": "2025-08-15T21:55:08.088554Z",
"digest": "0a109f422b47e3a30ba2b10eca18548e944e8a23073ee3f3e947efcf3c45e59f",
"size": 274302450,
"details": {
@ -28,41 +28,9 @@
"quantization_level": "F16"
}
},
{
"model": "llama3.2-vision:11b",
"modified_at": "2025-07-30T18:45:02.517873-07:00",
"digest": "6f2f9757ae97e8a3f8ea33d6adb2b11d93d9a35bef277cd2c0b1b5af8e8d0b1e",
"size": 7816589186,
"details": {
"parent_model": "",
"format": "gguf",
"family": "mllama",
"families": [
"mllama"
],
"parameter_size": "10.7B",
"quantization_level": "Q4_K_M"
}
},
{
"model": "llama3.2-vision:latest",
"modified_at": "2025-07-29T20:18:47.920468-07:00",
"digest": "6f2f9757ae97e8a3f8ea33d6adb2b11d93d9a35bef277cd2c0b1b5af8e8d0b1e",
"size": 7816589186,
"details": {
"parent_model": "",
"format": "gguf",
"family": "mllama",
"families": [
"mllama"
],
"parameter_size": "10.7B",
"quantization_level": "Q4_K_M"
}
},
{
"model": "llama-guard3:1b",
"modified_at": "2025-07-25T14:39:44.978630-07:00",
"modified_at": "2025-07-31T04:44:58Z",
"digest": "494147e06bf99e10dbe67b63a07ac81c162f18ef3341aa3390007ac828571b3b",
"size": 1600181919,
"details": {
@ -78,7 +46,7 @@
},
{
"model": "all-minilm:l6-v2",
"modified_at": "2025-07-24T15:15:11.129290-07:00",
"modified_at": "2025-07-31T04:42:15Z",
"digest": "1b226e2802dbb772b5fc32a58f103ca1804ef7501331012de126ab22f67475ef",
"size": 45960996,
"details": {
@ -92,57 +60,9 @@
"quantization_level": "F16"
}
},
{
"model": "llama3.2:1b",
"modified_at": "2025-07-17T22:02:24.953208-07:00",
"digest": "baf6a787fdffd633537aa2eb51cfd54cb93ff08e28040095462bb63daf552878",
"size": 1321098329,
"details": {
"parent_model": "",
"format": "gguf",
"family": "llama",
"families": [
"llama"
],
"parameter_size": "1.2B",
"quantization_level": "Q8_0"
}
},
{
"model": "all-minilm:latest",
"modified_at": "2025-06-03T16:50:10.946583-07:00",
"digest": "1b226e2802dbb772b5fc32a58f103ca1804ef7501331012de126ab22f67475ef",
"size": 45960996,
"details": {
"parent_model": "",
"format": "gguf",
"family": "bert",
"families": [
"bert"
],
"parameter_size": "23M",
"quantization_level": "F16"
}
},
{
"model": "llama3.2:3b",
"modified_at": "2025-05-01T11:15:23.797447-07:00",
"digest": "a80c4f17acd55265feec403c7aef86be0c25983ab279d83f3bcd3abbcb5b8b72",
"size": 2019393189,
"details": {
"parent_model": "",
"format": "gguf",
"family": "llama",
"families": [
"llama"
],
"parameter_size": "3.2B",
"quantization_level": "Q4_K_M"
}
},
{
"model": "llama3.2:3b-instruct-fp16",
"modified_at": "2025-04-30T15:33:48.939665-07:00",
"modified_at": "2025-07-31T04:42:05Z",
"digest": "195a8c01d91ec3cb1e0aad4624a51f2602c51fa7d96110f8ab5a20c84081804d",
"size": 6433703586,
"details": {

View file

@ -21,7 +21,7 @@
"__type__": "ollama._types.GenerateResponse",
"__data__": {
"model": "llama3.2:3b-instruct-fp16",
"created_at": "2025-08-04T22:55:14.141947Z",
"created_at": "2025-08-15T20:24:49.18651486Z",
"done": false,
"done_reason": null,
"total_duration": null,
@ -39,7 +39,7 @@
"__type__": "ollama._types.GenerateResponse",
"__data__": {
"model": "llama3.2:3b-instruct-fp16",
"created_at": "2025-08-04T22:55:14.194979Z",
"created_at": "2025-08-15T20:24:49.370611348Z",
"done": false,
"done_reason": null,
"total_duration": null,
@ -57,7 +57,7 @@
"__type__": "ollama._types.GenerateResponse",
"__data__": {
"model": "llama3.2:3b-instruct-fp16",
"created_at": "2025-08-04T22:55:14.248312Z",
"created_at": "2025-08-15T20:24:49.557000029Z",
"done": false,
"done_reason": null,
"total_duration": null,
@ -75,7 +75,7 @@
"__type__": "ollama._types.GenerateResponse",
"__data__": {
"model": "llama3.2:3b-instruct-fp16",
"created_at": "2025-08-04T22:55:14.301911Z",
"created_at": "2025-08-15T20:24:49.746777116Z",
"done": false,
"done_reason": null,
"total_duration": null,
@ -93,7 +93,7 @@
"__type__": "ollama._types.GenerateResponse",
"__data__": {
"model": "llama3.2:3b-instruct-fp16",
"created_at": "2025-08-04T22:55:14.354437Z",
"created_at": "2025-08-15T20:24:49.942233333Z",
"done": false,
"done_reason": null,
"total_duration": null,
@ -111,7 +111,7 @@
"__type__": "ollama._types.GenerateResponse",
"__data__": {
"model": "llama3.2:3b-instruct-fp16",
"created_at": "2025-08-04T22:55:14.406821Z",
"created_at": "2025-08-15T20:24:50.126788846Z",
"done": false,
"done_reason": null,
"total_duration": null,
@ -129,7 +129,7 @@
"__type__": "ollama._types.GenerateResponse",
"__data__": {
"model": "llama3.2:3b-instruct-fp16",
"created_at": "2025-08-04T22:55:14.457633Z",
"created_at": "2025-08-15T20:24:50.311346131Z",
"done": false,
"done_reason": null,
"total_duration": null,
@ -147,7 +147,7 @@
"__type__": "ollama._types.GenerateResponse",
"__data__": {
"model": "llama3.2:3b-instruct-fp16",
"created_at": "2025-08-04T22:55:14.507857Z",
"created_at": "2025-08-15T20:24:50.501507173Z",
"done": false,
"done_reason": null,
"total_duration": null,
@ -165,7 +165,7 @@
"__type__": "ollama._types.GenerateResponse",
"__data__": {
"model": "llama3.2:3b-instruct-fp16",
"created_at": "2025-08-04T22:55:14.558847Z",
"created_at": "2025-08-15T20:24:50.692296777Z",
"done": false,
"done_reason": null,
"total_duration": null,
@ -183,7 +183,7 @@
"__type__": "ollama._types.GenerateResponse",
"__data__": {
"model": "llama3.2:3b-instruct-fp16",
"created_at": "2025-08-04T22:55:14.609969Z",
"created_at": "2025-08-15T20:24:50.878846539Z",
"done": false,
"done_reason": null,
"total_duration": null,
@ -201,15 +201,15 @@
"__type__": "ollama._types.GenerateResponse",
"__data__": {
"model": "llama3.2:3b-instruct-fp16",
"created_at": "2025-08-04T22:55:14.660997Z",
"created_at": "2025-08-15T20:24:51.063200561Z",
"done": true,
"done_reason": "stop",
"total_duration": 715356542,
"load_duration": 59747500,
"total_duration": 33982453650,
"load_duration": 2909001805,
"prompt_eval_count": 341,
"prompt_eval_duration": 128000000,
"prompt_eval_duration": 29194357307,
"eval_count": 11,
"eval_duration": 526000000,
"eval_duration": 1878247732,
"response": "",
"thinking": null,
"context": null

File diff suppressed because it is too large Load diff

View file

@ -11,26 +11,7 @@
"body": {
"__type__": "ollama._types.ProcessResponse",
"__data__": {
"models": [
{
"model": "llama3.2:3b",
"name": "llama3.2:3b",
"digest": "a80c4f17acd55265feec403c7aef86be0c25983ab279d83f3bcd3abbcb5b8b72",
"expires_at": "2025-08-06T15:57:21.573326-04:00",
"size": 4030033920,
"size_vram": 4030033920,
"details": {
"parent_model": "",
"format": "gguf",
"family": "llama",
"families": [
"llama"
],
"parameter_size": "3.2B",
"quantization_level": "Q4_K_M"
}
}
]
"models": []
}
},
"is_streaming": false

View file

@ -0,0 +1,109 @@
{
"request": {
"method": "POST",
"url": "http://localhost:11434/v1/v1/chat/completions",
"headers": {},
"body": {
"model": "llama3.2:3b-instruct-fp16",
"messages": [
{
"role": "user",
"content": "What's the weather in Tokyo? YOU MUST USE THE get_weather function to get the weather."
}
],
"response_format": {
"type": "text"
},
"stream": true,
"tools": [
{
"type": "function",
"function": {
"type": "function",
"name": "get_weather",
"description": "Get the weather in a given city",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "The city to get the weather for"
}
}
},
"strict": null
}
}
]
},
"endpoint": "/v1/chat/completions",
"model": "llama3.2:3b-instruct-fp16"
},
"response": {
"body": [
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-620",
"choices": [
{
"delta": {
"content": "",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": [
{
"index": 0,
"id": "call_490d5ur7",
"function": {
"arguments": "{\"city\":\"Tokyo\"}",
"name": "get_weather"
},
"type": "function"
}
]
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1755228972,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-620",
"choices": [
{
"delta": {
"content": "",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": "tool_calls",
"index": 0,
"logprobs": null
}
],
"created": 1755228972,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
}
],
"is_streaming": true
}
}

View file

@ -9,10 +9,11 @@ import time
from io import BytesIO
import pytest
from llama_stack_client import BadRequestError, LlamaStackClient
from llama_stack_client import BadRequestError
from openai import BadRequestError as OpenAIBadRequestError
from llama_stack.apis.vector_io import Chunk
from llama_stack.core.library_client import LlamaStackAsLibraryClient
logger = logging.getLogger(__name__)
@ -475,9 +476,6 @@ def test_openai_vector_store_attach_file(compat_client_with_empty_stores, client
"""Test OpenAI vector store attach file."""
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
if isinstance(compat_client_with_empty_stores, LlamaStackClient):
pytest.skip("Vector Store Files attach is not yet supported with LlamaStackClient")
compat_client = compat_client_with_empty_stores
# Create a vector store
@ -526,9 +524,6 @@ def test_openai_vector_store_attach_files_on_creation(compat_client_with_empty_s
"""Test OpenAI vector store attach files on creation."""
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
if isinstance(compat_client_with_empty_stores, LlamaStackClient):
pytest.skip("Vector Store Files attach is not yet supported with LlamaStackClient")
compat_client = compat_client_with_empty_stores
# Create some files and attach them to the vector store
@ -582,9 +577,6 @@ def test_openai_vector_store_list_files(compat_client_with_empty_stores, client_
"""Test OpenAI vector store list files."""
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
if isinstance(compat_client_with_empty_stores, LlamaStackClient):
pytest.skip("Vector Store Files list is not yet supported with LlamaStackClient")
compat_client = compat_client_with_empty_stores
# Create a vector store
@ -597,16 +589,20 @@ def test_openai_vector_store_list_files(compat_client_with_empty_stores, client_
file_buffer.name = f"openai_test_{i}.txt"
file = compat_client.files.create(file=file_buffer, purpose="assistants")
compat_client.vector_stores.files.create(
response = compat_client.vector_stores.files.create(
vector_store_id=vector_store.id,
file_id=file.id,
)
assert response is not None
assert response.status == "completed", (
f"Failed to attach file {file.id} to vector store {vector_store.id}: {response=}"
)
file_ids.append(file.id)
files_list = compat_client.vector_stores.files.list(vector_store_id=vector_store.id)
assert files_list
assert files_list.object == "list"
assert files_list.data
assert files_list.data is not None
assert not files_list.has_more
assert len(files_list.data) == 3
assert set(file_ids) == {file.id for file in files_list.data}
@ -642,12 +638,13 @@ def test_openai_vector_store_list_files_invalid_vector_store(compat_client_with_
"""Test OpenAI vector store list files with invalid vector store ID."""
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
if isinstance(compat_client_with_empty_stores, LlamaStackClient):
pytest.skip("Vector Store Files list is not yet supported with LlamaStackClient")
compat_client = compat_client_with_empty_stores
if isinstance(compat_client, LlamaStackAsLibraryClient):
errors = ValueError
else:
errors = (BadRequestError, OpenAIBadRequestError)
with pytest.raises((BadRequestError, OpenAIBadRequestError)):
with pytest.raises(errors):
compat_client.vector_stores.files.list(vector_store_id="abc123")
@ -655,9 +652,6 @@ def test_openai_vector_store_retrieve_file_contents(compat_client_with_empty_sto
"""Test OpenAI vector store retrieve file contents."""
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
if isinstance(compat_client_with_empty_stores, LlamaStackClient):
pytest.skip("Vector Store Files retrieve contents is not yet supported with LlamaStackClient")
compat_client = compat_client_with_empty_stores
# Create a vector store
@ -685,9 +679,15 @@ def test_openai_vector_store_retrieve_file_contents(compat_client_with_empty_sto
file_id=file.id,
)
assert file_contents
assert file_contents.content[0]["type"] == "text"
assert file_contents.content[0]["text"] == test_content.decode("utf-8")
assert file_contents is not None
assert len(file_contents.content) == 1
content = file_contents.content[0]
# llama-stack-client returns a model, openai-python is a badboy and returns a dict
if not isinstance(content, dict):
content = content.model_dump()
assert content["type"] == "text"
assert content["text"] == test_content.decode("utf-8")
assert file_contents.filename == file_name
assert file_contents.attributes == attributes
@ -696,9 +696,6 @@ def test_openai_vector_store_delete_file(compat_client_with_empty_stores, client
"""Test OpenAI vector store delete file."""
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
if isinstance(compat_client_with_empty_stores, LlamaStackClient):
pytest.skip("Vector Store Files list is not yet supported with LlamaStackClient")
compat_client = compat_client_with_empty_stores
# Create a vector store
@ -751,9 +748,6 @@ def test_openai_vector_store_delete_file_removes_from_vector_store(compat_client
"""Test OpenAI vector store delete file removes from vector store."""
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
if isinstance(compat_client_with_empty_stores, LlamaStackClient):
pytest.skip("Vector Store Files attach is not yet supported with LlamaStackClient")
compat_client = compat_client_with_empty_stores
# Create a vector store
@ -792,9 +786,6 @@ def test_openai_vector_store_update_file(compat_client_with_empty_stores, client
"""Test OpenAI vector store update file."""
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
if isinstance(compat_client_with_empty_stores, LlamaStackClient):
pytest.skip("Vector Store Files update is not yet supported with LlamaStackClient")
compat_client = compat_client_with_empty_stores
# Create a vector store
@ -840,9 +831,6 @@ def test_create_vector_store_files_duplicate_vector_store_name(compat_client_wit
"""
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
if isinstance(compat_client_with_empty_stores, LlamaStackClient):
pytest.skip("Vector Store Files create is not yet supported with LlamaStackClient")
compat_client = compat_client_with_empty_stores
# Create a vector store with files

View file

@ -0,0 +1,347 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from datetime import UTC, datetime
from unittest.mock import AsyncMock, patch
import pytest
from llama_stack.apis.agents import Session
from llama_stack.core.datatypes import User
from llama_stack.providers.inline.agents.meta_reference.persistence import (
AgentPersistence,
AgentSessionInfo,
)
from llama_stack.providers.utils.kvstore import KVStore
@pytest.fixture
def mock_kvstore():
return AsyncMock(spec=KVStore)
@pytest.fixture
def mock_policy():
return []
@pytest.fixture
def agent_persistence(mock_kvstore, mock_policy):
return AgentPersistence(agent_id="test-agent-123", kvstore=mock_kvstore, policy=mock_policy)
@pytest.fixture
def sample_session():
return AgentSessionInfo(
session_id="session-123",
session_name="Test Session",
started_at=datetime.now(UTC),
owner=User(principal="user-123", attributes=None),
turns=[],
identifier="test-session",
type="session",
)
@pytest.fixture
def sample_session_json(sample_session):
return sample_session.model_dump_json()
class TestAgentPersistenceListSessions:
def setup_mock_kvstore(self, mock_kvstore, session_keys=None, turn_keys=None, invalid_keys=None, custom_data=None):
"""Helper to setup mock kvstore with sessions, turns, and custom/invalid data
Args:
mock_kvstore: The mock KVStore object
session_keys: List of session keys or dict mapping keys to custom session data
turn_keys: List of turn keys or dict mapping keys to custom turn data
invalid_keys: Dict mapping keys to invalid/corrupt data
custom_data: Additional custom data to add to the mock responses
"""
all_keys = []
mock_data = {}
# session keys
if session_keys:
if isinstance(session_keys, dict):
all_keys.extend(session_keys.keys())
mock_data.update({k: json.dumps(v) if isinstance(v, dict) else v for k, v in session_keys.items()})
else:
all_keys.extend(session_keys)
for key in session_keys:
session_id = key.split(":")[-1]
mock_data[key] = json.dumps(
{
"session_id": session_id,
"session_name": f"Session {session_id}",
"started_at": datetime.now(UTC).isoformat(),
"turns": [],
}
)
# turn keys
if turn_keys:
if isinstance(turn_keys, dict):
all_keys.extend(turn_keys.keys())
mock_data.update({k: json.dumps(v) if isinstance(v, dict) else v for k, v in turn_keys.items()})
else:
all_keys.extend(turn_keys)
for key in turn_keys:
parts = key.split(":")
session_id = parts[-2]
turn_id = parts[-1]
mock_data[key] = json.dumps(
{
"turn_id": turn_id,
"session_id": session_id,
"input_messages": [],
"started_at": datetime.now(UTC).isoformat(),
}
)
if invalid_keys:
all_keys.extend(invalid_keys.keys())
mock_data.update(invalid_keys)
if custom_data:
mock_data.update(custom_data)
values_list = list(mock_data.values())
mock_kvstore.values_in_range.return_value = values_list
async def mock_get(key):
return mock_data.get(key)
mock_kvstore.get.side_effect = mock_get
return mock_data
@pytest.mark.parametrize(
"scenario",
[
{
# from this issue: https://github.com/meta-llama/llama-stack/issues/3048
"name": "reported_bug",
"session_keys": ["session:test-agent-123:1f08fd1c-5a9d-459d-a00b-36d4dfa49b7d"],
"turn_keys": [
"session:test-agent-123:1f08fd1c-5a9d-459d-a00b-36d4dfa49b7d:eb7e818f-41fb-49a0-bdd6-464974a2d2ad"
],
"expected_sessions": ["1f08fd1c-5a9d-459d-a00b-36d4dfa49b7d"],
},
{
"name": "basic_filtering",
"session_keys": ["session:test-agent-123:session-1", "session:test-agent-123:session-2"],
"turn_keys": ["session:test-agent-123:session-1:turn-1", "session:test-agent-123:session-1:turn-2"],
"expected_sessions": ["session-1", "session-2"],
},
{
"name": "multiple_turns_per_session",
"session_keys": ["session:test-agent-123:session-456"],
"turn_keys": [
"session:test-agent-123:session-456:turn-789",
"session:test-agent-123:session-456:turn-790",
],
"expected_sessions": ["session-456"],
},
{
"name": "multiple_sessions_with_turns",
"session_keys": ["session:test-agent-123:session-1", "session:test-agent-123:session-2"],
"turn_keys": [
"session:test-agent-123:session-1:turn-1",
"session:test-agent-123:session-1:turn-2",
"session:test-agent-123:session-2:turn-3",
],
"expected_sessions": ["session-1", "session-2"],
},
],
)
async def test_list_sessions_key_filtering(self, agent_persistence, mock_kvstore, scenario):
self.setup_mock_kvstore(mock_kvstore, session_keys=scenario["session_keys"], turn_keys=scenario["turn_keys"])
with patch("llama_stack.providers.inline.agents.meta_reference.persistence.log") as mock_log:
result = await agent_persistence.list_sessions()
assert len(result) == len(scenario["expected_sessions"])
session_ids = {s.session_id for s in result}
for expected_id in scenario["expected_sessions"]:
assert expected_id in session_ids
# no errors should be logged
mock_log.error.assert_not_called()
@pytest.mark.parametrize(
"error_scenario",
[
{
"name": "invalid_json",
"valid_keys": ["session:test-agent-123:valid-session"],
"invalid_data": {"session:test-agent-123:invalid-json": "corrupted-json-data{"},
"expected_valid_sessions": ["valid-session"],
"expected_error_count": 1,
},
{
"name": "missing_fields",
"valid_keys": ["session:test-agent-123:valid-session"],
"invalid_data": {
"session:test-agent-123:invalid-schema": json.dumps(
{
"session_id": "invalid-schema",
"session_name": "Missing Fields",
# missing `started_at` and `turns`
}
)
},
"expected_valid_sessions": ["valid-session"],
"expected_error_count": 1,
},
{
"name": "multiple_invalid",
"valid_keys": ["session:test-agent-123:valid-session-1", "session:test-agent-123:valid-session-2"],
"invalid_data": {
"session:test-agent-123:corrupted-json": "not-valid-json{",
"session:test-agent-123:incomplete-data": json.dumps({"incomplete": "data"}),
},
"expected_valid_sessions": ["valid-session-1", "valid-session-2"],
"expected_error_count": 2,
},
],
)
async def test_list_sessions_error_handling(self, agent_persistence, mock_kvstore, error_scenario):
session_keys = {}
for key in error_scenario["valid_keys"]:
session_id = key.split(":")[-1]
session_keys[key] = {
"session_id": session_id,
"session_name": f"Valid {session_id}",
"started_at": datetime.now(UTC).isoformat(),
"turns": [],
}
self.setup_mock_kvstore(mock_kvstore, session_keys=session_keys, invalid_keys=error_scenario["invalid_data"])
with patch("llama_stack.providers.inline.agents.meta_reference.persistence.log") as mock_log:
result = await agent_persistence.list_sessions()
# only valid sessions should be returned
assert len(result) == len(error_scenario["expected_valid_sessions"])
session_ids = {s.session_id for s in result}
for expected_id in error_scenario["expected_valid_sessions"]:
assert expected_id in session_ids
# error should be logged
assert mock_log.error.call_count > 0
assert mock_log.error.call_count == error_scenario["expected_error_count"]
async def test_list_sessions_empty(self, agent_persistence, mock_kvstore):
mock_kvstore.values_in_range.return_value = []
result = await agent_persistence.list_sessions()
assert result == []
mock_kvstore.values_in_range.assert_called_once_with(
start_key="session:test-agent-123:", end_key="session:test-agent-123:\xff\xff\xff\xff"
)
async def test_list_sessions_properties(self, agent_persistence, mock_kvstore):
session_data = {
"session_id": "session-123",
"session_name": "Test Session",
"started_at": datetime.now(UTC).isoformat(),
"owner": {"principal": "user-123", "attributes": None},
"turns": [],
}
self.setup_mock_kvstore(mock_kvstore, session_keys={"session:test-agent-123:session-123": session_data})
result = await agent_persistence.list_sessions()
assert len(result) == 1
assert isinstance(result[0], Session)
assert result[0].session_id == "session-123"
assert result[0].session_name == "Test Session"
assert result[0].turns == []
assert hasattr(result[0], "started_at")
async def test_list_sessions_kvstore_exception(self, agent_persistence, mock_kvstore):
mock_kvstore.values_in_range.side_effect = Exception("KVStore error")
with pytest.raises(Exception, match="KVStore error"):
await agent_persistence.list_sessions()
async def test_bug_data_loss_with_real_data(self, agent_persistence, mock_kvstore):
# tests the handling of the issue reported in: https://github.com/meta-llama/llama-stack/issues/3048
session_data = {
"session_id": "1f08fd1c-5a9d-459d-a00b-36d4dfa49b7d",
"session_name": "Test Session",
"started_at": datetime.now(UTC).isoformat(),
"turns": [],
}
turn_data = {
"turn_id": "eb7e818f-41fb-49a0-bdd6-464974a2d2ad",
"session_id": "1f08fd1c-5a9d-459d-a00b-36d4dfa49b7d",
"input_messages": [
{"role": "user", "content": "if i had a cluster i would want to call it persistence01", "context": None}
],
"steps": [
{
"turn_id": "eb7e818f-41fb-49a0-bdd6-464974a2d2ad",
"step_id": "c0f797dd-3d34-4bc5-a8f4-db6af9455132",
"started_at": "2025-08-05T14:31:50.000484Z",
"completed_at": "2025-08-05T14:31:51.303691Z",
"step_type": "inference",
"model_response": {
"role": "assistant",
"content": "OK, I can create a cluster named 'persistence01' for you.",
"stop_reason": "end_of_turn",
"tool_calls": [],
},
}
],
"output_message": {
"role": "assistant",
"content": "OK, I can create a cluster named 'persistence01' for you.",
"stop_reason": "end_of_turn",
"tool_calls": [],
},
"output_attachments": [],
"started_at": "2025-08-05T14:31:49.999950Z",
"completed_at": "2025-08-05T14:31:51.305384Z",
}
mock_data = {
"session:test-agent-123:1f08fd1c-5a9d-459d-a00b-36d4dfa49b7d": json.dumps(session_data),
"session:test-agent-123:1f08fd1c-5a9d-459d-a00b-36d4dfa49b7d:eb7e818f-41fb-49a0-bdd6-464974a2d2ad": json.dumps(
turn_data
),
}
mock_kvstore.values_in_range.return_value = list(mock_data.values())
async def mock_get(key):
return mock_data.get(key)
mock_kvstore.get.side_effect = mock_get
with patch("llama_stack.providers.inline.agents.meta_reference.persistence.log") as mock_log:
result = await agent_persistence.list_sessions()
assert len(result) == 1
assert result[0].session_id == "1f08fd1c-5a9d-459d-a00b-36d4dfa49b7d"
# confirm no errors logged
mock_log.error.assert_not_called()
async def test_list_sessions_key_range_construction(self, agent_persistence, mock_kvstore):
mock_kvstore.values_in_range.return_value = []
await agent_persistence.list_sessions()
mock_kvstore.values_in_range.assert_called_once_with(
start_key="session:test-agent-123:", end_key="session:test-agent-123:\xff\xff\xff\xff"
)

View file

@ -41,7 +41,7 @@ from llama_stack.apis.inference import (
)
from llama_stack.apis.tools.tools import Tool, ToolGroups, ToolInvocationResult, ToolParameter, ToolRuntime
from llama_stack.core.access_control.access_control import default_policy
from llama_stack.providers.inline.agents.meta_reference.openai_responses import (
from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import (
OpenAIResponsesImpl,
)
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
@ -136,9 +136,12 @@ async def test_create_openai_response_with_string_input(openai_responses_impl, m
input=input_text,
model=model,
temperature=0.1,
stream=True, # Enable streaming to test content part events
)
# Verify
# For streaming response, collect all chunks
chunks = [chunk async for chunk in result]
mock_inference_api.openai_chat_completion.assert_called_once_with(
model=model,
messages=[OpenAIUserMessageParam(role="user", content="What is the capital of Ireland?", name=None)],
@ -147,11 +150,32 @@ async def test_create_openai_response_with_string_input(openai_responses_impl, m
stream=True,
temperature=0.1,
)
# Should have content part events for text streaming
# Expected: response.created, content_part.added, output_text.delta, content_part.done, response.completed
assert len(chunks) >= 4
assert chunks[0].type == "response.created"
# Check for content part events
content_part_added_events = [c for c in chunks if c.type == "response.content_part.added"]
content_part_done_events = [c for c in chunks if c.type == "response.content_part.done"]
text_delta_events = [c for c in chunks if c.type == "response.output_text.delta"]
assert len(content_part_added_events) >= 1, "Should have content_part.added event for text"
assert len(content_part_done_events) >= 1, "Should have content_part.done event for text"
assert len(text_delta_events) >= 1, "Should have text delta events"
# Verify final event is completion
assert chunks[-1].type == "response.completed"
# When streaming, the final response is in the last chunk
final_response = chunks[-1].response
assert final_response.model == model
assert len(final_response.output) == 1
assert isinstance(final_response.output[0], OpenAIResponseMessage)
openai_responses_impl.responses_store.store_response_object.assert_called_once()
assert result.model == model
assert len(result.output) == 1
assert isinstance(result.output[0], OpenAIResponseMessage)
assert result.output[0].content[0].text == "Dublin"
assert final_response.output[0].content[0].text == "Dublin"
async def test_create_openai_response_with_string_input_with_tools(openai_responses_impl, mock_inference_api):
@ -272,7 +296,11 @@ async def test_create_openai_response_with_tool_call_type_none(openai_responses_
# Check that we got the content from our mocked tool execution result
chunks = [chunk async for chunk in result]
assert len(chunks) == 2 # Should have response.created and response.completed
# Verify event types
# Should have: response.created, output_item.added, function_call_arguments.delta,
# function_call_arguments.done, output_item.done, response.completed
assert len(chunks) == 6
# Verify inference API was called correctly (after iterating over result)
first_call = mock_inference_api.openai_chat_completion.call_args_list[0]
@ -284,11 +312,17 @@ async def test_create_openai_response_with_tool_call_type_none(openai_responses_
assert chunks[0].type == "response.created"
assert len(chunks[0].response.output) == 0
# Check streaming events
assert chunks[1].type == "response.output_item.added"
assert chunks[2].type == "response.function_call_arguments.delta"
assert chunks[3].type == "response.function_call_arguments.done"
assert chunks[4].type == "response.output_item.done"
# Check response.completed event (should have the tool call)
assert chunks[1].type == "response.completed"
assert len(chunks[1].response.output) == 1
assert chunks[1].response.output[0].type == "function_call"
assert chunks[1].response.output[0].name == "get_weather"
assert chunks[5].type == "response.completed"
assert len(chunks[5].response.output) == 1
assert chunks[5].response.output[0].type == "function_call"
assert chunks[5].response.output[0].name == "get_weather"
async def test_create_openai_response_with_multiple_messages(openai_responses_impl, mock_inference_api):

View file

@ -0,0 +1,310 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from llama_stack.apis.agents.openai_responses import (
OpenAIResponseInputFunctionToolCallOutput,
OpenAIResponseInputMessageContentImage,
OpenAIResponseInputMessageContentText,
OpenAIResponseInputToolFunction,
OpenAIResponseInputToolWebSearch,
OpenAIResponseMessage,
OpenAIResponseOutputMessageContentOutputText,
OpenAIResponseOutputMessageFunctionToolCall,
OpenAIResponseText,
OpenAIResponseTextFormat,
)
from llama_stack.apis.inference import (
OpenAIAssistantMessageParam,
OpenAIChatCompletionContentPartImageParam,
OpenAIChatCompletionContentPartTextParam,
OpenAIChatCompletionToolCall,
OpenAIChatCompletionToolCallFunction,
OpenAIChoice,
OpenAIDeveloperMessageParam,
OpenAIResponseFormatJSONObject,
OpenAIResponseFormatJSONSchema,
OpenAIResponseFormatText,
OpenAISystemMessageParam,
OpenAIToolMessageParam,
OpenAIUserMessageParam,
)
from llama_stack.providers.inline.agents.meta_reference.responses.utils import (
convert_chat_choice_to_response_message,
convert_response_content_to_chat_content,
convert_response_input_to_chat_messages,
convert_response_text_to_chat_response_format,
get_message_type_by_role,
is_function_tool_call,
)
class TestConvertChatChoiceToResponseMessage:
@pytest.mark.asyncio
async def test_convert_string_content(self):
choice = OpenAIChoice(
message=OpenAIAssistantMessageParam(content="Test message"),
finish_reason="stop",
index=0,
)
result = await convert_chat_choice_to_response_message(choice)
assert result.role == "assistant"
assert result.status == "completed"
assert len(result.content) == 1
assert isinstance(result.content[0], OpenAIResponseOutputMessageContentOutputText)
assert result.content[0].text == "Test message"
@pytest.mark.asyncio
async def test_convert_text_param_content(self):
choice = OpenAIChoice(
message=OpenAIAssistantMessageParam(
content=[OpenAIChatCompletionContentPartTextParam(text="Test text param")]
),
finish_reason="stop",
index=0,
)
with pytest.raises(ValueError) as exc_info:
await convert_chat_choice_to_response_message(choice)
assert "does not yet support output content type" in str(exc_info.value)
class TestConvertResponseContentToChatContent:
@pytest.mark.asyncio
async def test_convert_string_content(self):
result = await convert_response_content_to_chat_content("Simple string")
assert result == "Simple string"
@pytest.mark.asyncio
async def test_convert_text_content_parts(self):
content = [
OpenAIResponseInputMessageContentText(text="First part"),
OpenAIResponseOutputMessageContentOutputText(text="Second part"),
]
result = await convert_response_content_to_chat_content(content)
assert len(result) == 2
assert isinstance(result[0], OpenAIChatCompletionContentPartTextParam)
assert result[0].text == "First part"
assert isinstance(result[1], OpenAIChatCompletionContentPartTextParam)
assert result[1].text == "Second part"
@pytest.mark.asyncio
async def test_convert_image_content(self):
content = [OpenAIResponseInputMessageContentImage(image_url="https://example.com/image.jpg", detail="high")]
result = await convert_response_content_to_chat_content(content)
assert len(result) == 1
assert isinstance(result[0], OpenAIChatCompletionContentPartImageParam)
assert result[0].image_url.url == "https://example.com/image.jpg"
assert result[0].image_url.detail == "high"
class TestConvertResponseInputToChatMessages:
@pytest.mark.asyncio
async def test_convert_string_input(self):
result = await convert_response_input_to_chat_messages("User message")
assert len(result) == 1
assert isinstance(result[0], OpenAIUserMessageParam)
assert result[0].content == "User message"
@pytest.mark.asyncio
async def test_convert_function_tool_call_output(self):
input_items = [
OpenAIResponseInputFunctionToolCallOutput(
output="Tool output",
call_id="call_123",
)
]
result = await convert_response_input_to_chat_messages(input_items)
assert len(result) == 1
assert isinstance(result[0], OpenAIToolMessageParam)
assert result[0].content == "Tool output"
assert result[0].tool_call_id == "call_123"
@pytest.mark.asyncio
async def test_convert_function_tool_call(self):
input_items = [
OpenAIResponseOutputMessageFunctionToolCall(
call_id="call_456",
name="test_function",
arguments='{"param": "value"}',
)
]
result = await convert_response_input_to_chat_messages(input_items)
assert len(result) == 1
assert isinstance(result[0], OpenAIAssistantMessageParam)
assert len(result[0].tool_calls) == 1
assert result[0].tool_calls[0].id == "call_456"
assert result[0].tool_calls[0].function.name == "test_function"
assert result[0].tool_calls[0].function.arguments == '{"param": "value"}'
@pytest.mark.asyncio
async def test_convert_response_message(self):
input_items = [
OpenAIResponseMessage(
role="user",
content=[OpenAIResponseInputMessageContentText(text="User text")],
)
]
result = await convert_response_input_to_chat_messages(input_items)
assert len(result) == 1
assert isinstance(result[0], OpenAIUserMessageParam)
# Content should be converted to chat content format
assert len(result[0].content) == 1
assert result[0].content[0].text == "User text"
class TestConvertResponseTextToChatResponseFormat:
@pytest.mark.asyncio
async def test_convert_text_format(self):
text = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text"))
result = await convert_response_text_to_chat_response_format(text)
assert isinstance(result, OpenAIResponseFormatText)
assert result.type == "text"
@pytest.mark.asyncio
async def test_convert_json_object_format(self):
text = OpenAIResponseText(format={"type": "json_object"})
result = await convert_response_text_to_chat_response_format(text)
assert isinstance(result, OpenAIResponseFormatJSONObject)
@pytest.mark.asyncio
async def test_convert_json_schema_format(self):
schema_def = {"type": "object", "properties": {"test": {"type": "string"}}}
text = OpenAIResponseText(
format={
"type": "json_schema",
"name": "test_schema",
"schema": schema_def,
}
)
result = await convert_response_text_to_chat_response_format(text)
assert isinstance(result, OpenAIResponseFormatJSONSchema)
assert result.json_schema["name"] == "test_schema"
assert result.json_schema["schema"] == schema_def
@pytest.mark.asyncio
async def test_default_text_format(self):
text = OpenAIResponseText()
result = await convert_response_text_to_chat_response_format(text)
assert isinstance(result, OpenAIResponseFormatText)
assert result.type == "text"
class TestGetMessageTypeByRole:
@pytest.mark.asyncio
async def test_user_role(self):
result = await get_message_type_by_role("user")
assert result == OpenAIUserMessageParam
@pytest.mark.asyncio
async def test_system_role(self):
result = await get_message_type_by_role("system")
assert result == OpenAISystemMessageParam
@pytest.mark.asyncio
async def test_assistant_role(self):
result = await get_message_type_by_role("assistant")
assert result == OpenAIAssistantMessageParam
@pytest.mark.asyncio
async def test_developer_role(self):
result = await get_message_type_by_role("developer")
assert result == OpenAIDeveloperMessageParam
@pytest.mark.asyncio
async def test_unknown_role(self):
result = await get_message_type_by_role("unknown")
assert result is None
class TestIsFunctionToolCall:
def test_is_function_tool_call_true(self):
tool_call = OpenAIChatCompletionToolCall(
index=0,
id="call_123",
function=OpenAIChatCompletionToolCallFunction(
name="test_function",
arguments="{}",
),
)
tools = [
OpenAIResponseInputToolFunction(
type="function", name="test_function", parameters={"type": "object", "properties": {}}
),
OpenAIResponseInputToolWebSearch(type="web_search"),
]
result = is_function_tool_call(tool_call, tools)
assert result is True
def test_is_function_tool_call_false_different_name(self):
tool_call = OpenAIChatCompletionToolCall(
index=0,
id="call_123",
function=OpenAIChatCompletionToolCallFunction(
name="other_function",
arguments="{}",
),
)
tools = [
OpenAIResponseInputToolFunction(
type="function", name="test_function", parameters={"type": "object", "properties": {}}
),
]
result = is_function_tool_call(tool_call, tools)
assert result is False
def test_is_function_tool_call_false_no_function(self):
tool_call = OpenAIChatCompletionToolCall(
index=0,
id="call_123",
function=None,
)
tools = [
OpenAIResponseInputToolFunction(
type="function", name="test_function", parameters={"type": "object", "properties": {}}
),
]
result = is_function_tool_call(tool_call, tools)
assert result is False
def test_is_function_tool_call_false_wrong_type(self):
tool_call = OpenAIChatCompletionToolCall(
index=0,
id="call_123",
function=OpenAIChatCompletionToolCallFunction(
name="web_search",
arguments="{}",
),
)
tools = [
OpenAIResponseInputToolWebSearch(type="web_search"),
]
result = is_function_tool_call(tool_call, tools)
assert result is False

View file

@ -0,0 +1,753 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
"""
Test suite for the reference implementation of the Batches API.
The tests are categorized and outlined below, keep this updated:
- Batch creation with various parameters and validation:
* test_create_and_retrieve_batch_success (positive)
* test_create_batch_without_metadata (positive)
* test_create_batch_completion_window (negative)
* test_create_batch_invalid_endpoints (negative)
* test_create_batch_invalid_metadata (negative)
- Batch retrieval and error handling for non-existent batches:
* test_retrieve_batch_not_found (negative)
- Batch cancellation with proper status transitions:
* test_cancel_batch_success (positive)
* test_cancel_batch_invalid_statuses (negative)
* test_cancel_batch_not_found (negative)
- Batch listing with pagination and filtering:
* test_list_batches_empty (positive)
* test_list_batches_single_batch (positive)
* test_list_batches_multiple_batches (positive)
* test_list_batches_with_limit (positive)
* test_list_batches_with_pagination (positive)
* test_list_batches_invalid_after (negative)
- Data persistence in the underlying key-value store:
* test_kvstore_persistence (positive)
- Batch processing concurrency control:
* test_max_concurrent_batches (positive)
- Input validation testing (direct _validate_input method tests):
* test_validate_input_file_not_found (negative)
* test_validate_input_file_exists_empty_content (positive)
* test_validate_input_file_mixed_valid_invalid_json (mixed)
* test_validate_input_invalid_model (negative)
* test_validate_input_url_mismatch (negative)
* test_validate_input_multiple_errors_per_request (negative)
* test_validate_input_invalid_request_format (negative)
* test_validate_input_missing_parameters (parametrized negative - custom_id, method, url, body, model, messages missing validation)
* test_validate_input_invalid_parameter_types (parametrized negative - custom_id, url, method, body, model, messages type validation)
The tests use temporary SQLite databases for isolation and mock external
dependencies like inference, files, and models APIs.
"""
import json
import tempfile
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock
import pytest
from llama_stack.apis.batches import BatchObject
from llama_stack.apis.common.errors import ConflictError, ResourceNotFoundError
from llama_stack.providers.inline.batches.reference.batches import ReferenceBatchesImpl
from llama_stack.providers.inline.batches.reference.config import ReferenceBatchesImplConfig
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
class TestReferenceBatchesImpl:
"""Test the reference implementation of the Batches API."""
@pytest.fixture
async def provider(self):
"""Create a test provider instance with temporary database."""
with tempfile.TemporaryDirectory() as tmpdir:
db_path = Path(tmpdir) / "test_batches.db"
kvstore_config = SqliteKVStoreConfig(db_path=str(db_path))
config = ReferenceBatchesImplConfig(kvstore=kvstore_config)
# Create kvstore and mock APIs
from unittest.mock import AsyncMock
from llama_stack.providers.utils.kvstore import kvstore_impl
kvstore = await kvstore_impl(config.kvstore)
mock_inference = AsyncMock()
mock_files = AsyncMock()
mock_models = AsyncMock()
provider = ReferenceBatchesImpl(config, mock_inference, mock_files, mock_models, kvstore)
await provider.initialize()
# unit tests should not require background processing
provider.process_batches = False
yield provider
await provider.shutdown()
@pytest.fixture
def sample_batch_data(self):
"""Sample batch data for testing."""
return {
"input_file_id": "file_abc123",
"endpoint": "/v1/chat/completions",
"completion_window": "24h",
"metadata": {"test": "true", "priority": "high"},
}
def _validate_batch_type(self, batch, expected_metadata=None):
"""
Helper function to validate batch object structure and field types.
Note: This validates the direct BatchObject from the provider, not the
client library response which has a different structure.
Args:
batch: The BatchObject instance to validate.
expected_metadata: Optional expected metadata dictionary to validate against.
"""
assert isinstance(batch.id, str)
assert isinstance(batch.completion_window, str)
assert isinstance(batch.created_at, int)
assert isinstance(batch.endpoint, str)
assert isinstance(batch.input_file_id, str)
assert batch.object == "batch"
assert batch.status in [
"validating",
"failed",
"in_progress",
"finalizing",
"completed",
"expired",
"cancelling",
"cancelled",
]
if expected_metadata is not None:
assert batch.metadata == expected_metadata
timestamp_fields = [
"cancelled_at",
"cancelling_at",
"completed_at",
"expired_at",
"expires_at",
"failed_at",
"finalizing_at",
"in_progress_at",
]
for field in timestamp_fields:
field_value = getattr(batch, field, None)
if field_value is not None:
assert isinstance(field_value, int), f"{field} should be int or None, got {type(field_value)}"
file_id_fields = ["error_file_id", "output_file_id"]
for field in file_id_fields:
field_value = getattr(batch, field, None)
if field_value is not None:
assert isinstance(field_value, str), f"{field} should be str or None, got {type(field_value)}"
if hasattr(batch, "request_counts") and batch.request_counts is not None:
assert isinstance(batch.request_counts.completed, int), (
f"request_counts.completed should be int, got {type(batch.request_counts.completed)}"
)
assert isinstance(batch.request_counts.failed, int), (
f"request_counts.failed should be int, got {type(batch.request_counts.failed)}"
)
assert isinstance(batch.request_counts.total, int), (
f"request_counts.total should be int, got {type(batch.request_counts.total)}"
)
if hasattr(batch, "errors") and batch.errors is not None:
assert isinstance(batch.errors, dict), f"errors should be object or dict, got {type(batch.errors)}"
if hasattr(batch.errors, "data") and batch.errors.data is not None:
assert isinstance(batch.errors.data, list), (
f"errors.data should be list or None, got {type(batch.errors.data)}"
)
for i, error_item in enumerate(batch.errors.data):
assert isinstance(error_item, dict), (
f"errors.data[{i}] should be object or dict, got {type(error_item)}"
)
if hasattr(error_item, "code") and error_item.code is not None:
assert isinstance(error_item.code, str), (
f"errors.data[{i}].code should be str or None, got {type(error_item.code)}"
)
if hasattr(error_item, "line") and error_item.line is not None:
assert isinstance(error_item.line, int), (
f"errors.data[{i}].line should be int or None, got {type(error_item.line)}"
)
if hasattr(error_item, "message") and error_item.message is not None:
assert isinstance(error_item.message, str), (
f"errors.data[{i}].message should be str or None, got {type(error_item.message)}"
)
if hasattr(error_item, "param") and error_item.param is not None:
assert isinstance(error_item.param, str), (
f"errors.data[{i}].param should be str or None, got {type(error_item.param)}"
)
if hasattr(batch.errors, "object") and batch.errors.object is not None:
assert isinstance(batch.errors.object, str), (
f"errors.object should be str or None, got {type(batch.errors.object)}"
)
assert batch.errors.object == "list", f"errors.object should be 'list', got {batch.errors.object}"
async def test_create_and_retrieve_batch_success(self, provider, sample_batch_data):
"""Test successful batch creation and retrieval."""
created_batch = await provider.create_batch(**sample_batch_data)
self._validate_batch_type(created_batch, expected_metadata=sample_batch_data["metadata"])
assert created_batch.id.startswith("batch_")
assert len(created_batch.id) > 13
assert created_batch.object == "batch"
assert created_batch.endpoint == sample_batch_data["endpoint"]
assert created_batch.input_file_id == sample_batch_data["input_file_id"]
assert created_batch.completion_window == sample_batch_data["completion_window"]
assert created_batch.status == "validating"
assert created_batch.metadata == sample_batch_data["metadata"]
assert isinstance(created_batch.created_at, int)
assert created_batch.created_at > 0
retrieved_batch = await provider.retrieve_batch(created_batch.id)
self._validate_batch_type(retrieved_batch, expected_metadata=sample_batch_data["metadata"])
assert retrieved_batch.id == created_batch.id
assert retrieved_batch.input_file_id == created_batch.input_file_id
assert retrieved_batch.endpoint == created_batch.endpoint
assert retrieved_batch.status == created_batch.status
assert retrieved_batch.metadata == created_batch.metadata
async def test_create_batch_without_metadata(self, provider):
"""Test batch creation without optional metadata."""
batch = await provider.create_batch(
input_file_id="file_123", endpoint="/v1/chat/completions", completion_window="24h"
)
assert batch.metadata is None
async def test_create_batch_completion_window(self, provider):
"""Test batch creation with invalid completion window."""
with pytest.raises(ValueError, match="Invalid completion_window"):
await provider.create_batch(
input_file_id="file_123", endpoint="/v1/chat/completions", completion_window="now"
)
@pytest.mark.parametrize(
"endpoint",
[
"/v1/embeddings",
"/v1/completions",
"/v1/invalid/endpoint",
"",
],
)
async def test_create_batch_invalid_endpoints(self, provider, endpoint):
"""Test batch creation with various invalid endpoints."""
with pytest.raises(ValueError, match="Invalid endpoint"):
await provider.create_batch(input_file_id="file_123", endpoint=endpoint, completion_window="24h")
async def test_create_batch_invalid_metadata(self, provider):
"""Test that batch creation fails with invalid metadata."""
with pytest.raises(ValueError, match="should be a valid string"):
await provider.create_batch(
input_file_id="file_123",
endpoint="/v1/chat/completions",
completion_window="24h",
metadata={123: "invalid_key"}, # Non-string key
)
with pytest.raises(ValueError, match="should be a valid string"):
await provider.create_batch(
input_file_id="file_123",
endpoint="/v1/chat/completions",
completion_window="24h",
metadata={"valid_key": 456}, # Non-string value
)
async def test_retrieve_batch_not_found(self, provider):
"""Test error when retrieving non-existent batch."""
with pytest.raises(ResourceNotFoundError, match=r"Batch 'nonexistent_batch' not found"):
await provider.retrieve_batch("nonexistent_batch")
async def test_cancel_batch_success(self, provider, sample_batch_data):
"""Test successful batch cancellation."""
created_batch = await provider.create_batch(**sample_batch_data)
assert created_batch.status == "validating"
cancelled_batch = await provider.cancel_batch(created_batch.id)
assert cancelled_batch.id == created_batch.id
assert cancelled_batch.status in ["cancelling", "cancelled"]
assert isinstance(cancelled_batch.cancelling_at, int)
assert cancelled_batch.cancelling_at >= created_batch.created_at
@pytest.mark.parametrize("status", ["failed", "expired", "completed"])
async def test_cancel_batch_invalid_statuses(self, provider, sample_batch_data, status):
"""Test error when cancelling batch in final states."""
provider.process_batches = False
created_batch = await provider.create_batch(**sample_batch_data)
# directly update status in kvstore
await provider._update_batch(created_batch.id, status=status)
with pytest.raises(ConflictError, match=f"Cannot cancel batch '{created_batch.id}' with status '{status}'"):
await provider.cancel_batch(created_batch.id)
async def test_cancel_batch_not_found(self, provider):
"""Test error when cancelling non-existent batch."""
with pytest.raises(ResourceNotFoundError, match=r"Batch 'nonexistent_batch' not found"):
await provider.cancel_batch("nonexistent_batch")
async def test_list_batches_empty(self, provider):
"""Test listing batches when none exist."""
response = await provider.list_batches()
assert response.object == "list"
assert response.data == []
assert response.first_id is None
assert response.last_id is None
assert response.has_more is False
async def test_list_batches_single_batch(self, provider, sample_batch_data):
"""Test listing batches with single batch."""
created_batch = await provider.create_batch(**sample_batch_data)
response = await provider.list_batches()
assert len(response.data) == 1
self._validate_batch_type(response.data[0], expected_metadata=sample_batch_data["metadata"])
assert response.data[0].id == created_batch.id
assert response.first_id == created_batch.id
assert response.last_id == created_batch.id
assert response.has_more is False
async def test_list_batches_multiple_batches(self, provider):
"""Test listing multiple batches."""
batches = [
await provider.create_batch(
input_file_id=f"file_{i}", endpoint="/v1/chat/completions", completion_window="24h"
)
for i in range(3)
]
response = await provider.list_batches()
assert len(response.data) == 3
batch_ids = {batch.id for batch in response.data}
expected_ids = {batch.id for batch in batches}
assert batch_ids == expected_ids
assert response.has_more is False
assert response.first_id in expected_ids
assert response.last_id in expected_ids
async def test_list_batches_with_limit(self, provider):
"""Test listing batches with limit parameter."""
batches = [
await provider.create_batch(
input_file_id=f"file_{i}", endpoint="/v1/chat/completions", completion_window="24h"
)
for i in range(3)
]
response = await provider.list_batches(limit=2)
assert len(response.data) == 2
assert response.has_more is True
assert response.first_id == response.data[0].id
assert response.last_id == response.data[1].id
batch_ids = {batch.id for batch in response.data}
expected_ids = {batch.id for batch in batches}
assert batch_ids.issubset(expected_ids)
async def test_list_batches_with_pagination(self, provider):
"""Test listing batches with pagination using 'after' parameter."""
for i in range(3):
await provider.create_batch(
input_file_id=f"file_{i}", endpoint="/v1/chat/completions", completion_window="24h"
)
# Get first page
first_page = await provider.list_batches(limit=1)
assert len(first_page.data) == 1
assert first_page.has_more is True
# Get second page using 'after'
second_page = await provider.list_batches(limit=1, after=first_page.data[0].id)
assert len(second_page.data) == 1
assert second_page.data[0].id != first_page.data[0].id
# Verify we got the next batch in order
all_batches = await provider.list_batches()
expected_second_batch_id = all_batches.data[1].id
assert second_page.data[0].id == expected_second_batch_id
async def test_list_batches_invalid_after(self, provider, sample_batch_data):
"""Test listing batches with invalid 'after' parameter."""
await provider.create_batch(**sample_batch_data)
response = await provider.list_batches(after="nonexistent_batch")
# Should return all batches (no filtering when 'after' batch not found)
assert len(response.data) == 1
async def test_kvstore_persistence(self, provider, sample_batch_data):
"""Test that batches are properly persisted in kvstore."""
batch = await provider.create_batch(**sample_batch_data)
stored_data = await provider.kvstore.get(f"batch:{batch.id}")
assert stored_data is not None
stored_batch_dict = json.loads(stored_data)
assert stored_batch_dict["id"] == batch.id
assert stored_batch_dict["input_file_id"] == sample_batch_data["input_file_id"]
async def test_validate_input_file_not_found(self, provider):
"""Test _validate_input when input file does not exist."""
provider.files_api.openai_retrieve_file = AsyncMock(side_effect=Exception("File not found"))
batch = BatchObject(
id="batch_test",
object="batch",
endpoint="/v1/chat/completions",
input_file_id="nonexistent_file",
completion_window="24h",
status="validating",
created_at=1234567890,
)
errors, requests = await provider._validate_input(batch)
assert len(errors) == 1
assert len(requests) == 0
assert errors[0].code == "invalid_request"
assert errors[0].message == "Cannot find file nonexistent_file."
assert errors[0].param == "input_file_id"
assert errors[0].line is None
async def test_validate_input_file_exists_empty_content(self, provider):
"""Test _validate_input when file exists but is empty."""
provider.files_api.openai_retrieve_file = AsyncMock()
mock_response = MagicMock()
mock_response.body = b""
provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
batch = BatchObject(
id="batch_test",
object="batch",
endpoint="/v1/chat/completions",
input_file_id="empty_file",
completion_window="24h",
status="validating",
created_at=1234567890,
)
errors, requests = await provider._validate_input(batch)
assert len(errors) == 0
assert len(requests) == 0
async def test_validate_input_file_mixed_valid_invalid_json(self, provider):
"""Test _validate_input when file contains valid and invalid JSON lines."""
provider.files_api.openai_retrieve_file = AsyncMock()
mock_response = MagicMock()
# Line 1: valid JSON with proper body args, Line 2: invalid JSON
mock_response.body = b'{"custom_id": "req-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "test-model", "messages": [{"role": "user", "content": "Hello"}]}}\n{invalid json'
provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
batch = BatchObject(
id="batch_test",
object="batch",
endpoint="/v1/chat/completions",
input_file_id="mixed_file",
completion_window="24h",
status="validating",
created_at=1234567890,
)
errors, requests = await provider._validate_input(batch)
# Should have 1 JSON parsing error from line 2, and 1 valid request from line 1
assert len(errors) == 1
assert len(requests) == 1
assert errors[0].code == "invalid_json_line"
assert errors[0].line == 2
assert errors[0].message == "This line is not parseable as valid JSON."
assert requests[0].custom_id == "req-1"
assert requests[0].method == "POST"
assert requests[0].url == "/v1/chat/completions"
assert requests[0].body["model"] == "test-model"
assert requests[0].body["messages"] == [{"role": "user", "content": "Hello"}]
async def test_validate_input_invalid_model(self, provider):
"""Test _validate_input when file contains request with non-existent model."""
provider.files_api.openai_retrieve_file = AsyncMock()
mock_response = MagicMock()
mock_response.body = b'{"custom_id": "req-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "nonexistent-model", "messages": [{"role": "user", "content": "Hello"}]}}'
provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
provider.models_api.get_model = AsyncMock(side_effect=Exception("Model not found"))
batch = BatchObject(
id="batch_test",
object="batch",
endpoint="/v1/chat/completions",
input_file_id="invalid_model_file",
completion_window="24h",
status="validating",
created_at=1234567890,
)
errors, requests = await provider._validate_input(batch)
assert len(errors) == 1
assert len(requests) == 0
assert errors[0].code == "model_not_found"
assert errors[0].line == 1
assert errors[0].message == "Model 'nonexistent-model' does not exist or is not supported"
assert errors[0].param == "body.model"
@pytest.mark.parametrize(
"param_name,param_path,error_code,error_message",
[
("custom_id", "custom_id", "missing_required_parameter", "Missing required parameter: custom_id"),
("method", "method", "missing_required_parameter", "Missing required parameter: method"),
("url", "url", "missing_required_parameter", "Missing required parameter: url"),
("body", "body", "missing_required_parameter", "Missing required parameter: body"),
("model", "body.model", "invalid_request", "Model parameter is required"),
("messages", "body.messages", "invalid_request", "Messages parameter is required"),
],
)
async def test_validate_input_missing_parameters(self, provider, param_name, param_path, error_code, error_message):
"""Test _validate_input when file contains request with missing required parameters."""
provider.files_api.openai_retrieve_file = AsyncMock()
mock_response = MagicMock()
base_request = {
"custom_id": "req-1",
"method": "POST",
"url": "/v1/chat/completions",
"body": {"model": "test-model", "messages": [{"role": "user", "content": "Hello"}]},
}
# Remove the specific parameter being tested
if "." in param_path:
top_level, nested_param = param_path.split(".", 1)
del base_request[top_level][nested_param]
else:
del base_request[param_name]
mock_response.body = json.dumps(base_request).encode()
provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
batch = BatchObject(
id="batch_test",
object="batch",
endpoint="/v1/chat/completions",
input_file_id=f"missing_{param_name}_file",
completion_window="24h",
status="validating",
created_at=1234567890,
)
errors, requests = await provider._validate_input(batch)
assert len(errors) == 1
assert len(requests) == 0
assert errors[0].code == error_code
assert errors[0].line == 1
assert errors[0].message == error_message
assert errors[0].param == param_path
async def test_validate_input_url_mismatch(self, provider):
"""Test _validate_input when file contains request with URL that doesn't match batch endpoint."""
provider.files_api.openai_retrieve_file = AsyncMock()
mock_response = MagicMock()
mock_response.body = b'{"custom_id": "req-1", "method": "POST", "url": "/v1/embeddings", "body": {"model": "test-model", "messages": [{"role": "user", "content": "Hello"}]}}'
provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
batch = BatchObject(
id="batch_test",
object="batch",
endpoint="/v1/chat/completions", # This doesn't match the URL in the request
input_file_id="url_mismatch_file",
completion_window="24h",
status="validating",
created_at=1234567890,
)
errors, requests = await provider._validate_input(batch)
assert len(errors) == 1
assert len(requests) == 0
assert errors[0].code == "invalid_url"
assert errors[0].line == 1
assert errors[0].message == "URL provided for this request does not match the batch endpoint"
assert errors[0].param == "url"
async def test_validate_input_multiple_errors_per_request(self, provider):
"""Test _validate_input when a single request has multiple validation errors."""
provider.files_api.openai_retrieve_file = AsyncMock()
mock_response = MagicMock()
# Request missing custom_id, has invalid URL, and missing model in body
mock_response.body = (
b'{"method": "POST", "url": "/v1/embeddings", "body": {"messages": [{"role": "user", "content": "Hello"}]}}'
)
provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
batch = BatchObject(
id="batch_test",
object="batch",
endpoint="/v1/chat/completions", # Doesn't match /v1/embeddings in request
input_file_id="multiple_errors_file",
completion_window="24h",
status="validating",
created_at=1234567890,
)
errors, requests = await provider._validate_input(batch)
assert len(errors) >= 2 # At least missing custom_id and URL mismatch
assert len(requests) == 0
for error in errors:
assert error.line == 1
error_codes = {error.code for error in errors}
assert "missing_required_parameter" in error_codes # missing custom_id
assert "invalid_url" in error_codes # URL mismatch
async def test_validate_input_invalid_request_format(self, provider):
"""Test _validate_input when file contains non-object JSON (array, string, number)."""
provider.files_api.openai_retrieve_file = AsyncMock()
mock_response = MagicMock()
mock_response.body = b'["not", "a", "request", "object"]'
provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
batch = BatchObject(
id="batch_test",
object="batch",
endpoint="/v1/chat/completions",
input_file_id="invalid_format_file",
completion_window="24h",
status="validating",
created_at=1234567890,
)
errors, requests = await provider._validate_input(batch)
assert len(errors) == 1
assert len(requests) == 0
assert errors[0].code == "invalid_request"
assert errors[0].line == 1
assert errors[0].message == "Each line must be a JSON dictionary object"
@pytest.mark.parametrize(
"param_name,param_path,invalid_value,error_message",
[
("custom_id", "custom_id", 12345, "Custom_id must be a string"),
("url", "url", 123, "URL must be a string"),
("method", "method", ["POST"], "Method must be a string"),
("body", "body", ["not", "valid"], "Body must be a JSON dictionary object"),
("model", "body.model", 123, "Model must be a string"),
("messages", "body.messages", "invalid messages format", "Messages must be an array"),
],
)
async def test_validate_input_invalid_parameter_types(
self, provider, param_name, param_path, invalid_value, error_message
):
"""Test _validate_input when file contains request with parameters that have invalid types."""
provider.files_api.openai_retrieve_file = AsyncMock()
mock_response = MagicMock()
base_request = {
"custom_id": "req-1",
"method": "POST",
"url": "/v1/chat/completions",
"body": {"model": "test-model", "messages": [{"role": "user", "content": "Hello"}]},
}
# Override the specific parameter with invalid value
if "." in param_path:
top_level, nested_param = param_path.split(".", 1)
base_request[top_level][nested_param] = invalid_value
else:
base_request[param_name] = invalid_value
mock_response.body = json.dumps(base_request).encode()
provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
batch = BatchObject(
id="batch_test",
object="batch",
endpoint="/v1/chat/completions",
input_file_id=f"invalid_{param_name}_type_file",
completion_window="24h",
status="validating",
created_at=1234567890,
)
errors, requests = await provider._validate_input(batch)
assert len(errors) == 1
assert len(requests) == 0
assert errors[0].code == "invalid_request"
assert errors[0].line == 1
assert errors[0].message == error_message
assert errors[0].param == param_path
async def test_max_concurrent_batches(self, provider):
"""Test max_concurrent_batches configuration and concurrency control."""
import asyncio
provider._batch_semaphore = asyncio.Semaphore(2)
provider.process_batches = True # enable because we're testing background processing
active_batches = 0
async def add_and_wait(batch_id: str):
nonlocal active_batches
active_batches += 1
await asyncio.sleep(float("inf"))
# the first thing done in _process_batch is to acquire the semaphore, then call _process_batch_impl,
# so we can replace _process_batch_impl with our mock to control concurrency
provider._process_batch_impl = add_and_wait
for _ in range(3):
await provider.create_batch(
input_file_id="file_id", endpoint="/v1/chat/completions", completion_window="24h"
)
await asyncio.sleep(0.042) # let tasks start
assert active_batches == 2, f"Expected 2 active batches, got {active_batches}"

View file

@ -24,6 +24,7 @@ from llama_stack.apis.inference import (
from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall
from llama_stack.providers.utils.inference.openai_compat import (
convert_message_to_openai_dict,
convert_message_to_openai_dict_new,
openai_messages_to_messages,
)
@ -182,3 +183,42 @@ def test_user_message_accepts_images():
assert len(msg.content) == 2
assert msg.content[0].text == "Describe this image:"
assert msg.content[1].image_url.url == "http://example.com/image.jpg"
async def test_convert_message_to_openai_dict_new_user_message():
"""Test convert_message_to_openai_dict_new with UserMessage."""
message = UserMessage(content="Hello, world!", role="user")
result = await convert_message_to_openai_dict_new(message)
assert result["role"] == "user"
assert result["content"] == "Hello, world!"
async def test_convert_message_to_openai_dict_new_completion_message_with_tool_calls():
"""Test convert_message_to_openai_dict_new with CompletionMessage containing tool calls."""
message = CompletionMessage(
content="I'll help you find the weather.",
tool_calls=[
ToolCall(
call_id="call_123",
tool_name="get_weather",
arguments={"city": "Sligo"},
arguments_json='{"city": "Sligo"}',
)
],
stop_reason=StopReason.end_of_turn,
)
result = await convert_message_to_openai_dict_new(message)
# This would have failed with "Cannot instantiate typing.Union" before the fix
assert result["role"] == "assistant"
assert result["content"] == "I'll help you find the weather."
assert "tool_calls" in result
assert result["tool_calls"] is not None
assert len(result["tool_calls"]) == 1
tool_call = result["tool_calls"][0]
assert tool_call.id == "call_123"
assert tool_call.type == "function"
assert tool_call.function.name == "get_weather"
assert tool_call.function.arguments == '{"city": "Sligo"}'