mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-11 13:44:38 +00:00
Merge branch 'main' into content-extension
This commit is contained in:
commit
3e11e1472c
334 changed files with 22841 additions and 8940 deletions
120
tests/README.md
120
tests/README.md
|
@ -1,9 +1,115 @@
|
|||
# Llama Stack Tests
|
||||
There are two obvious types of tests:
|
||||
|
||||
Llama Stack has multiple layers of testing done to ensure continuous functionality and prevent regressions to the codebase.
|
||||
| Type | Location | Purpose |
|
||||
|------|----------|---------|
|
||||
| **Unit** | [`tests/unit/`](unit/README.md) | Fast, isolated component testing |
|
||||
| **Integration** | [`tests/integration/`](integration/README.md) | End-to-end workflows with record-replay |
|
||||
|
||||
| Testing Type | Details |
|
||||
|--------------|---------|
|
||||
| Unit | [unit/README.md](unit/README.md) |
|
||||
| Integration | [integration/README.md](integration/README.md) |
|
||||
| Verification | [verifications/README.md](verifications/README.md) |
|
||||
Both have their place. For unit tests, it is important to create minimal mocks and instead rely more on "fakes". Mocks are too brittle. In either case, tests must be very fast and reliable.
|
||||
|
||||
### Record-replay for integration tests
|
||||
|
||||
Testing AI applications end-to-end creates some challenges:
|
||||
- **API costs** accumulate quickly during development and CI
|
||||
- **Non-deterministic responses** make tests unreliable
|
||||
- **Multiple providers** require testing the same logic across different APIs
|
||||
|
||||
Our solution: **Record real API responses once, replay them for fast, deterministic tests.** This is better than mocking because AI APIs have complex response structures and streaming behavior. Mocks can miss edge cases that real APIs exhibit. A single test can exercise underlying APIs in multiple complex ways making it really hard to mock.
|
||||
|
||||
This gives you:
|
||||
- Cost control - No repeated API calls during development
|
||||
- Speed - Instant test execution with cached responses
|
||||
- Reliability - Consistent results regardless of external service state
|
||||
- Provider coverage - Same tests work across OpenAI, Anthropic, local models, etc.
|
||||
|
||||
### Testing Quick Start
|
||||
|
||||
You can run the unit tests with:
|
||||
```bash
|
||||
uv run --group unit pytest -sv tests/unit/
|
||||
```
|
||||
|
||||
For running integration tests, you must provide a few things:
|
||||
|
||||
- A stack config. This is a pointer to a stack. You have a few ways to point to a stack:
|
||||
- **`server:<config>`** - automatically start a server with the given config (e.g., `server:starter`). This provides one-step testing by auto-starting the server if the port is available, or reusing an existing server if already running.
|
||||
- **`server:<config>:<port>`** - same as above but with a custom port (e.g., `server:starter:8322`)
|
||||
- a URL which points to a Llama Stack distribution server
|
||||
- a distribution name (e.g., `starter`) or a path to a `run.yaml` file
|
||||
- a comma-separated list of api=provider pairs, e.g. `inference=fireworks,safety=llama-guard,agents=meta-reference`. This is most useful for testing a single API surface.
|
||||
|
||||
- Whether you are using replay or live mode for inference. This is specified with the LLAMA_STACK_TEST_INFERENCE_MODE environment variable. The default mode currently is "live" -- that is certainly surprising, but we will fix this soon.
|
||||
|
||||
- Any API keys you need to use should be set in the environment, or can be passed in with the --env option.
|
||||
|
||||
You can run the integration tests in replay mode with:
|
||||
```bash
|
||||
# Run all tests with existing recordings
|
||||
LLAMA_STACK_TEST_INFERENCE_MODE=replay \
|
||||
LLAMA_STACK_TEST_RECORDING_DIR=tests/integration/recordings \
|
||||
uv run --group test \
|
||||
pytest -sv tests/integration/ --stack-config=starter
|
||||
```
|
||||
|
||||
If you don't specify LLAMA_STACK_TEST_INFERENCE_MODE, by default it will be in "live" mode -- that is, it will make real API calls.
|
||||
|
||||
```bash
|
||||
# Test against live APIs
|
||||
FIREWORKS_API_KEY=your_key pytest -sv tests/integration/inference --stack-config=starter
|
||||
```
|
||||
|
||||
### Re-recording tests
|
||||
|
||||
#### Local Re-recording (Manual Setup Required)
|
||||
|
||||
If you want to re-record tests locally, you can do so with:
|
||||
|
||||
```bash
|
||||
LLAMA_STACK_TEST_INFERENCE_MODE=record \
|
||||
LLAMA_STACK_TEST_RECORDING_DIR=tests/integration/recordings \
|
||||
uv run --group test \
|
||||
pytest -sv tests/integration/ --stack-config=starter -k "<appropriate test name>"
|
||||
```
|
||||
|
||||
This will record new API responses and overwrite the existing recordings.
|
||||
|
||||
```{warning}
|
||||
|
||||
You must be careful when re-recording. CI workflows assume a specific setup for running the replay-mode tests. You must re-record the tests in the same way as the CI workflows. This means
|
||||
- you need Ollama running and serving some specific models.
|
||||
- you are using the `starter` distribution.
|
||||
```
|
||||
|
||||
#### Remote Re-recording (Recommended)
|
||||
|
||||
**For easier re-recording without local setup**, use the automated recording workflow:
|
||||
|
||||
```bash
|
||||
# Record tests for specific test subdirectories
|
||||
./scripts/github/schedule-record-workflow.sh --test-subdirs "agents,inference"
|
||||
|
||||
# Record with vision tests enabled
|
||||
./scripts/github/schedule-record-workflow.sh --test-subdirs "inference" --run-vision-tests
|
||||
|
||||
# Record with specific provider
|
||||
./scripts/github/schedule-record-workflow.sh --test-subdirs "agents" --test-provider vllm
|
||||
```
|
||||
|
||||
This script:
|
||||
- 🚀 **Runs in GitHub Actions** - no local Ollama setup required
|
||||
- 🔍 **Auto-detects your branch** and associated PR
|
||||
- 🍴 **Works from forks** - handles repository context automatically
|
||||
- ✅ **Commits recordings back** to your branch
|
||||
|
||||
**Prerequisites:**
|
||||
- GitHub CLI: `brew install gh && gh auth login`
|
||||
- jq: `brew install jq`
|
||||
- Your branch pushed to a remote
|
||||
|
||||
**Supported providers:** `vllm`, `ollama`
|
||||
|
||||
|
||||
### Next Steps
|
||||
|
||||
- [Integration Testing Guide](integration/README.md) - Detailed usage and configuration
|
||||
- [Unit Testing Guide](unit/README.md) - Fast component testing
|
||||
|
|
|
@ -1,6 +1,20 @@
|
|||
# Llama Stack Integration Tests
|
||||
# Integration Testing Guide
|
||||
|
||||
We use `pytest` for parameterizing and running tests. You can see all options with:
|
||||
Integration tests verify complete workflows across different providers using Llama Stack's record-replay system.
|
||||
|
||||
## Quick Start
|
||||
|
||||
```bash
|
||||
# Run all integration tests with existing recordings
|
||||
LLAMA_STACK_TEST_INFERENCE_MODE=replay \
|
||||
LLAMA_STACK_TEST_RECORDING_DIR=tests/integration/recordings \
|
||||
uv run --group test \
|
||||
pytest -sv tests/integration/ --stack-config=starter
|
||||
```
|
||||
|
||||
## Configuration Options
|
||||
|
||||
You can see all options with:
|
||||
```bash
|
||||
cd tests/integration
|
||||
|
||||
|
@ -10,11 +24,11 @@ pytest --help
|
|||
|
||||
Here are the most important options:
|
||||
- `--stack-config`: specify the stack config to use. You have four ways to point to a stack:
|
||||
- **`server:<config>`** - automatically start a server with the given config (e.g., `server:fireworks`). This provides one-step testing by auto-starting the server if the port is available, or reusing an existing server if already running.
|
||||
- **`server:<config>:<port>`** - same as above but with a custom port (e.g., `server:together:8322`)
|
||||
- **`server:<config>`** - automatically start a server with the given config (e.g., `server:starter`). This provides one-step testing by auto-starting the server if the port is available, or reusing an existing server if already running.
|
||||
- **`server:<config>:<port>`** - same as above but with a custom port (e.g., `server:starter:8322`)
|
||||
- a URL which points to a Llama Stack distribution server
|
||||
- a template (e.g., `starter`) or a path to a `run.yaml` file
|
||||
- a comma-separated list of api=provider pairs, e.g. `inference=fireworks,safety=llama-guard,agents=meta-reference`. This is most useful for testing a single API surface.
|
||||
- a distribution name (e.g., `starter`) or a path to a `run.yaml` file
|
||||
- a comma-separated list of api=provider pairs, e.g. `inference=ollama,safety=llama-guard,agents=meta-reference`. This is most useful for testing a single API surface.
|
||||
- `--env`: set environment variables, e.g. --env KEY=value. this is a utility option to set environment variables required by various providers.
|
||||
|
||||
Model parameters can be influenced by the following options:
|
||||
|
@ -32,83 +46,139 @@ if no model is specified.
|
|||
|
||||
### Testing against a Server
|
||||
|
||||
Run all text inference tests by auto-starting a server with the `fireworks` config:
|
||||
Run all text inference tests by auto-starting a server with the `starter` config:
|
||||
|
||||
```bash
|
||||
pytest -s -v tests/integration/inference/test_text_inference.py \
|
||||
--stack-config=server:fireworks \
|
||||
--text-model=meta-llama/Llama-3.1-8B-Instruct
|
||||
OLLAMA_URL=http://localhost:11434 \
|
||||
pytest -s -v tests/integration/inference/test_text_inference.py \
|
||||
--stack-config=server:starter \
|
||||
--text-model=ollama/llama3.2:3b-instruct-fp16 \
|
||||
--embedding-model=sentence-transformers/all-MiniLM-L6-v2
|
||||
```
|
||||
|
||||
Run tests with auto-server startup on a custom port:
|
||||
|
||||
```bash
|
||||
pytest -s -v tests/integration/inference/ \
|
||||
--stack-config=server:together:8322 \
|
||||
--text-model=meta-llama/Llama-3.1-8B-Instruct
|
||||
```
|
||||
|
||||
Run multiple test suites with auto-server (eliminates manual server management):
|
||||
|
||||
```bash
|
||||
# Auto-start server and run all integration tests
|
||||
export FIREWORKS_API_KEY=<your_key>
|
||||
|
||||
pytest -s -v tests/integration/inference/ tests/integration/safety/ tests/integration/agents/ \
|
||||
--stack-config=server:fireworks \
|
||||
--text-model=meta-llama/Llama-3.1-8B-Instruct
|
||||
OLLAMA_URL=http://localhost:11434 \
|
||||
pytest -s -v tests/integration/inference/ \
|
||||
--stack-config=server:starter:8322 \
|
||||
--text-model=ollama/llama3.2:3b-instruct-fp16 \
|
||||
--embedding-model=sentence-transformers/all-MiniLM-L6-v2
|
||||
```
|
||||
|
||||
### Testing with Library Client
|
||||
|
||||
Run all text inference tests with the `starter` distribution using the `together` provider:
|
||||
The library client constructs the Stack "in-process" instead of using a server. This is useful during the iterative development process since you don't need to constantly start and stop servers.
|
||||
|
||||
|
||||
You can do this by simply using `--stack-config=starter` instead of `--stack-config=server:starter`.
|
||||
|
||||
|
||||
### Using ad-hoc distributions
|
||||
|
||||
Sometimes, you may want to make up a distribution on the fly. This is useful for testing a single provider or a single API or a small combination of providers. You can do so by specifying a comma-separated list of api=provider pairs to the `--stack-config` option, e.g. `inference=remote::ollama,safety=inline::llama-guard,agents=inline::meta-reference`.
|
||||
|
||||
```bash
|
||||
ENABLE_TOGETHER=together pytest -s -v tests/integration/inference/test_text_inference.py \
|
||||
--stack-config=starter \
|
||||
--text-model=meta-llama/Llama-3.1-8B-Instruct
|
||||
```
|
||||
|
||||
Run all text inference tests with the `starter` distribution using the `together` provider and `meta-llama/Llama-3.1-8B-Instruct`:
|
||||
|
||||
```bash
|
||||
ENABLE_TOGETHER=together pytest -s -v tests/integration/inference/test_text_inference.py \
|
||||
--stack-config=starter \
|
||||
--text-model=meta-llama/Llama-3.1-8B-Instruct
|
||||
```
|
||||
|
||||
Running all inference tests for a number of models using the `together` provider:
|
||||
|
||||
```bash
|
||||
TEXT_MODELS=meta-llama/Llama-3.1-8B-Instruct,meta-llama/Llama-3.1-70B-Instruct
|
||||
VISION_MODELS=meta-llama/Llama-3.2-11B-Vision-Instruct
|
||||
EMBEDDING_MODELS=all-MiniLM-L6-v2
|
||||
ENABLE_TOGETHER=together
|
||||
export TOGETHER_API_KEY=<together_api_key>
|
||||
|
||||
pytest -s -v tests/integration/inference/ \
|
||||
--stack-config=together \
|
||||
--stack-config=inference=remote::ollama,safety=inline::llama-guard,agents=inline::meta-reference \
|
||||
--text-model=$TEXT_MODELS \
|
||||
--vision-model=$VISION_MODELS \
|
||||
--embedding-model=$EMBEDDING_MODELS
|
||||
```
|
||||
|
||||
Same thing but instead of using the distribution, use an adhoc stack with just one provider (`fireworks` for inference):
|
||||
|
||||
```bash
|
||||
export FIREWORKS_API_KEY=<fireworks_api_key>
|
||||
|
||||
pytest -s -v tests/integration/inference/ \
|
||||
--stack-config=inference=fireworks \
|
||||
--text-model=$TEXT_MODELS \
|
||||
--vision-model=$VISION_MODELS \
|
||||
--embedding-model=$EMBEDDING_MODELS
|
||||
```
|
||||
|
||||
Running Vector IO tests for a number of embedding models:
|
||||
Another example: Running Vector IO tests for embedding models:
|
||||
|
||||
```bash
|
||||
uv run pytest -sv --stack-config="inference=inline::sentence-transformers,vector_io=inline::sqlite-vec,files=localfs" \
|
||||
tests/integration/vector_io --embedding-model \
|
||||
sentence-transformers/all-MiniLM-L6-v2
|
||||
```
|
||||
|
||||
## Recording Modes
|
||||
|
||||
The testing system supports three modes controlled by environment variables:
|
||||
|
||||
### LIVE Mode (Default)
|
||||
Tests make real API calls:
|
||||
```bash
|
||||
LLAMA_STACK_TEST_INFERENCE_MODE=live pytest tests/integration/
|
||||
```
|
||||
|
||||
### RECORD Mode
|
||||
Captures API interactions for later replay:
|
||||
```bash
|
||||
LLAMA_STACK_TEST_INFERENCE_MODE=record \
|
||||
LLAMA_STACK_TEST_RECORDING_DIR=tests/integration/recordings \
|
||||
pytest tests/integration/inference/test_new_feature.py
|
||||
```
|
||||
|
||||
### REPLAY Mode
|
||||
Uses cached responses instead of making API calls:
|
||||
```bash
|
||||
LLAMA_STACK_TEST_INFERENCE_MODE=replay \
|
||||
LLAMA_STACK_TEST_RECORDING_DIR=tests/integration/recordings \
|
||||
pytest tests/integration/
|
||||
```
|
||||
|
||||
Note that right now you must specify the recording directory. This is because different tests use different recording directories and we don't (yet) have a fool-proof way to map a test to a recording directory. We are working on this.
|
||||
|
||||
## Managing Recordings
|
||||
|
||||
### Viewing Recordings
|
||||
```bash
|
||||
# See what's recorded
|
||||
sqlite3 recordings/index.sqlite "SELECT endpoint, model, timestamp FROM recordings;"
|
||||
|
||||
# Inspect specific response
|
||||
cat recordings/responses/abc123.json | jq '.'
|
||||
```
|
||||
|
||||
### Re-recording Tests
|
||||
|
||||
#### Remote Re-recording (Recommended)
|
||||
Use the automated workflow script for easier re-recording:
|
||||
```bash
|
||||
./scripts/github/schedule-record-workflow.sh --test-subdirs "inference,agents"
|
||||
```
|
||||
See the [main testing guide](../README.md#remote-re-recording-recommended) for full details.
|
||||
|
||||
#### Local Re-recording
|
||||
```bash
|
||||
# Re-record specific tests
|
||||
LLAMA_STACK_TEST_INFERENCE_MODE=record \
|
||||
LLAMA_STACK_TEST_RECORDING_DIR=tests/integration/recordings \
|
||||
pytest -s -v --stack-config=server:starter tests/integration/inference/test_modified.py
|
||||
```
|
||||
|
||||
Note that when re-recording tests, you must use a Stack pointing to a server (i.e., `server:starter`). This subtlety exists because the set of tests run in server are a superset of the set of tests run in the library client.
|
||||
|
||||
## Writing Tests
|
||||
|
||||
### Basic Test Pattern
|
||||
```python
|
||||
def test_basic_completion(llama_stack_client, text_model_id):
|
||||
response = llama_stack_client.inference.completion(
|
||||
model_id=text_model_id,
|
||||
content=CompletionMessage(role="user", content="Hello"),
|
||||
)
|
||||
|
||||
# Test structure, not AI output quality
|
||||
assert response.completion_message is not None
|
||||
assert isinstance(response.completion_message.content, str)
|
||||
assert len(response.completion_message.content) > 0
|
||||
```
|
||||
|
||||
### Provider-Specific Tests
|
||||
```python
|
||||
def test_asymmetric_embeddings(llama_stack_client, embedding_model_id):
|
||||
if embedding_model_id not in MODELS_SUPPORTING_TASK_TYPE:
|
||||
pytest.skip(f"Model {embedding_model_id} doesn't support task types")
|
||||
|
||||
query_response = llama_stack_client.inference.embeddings(
|
||||
model_id=embedding_model_id,
|
||||
contents=["What is machine learning?"],
|
||||
task_type="query",
|
||||
)
|
||||
|
||||
assert query_response.embeddings is not None
|
||||
```
|
||||
|
|
|
@ -133,24 +133,15 @@ def test_agent_simple(llama_stack_client, agent_config):
|
|||
assert "I can't" in logs_str
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="this test was disabled for a long time, and now has turned flaky")
|
||||
def test_agent_name(llama_stack_client, text_model_id):
|
||||
agent_name = f"test-agent-{uuid4()}"
|
||||
|
||||
try:
|
||||
agent = Agent(
|
||||
llama_stack_client,
|
||||
model=text_model_id,
|
||||
instructions="You are a helpful assistant",
|
||||
name=agent_name,
|
||||
)
|
||||
except TypeError:
|
||||
agent = Agent(
|
||||
llama_stack_client,
|
||||
model=text_model_id,
|
||||
instructions="You are a helpful assistant",
|
||||
)
|
||||
return
|
||||
|
||||
agent = Agent(
|
||||
llama_stack_client,
|
||||
model=text_model_id,
|
||||
instructions="You are a helpful assistant",
|
||||
name=agent_name,
|
||||
)
|
||||
session_id = agent.create_session(f"test-session-{uuid4()}")
|
||||
|
||||
agent.create_turn(
|
||||
|
|
5
tests/integration/batches/__init__.py
Normal file
5
tests/integration/batches/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
122
tests/integration/batches/conftest.py
Normal file
122
tests/integration/batches/conftest.py
Normal file
|
@ -0,0 +1,122 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""Shared pytest fixtures for batch tests."""
|
||||
|
||||
import json
|
||||
import time
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from io import BytesIO
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.files import OpenAIFilePurpose
|
||||
|
||||
|
||||
class BatchHelper:
|
||||
"""Helper class for creating and managing batch input files."""
|
||||
|
||||
def __init__(self, client):
|
||||
"""Initialize with either a batch_client or openai_client."""
|
||||
self.client = client
|
||||
|
||||
@contextmanager
|
||||
def create_file(self, content: str | list[dict], filename_prefix="batch_input"):
|
||||
"""Context manager for creating and cleaning up batch input files.
|
||||
|
||||
Args:
|
||||
content: Either a list of batch request dictionaries or raw string content
|
||||
filename_prefix: Prefix for the generated filename (or full filename if content is string)
|
||||
|
||||
Yields:
|
||||
The uploaded file object
|
||||
"""
|
||||
if isinstance(content, str):
|
||||
# Handle raw string content (e.g., malformed JSONL, empty files)
|
||||
file_content = content.encode("utf-8")
|
||||
else:
|
||||
# Handle list of batch request dictionaries
|
||||
jsonl_content = "\n".join(json.dumps(req) for req in content)
|
||||
file_content = jsonl_content.encode("utf-8")
|
||||
|
||||
filename = filename_prefix if filename_prefix.endswith(".jsonl") else f"{filename_prefix}.jsonl"
|
||||
|
||||
with BytesIO(file_content) as file_buffer:
|
||||
file_buffer.name = filename
|
||||
uploaded_file = self.client.files.create(file=file_buffer, purpose=OpenAIFilePurpose.BATCH)
|
||||
|
||||
try:
|
||||
yield uploaded_file
|
||||
finally:
|
||||
try:
|
||||
self.client.files.delete(uploaded_file.id)
|
||||
except Exception:
|
||||
warnings.warn(
|
||||
f"Failed to cleanup file {uploaded_file.id}: {uploaded_file.filename}",
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
def wait_for(
|
||||
self,
|
||||
batch_id: str,
|
||||
max_wait_time: int = 60,
|
||||
sleep_interval: int | None = None,
|
||||
expected_statuses: set[str] | None = None,
|
||||
timeout_action: str = "fail",
|
||||
):
|
||||
"""Wait for a batch to reach a terminal status.
|
||||
|
||||
Args:
|
||||
batch_id: The batch ID to monitor
|
||||
max_wait_time: Maximum time to wait in seconds (default: 60 seconds)
|
||||
sleep_interval: Time to sleep between checks in seconds (default: 1/10th of max_wait_time, min 1s, max 15s)
|
||||
expected_statuses: Set of expected terminal statuses (default: {"completed"})
|
||||
timeout_action: Action on timeout - "fail" (pytest.fail) or "skip" (pytest.skip)
|
||||
|
||||
Returns:
|
||||
The final batch object
|
||||
|
||||
Raises:
|
||||
pytest.Failed: If batch reaches an unexpected status or timeout_action is "fail"
|
||||
pytest.Skipped: If timeout_action is "skip" on timeout or unexpected status
|
||||
"""
|
||||
if sleep_interval is None:
|
||||
# Default to 1/10th of max_wait_time, with min 1s and max 15s
|
||||
sleep_interval = max(1, min(15, max_wait_time // 10))
|
||||
|
||||
if expected_statuses is None:
|
||||
expected_statuses = {"completed"}
|
||||
|
||||
terminal_statuses = {"completed", "failed", "cancelled", "expired"}
|
||||
unexpected_statuses = terminal_statuses - expected_statuses
|
||||
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < max_wait_time:
|
||||
current_batch = self.client.batches.retrieve(batch_id)
|
||||
|
||||
if current_batch.status in expected_statuses:
|
||||
return current_batch
|
||||
elif current_batch.status in unexpected_statuses:
|
||||
error_msg = f"Batch reached unexpected status: {current_batch.status}"
|
||||
if timeout_action == "skip":
|
||||
pytest.skip(error_msg)
|
||||
else:
|
||||
pytest.fail(error_msg)
|
||||
|
||||
time.sleep(sleep_interval)
|
||||
|
||||
timeout_msg = f"Batch did not reach expected status {expected_statuses} within {max_wait_time} seconds"
|
||||
if timeout_action == "skip":
|
||||
pytest.skip(timeout_msg)
|
||||
else:
|
||||
pytest.fail(timeout_msg)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def batch_helper(openai_client):
|
||||
"""Fixture that provides a BatchHelper instance for OpenAI client."""
|
||||
return BatchHelper(openai_client)
|
270
tests/integration/batches/test_batches.py
Normal file
270
tests/integration/batches/test_batches.py
Normal file
|
@ -0,0 +1,270 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Integration tests for the Llama Stack batch processing functionality.
|
||||
|
||||
This module contains comprehensive integration tests for the batch processing API,
|
||||
using the OpenAI-compatible client interface for consistency.
|
||||
|
||||
Test Categories:
|
||||
1. Core Batch Operations:
|
||||
- test_batch_creation_and_retrieval: Comprehensive batch creation, structure validation, and retrieval
|
||||
- test_batch_listing: Basic batch listing functionality
|
||||
- test_batch_immediate_cancellation: Batch cancellation workflow
|
||||
# TODO: cancel during processing
|
||||
|
||||
2. End-to-End Processing:
|
||||
- test_batch_e2e_chat_completions: Full chat completions workflow with output and error validation
|
||||
|
||||
Note: Error conditions and edge cases are primarily tested in test_batches_errors.py
|
||||
for better organization and separation of concerns.
|
||||
|
||||
CLEANUP WARNING: These tests currently create batches that are not automatically
|
||||
cleaned up after test completion. This may lead to resource accumulation over
|
||||
multiple test runs. Only test_batch_immediate_cancellation properly cancels its batch.
|
||||
The test_batch_e2e_chat_completions test does clean up its output and error files.
|
||||
"""
|
||||
|
||||
import json
|
||||
|
||||
|
||||
class TestBatchesIntegration:
|
||||
"""Integration tests for the batches API."""
|
||||
|
||||
def test_batch_creation_and_retrieval(self, openai_client, batch_helper, text_model_id):
|
||||
"""Test comprehensive batch creation and retrieval scenarios."""
|
||||
test_metadata = {
|
||||
"test_type": "comprehensive",
|
||||
"purpose": "creation_and_retrieval_test",
|
||||
"version": "1.0",
|
||||
"tags": "test,batch",
|
||||
}
|
||||
|
||||
batch_requests = [
|
||||
{
|
||||
"custom_id": "request-1",
|
||||
"method": "POST",
|
||||
"url": "/v1/chat/completions",
|
||||
"body": {
|
||||
"model": text_model_id,
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"max_tokens": 10,
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
with batch_helper.create_file(batch_requests, "batch_creation_test") as uploaded_file:
|
||||
batch = openai_client.batches.create(
|
||||
input_file_id=uploaded_file.id,
|
||||
endpoint="/v1/chat/completions",
|
||||
completion_window="24h",
|
||||
metadata=test_metadata,
|
||||
)
|
||||
|
||||
assert batch.endpoint == "/v1/chat/completions"
|
||||
assert batch.input_file_id == uploaded_file.id
|
||||
assert batch.completion_window == "24h"
|
||||
assert batch.metadata == test_metadata
|
||||
|
||||
retrieved_batch = openai_client.batches.retrieve(batch.id)
|
||||
|
||||
assert retrieved_batch.id == batch.id
|
||||
assert retrieved_batch.object == batch.object
|
||||
assert retrieved_batch.endpoint == batch.endpoint
|
||||
assert retrieved_batch.input_file_id == batch.input_file_id
|
||||
assert retrieved_batch.completion_window == batch.completion_window
|
||||
assert retrieved_batch.metadata == batch.metadata
|
||||
|
||||
def test_batch_listing(self, openai_client, batch_helper, text_model_id):
|
||||
"""
|
||||
Test batch listing.
|
||||
|
||||
This test creates multiple batches and verifies that they can be listed.
|
||||
It also deletes the input files before execution, which means the batches
|
||||
will appear as failed due to missing input files. This is expected and
|
||||
a good thing, because it means no inference is performed.
|
||||
"""
|
||||
batch_ids = []
|
||||
|
||||
for i in range(2):
|
||||
batch_requests = [
|
||||
{
|
||||
"custom_id": f"request-{i}",
|
||||
"method": "POST",
|
||||
"url": "/v1/chat/completions",
|
||||
"body": {
|
||||
"model": text_model_id,
|
||||
"messages": [{"role": "user", "content": f"Hello {i}"}],
|
||||
"max_tokens": 10,
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
with batch_helper.create_file(batch_requests, f"batch_input_{i}") as uploaded_file:
|
||||
batch = openai_client.batches.create(
|
||||
input_file_id=uploaded_file.id,
|
||||
endpoint="/v1/chat/completions",
|
||||
completion_window="24h",
|
||||
)
|
||||
batch_ids.append(batch.id)
|
||||
|
||||
batch_list = openai_client.batches.list()
|
||||
|
||||
assert isinstance(batch_list.data, list)
|
||||
|
||||
listed_batch_ids = {b.id for b in batch_list.data}
|
||||
for batch_id in batch_ids:
|
||||
assert batch_id in listed_batch_ids
|
||||
|
||||
def test_batch_immediate_cancellation(self, openai_client, batch_helper, text_model_id):
|
||||
"""Test immediate batch cancellation."""
|
||||
batch_requests = [
|
||||
{
|
||||
"custom_id": "request-1",
|
||||
"method": "POST",
|
||||
"url": "/v1/chat/completions",
|
||||
"body": {
|
||||
"model": text_model_id,
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"max_tokens": 10,
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
with batch_helper.create_file(batch_requests) as uploaded_file:
|
||||
batch = openai_client.batches.create(
|
||||
input_file_id=uploaded_file.id,
|
||||
endpoint="/v1/chat/completions",
|
||||
completion_window="24h",
|
||||
)
|
||||
|
||||
# hopefully cancel the batch before it completes
|
||||
cancelling_batch = openai_client.batches.cancel(batch.id)
|
||||
assert cancelling_batch.status in ["cancelling", "cancelled"]
|
||||
assert isinstance(cancelling_batch.cancelling_at, int), (
|
||||
f"cancelling_at should be int, got {type(cancelling_batch.cancelling_at)}"
|
||||
)
|
||||
|
||||
final_batch = batch_helper.wait_for(
|
||||
batch.id,
|
||||
max_wait_time=3 * 60, # often takes 10-11 minutes, give it 3 min
|
||||
expected_statuses={"cancelled"},
|
||||
timeout_action="skip",
|
||||
)
|
||||
|
||||
assert final_batch.status == "cancelled"
|
||||
assert isinstance(final_batch.cancelled_at, int), (
|
||||
f"cancelled_at should be int, got {type(final_batch.cancelled_at)}"
|
||||
)
|
||||
|
||||
def test_batch_e2e_chat_completions(self, openai_client, batch_helper, text_model_id):
|
||||
"""Test end-to-end batch processing for chat completions with both successful and failed operations."""
|
||||
batch_requests = [
|
||||
{
|
||||
"custom_id": "success-1",
|
||||
"method": "POST",
|
||||
"url": "/v1/chat/completions",
|
||||
"body": {
|
||||
"model": text_model_id,
|
||||
"messages": [{"role": "user", "content": "Say hello"}],
|
||||
"max_tokens": 20,
|
||||
},
|
||||
},
|
||||
{
|
||||
"custom_id": "error-1",
|
||||
"method": "POST",
|
||||
"url": "/v1/chat/completions",
|
||||
"body": {
|
||||
"model": text_model_id,
|
||||
"messages": [{"rolez": "user", "contentz": "This should fail"}], # Invalid keys to trigger error
|
||||
# note: ollama does not validate max_tokens values or the "role" key, so they won't trigger an error
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
with batch_helper.create_file(batch_requests) as uploaded_file:
|
||||
batch = openai_client.batches.create(
|
||||
input_file_id=uploaded_file.id,
|
||||
endpoint="/v1/chat/completions",
|
||||
completion_window="24h",
|
||||
metadata={"test": "e2e_success_and_errors_test"},
|
||||
)
|
||||
|
||||
final_batch = batch_helper.wait_for(
|
||||
batch.id,
|
||||
max_wait_time=3 * 60, # often takes 2-3 minutes
|
||||
expected_statuses={"completed"},
|
||||
timeout_action="skip",
|
||||
)
|
||||
|
||||
# Expecting a completed batch with both successful and failed requests
|
||||
# Batch(id='batch_xxx',
|
||||
# completion_window='24h',
|
||||
# created_at=...,
|
||||
# endpoint='/v1/chat/completions',
|
||||
# input_file_id='file-xxx',
|
||||
# object='batch',
|
||||
# status='completed',
|
||||
# output_file_id='file-xxx',
|
||||
# error_file_id='file-xxx',
|
||||
# request_counts=BatchRequestCounts(completed=1, failed=1, total=2))
|
||||
|
||||
assert final_batch.status == "completed"
|
||||
assert final_batch.request_counts is not None
|
||||
assert final_batch.request_counts.total == 2
|
||||
assert final_batch.request_counts.completed == 1
|
||||
assert final_batch.request_counts.failed == 1
|
||||
|
||||
assert final_batch.output_file_id is not None, "Output file should exist for successful requests"
|
||||
|
||||
output_content = openai_client.files.content(final_batch.output_file_id)
|
||||
if isinstance(output_content, str):
|
||||
output_text = output_content
|
||||
else:
|
||||
output_text = output_content.content.decode("utf-8")
|
||||
|
||||
output_lines = output_text.strip().split("\n")
|
||||
|
||||
for line in output_lines:
|
||||
result = json.loads(line)
|
||||
|
||||
assert "id" in result
|
||||
assert "custom_id" in result
|
||||
assert result["custom_id"] == "success-1"
|
||||
|
||||
assert "response" in result
|
||||
|
||||
assert result["response"]["status_code"] == 200
|
||||
assert "body" in result["response"]
|
||||
assert "choices" in result["response"]["body"]
|
||||
|
||||
assert final_batch.error_file_id is not None, "Error file should exist for failed requests"
|
||||
|
||||
error_content = openai_client.files.content(final_batch.error_file_id)
|
||||
if isinstance(error_content, str):
|
||||
error_text = error_content
|
||||
else:
|
||||
error_text = error_content.content.decode("utf-8")
|
||||
|
||||
error_lines = error_text.strip().split("\n")
|
||||
|
||||
for line in error_lines:
|
||||
result = json.loads(line)
|
||||
|
||||
assert "id" in result
|
||||
assert "custom_id" in result
|
||||
assert result["custom_id"] == "error-1"
|
||||
assert "error" in result
|
||||
error = result["error"]
|
||||
assert error is not None
|
||||
assert "code" in error or "message" in error, "Error should have code or message"
|
||||
|
||||
deleted_output_file = openai_client.files.delete(final_batch.output_file_id)
|
||||
assert deleted_output_file.deleted, f"Output file {final_batch.output_file_id} was not deleted successfully"
|
||||
|
||||
deleted_error_file = openai_client.files.delete(final_batch.error_file_id)
|
||||
assert deleted_error_file.deleted, f"Error file {final_batch.error_file_id} was not deleted successfully"
|
693
tests/integration/batches/test_batches_errors.py
Normal file
693
tests/integration/batches/test_batches_errors.py
Normal file
|
@ -0,0 +1,693 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Error handling and edge case tests for the Llama Stack batch processing functionality.
|
||||
|
||||
This module focuses exclusively on testing error conditions, validation failures,
|
||||
and edge cases for batch operations to ensure robust error handling and graceful
|
||||
degradation.
|
||||
|
||||
Test Categories:
|
||||
1. File and Input Validation:
|
||||
- test_batch_nonexistent_file_id: Handling invalid file IDs
|
||||
- test_batch_malformed_jsonl: Processing malformed JSONL input files
|
||||
- test_file_malformed_batch_file: Handling malformed files at upload time
|
||||
- test_batch_missing_required_fields: Validation of required request fields
|
||||
|
||||
2. API Endpoint and Model Validation:
|
||||
- test_batch_invalid_endpoint: Invalid endpoint handling during creation
|
||||
- test_batch_error_handling_invalid_model: Error handling with nonexistent models
|
||||
- test_batch_endpoint_mismatch: Validation of endpoint/URL consistency
|
||||
|
||||
3. Batch Lifecycle Error Handling:
|
||||
- test_batch_retrieve_nonexistent: Retrieving non-existent batches
|
||||
- test_batch_cancel_nonexistent: Cancelling non-existent batches
|
||||
- test_batch_cancel_completed: Attempting to cancel completed batches
|
||||
|
||||
4. Parameter and Configuration Validation:
|
||||
- test_batch_invalid_completion_window: Invalid completion window values
|
||||
- test_batch_invalid_metadata_types: Invalid metadata type validation
|
||||
- test_batch_missing_required_body_fields: Validation of required fields in request body
|
||||
|
||||
5. Feature Restriction and Compatibility:
|
||||
- test_batch_streaming_not_supported: Streaming request rejection
|
||||
- test_batch_mixed_streaming_requests: Mixed streaming/non-streaming validation
|
||||
|
||||
Note: Core functionality and OpenAI compatibility tests are located in
|
||||
test_batches_integration.py for better organization and separation of concerns.
|
||||
|
||||
CLEANUP WARNING: These tests create batches to test error conditions but do not
|
||||
automatically clean them up after test completion. While most error tests create
|
||||
batches that fail quickly, some may create valid batches that consume resources.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from openai import BadRequestError, ConflictError, NotFoundError
|
||||
|
||||
|
||||
class TestBatchesErrorHandling:
|
||||
"""Error handling and edge case tests for the batches API using OpenAI client."""
|
||||
|
||||
def test_batch_nonexistent_file_id(self, openai_client, batch_helper):
|
||||
"""Test batch creation with nonexistent input file ID."""
|
||||
|
||||
batch = openai_client.batches.create(
|
||||
input_file_id="file-nonexistent-xyz",
|
||||
endpoint="/v1/chat/completions",
|
||||
completion_window="24h",
|
||||
)
|
||||
|
||||
final_batch = batch_helper.wait_for(batch.id, expected_statuses={"failed"})
|
||||
|
||||
# Expecting -
|
||||
# Batch(...,
|
||||
# status='failed',
|
||||
# errors=Errors(data=[
|
||||
# BatchError(
|
||||
# code='invalid_request',
|
||||
# line=None,
|
||||
# message='Cannot find file ..., or organization ... does not have access to it.',
|
||||
# param='file_id')
|
||||
# ], object='list'),
|
||||
# failed_at=1754566971,
|
||||
# ...)
|
||||
|
||||
assert final_batch.status == "failed"
|
||||
assert final_batch.errors is not None
|
||||
assert len(final_batch.errors.data) == 1
|
||||
error = final_batch.errors.data[0]
|
||||
assert error.code == "invalid_request"
|
||||
assert "cannot find file" in error.message.lower()
|
||||
|
||||
def test_batch_invalid_endpoint(self, openai_client, batch_helper, text_model_id):
|
||||
"""Test batch creation with invalid endpoint."""
|
||||
batch_requests = [
|
||||
{
|
||||
"custom_id": "invalid-endpoint",
|
||||
"method": "POST",
|
||||
"url": "/v1/chat/completions",
|
||||
"body": {
|
||||
"model": text_model_id,
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"max_tokens": 10,
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
with batch_helper.create_file(batch_requests) as uploaded_file:
|
||||
with pytest.raises(BadRequestError) as exc_info:
|
||||
openai_client.batches.create(
|
||||
input_file_id=uploaded_file.id,
|
||||
endpoint="/v1/invalid/endpoint",
|
||||
completion_window="24h",
|
||||
)
|
||||
|
||||
# Expected -
|
||||
# Error code: 400 - {
|
||||
# 'error': {
|
||||
# 'message': "Invalid value: '/v1/invalid/endpoint'. Supported values are: '/v1/chat/completions', '/v1/completions', '/v1/embeddings', and '/v1/responses'.",
|
||||
# 'type': 'invalid_request_error',
|
||||
# 'param': 'endpoint',
|
||||
# 'code': 'invalid_value'
|
||||
# }
|
||||
# }
|
||||
|
||||
error_msg = str(exc_info.value).lower()
|
||||
assert exc_info.value.status_code == 400
|
||||
assert "invalid value" in error_msg
|
||||
assert "/v1/invalid/endpoint" in error_msg
|
||||
assert "supported values" in error_msg
|
||||
assert "endpoint" in error_msg
|
||||
assert "invalid_value" in error_msg
|
||||
|
||||
def test_batch_malformed_jsonl(self, openai_client, batch_helper):
|
||||
"""
|
||||
Test batch with malformed JSONL input.
|
||||
|
||||
The /v1/files endpoint requires valid JSONL format, so we provide a well formed line
|
||||
before a malformed line to ensure we get to the /v1/batches validation stage.
|
||||
"""
|
||||
with batch_helper.create_file(
|
||||
"""{"custom_id": "valid", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "test"}}
|
||||
{invalid json here""",
|
||||
"malformed_batch_input.jsonl",
|
||||
) as uploaded_file:
|
||||
batch = openai_client.batches.create(
|
||||
input_file_id=uploaded_file.id,
|
||||
endpoint="/v1/chat/completions",
|
||||
completion_window="24h",
|
||||
)
|
||||
|
||||
final_batch = batch_helper.wait_for(batch.id, expected_statuses={"failed"})
|
||||
|
||||
# Expecting -
|
||||
# Batch(...,
|
||||
# status='failed',
|
||||
# errors=Errors(data=[
|
||||
# ...,
|
||||
# BatchError(code='invalid_json_line',
|
||||
# line=2,
|
||||
# message='This line is not parseable as valid JSON.',
|
||||
# param=None)
|
||||
# ], object='list'),
|
||||
# ...)
|
||||
|
||||
assert final_batch.status == "failed"
|
||||
assert final_batch.errors is not None
|
||||
assert len(final_batch.errors.data) > 0
|
||||
error = final_batch.errors.data[-1] # get last error because first may be about the "test" model
|
||||
assert error.code == "invalid_json_line"
|
||||
assert error.line == 2
|
||||
assert "not" in error.message.lower()
|
||||
assert "valid json" in error.message.lower()
|
||||
|
||||
@pytest.mark.xfail(reason="Not all file providers validate content")
|
||||
@pytest.mark.parametrize("batch_requests", ["", "{malformed json"], ids=["empty", "malformed"])
|
||||
def test_file_malformed_batch_file(self, openai_client, batch_helper, batch_requests):
|
||||
"""Test file upload with malformed content."""
|
||||
|
||||
with pytest.raises(BadRequestError) as exc_info:
|
||||
with batch_helper.create_file(batch_requests, "malformed_batch_input_file.jsonl"):
|
||||
# /v1/files rejects the file, we don't get to batch creation
|
||||
pass
|
||||
|
||||
error_msg = str(exc_info.value).lower()
|
||||
assert exc_info.value.status_code == 400
|
||||
assert "invalid file format" in error_msg
|
||||
assert "jsonl" in error_msg
|
||||
|
||||
def test_batch_retrieve_nonexistent(self, openai_client):
|
||||
"""Test retrieving nonexistent batch."""
|
||||
with pytest.raises(NotFoundError) as exc_info:
|
||||
openai_client.batches.retrieve("batch-nonexistent-xyz")
|
||||
|
||||
error_msg = str(exc_info.value).lower()
|
||||
assert exc_info.value.status_code == 404
|
||||
assert "no batch found" in error_msg or "not found" in error_msg
|
||||
|
||||
def test_batch_cancel_nonexistent(self, openai_client):
|
||||
"""Test cancelling nonexistent batch."""
|
||||
with pytest.raises(NotFoundError) as exc_info:
|
||||
openai_client.batches.cancel("batch-nonexistent-xyz")
|
||||
|
||||
error_msg = str(exc_info.value).lower()
|
||||
assert exc_info.value.status_code == 404
|
||||
assert "no batch found" in error_msg or "not found" in error_msg
|
||||
|
||||
def test_batch_cancel_completed(self, openai_client, batch_helper, text_model_id):
|
||||
"""Test cancelling already completed batch."""
|
||||
batch_requests = [
|
||||
{
|
||||
"custom_id": "cancel-completed",
|
||||
"method": "POST",
|
||||
"url": "/v1/chat/completions",
|
||||
"body": {
|
||||
"model": text_model_id,
|
||||
"messages": [{"role": "user", "content": "Quick test"}],
|
||||
"max_tokens": 5,
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
with batch_helper.create_file(batch_requests, "cancel_test_batch_input") as uploaded_file:
|
||||
batch = openai_client.batches.create(
|
||||
input_file_id=uploaded_file.id,
|
||||
endpoint="/v1/chat/completions",
|
||||
completion_window="24h",
|
||||
)
|
||||
|
||||
final_batch = batch_helper.wait_for(
|
||||
batch.id,
|
||||
max_wait_time=3 * 60, # often take 10-11 min, give it 3 min
|
||||
expected_statuses={"completed"},
|
||||
timeout_action="skip",
|
||||
)
|
||||
|
||||
deleted_file = openai_client.files.delete(final_batch.output_file_id)
|
||||
assert deleted_file.deleted, f"File {final_batch.output_file_id} was not deleted successfully"
|
||||
|
||||
with pytest.raises(ConflictError) as exc_info:
|
||||
openai_client.batches.cancel(batch.id)
|
||||
|
||||
# Expecting -
|
||||
# Error code: 409 - {
|
||||
# 'error': {
|
||||
# 'message': "Cannot cancel a batch with status 'completed'.",
|
||||
# 'type': 'invalid_request_error',
|
||||
# 'param': None,
|
||||
# 'code': None
|
||||
# }
|
||||
# }
|
||||
#
|
||||
# NOTE: Same for "failed", cancelling "cancelled" batches is allowed
|
||||
|
||||
error_msg = str(exc_info.value).lower()
|
||||
assert exc_info.value.status_code == 409
|
||||
assert "cannot cancel" in error_msg
|
||||
|
||||
def test_batch_missing_required_fields(self, openai_client, batch_helper, text_model_id):
|
||||
"""Test batch with requests missing required fields."""
|
||||
batch_requests = [
|
||||
{
|
||||
# Missing custom_id
|
||||
"method": "POST",
|
||||
"url": "/v1/chat/completions",
|
||||
"body": {
|
||||
"model": text_model_id,
|
||||
"messages": [{"role": "user", "content": "No custom_id"}],
|
||||
"max_tokens": 10,
|
||||
},
|
||||
},
|
||||
{
|
||||
"custom_id": "no-method",
|
||||
"url": "/v1/chat/completions",
|
||||
"body": {
|
||||
"model": text_model_id,
|
||||
"messages": [{"role": "user", "content": "No method"}],
|
||||
"max_tokens": 10,
|
||||
},
|
||||
},
|
||||
{
|
||||
"custom_id": "no-url",
|
||||
"method": "POST",
|
||||
"body": {
|
||||
"model": text_model_id,
|
||||
"messages": [{"role": "user", "content": "No URL"}],
|
||||
"max_tokens": 10,
|
||||
},
|
||||
},
|
||||
{
|
||||
"custom_id": "no-body",
|
||||
"method": "POST",
|
||||
"url": "/v1/chat/completions",
|
||||
},
|
||||
]
|
||||
|
||||
with batch_helper.create_file(batch_requests, "missing_fields_batch_input") as uploaded_file:
|
||||
batch = openai_client.batches.create(
|
||||
input_file_id=uploaded_file.id,
|
||||
endpoint="/v1/chat/completions",
|
||||
completion_window="24h",
|
||||
)
|
||||
|
||||
final_batch = batch_helper.wait_for(batch.id, expected_statuses={"failed"})
|
||||
|
||||
# Expecting -
|
||||
# Batch(...,
|
||||
# status='failed',
|
||||
# errors=Errors(
|
||||
# data=[
|
||||
# BatchError(
|
||||
# code='missing_required_parameter',
|
||||
# line=1,
|
||||
# message="Missing required parameter: 'custom_id'.",
|
||||
# param='custom_id'
|
||||
# ),
|
||||
# BatchError(
|
||||
# code='missing_required_parameter',
|
||||
# line=2,
|
||||
# message="Missing required parameter: 'method'.",
|
||||
# param='method'
|
||||
# ),
|
||||
# BatchError(
|
||||
# code='missing_required_parameter',
|
||||
# line=3,
|
||||
# message="Missing required parameter: 'url'.",
|
||||
# param='url'
|
||||
# ),
|
||||
# BatchError(
|
||||
# code='missing_required_parameter',
|
||||
# line=4,
|
||||
# message="Missing required parameter: 'body'.",
|
||||
# param='body'
|
||||
# )
|
||||
# ], object='list'),
|
||||
# failed_at=1754566945,
|
||||
# ...)
|
||||
# )
|
||||
|
||||
assert final_batch.status == "failed"
|
||||
assert final_batch.errors is not None
|
||||
assert len(final_batch.errors.data) == 4
|
||||
no_custom_id_error = final_batch.errors.data[0]
|
||||
assert no_custom_id_error.code == "missing_required_parameter"
|
||||
assert no_custom_id_error.line == 1
|
||||
assert "missing" in no_custom_id_error.message.lower()
|
||||
assert "custom_id" in no_custom_id_error.message.lower()
|
||||
no_method_error = final_batch.errors.data[1]
|
||||
assert no_method_error.code == "missing_required_parameter"
|
||||
assert no_method_error.line == 2
|
||||
assert "missing" in no_method_error.message.lower()
|
||||
assert "method" in no_method_error.message.lower()
|
||||
no_url_error = final_batch.errors.data[2]
|
||||
assert no_url_error.code == "missing_required_parameter"
|
||||
assert no_url_error.line == 3
|
||||
assert "missing" in no_url_error.message.lower()
|
||||
assert "url" in no_url_error.message.lower()
|
||||
no_body_error = final_batch.errors.data[3]
|
||||
assert no_body_error.code == "missing_required_parameter"
|
||||
assert no_body_error.line == 4
|
||||
assert "missing" in no_body_error.message.lower()
|
||||
assert "body" in no_body_error.message.lower()
|
||||
|
||||
def test_batch_invalid_completion_window(self, openai_client, batch_helper, text_model_id):
|
||||
"""Test batch creation with invalid completion window."""
|
||||
batch_requests = [
|
||||
{
|
||||
"custom_id": "invalid-completion-window",
|
||||
"method": "POST",
|
||||
"url": "/v1/chat/completions",
|
||||
"body": {
|
||||
"model": text_model_id,
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"max_tokens": 10,
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
with batch_helper.create_file(batch_requests) as uploaded_file:
|
||||
for window in ["1h", "48h", "invalid", ""]:
|
||||
with pytest.raises(BadRequestError) as exc_info:
|
||||
openai_client.batches.create(
|
||||
input_file_id=uploaded_file.id,
|
||||
endpoint="/v1/chat/completions",
|
||||
completion_window=window,
|
||||
)
|
||||
assert exc_info.value.status_code == 400
|
||||
error_msg = str(exc_info.value).lower()
|
||||
assert "error" in error_msg
|
||||
assert "completion_window" in error_msg
|
||||
|
||||
def test_batch_streaming_not_supported(self, openai_client, batch_helper, text_model_id):
|
||||
"""Test that streaming responses are not supported in batches."""
|
||||
batch_requests = [
|
||||
{
|
||||
"custom_id": "streaming-test",
|
||||
"method": "POST",
|
||||
"url": "/v1/chat/completions",
|
||||
"body": {
|
||||
"model": text_model_id,
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"max_tokens": 10,
|
||||
"stream": True, # Not supported
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
with batch_helper.create_file(batch_requests, "streaming_batch_input") as uploaded_file:
|
||||
batch = openai_client.batches.create(
|
||||
input_file_id=uploaded_file.id,
|
||||
endpoint="/v1/chat/completions",
|
||||
completion_window="24h",
|
||||
)
|
||||
|
||||
final_batch = batch_helper.wait_for(batch.id, expected_statuses={"failed"})
|
||||
|
||||
# Expecting -
|
||||
# Batch(...,
|
||||
# status='failed',
|
||||
# errors=Errors(data=[
|
||||
# BatchError(code='streaming_unsupported',
|
||||
# line=1,
|
||||
# message='Chat Completions: Streaming is not supported in the Batch API.',
|
||||
# param='body.stream')
|
||||
# ], object='list'),
|
||||
# failed_at=1754566965,
|
||||
# ...)
|
||||
|
||||
assert final_batch.status == "failed"
|
||||
assert final_batch.errors is not None
|
||||
assert len(final_batch.errors.data) == 1
|
||||
error = final_batch.errors.data[0]
|
||||
assert error.code == "streaming_unsupported"
|
||||
assert error.line == 1
|
||||
assert "streaming" in error.message.lower()
|
||||
assert "not supported" in error.message.lower()
|
||||
assert error.param == "body.stream"
|
||||
assert final_batch.failed_at is not None
|
||||
|
||||
def test_batch_mixed_streaming_requests(self, openai_client, batch_helper, text_model_id):
|
||||
"""
|
||||
Test batch with mixed streaming and non-streaming requests.
|
||||
|
||||
This is distinct from test_batch_streaming_not_supported, which tests a single
|
||||
streaming request, to ensure an otherwise valid batch fails when a single
|
||||
streaming request is included.
|
||||
"""
|
||||
batch_requests = [
|
||||
{
|
||||
"custom_id": "valid-non-streaming-request",
|
||||
"method": "POST",
|
||||
"url": "/v1/chat/completions",
|
||||
"body": {
|
||||
"model": text_model_id,
|
||||
"messages": [{"role": "user", "content": "Hello without streaming"}],
|
||||
"max_tokens": 10,
|
||||
},
|
||||
},
|
||||
{
|
||||
"custom_id": "streaming-request",
|
||||
"method": "POST",
|
||||
"url": "/v1/chat/completions",
|
||||
"body": {
|
||||
"model": text_model_id,
|
||||
"messages": [{"role": "user", "content": "Hello with streaming"}],
|
||||
"max_tokens": 10,
|
||||
"stream": True, # Not supported
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
with batch_helper.create_file(batch_requests, "mixed_streaming_batch_input") as uploaded_file:
|
||||
batch = openai_client.batches.create(
|
||||
input_file_id=uploaded_file.id,
|
||||
endpoint="/v1/chat/completions",
|
||||
completion_window="24h",
|
||||
)
|
||||
|
||||
final_batch = batch_helper.wait_for(batch.id, expected_statuses={"failed"})
|
||||
|
||||
# Expecting -
|
||||
# Batch(...,
|
||||
# status='failed',
|
||||
# errors=Errors(data=[
|
||||
# BatchError(
|
||||
# code='streaming_unsupported',
|
||||
# line=2,
|
||||
# message='Chat Completions: Streaming is not supported in the Batch API.',
|
||||
# param='body.stream')
|
||||
# ], object='list'),
|
||||
# failed_at=1754574442,
|
||||
# ...)
|
||||
|
||||
assert final_batch.status == "failed"
|
||||
assert final_batch.errors is not None
|
||||
assert len(final_batch.errors.data) == 1
|
||||
error = final_batch.errors.data[0]
|
||||
assert error.code == "streaming_unsupported"
|
||||
assert error.line == 2
|
||||
assert "streaming" in error.message.lower()
|
||||
assert "not supported" in error.message.lower()
|
||||
assert error.param == "body.stream"
|
||||
assert final_batch.failed_at is not None
|
||||
|
||||
def test_batch_endpoint_mismatch(self, openai_client, batch_helper, text_model_id):
|
||||
"""Test batch creation with mismatched endpoint and request URL."""
|
||||
batch_requests = [
|
||||
{
|
||||
"custom_id": "endpoint-mismatch",
|
||||
"method": "POST",
|
||||
"url": "/v1/embeddings", # Different from batch endpoint
|
||||
"body": {
|
||||
"model": text_model_id,
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
with batch_helper.create_file(batch_requests, "endpoint_mismatch_batch_input") as uploaded_file:
|
||||
batch = openai_client.batches.create(
|
||||
input_file_id=uploaded_file.id,
|
||||
endpoint="/v1/chat/completions", # Different from request URL
|
||||
completion_window="24h",
|
||||
)
|
||||
|
||||
final_batch = batch_helper.wait_for(batch.id, expected_statuses={"failed"})
|
||||
|
||||
# Expecting -
|
||||
# Batch(...,
|
||||
# status='failed',
|
||||
# errors=Errors(data=[
|
||||
# BatchError(
|
||||
# code='invalid_url',
|
||||
# line=1,
|
||||
# message='The URL provided for this request does not match the batch endpoint.',
|
||||
# param='url')
|
||||
# ], object='list'),
|
||||
# failed_at=1754566972,
|
||||
# ...)
|
||||
|
||||
assert final_batch.status == "failed"
|
||||
assert final_batch.errors is not None
|
||||
assert len(final_batch.errors.data) == 1
|
||||
error = final_batch.errors.data[0]
|
||||
assert error.line == 1
|
||||
assert error.code == "invalid_url"
|
||||
assert "does not match" in error.message.lower()
|
||||
assert "endpoint" in error.message.lower()
|
||||
assert final_batch.failed_at is not None
|
||||
|
||||
def test_batch_error_handling_invalid_model(self, openai_client, batch_helper):
|
||||
"""Test batch error handling with invalid model."""
|
||||
batch_requests = [
|
||||
{
|
||||
"custom_id": "invalid-model",
|
||||
"method": "POST",
|
||||
"url": "/v1/chat/completions",
|
||||
"body": {
|
||||
"model": "nonexistent-model-xyz",
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"max_tokens": 10,
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
with batch_helper.create_file(batch_requests) as uploaded_file:
|
||||
batch = openai_client.batches.create(
|
||||
input_file_id=uploaded_file.id,
|
||||
endpoint="/v1/chat/completions",
|
||||
completion_window="24h",
|
||||
)
|
||||
|
||||
final_batch = batch_helper.wait_for(batch.id, expected_statuses={"failed"})
|
||||
|
||||
# Expecting -
|
||||
# Batch(...,
|
||||
# status='failed',
|
||||
# errors=Errors(data=[
|
||||
# BatchError(code='model_not_found',
|
||||
# line=1,
|
||||
# message="The provided model 'nonexistent-model-xyz' is not supported by the Batch API.",
|
||||
# param='body.model')
|
||||
# ], object='list'),
|
||||
# failed_at=1754566978,
|
||||
# ...)
|
||||
|
||||
assert final_batch.status == "failed"
|
||||
assert final_batch.errors is not None
|
||||
assert len(final_batch.errors.data) == 1
|
||||
error = final_batch.errors.data[0]
|
||||
assert error.line == 1
|
||||
assert error.code == "model_not_found"
|
||||
assert "not supported" in error.message.lower()
|
||||
assert error.param == "body.model"
|
||||
assert final_batch.failed_at is not None
|
||||
|
||||
def test_batch_missing_required_body_fields(self, openai_client, batch_helper, text_model_id):
|
||||
"""Test batch with requests missing required fields in body (model and messages)."""
|
||||
batch_requests = [
|
||||
{
|
||||
"custom_id": "missing-model",
|
||||
"method": "POST",
|
||||
"url": "/v1/chat/completions",
|
||||
"body": {
|
||||
# Missing model field
|
||||
"messages": [{"role": "user", "content": "Hello without model"}],
|
||||
"max_tokens": 10,
|
||||
},
|
||||
},
|
||||
{
|
||||
"custom_id": "missing-messages",
|
||||
"method": "POST",
|
||||
"url": "/v1/chat/completions",
|
||||
"body": {
|
||||
"model": text_model_id,
|
||||
# Missing messages field
|
||||
"max_tokens": 10,
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
with batch_helper.create_file(batch_requests, "missing_body_fields_batch_input") as uploaded_file:
|
||||
batch = openai_client.batches.create(
|
||||
input_file_id=uploaded_file.id,
|
||||
endpoint="/v1/chat/completions",
|
||||
completion_window="24h",
|
||||
)
|
||||
|
||||
final_batch = batch_helper.wait_for(batch.id, expected_statuses={"failed"})
|
||||
|
||||
# Expecting -
|
||||
# Batch(...,
|
||||
# status='failed',
|
||||
# errors=Errors(data=[
|
||||
# BatchError(
|
||||
# code='invalid_request',
|
||||
# line=1,
|
||||
# message='Model parameter is required.',
|
||||
# param='body.model'),
|
||||
# BatchError(
|
||||
# code='invalid_request',
|
||||
# line=2,
|
||||
# message='Messages parameter is required.',
|
||||
# param='body.messages')
|
||||
# ], object='list'),
|
||||
# ...)
|
||||
|
||||
assert final_batch.status == "failed"
|
||||
assert final_batch.errors is not None
|
||||
assert len(final_batch.errors.data) == 2
|
||||
|
||||
model_error = final_batch.errors.data[0]
|
||||
assert model_error.line == 1
|
||||
assert "model" in model_error.message.lower()
|
||||
assert model_error.param == "body.model"
|
||||
|
||||
messages_error = final_batch.errors.data[1]
|
||||
assert messages_error.line == 2
|
||||
assert "messages" in messages_error.message.lower()
|
||||
assert messages_error.param == "body.messages"
|
||||
|
||||
assert final_batch.failed_at is not None
|
||||
|
||||
def test_batch_invalid_metadata_types(self, openai_client, batch_helper, text_model_id):
|
||||
"""Test batch creation with invalid metadata types (like lists)."""
|
||||
batch_requests = [
|
||||
{
|
||||
"custom_id": "invalid-metadata-type",
|
||||
"method": "POST",
|
||||
"url": "/v1/chat/completions",
|
||||
"body": {
|
||||
"model": text_model_id,
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"max_tokens": 10,
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
with batch_helper.create_file(batch_requests) as uploaded_file:
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
openai_client.batches.create(
|
||||
input_file_id=uploaded_file.id,
|
||||
endpoint="/v1/chat/completions",
|
||||
completion_window="24h",
|
||||
metadata={
|
||||
"tags": ["tag1", "tag2"], # Invalid type, should be a string
|
||||
},
|
||||
)
|
||||
|
||||
# Expecting -
|
||||
# Error code: 400 - {'error':
|
||||
# {'message': "Invalid type for 'metadata.tags': expected a string,
|
||||
# but got an array instead.",
|
||||
# 'type': 'invalid_request_error', 'param': 'metadata.tags',
|
||||
# 'code': 'invalid_type'}}
|
||||
|
||||
error_msg = str(exc_info.value).lower()
|
||||
assert "400" in error_msg
|
||||
assert "tags" in error_msg
|
||||
assert "string" in error_msg
|
91
tests/integration/batches/test_batches_idempotency.py
Normal file
91
tests/integration/batches/test_batches_idempotency.py
Normal file
|
@ -0,0 +1,91 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Integration tests for batch idempotency functionality using the OpenAI client library.
|
||||
|
||||
This module tests the idempotency feature in the batches API using the OpenAI-compatible
|
||||
client interface. These tests verify that the idempotency key (idempotency_key) works correctly
|
||||
in a real client-server environment.
|
||||
|
||||
Test Categories:
|
||||
1. Successful Idempotency: Same key returns same batch with identical parameters
|
||||
- test_idempotent_batch_creation_successful: Verifies that requests with the same
|
||||
idempotency key return identical batches, even with different metadata order
|
||||
|
||||
2. Conflict Detection: Same key with conflicting parameters raises HTTP 409 errors
|
||||
- test_idempotency_conflict_with_different_params: Verifies that reusing an idempotency key
|
||||
with truly conflicting parameters (both file ID and metadata values) raises ConflictError
|
||||
"""
|
||||
|
||||
import time
|
||||
|
||||
import pytest
|
||||
from openai import ConflictError
|
||||
|
||||
|
||||
class TestBatchesIdempotencyIntegration:
|
||||
"""Integration tests for batch idempotency using OpenAI client."""
|
||||
|
||||
def test_idempotent_batch_creation_successful(self, openai_client):
|
||||
"""Test that identical requests with same idempotency key return the same batch."""
|
||||
batch1 = openai_client.batches.create(
|
||||
input_file_id="bogus-id",
|
||||
endpoint="/v1/chat/completions",
|
||||
completion_window="24h",
|
||||
metadata={
|
||||
"test_type": "idempotency_success",
|
||||
"purpose": "integration_test",
|
||||
},
|
||||
extra_body={"idempotency_key": "test-idempotency-token-1"},
|
||||
)
|
||||
|
||||
# sleep to ensure different timestamps
|
||||
time.sleep(1)
|
||||
|
||||
batch2 = openai_client.batches.create(
|
||||
input_file_id="bogus-id",
|
||||
endpoint="/v1/chat/completions",
|
||||
completion_window="24h",
|
||||
metadata={
|
||||
"purpose": "integration_test",
|
||||
"test_type": "idempotency_success",
|
||||
}, # Different order
|
||||
extra_body={"idempotency_key": "test-idempotency-token-1"},
|
||||
)
|
||||
|
||||
assert batch1.id == batch2.id
|
||||
assert batch1.input_file_id == batch2.input_file_id
|
||||
assert batch1.endpoint == batch2.endpoint
|
||||
assert batch1.completion_window == batch2.completion_window
|
||||
assert batch1.metadata == batch2.metadata
|
||||
assert batch1.created_at == batch2.created_at
|
||||
|
||||
def test_idempotency_conflict_with_different_params(self, openai_client):
|
||||
"""Test that using same idempotency key with different params raises conflict error."""
|
||||
batch1 = openai_client.batches.create(
|
||||
input_file_id="bogus-id-1",
|
||||
endpoint="/v1/chat/completions",
|
||||
completion_window="24h",
|
||||
metadata={"test_type": "conflict_test_1"},
|
||||
extra_body={"idempotency_key": "conflict-token"},
|
||||
)
|
||||
|
||||
with pytest.raises(ConflictError) as exc_info:
|
||||
openai_client.batches.create(
|
||||
input_file_id="bogus-id-2", # Different file ID
|
||||
endpoint="/v1/chat/completions",
|
||||
completion_window="24h",
|
||||
metadata={"test_type": "conflict_test_2"}, # Different metadata
|
||||
extra_body={"idempotency_key": "conflict-token"}, # Same token
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 409
|
||||
assert "conflict" in str(exc_info.value).lower()
|
||||
|
||||
retrieved_batch = openai_client.batches.retrieve(batch1.id)
|
||||
assert retrieved_batch.id == batch1.id
|
||||
assert retrieved_batch.input_file_id == "bogus-id-1"
|
|
@ -8,20 +8,27 @@ from io import BytesIO
|
|||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from openai import OpenAI
|
||||
|
||||
from llama_stack.core.datatypes import User
|
||||
from llama_stack.core.library_client import LlamaStackAsLibraryClient
|
||||
|
||||
|
||||
def test_openai_client_basic_operations(compat_client, client_with_models):
|
||||
# a fixture to skip all these tests if a files provider is not available
|
||||
@pytest.fixture(autouse=True)
|
||||
def skip_if_no_files_provider(llama_stack_client):
|
||||
if not [provider for provider in llama_stack_client.providers.list() if provider.api == "files"]:
|
||||
pytest.skip("No files providers found")
|
||||
|
||||
|
||||
def test_openai_client_basic_operations(openai_client):
|
||||
"""Test basic file operations through OpenAI client."""
|
||||
if isinstance(client_with_models, LlamaStackAsLibraryClient) and isinstance(compat_client, OpenAI):
|
||||
pytest.skip("OpenAI files are not supported when testing with LlamaStackAsLibraryClient")
|
||||
client = compat_client
|
||||
from openai import NotFoundError
|
||||
|
||||
client = openai_client
|
||||
|
||||
test_content = b"files test content"
|
||||
|
||||
uploaded_file = None
|
||||
|
||||
try:
|
||||
# Upload file using OpenAI client
|
||||
with BytesIO(test_content) as file_buffer:
|
||||
|
@ -31,6 +38,7 @@ def test_openai_client_basic_operations(compat_client, client_with_models):
|
|||
# Verify basic response structure
|
||||
assert uploaded_file.id.startswith("file-")
|
||||
assert hasattr(uploaded_file, "filename")
|
||||
assert uploaded_file.filename == "openai_test.txt"
|
||||
|
||||
# List files
|
||||
files_list = client.files.list()
|
||||
|
@ -43,37 +51,41 @@ def test_openai_client_basic_operations(compat_client, client_with_models):
|
|||
|
||||
# Retrieve file content - OpenAI client returns httpx Response object
|
||||
content_response = client.files.content(uploaded_file.id)
|
||||
# The response is an httpx Response object with .content attribute containing bytes
|
||||
if isinstance(content_response, str):
|
||||
# Llama Stack Client returns a str
|
||||
# TODO: fix Llama Stack Client
|
||||
content = bytes(content_response, "utf-8")
|
||||
else:
|
||||
content = content_response.content
|
||||
assert content == test_content
|
||||
assert content_response.content == test_content
|
||||
|
||||
# Delete file
|
||||
delete_response = client.files.delete(uploaded_file.id)
|
||||
assert delete_response.deleted is True
|
||||
|
||||
except Exception as e:
|
||||
# Cleanup in case of failure
|
||||
try:
|
||||
# Retrieve file should fail
|
||||
with pytest.raises(NotFoundError, match="not found"):
|
||||
client.files.retrieve(uploaded_file.id)
|
||||
|
||||
# File should not be found in listing
|
||||
files_list = client.files.list()
|
||||
file_ids = [f.id for f in files_list.data]
|
||||
assert uploaded_file.id not in file_ids
|
||||
|
||||
# Double delete should fail
|
||||
with pytest.raises(NotFoundError, match="not found"):
|
||||
client.files.delete(uploaded_file.id)
|
||||
except Exception:
|
||||
pass
|
||||
raise e
|
||||
|
||||
finally:
|
||||
# Cleanup in case of failure
|
||||
if uploaded_file is not None:
|
||||
try:
|
||||
client.files.delete(uploaded_file.id)
|
||||
except NotFoundError:
|
||||
pass # ignore 404
|
||||
|
||||
|
||||
@pytest.mark.xfail(message="User isolation broken for current providers, must be fixed.")
|
||||
@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user")
|
||||
def test_files_authentication_isolation(mock_get_authenticated_user, compat_client, client_with_models):
|
||||
def test_files_authentication_isolation(mock_get_authenticated_user, llama_stack_client):
|
||||
"""Test that users can only access their own files."""
|
||||
if isinstance(client_with_models, LlamaStackAsLibraryClient) and isinstance(compat_client, OpenAI):
|
||||
pytest.skip("OpenAI files are not supported when testing with LlamaStackAsLibraryClient")
|
||||
if not isinstance(client_with_models, LlamaStackAsLibraryClient):
|
||||
pytest.skip("Authentication tests require LlamaStackAsLibraryClient (library mode)")
|
||||
from llama_stack_client import NotFoundError
|
||||
|
||||
client = compat_client
|
||||
client = llama_stack_client
|
||||
|
||||
# Create two test users
|
||||
user1 = User("user1", {"roles": ["user"], "teams": ["team-a"]})
|
||||
|
@ -117,7 +129,7 @@ def test_files_authentication_isolation(mock_get_authenticated_user, compat_clie
|
|||
|
||||
# User 1 cannot retrieve user2's file
|
||||
mock_get_authenticated_user.return_value = user1
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
with pytest.raises(NotFoundError, match="not found"):
|
||||
client.files.retrieve(user2_file.id)
|
||||
|
||||
# User 1 can access their file content
|
||||
|
@ -131,7 +143,7 @@ def test_files_authentication_isolation(mock_get_authenticated_user, compat_clie
|
|||
|
||||
# User 1 cannot access user2's file content
|
||||
mock_get_authenticated_user.return_value = user1
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
with pytest.raises(NotFoundError, match="not found"):
|
||||
client.files.content(user2_file.id)
|
||||
|
||||
# User 1 can delete their own file
|
||||
|
@ -141,7 +153,7 @@ def test_files_authentication_isolation(mock_get_authenticated_user, compat_clie
|
|||
|
||||
# User 1 cannot delete user2's file
|
||||
mock_get_authenticated_user.return_value = user1
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
with pytest.raises(NotFoundError, match="not found"):
|
||||
client.files.delete(user2_file.id)
|
||||
|
||||
# User 2 can still access their file after user1's file is deleted
|
||||
|
@ -169,14 +181,9 @@ def test_files_authentication_isolation(mock_get_authenticated_user, compat_clie
|
|||
|
||||
|
||||
@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user")
|
||||
def test_files_authentication_shared_attributes(mock_get_authenticated_user, compat_client, client_with_models):
|
||||
def test_files_authentication_shared_attributes(mock_get_authenticated_user, llama_stack_client):
|
||||
"""Test access control with users having identical attributes."""
|
||||
if isinstance(client_with_models, LlamaStackAsLibraryClient) and isinstance(compat_client, OpenAI):
|
||||
pytest.skip("OpenAI files are not supported when testing with LlamaStackAsLibraryClient")
|
||||
if not isinstance(client_with_models, LlamaStackAsLibraryClient):
|
||||
pytest.skip("Authentication tests require LlamaStackAsLibraryClient (library mode)")
|
||||
|
||||
client = compat_client
|
||||
client = llama_stack_client
|
||||
|
||||
# Create users with identical attributes (required for default policy)
|
||||
user_a = User("user-a", {"roles": ["user"], "teams": ["shared-team"]})
|
||||
|
@ -231,14 +238,8 @@ def test_files_authentication_shared_attributes(mock_get_authenticated_user, com
|
|||
|
||||
|
||||
@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user")
|
||||
def test_files_authentication_anonymous_access(mock_get_authenticated_user, compat_client, client_with_models):
|
||||
"""Test anonymous user behavior when no authentication is present."""
|
||||
if isinstance(client_with_models, LlamaStackAsLibraryClient) and isinstance(compat_client, OpenAI):
|
||||
pytest.skip("OpenAI files are not supported when testing with LlamaStackAsLibraryClient")
|
||||
if not isinstance(client_with_models, LlamaStackAsLibraryClient):
|
||||
pytest.skip("Authentication tests require LlamaStackAsLibraryClient (library mode)")
|
||||
|
||||
client = compat_client
|
||||
def test_files_authentication_anonymous_access(mock_get_authenticated_user, llama_stack_client):
|
||||
client = llama_stack_client
|
||||
|
||||
# Simulate anonymous user (no authentication)
|
||||
mock_get_authenticated_user.return_value = None
|
||||
|
|
|
@ -256,15 +256,25 @@ def instantiate_llama_stack_client(session):
|
|||
provider_data=get_provider_data(),
|
||||
skip_logger_removal=True,
|
||||
)
|
||||
if not client.initialize():
|
||||
raise RuntimeError("Initialization failed")
|
||||
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def openai_client(client_with_models):
|
||||
base_url = f"{client_with_models.base_url}/v1/openai/v1"
|
||||
def require_server(llama_stack_client):
|
||||
"""
|
||||
Skip test if no server is running.
|
||||
|
||||
We use the llama_stack_client to tell if a server was started or not.
|
||||
|
||||
We use this with openai_client because it relies on a running server.
|
||||
"""
|
||||
if isinstance(llama_stack_client, LlamaStackAsLibraryClient):
|
||||
pytest.skip("No server running")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def openai_client(llama_stack_client, require_server):
|
||||
base_url = f"{llama_stack_client.base_url}/v1/openai/v1"
|
||||
return OpenAI(base_url=base_url, api_key="fake")
|
||||
|
||||
|
||||
|
|
|
@ -55,7 +55,7 @@
|
|||
#
|
||||
|
||||
import pytest
|
||||
from llama_stack_client import BadRequestError
|
||||
from llama_stack_client import BadRequestError as LlamaStackBadRequestError
|
||||
from llama_stack_client.types import EmbeddingsResponse
|
||||
from llama_stack_client.types.shared.interleaved_content import (
|
||||
ImageContentItem,
|
||||
|
@ -63,6 +63,9 @@ from llama_stack_client.types.shared.interleaved_content import (
|
|||
ImageContentItemImageURL,
|
||||
TextContentItem,
|
||||
)
|
||||
from openai import BadRequestError as OpenAIBadRequestError
|
||||
|
||||
from llama_stack.core.library_client import LlamaStackAsLibraryClient
|
||||
|
||||
DUMMY_STRING = "hello"
|
||||
DUMMY_STRING2 = "world"
|
||||
|
@ -203,7 +206,14 @@ def test_embedding_truncation_error(
|
|||
):
|
||||
if inference_provider_type not in SUPPORTED_PROVIDERS:
|
||||
pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet")
|
||||
with pytest.raises(BadRequestError):
|
||||
# Using LlamaStackClient from llama_stack_client will raise llama_stack_client.BadRequestError
|
||||
# While using LlamaStackAsLibraryClient from llama_stack.distribution.library_client will raise the error that the backend raises
|
||||
error_type = (
|
||||
OpenAIBadRequestError
|
||||
if isinstance(llama_stack_client, LlamaStackAsLibraryClient)
|
||||
else LlamaStackBadRequestError
|
||||
)
|
||||
with pytest.raises(error_type):
|
||||
llama_stack_client.inference.embeddings(
|
||||
model_id=embedding_model_id,
|
||||
contents=[DUMMY_LONG_TEXT],
|
||||
|
@ -283,7 +293,8 @@ def test_embedding_text_truncation_error(
|
|||
):
|
||||
if inference_provider_type not in SUPPORTED_PROVIDERS:
|
||||
pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet")
|
||||
with pytest.raises(BadRequestError):
|
||||
error_type = ValueError if isinstance(llama_stack_client, LlamaStackAsLibraryClient) else LlamaStackBadRequestError
|
||||
with pytest.raises(error_type):
|
||||
llama_stack_client.inference.embeddings(
|
||||
model_id=embedding_model_id,
|
||||
contents=[DUMMY_STRING],
|
||||
|
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
@ -48,19 +47,6 @@ def _load_all_verification_configs():
|
|||
return {"providers": all_provider_configs}
|
||||
|
||||
|
||||
def case_id_generator(case):
|
||||
"""Generate a test ID from the case's 'case_id' field, or use a default."""
|
||||
case_id = case.get("case_id")
|
||||
if isinstance(case_id, str | int):
|
||||
return re.sub(r"\\W|^(?=\\d)", "_", str(case_id))
|
||||
return None
|
||||
|
||||
|
||||
# Helper to get the base test name from the request object
|
||||
def get_base_test_name(request):
|
||||
return request.node.originalname
|
||||
|
||||
|
||||
# --- End Helper Functions ---
|
||||
|
||||
|
||||
|
@ -127,8 +113,6 @@ def openai_client(base_url, api_key, provider):
|
|||
raise ValueError(f"Invalid config for Llama Stack: {provider}, it must be of the form 'stack:<config>'")
|
||||
config = parts[1]
|
||||
client = LlamaStackAsLibraryClient(config, skip_logger_removal=True)
|
||||
if not client.initialize():
|
||||
raise RuntimeError("Initialization failed")
|
||||
return client
|
||||
|
||||
return OpenAI(
|
||||
|
|
|
@ -1,16 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
|
||||
|
||||
def load_test_cases(name: str):
|
||||
fixture_dir = Path(__file__).parent / "test_cases"
|
||||
yaml_path = fixture_dir / f"{name}.yaml"
|
||||
with open(yaml_path) as f:
|
||||
return yaml.safe_load(f)
|
262
tests/integration/non_ci/responses/fixtures/test_cases.py
Normal file
262
tests/integration/non_ci/responses/fixtures/test_cases.py
Normal file
|
@ -0,0 +1,262 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ResponsesTestCase(BaseModel):
|
||||
# Input can be a simple string or complex message structure
|
||||
input: str | list[dict[str, Any]]
|
||||
expected: str
|
||||
# Tools as flexible dict structure (gets validated at runtime by the API)
|
||||
tools: list[dict[str, Any]] | None = None
|
||||
# Multi-turn conversations with input/output pairs
|
||||
turns: list[tuple[str | list[dict[str, Any]], str]] | None = None
|
||||
# File search specific fields
|
||||
file_content: str | None = None
|
||||
file_path: str | None = None
|
||||
# Streaming flag
|
||||
stream: bool | None = None
|
||||
|
||||
|
||||
# Basic response test cases
|
||||
basic_test_cases = [
|
||||
pytest.param(
|
||||
ResponsesTestCase(
|
||||
input="Which planet do humans live on?",
|
||||
expected="earth",
|
||||
),
|
||||
id="earth",
|
||||
),
|
||||
pytest.param(
|
||||
ResponsesTestCase(
|
||||
input="Which planet has rings around it with a name starting with letter S?",
|
||||
expected="saturn",
|
||||
),
|
||||
id="saturn",
|
||||
),
|
||||
pytest.param(
|
||||
ResponsesTestCase(
|
||||
input=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "input_text",
|
||||
"text": "what teams are playing in this image?",
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "input_image",
|
||||
"image_url": "https://upload.wikimedia.org/wikipedia/commons/3/3b/LeBron_James_Layup_%28Cleveland_vs_Brooklyn_2018%29.jpg",
|
||||
}
|
||||
],
|
||||
},
|
||||
],
|
||||
expected="brooklyn nets",
|
||||
),
|
||||
id="image_input",
|
||||
),
|
||||
]
|
||||
|
||||
# Multi-turn test cases
|
||||
multi_turn_test_cases = [
|
||||
pytest.param(
|
||||
ResponsesTestCase(
|
||||
input="", # Not used for multi-turn
|
||||
expected="", # Not used for multi-turn
|
||||
turns=[
|
||||
("Which planet do humans live on?", "earth"),
|
||||
("What is the name of the planet from your previous response?", "earth"),
|
||||
],
|
||||
),
|
||||
id="earth",
|
||||
),
|
||||
]
|
||||
|
||||
# Web search test cases
|
||||
web_search_test_cases = [
|
||||
pytest.param(
|
||||
ResponsesTestCase(
|
||||
input="How many experts does the Llama 4 Maverick model have?",
|
||||
tools=[{"type": "web_search", "search_context_size": "low"}],
|
||||
expected="128",
|
||||
),
|
||||
id="llama_experts",
|
||||
),
|
||||
]
|
||||
|
||||
# File search test cases
|
||||
file_search_test_cases = [
|
||||
pytest.param(
|
||||
ResponsesTestCase(
|
||||
input="How many experts does the Llama 4 Maverick model have?",
|
||||
tools=[{"type": "file_search"}],
|
||||
expected="128",
|
||||
file_content="Llama 4 Maverick has 128 experts",
|
||||
),
|
||||
id="llama_experts",
|
||||
),
|
||||
pytest.param(
|
||||
ResponsesTestCase(
|
||||
input="How many experts does the Llama 4 Maverick model have?",
|
||||
tools=[{"type": "file_search"}],
|
||||
expected="128",
|
||||
file_path="pdfs/llama_stack_and_models.pdf",
|
||||
),
|
||||
id="llama_experts_pdf",
|
||||
),
|
||||
]
|
||||
|
||||
# MCP tool test cases
|
||||
mcp_tool_test_cases = [
|
||||
pytest.param(
|
||||
ResponsesTestCase(
|
||||
input="What is the boiling point of myawesomeliquid in Celsius?",
|
||||
tools=[{"type": "mcp", "server_label": "localmcp", "server_url": "<FILLED_BY_TEST_RUNNER>"}],
|
||||
expected="Hello, world!",
|
||||
),
|
||||
id="boiling_point_tool",
|
||||
),
|
||||
]
|
||||
|
||||
# Custom tool test cases
|
||||
custom_tool_test_cases = [
|
||||
pytest.param(
|
||||
ResponsesTestCase(
|
||||
input="What's the weather like in San Francisco?",
|
||||
tools=[
|
||||
{
|
||||
"type": "function",
|
||||
"name": "get_weather",
|
||||
"description": "Get current temperature for a given location.",
|
||||
"parameters": {
|
||||
"additionalProperties": False,
|
||||
"properties": {
|
||||
"location": {
|
||||
"description": "City and country e.g. Bogotá, Colombia",
|
||||
"type": "string",
|
||||
}
|
||||
},
|
||||
"required": ["location"],
|
||||
"type": "object",
|
||||
},
|
||||
}
|
||||
],
|
||||
expected="", # No specific expected output for custom tools
|
||||
),
|
||||
id="sf_weather",
|
||||
),
|
||||
]
|
||||
|
||||
# Image test cases
|
||||
image_test_cases = [
|
||||
pytest.param(
|
||||
ResponsesTestCase(
|
||||
input=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "input_text",
|
||||
"text": "Identify the type of animal in this image.",
|
||||
},
|
||||
{
|
||||
"type": "input_image",
|
||||
"image_url": "https://upload.wikimedia.org/wikipedia/commons/f/f7/Llamas%2C_Vernagt-Stausee%2C_Italy.jpg",
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
expected="llama",
|
||||
),
|
||||
id="llama_image",
|
||||
),
|
||||
]
|
||||
|
||||
# Multi-turn image test cases
|
||||
multi_turn_image_test_cases = [
|
||||
pytest.param(
|
||||
ResponsesTestCase(
|
||||
input="", # Not used for multi-turn
|
||||
expected="", # Not used for multi-turn
|
||||
turns=[
|
||||
(
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "input_text",
|
||||
"text": "What type of animal is in this image? Please respond with a single word that starts with the letter 'L'.",
|
||||
},
|
||||
{
|
||||
"type": "input_image",
|
||||
"image_url": "https://upload.wikimedia.org/wikipedia/commons/f/f7/Llamas%2C_Vernagt-Stausee%2C_Italy.jpg",
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
"llama",
|
||||
),
|
||||
(
|
||||
"What country do you find this animal primarily in? What continent?",
|
||||
"peru",
|
||||
),
|
||||
],
|
||||
),
|
||||
id="llama_image_understanding",
|
||||
),
|
||||
]
|
||||
|
||||
# Multi-turn tool execution test cases
|
||||
multi_turn_tool_execution_test_cases = [
|
||||
pytest.param(
|
||||
ResponsesTestCase(
|
||||
input="I need to check if user 'alice' can access the file 'document.txt'. First, get alice's user ID, then check if that user ID can access the file 'document.txt'. Do this as a series of steps, where each step is a separate message. Return only one tool call per step. Summarize the final result with a single 'yes' or 'no' response.",
|
||||
tools=[{"type": "mcp", "server_label": "localmcp", "server_url": "<FILLED_BY_TEST_RUNNER>"}],
|
||||
expected="yes",
|
||||
),
|
||||
id="user_file_access_check",
|
||||
),
|
||||
pytest.param(
|
||||
ResponsesTestCase(
|
||||
input="I need to get the results for the 'boiling_point' experiment. First, get the experiment ID for 'boiling_point', then use that ID to get the experiment results. Tell me the boiling point in Celsius.",
|
||||
tools=[{"type": "mcp", "server_label": "localmcp", "server_url": "<FILLED_BY_TEST_RUNNER>"}],
|
||||
expected="100°C",
|
||||
),
|
||||
id="experiment_results_lookup",
|
||||
),
|
||||
]
|
||||
|
||||
# Multi-turn tool execution streaming test cases
|
||||
multi_turn_tool_execution_streaming_test_cases = [
|
||||
pytest.param(
|
||||
ResponsesTestCase(
|
||||
input="Help me with this security check: First, get the user ID for 'charlie', then get the permissions for that user ID, and finally check if that user can access 'secret_file.txt'. Stream your progress as you work through each step. Return only one tool call per step. Summarize the final result with a single 'yes' or 'no' response.",
|
||||
tools=[{"type": "mcp", "server_label": "localmcp", "server_url": "<FILLED_BY_TEST_RUNNER>"}],
|
||||
expected="no",
|
||||
stream=True,
|
||||
),
|
||||
id="user_permissions_workflow",
|
||||
),
|
||||
pytest.param(
|
||||
ResponsesTestCase(
|
||||
input="I need a complete analysis: First, get the experiment ID for 'chemical_reaction', then get the results for that experiment, and tell me if the yield was above 80%. Return only one tool call per step. Please stream your analysis process.",
|
||||
tools=[{"type": "mcp", "server_label": "localmcp", "server_url": "<FILLED_BY_TEST_RUNNER>"}],
|
||||
expected="85%",
|
||||
stream=True,
|
||||
),
|
||||
id="experiment_analysis_streaming",
|
||||
),
|
||||
]
|
|
@ -1,397 +0,0 @@
|
|||
test_chat_basic:
|
||||
test_name: test_chat_basic
|
||||
test_params:
|
||||
case:
|
||||
- case_id: "earth"
|
||||
input:
|
||||
messages:
|
||||
- content: Which planet do humans live on?
|
||||
role: user
|
||||
output: Earth
|
||||
- case_id: "saturn"
|
||||
input:
|
||||
messages:
|
||||
- content: Which planet has rings around it with a name starting with letter
|
||||
S?
|
||||
role: user
|
||||
output: Saturn
|
||||
test_chat_input_validation:
|
||||
test_name: test_chat_input_validation
|
||||
test_params:
|
||||
case:
|
||||
- case_id: "messages_missing"
|
||||
input:
|
||||
messages: []
|
||||
output:
|
||||
error:
|
||||
status_code: 400
|
||||
- case_id: "messages_role_invalid"
|
||||
input:
|
||||
messages:
|
||||
- content: Which planet do humans live on?
|
||||
role: fake_role
|
||||
output:
|
||||
error:
|
||||
status_code: 400
|
||||
- case_id: "tool_choice_invalid"
|
||||
input:
|
||||
messages:
|
||||
- content: Which planet do humans live on?
|
||||
role: user
|
||||
tool_choice: invalid
|
||||
output:
|
||||
error:
|
||||
status_code: 400
|
||||
- case_id: "tool_choice_no_tools"
|
||||
input:
|
||||
messages:
|
||||
- content: Which planet do humans live on?
|
||||
role: user
|
||||
tool_choice: required
|
||||
output:
|
||||
error:
|
||||
status_code: 400
|
||||
- case_id: "tools_type_invalid"
|
||||
input:
|
||||
messages:
|
||||
- content: Which planet do humans live on?
|
||||
role: user
|
||||
tools:
|
||||
- type: invalid
|
||||
output:
|
||||
error:
|
||||
status_code: 400
|
||||
test_chat_image:
|
||||
test_name: test_chat_image
|
||||
test_params:
|
||||
case:
|
||||
- input:
|
||||
messages:
|
||||
- content:
|
||||
- text: What is in this image?
|
||||
type: text
|
||||
- image_url:
|
||||
url: https://upload.wikimedia.org/wikipedia/commons/f/f7/Llamas%2C_Vernagt-Stausee%2C_Italy.jpg
|
||||
type: image_url
|
||||
role: user
|
||||
output: llama
|
||||
test_chat_structured_output:
|
||||
test_name: test_chat_structured_output
|
||||
test_params:
|
||||
case:
|
||||
- case_id: "calendar"
|
||||
input:
|
||||
messages:
|
||||
- content: Extract the event information.
|
||||
role: system
|
||||
- content: Alice and Bob are going to a science fair on Friday.
|
||||
role: user
|
||||
response_format:
|
||||
json_schema:
|
||||
name: calendar_event
|
||||
schema:
|
||||
properties:
|
||||
date:
|
||||
title: Date
|
||||
type: string
|
||||
name:
|
||||
title: Name
|
||||
type: string
|
||||
participants:
|
||||
items:
|
||||
type: string
|
||||
title: Participants
|
||||
type: array
|
||||
required:
|
||||
- name
|
||||
- date
|
||||
- participants
|
||||
title: CalendarEvent
|
||||
type: object
|
||||
type: json_schema
|
||||
output: valid_calendar_event
|
||||
- case_id: "math"
|
||||
input:
|
||||
messages:
|
||||
- content: You are a helpful math tutor. Guide the user through the solution
|
||||
step by step.
|
||||
role: system
|
||||
- content: how can I solve 8x + 7 = -23
|
||||
role: user
|
||||
response_format:
|
||||
json_schema:
|
||||
name: math_reasoning
|
||||
schema:
|
||||
$defs:
|
||||
Step:
|
||||
properties:
|
||||
explanation:
|
||||
title: Explanation
|
||||
type: string
|
||||
output:
|
||||
title: Output
|
||||
type: string
|
||||
required:
|
||||
- explanation
|
||||
- output
|
||||
title: Step
|
||||
type: object
|
||||
properties:
|
||||
final_answer:
|
||||
title: Final Answer
|
||||
type: string
|
||||
steps:
|
||||
items:
|
||||
$ref: '#/$defs/Step'
|
||||
title: Steps
|
||||
type: array
|
||||
required:
|
||||
- steps
|
||||
- final_answer
|
||||
title: MathReasoning
|
||||
type: object
|
||||
type: json_schema
|
||||
output: valid_math_reasoning
|
||||
test_tool_calling:
|
||||
test_name: test_tool_calling
|
||||
test_params:
|
||||
case:
|
||||
- input:
|
||||
messages:
|
||||
- content: You are a helpful assistant that can use tools to get information.
|
||||
role: system
|
||||
- content: What's the weather like in San Francisco?
|
||||
role: user
|
||||
tools:
|
||||
- function:
|
||||
description: Get current temperature for a given location.
|
||||
name: get_weather
|
||||
parameters:
|
||||
additionalProperties: false
|
||||
properties:
|
||||
location:
|
||||
description: "City and country e.g. Bogot\xE1, Colombia"
|
||||
type: string
|
||||
required:
|
||||
- location
|
||||
type: object
|
||||
type: function
|
||||
output: get_weather_tool_call
|
||||
|
||||
test_chat_multi_turn_tool_calling:
|
||||
test_name: test_chat_multi_turn_tool_calling
|
||||
test_params:
|
||||
case:
|
||||
- case_id: "text_then_weather_tool"
|
||||
input:
|
||||
messages:
|
||||
- - role: user
|
||||
content: "What's the name of the Sun in latin?"
|
||||
- - role: user
|
||||
content: "What's the weather like in San Francisco?"
|
||||
tools:
|
||||
- function:
|
||||
description: Get the current weather
|
||||
name: get_weather
|
||||
parameters:
|
||||
type: object
|
||||
properties:
|
||||
location:
|
||||
description: "The city and state (both required), e.g. San Francisco, CA."
|
||||
type: string
|
||||
required: ["location"]
|
||||
type: function
|
||||
tool_responses:
|
||||
- response: "{'response': '70 degrees and foggy'}"
|
||||
expected:
|
||||
- num_tool_calls: 0
|
||||
answer: ["sol"]
|
||||
- num_tool_calls: 1
|
||||
tool_name: get_weather
|
||||
tool_arguments:
|
||||
location: "San Francisco, CA"
|
||||
- num_tool_calls: 0
|
||||
answer: ["foggy", "70 degrees"]
|
||||
- case_id: "weather_tool_then_text"
|
||||
input:
|
||||
messages:
|
||||
- - role: user
|
||||
content: "What's the weather like in San Francisco?"
|
||||
tools:
|
||||
- function:
|
||||
description: Get the current weather
|
||||
name: get_weather
|
||||
parameters:
|
||||
type: object
|
||||
properties:
|
||||
location:
|
||||
description: "The city and state (both required), e.g. San Francisco, CA."
|
||||
type: string
|
||||
required: ["location"]
|
||||
type: function
|
||||
tool_responses:
|
||||
- response: "{'response': '70 degrees and foggy'}"
|
||||
expected:
|
||||
- num_tool_calls: 1
|
||||
tool_name: get_weather
|
||||
tool_arguments:
|
||||
location: "San Francisco, CA"
|
||||
- num_tool_calls: 0
|
||||
answer: ["foggy", "70 degrees"]
|
||||
- case_id: "add_product_tool"
|
||||
input:
|
||||
messages:
|
||||
- - role: user
|
||||
content: "Please add a new product with name 'Widget', price 19.99, in stock, and tags ['new', 'sale'] and give me the product id."
|
||||
tools:
|
||||
- function:
|
||||
description: Add a new product
|
||||
name: addProduct
|
||||
parameters:
|
||||
type: object
|
||||
properties:
|
||||
name:
|
||||
description: "Name of the product"
|
||||
type: string
|
||||
price:
|
||||
description: "Price of the product"
|
||||
type: number
|
||||
inStock:
|
||||
description: "Availability status of the product."
|
||||
type: boolean
|
||||
tags:
|
||||
description: "List of product tags"
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
required: ["name", "price", "inStock"]
|
||||
type: function
|
||||
tool_responses:
|
||||
- response: "{'response': 'Successfully added product with id: 123'}"
|
||||
expected:
|
||||
- num_tool_calls: 1
|
||||
tool_name: addProduct
|
||||
tool_arguments:
|
||||
name: "Widget"
|
||||
price: 19.99
|
||||
inStock: true
|
||||
tags:
|
||||
- "new"
|
||||
- "sale"
|
||||
- num_tool_calls: 0
|
||||
answer: ["123", "product id: 123"]
|
||||
- case_id: "get_then_create_event_tool"
|
||||
input:
|
||||
messages:
|
||||
- - role: system
|
||||
content: "Todays date is 2025-03-01."
|
||||
- role: user
|
||||
content: "Do i have any meetings on March 3rd at 10 am? Yes or no?"
|
||||
- - role: user
|
||||
content: "Alright then, Create an event named 'Team Building', scheduled for that time same time, in the 'Main Conference Room' and add Alice, Bob, Charlie to it. Give me the created event id."
|
||||
tools:
|
||||
- function:
|
||||
description: Create a new event
|
||||
name: create_event
|
||||
parameters:
|
||||
type: object
|
||||
properties:
|
||||
name:
|
||||
description: "Name of the event"
|
||||
type: string
|
||||
date:
|
||||
description: "Date of the event in ISO format"
|
||||
type: string
|
||||
time:
|
||||
description: "Event Time (HH:MM)"
|
||||
type: string
|
||||
location:
|
||||
description: "Location of the event"
|
||||
type: string
|
||||
participants:
|
||||
description: "List of participant names"
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
required: ["name", "date", "time", "location", "participants"]
|
||||
type: function
|
||||
- function:
|
||||
description: Get an event by date and time
|
||||
name: get_event
|
||||
parameters:
|
||||
type: object
|
||||
properties:
|
||||
date:
|
||||
description: "Date of the event in ISO format"
|
||||
type: string
|
||||
time:
|
||||
description: "Event Time (HH:MM)"
|
||||
type: string
|
||||
required: ["date", "time"]
|
||||
type: function
|
||||
tool_responses:
|
||||
- response: "{'response': 'No events found for 2025-03-03 at 10:00'}"
|
||||
- response: "{'response': 'Successfully created new event with id: e_123'}"
|
||||
expected:
|
||||
- num_tool_calls: 1
|
||||
tool_name: get_event
|
||||
tool_arguments:
|
||||
date: "2025-03-03"
|
||||
time: "10:00"
|
||||
- num_tool_calls: 0
|
||||
answer: ["no", "no events found", "no meetings"]
|
||||
- num_tool_calls: 1
|
||||
tool_name: create_event
|
||||
tool_arguments:
|
||||
name: "Team Building"
|
||||
date: "2025-03-03"
|
||||
time: "10:00"
|
||||
location: "Main Conference Room"
|
||||
participants:
|
||||
- "Alice"
|
||||
- "Bob"
|
||||
- "Charlie"
|
||||
- num_tool_calls: 0
|
||||
answer: ["e_123", "event id: e_123"]
|
||||
- case_id: "compare_monthly_expense_tool"
|
||||
input:
|
||||
messages:
|
||||
- - role: system
|
||||
content: "Todays date is 2025-03-01."
|
||||
- role: user
|
||||
content: "what was my monthly expense in Jan of this year?"
|
||||
- - role: user
|
||||
content: "Was it less than Feb of last year? Only answer with yes or no."
|
||||
tools:
|
||||
- function:
|
||||
description: Get monthly expense summary
|
||||
name: getMonthlyExpenseSummary
|
||||
parameters:
|
||||
type: object
|
||||
properties:
|
||||
month:
|
||||
description: "Month of the year (1-12)"
|
||||
type: integer
|
||||
year:
|
||||
description: "Year"
|
||||
type: integer
|
||||
required: ["month", "year"]
|
||||
type: function
|
||||
tool_responses:
|
||||
- response: "{'response': 'Total expenses for January 2025: $1000'}"
|
||||
- response: "{'response': 'Total expenses for February 2024: $2000'}"
|
||||
expected:
|
||||
- num_tool_calls: 1
|
||||
tool_name: getMonthlyExpenseSummary
|
||||
tool_arguments:
|
||||
month: 1
|
||||
year: 2025
|
||||
- num_tool_calls: 0
|
||||
answer: ["1000", "$1,000", "1,000"]
|
||||
- num_tool_calls: 1
|
||||
tool_name: getMonthlyExpenseSummary
|
||||
tool_arguments:
|
||||
month: 2
|
||||
year: 2024
|
||||
- num_tool_calls: 0
|
||||
answer: ["yes"]
|
|
@ -1,166 +0,0 @@
|
|||
test_response_basic:
|
||||
test_name: test_response_basic
|
||||
test_params:
|
||||
case:
|
||||
- case_id: "earth"
|
||||
input: "Which planet do humans live on?"
|
||||
output: "earth"
|
||||
- case_id: "saturn"
|
||||
input: "Which planet has rings around it with a name starting with letter S?"
|
||||
output: "saturn"
|
||||
- case_id: "image_input"
|
||||
input:
|
||||
- role: user
|
||||
content:
|
||||
- type: input_text
|
||||
text: "what teams are playing in this image?"
|
||||
- role: user
|
||||
content:
|
||||
- type: input_image
|
||||
image_url: "https://upload.wikimedia.org/wikipedia/commons/3/3b/LeBron_James_Layup_%28Cleveland_vs_Brooklyn_2018%29.jpg"
|
||||
output: "brooklyn nets"
|
||||
|
||||
test_response_multi_turn:
|
||||
test_name: test_response_multi_turn
|
||||
test_params:
|
||||
case:
|
||||
- case_id: "earth"
|
||||
turns:
|
||||
- input: "Which planet do humans live on?"
|
||||
output: "earth"
|
||||
- input: "What is the name of the planet from your previous response?"
|
||||
output: "earth"
|
||||
|
||||
test_response_web_search:
|
||||
test_name: test_response_web_search
|
||||
test_params:
|
||||
case:
|
||||
- case_id: "llama_experts"
|
||||
input: "How many experts does the Llama 4 Maverick model have?"
|
||||
tools:
|
||||
- type: web_search
|
||||
search_context_size: "low"
|
||||
output: "128"
|
||||
|
||||
test_response_file_search:
|
||||
test_name: test_response_file_search
|
||||
test_params:
|
||||
case:
|
||||
- case_id: "llama_experts"
|
||||
input: "How many experts does the Llama 4 Maverick model have?"
|
||||
tools:
|
||||
- type: file_search
|
||||
# vector_store_ids param for file_search tool gets added by the test runner
|
||||
file_content: "Llama 4 Maverick has 128 experts"
|
||||
output: "128"
|
||||
- case_id: "llama_experts_pdf"
|
||||
input: "How many experts does the Llama 4 Maverick model have?"
|
||||
tools:
|
||||
- type: file_search
|
||||
# vector_store_ids param for file_search toolgets added by the test runner
|
||||
file_path: "pdfs/llama_stack_and_models.pdf"
|
||||
output: "128"
|
||||
|
||||
test_response_mcp_tool:
|
||||
test_name: test_response_mcp_tool
|
||||
test_params:
|
||||
case:
|
||||
- case_id: "boiling_point_tool"
|
||||
input: "What is the boiling point of myawesomeliquid in Celsius?"
|
||||
tools:
|
||||
- type: mcp
|
||||
server_label: "localmcp"
|
||||
server_url: "<FILLED_BY_TEST_RUNNER>"
|
||||
output: "Hello, world!"
|
||||
|
||||
test_response_custom_tool:
|
||||
test_name: test_response_custom_tool
|
||||
test_params:
|
||||
case:
|
||||
- case_id: "sf_weather"
|
||||
input: "What's the weather like in San Francisco?"
|
||||
tools:
|
||||
- type: function
|
||||
name: get_weather
|
||||
description: Get current temperature for a given location.
|
||||
parameters:
|
||||
additionalProperties: false
|
||||
properties:
|
||||
location:
|
||||
description: "City and country e.g. Bogot\xE1, Colombia"
|
||||
type: string
|
||||
required:
|
||||
- location
|
||||
type: object
|
||||
|
||||
test_response_image:
|
||||
test_name: test_response_image
|
||||
test_params:
|
||||
case:
|
||||
- case_id: "llama_image"
|
||||
input:
|
||||
- role: user
|
||||
content:
|
||||
- type: input_text
|
||||
text: "Identify the type of animal in this image."
|
||||
- type: input_image
|
||||
image_url: "https://upload.wikimedia.org/wikipedia/commons/f/f7/Llamas%2C_Vernagt-Stausee%2C_Italy.jpg"
|
||||
output: "llama"
|
||||
|
||||
# the models are really poor at tool calling after seeing images :/
|
||||
test_response_multi_turn_image:
|
||||
test_name: test_response_multi_turn_image
|
||||
test_params:
|
||||
case:
|
||||
- case_id: "llama_image_understanding"
|
||||
turns:
|
||||
- input:
|
||||
- role: user
|
||||
content:
|
||||
- type: input_text
|
||||
text: "What type of animal is in this image? Please respond with a single word that starts with the letter 'L'."
|
||||
- type: input_image
|
||||
image_url: "https://upload.wikimedia.org/wikipedia/commons/f/f7/Llamas%2C_Vernagt-Stausee%2C_Italy.jpg"
|
||||
output: "llama"
|
||||
- input: "What country do you find this animal primarily in? What continent?"
|
||||
output: "peru"
|
||||
|
||||
test_response_multi_turn_tool_execution:
|
||||
test_name: test_response_multi_turn_tool_execution
|
||||
test_params:
|
||||
case:
|
||||
- case_id: "user_file_access_check"
|
||||
input: "I need to check if user 'alice' can access the file 'document.txt'. First, get alice's user ID, then check if that user ID can access the file 'document.txt'. Do this as a series of steps, where each step is a separate message. Return only one tool call per step. Summarize the final result with a single 'yes' or 'no' response."
|
||||
tools:
|
||||
- type: mcp
|
||||
server_label: "localmcp"
|
||||
server_url: "<FILLED_BY_TEST_RUNNER>"
|
||||
output: "yes"
|
||||
- case_id: "experiment_results_lookup"
|
||||
input: "I need to get the results for the 'boiling_point' experiment. First, get the experiment ID for 'boiling_point', then use that ID to get the experiment results. Tell me the boiling point in Celsius."
|
||||
tools:
|
||||
- type: mcp
|
||||
server_label: "localmcp"
|
||||
server_url: "<FILLED_BY_TEST_RUNNER>"
|
||||
output: "100°C"
|
||||
|
||||
test_response_multi_turn_tool_execution_streaming:
|
||||
test_name: test_response_multi_turn_tool_execution_streaming
|
||||
test_params:
|
||||
case:
|
||||
- case_id: "user_permissions_workflow"
|
||||
input: "Help me with this security check: First, get the user ID for 'charlie', then get the permissions for that user ID, and finally check if that user can access 'secret_file.txt'. Stream your progress as you work through each step. Return only one tool call per step. Summarize the final result with a single 'yes' or 'no' response."
|
||||
tools:
|
||||
- type: mcp
|
||||
server_label: "localmcp"
|
||||
server_url: "<FILLED_BY_TEST_RUNNER>"
|
||||
stream: true
|
||||
output: "no"
|
||||
- case_id: "experiment_analysis_streaming"
|
||||
input: "I need a complete analysis: First, get the experiment ID for 'chemical_reaction', then get the results for that experiment, and tell me if the yield was above 80%. Return only one tool call per step. Please stream your analysis process."
|
||||
tools:
|
||||
- type: mcp
|
||||
server_label: "localmcp"
|
||||
server_url: "<FILLED_BY_TEST_RUNNER>"
|
||||
stream: true
|
||||
output: "85%"
|
64
tests/integration/non_ci/responses/helpers.py
Normal file
64
tests/integration/non_ci/responses/helpers.py
Normal file
|
@ -0,0 +1,64 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import time
|
||||
|
||||
|
||||
def new_vector_store(openai_client, name):
|
||||
"""Create a new vector store, cleaning up any existing one with the same name."""
|
||||
# Ensure we don't reuse an existing vector store
|
||||
vector_stores = openai_client.vector_stores.list()
|
||||
for vector_store in vector_stores:
|
||||
if vector_store.name == name:
|
||||
openai_client.vector_stores.delete(vector_store_id=vector_store.id)
|
||||
|
||||
# Create a new vector store
|
||||
vector_store = openai_client.vector_stores.create(name=name)
|
||||
return vector_store
|
||||
|
||||
|
||||
def upload_file(openai_client, name, file_path):
|
||||
"""Upload a file, cleaning up any existing file with the same name."""
|
||||
# Ensure we don't reuse an existing file
|
||||
files = openai_client.files.list()
|
||||
for file in files:
|
||||
if file.filename == name:
|
||||
openai_client.files.delete(file_id=file.id)
|
||||
|
||||
# Upload a text file with our document content
|
||||
return openai_client.files.create(file=open(file_path, "rb"), purpose="assistants")
|
||||
|
||||
|
||||
def wait_for_file_attachment(compat_client, vector_store_id, file_id):
|
||||
"""Wait for a file to be attached to a vector store."""
|
||||
file_attach_response = compat_client.vector_stores.files.retrieve(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
)
|
||||
|
||||
while file_attach_response.status == "in_progress":
|
||||
time.sleep(0.1)
|
||||
file_attach_response = compat_client.vector_stores.files.retrieve(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
)
|
||||
|
||||
assert file_attach_response.status == "completed", f"Expected file to be attached, got {file_attach_response}"
|
||||
assert not file_attach_response.last_error
|
||||
return file_attach_response
|
||||
|
||||
|
||||
def setup_mcp_tools(tools, mcp_server_info):
|
||||
"""Replace placeholder MCP server URLs with actual server info."""
|
||||
# Create a deep copy to avoid modifying the original test case
|
||||
import copy
|
||||
|
||||
tools_copy = copy.deepcopy(tools)
|
||||
|
||||
for tool in tools_copy:
|
||||
if tool["type"] == "mcp" and tool["server_url"] == "<FILLED_BY_TEST_RUNNER>":
|
||||
tool["server_url"] = mcp_server_info["server_url"]
|
||||
return tools_copy
|
145
tests/integration/non_ci/responses/streaming_assertions.py
Normal file
145
tests/integration/non_ci/responses/streaming_assertions.py
Normal file
|
@ -0,0 +1,145 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
class StreamingValidator:
|
||||
"""Helper class for validating streaming response events."""
|
||||
|
||||
def __init__(self, chunks: list[Any]):
|
||||
self.chunks = chunks
|
||||
self.event_types = [chunk.type for chunk in chunks]
|
||||
|
||||
def assert_basic_event_sequence(self):
|
||||
"""Verify basic created -> completed event sequence."""
|
||||
assert len(self.chunks) >= 2, f"Expected at least 2 chunks (created + completed), got {len(self.chunks)}"
|
||||
assert self.chunks[0].type == "response.created", (
|
||||
f"First chunk should be response.created, got {self.chunks[0].type}"
|
||||
)
|
||||
assert self.chunks[-1].type == "response.completed", (
|
||||
f"Last chunk should be response.completed, got {self.chunks[-1].type}"
|
||||
)
|
||||
|
||||
# Verify event order
|
||||
created_index = self.event_types.index("response.created")
|
||||
completed_index = self.event_types.index("response.completed")
|
||||
assert created_index < completed_index, "response.created should come before response.completed"
|
||||
|
||||
def assert_response_consistency(self):
|
||||
"""Verify response ID consistency across events."""
|
||||
response_ids = set()
|
||||
for chunk in self.chunks:
|
||||
if hasattr(chunk, "response_id"):
|
||||
response_ids.add(chunk.response_id)
|
||||
elif hasattr(chunk, "response") and hasattr(chunk.response, "id"):
|
||||
response_ids.add(chunk.response.id)
|
||||
|
||||
assert len(response_ids) == 1, f"All events should reference the same response_id, found: {response_ids}"
|
||||
|
||||
def assert_has_incremental_content(self):
|
||||
"""Verify that content is delivered incrementally via delta events."""
|
||||
delta_events = [
|
||||
i for i, event_type in enumerate(self.event_types) if event_type == "response.output_text.delta"
|
||||
]
|
||||
assert len(delta_events) > 0, "Expected delta events for true incremental streaming, but found none"
|
||||
|
||||
# Verify delta events have content
|
||||
non_empty_deltas = 0
|
||||
delta_content_total = ""
|
||||
|
||||
for delta_idx in delta_events:
|
||||
chunk = self.chunks[delta_idx]
|
||||
if hasattr(chunk, "delta") and chunk.delta:
|
||||
delta_content_total += chunk.delta
|
||||
non_empty_deltas += 1
|
||||
|
||||
assert non_empty_deltas > 0, "Delta events found but none contain content"
|
||||
assert len(delta_content_total) > 0, "Delta events found but total delta content is empty"
|
||||
|
||||
return delta_content_total
|
||||
|
||||
def assert_content_quality(self, expected_content: str):
|
||||
"""Verify the final response contains expected content."""
|
||||
final_chunk = self.chunks[-1]
|
||||
if hasattr(final_chunk, "response"):
|
||||
output_text = final_chunk.response.output_text.lower().strip()
|
||||
assert len(output_text) > 0, "Response should have content"
|
||||
assert expected_content.lower() in output_text, f"Expected '{expected_content}' in response"
|
||||
|
||||
def assert_has_tool_calls(self):
|
||||
"""Verify tool call streaming events are present."""
|
||||
# Check for tool call events
|
||||
delta_events = [
|
||||
chunk
|
||||
for chunk in self.chunks
|
||||
if chunk.type in ["response.function_call_arguments.delta", "response.mcp_call.arguments.delta"]
|
||||
]
|
||||
done_events = [
|
||||
chunk
|
||||
for chunk in self.chunks
|
||||
if chunk.type in ["response.function_call_arguments.done", "response.mcp_call.arguments.done"]
|
||||
]
|
||||
|
||||
assert len(delta_events) > 0, f"Expected tool call delta events, got chunk types: {self.event_types}"
|
||||
assert len(done_events) > 0, f"Expected tool call done events, got chunk types: {self.event_types}"
|
||||
|
||||
# Verify output item events
|
||||
item_added_events = [chunk for chunk in self.chunks if chunk.type == "response.output_item.added"]
|
||||
item_done_events = [chunk for chunk in self.chunks if chunk.type == "response.output_item.done"]
|
||||
|
||||
assert len(item_added_events) > 0, (
|
||||
f"Expected response.output_item.added events, got chunk types: {self.event_types}"
|
||||
)
|
||||
assert len(item_done_events) > 0, (
|
||||
f"Expected response.output_item.done events, got chunk types: {self.event_types}"
|
||||
)
|
||||
|
||||
def assert_has_mcp_events(self):
|
||||
"""Verify MCP-specific streaming events are present."""
|
||||
# Tool execution progress events
|
||||
mcp_in_progress_events = [chunk for chunk in self.chunks if chunk.type == "response.mcp_call.in_progress"]
|
||||
mcp_completed_events = [chunk for chunk in self.chunks if chunk.type == "response.mcp_call.completed"]
|
||||
|
||||
assert len(mcp_in_progress_events) > 0, (
|
||||
f"Expected response.mcp_call.in_progress events, got chunk types: {self.event_types}"
|
||||
)
|
||||
assert len(mcp_completed_events) > 0, (
|
||||
f"Expected response.mcp_call.completed events, got chunk types: {self.event_types}"
|
||||
)
|
||||
|
||||
# MCP list tools events
|
||||
mcp_list_tools_in_progress_events = [
|
||||
chunk for chunk in self.chunks if chunk.type == "response.mcp_list_tools.in_progress"
|
||||
]
|
||||
mcp_list_tools_completed_events = [
|
||||
chunk for chunk in self.chunks if chunk.type == "response.mcp_list_tools.completed"
|
||||
]
|
||||
|
||||
assert len(mcp_list_tools_in_progress_events) > 0, (
|
||||
f"Expected response.mcp_list_tools.in_progress events, got chunk types: {self.event_types}"
|
||||
)
|
||||
assert len(mcp_list_tools_completed_events) > 0, (
|
||||
f"Expected response.mcp_list_tools.completed events, got chunk types: {self.event_types}"
|
||||
)
|
||||
|
||||
def assert_rich_streaming(self, min_chunks: int = 10):
|
||||
"""Verify we have substantial streaming activity."""
|
||||
assert len(self.chunks) > min_chunks, (
|
||||
f"Expected rich streaming with many events, got only {len(self.chunks)} chunks"
|
||||
)
|
||||
|
||||
def validate_event_structure(self):
|
||||
"""Validate the structure of various event types."""
|
||||
for chunk in self.chunks:
|
||||
if chunk.type == "response.created":
|
||||
assert chunk.response.status == "in_progress"
|
||||
elif chunk.type == "response.completed":
|
||||
assert chunk.response.status == "completed"
|
||||
elif hasattr(chunk, "item_id"):
|
||||
assert chunk.item_id, "Events with item_id should have non-empty item_id"
|
||||
elif hasattr(chunk, "sequence_number"):
|
||||
assert isinstance(chunk.sequence_number, int), "sequence_number should be an integer"
|
189
tests/integration/non_ci/responses/test_basic_responses.py
Normal file
189
tests/integration/non_ci/responses/test_basic_responses.py
Normal file
|
@ -0,0 +1,189 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
from .fixtures.test_cases import basic_test_cases, image_test_cases, multi_turn_image_test_cases, multi_turn_test_cases
|
||||
from .streaming_assertions import StreamingValidator
|
||||
|
||||
|
||||
@pytest.mark.parametrize("case", basic_test_cases)
|
||||
def test_response_non_streaming_basic(compat_client, text_model_id, case):
|
||||
response = compat_client.responses.create(
|
||||
model=text_model_id,
|
||||
input=case.input,
|
||||
stream=False,
|
||||
)
|
||||
output_text = response.output_text.lower().strip()
|
||||
assert len(output_text) > 0
|
||||
assert case.expected.lower() in output_text
|
||||
|
||||
retrieved_response = compat_client.responses.retrieve(response_id=response.id)
|
||||
assert retrieved_response.output_text == response.output_text
|
||||
|
||||
next_response = compat_client.responses.create(
|
||||
model=text_model_id,
|
||||
input="Repeat your previous response in all caps.",
|
||||
previous_response_id=response.id,
|
||||
)
|
||||
next_output_text = next_response.output_text.strip()
|
||||
assert case.expected.upper() in next_output_text
|
||||
|
||||
|
||||
@pytest.mark.parametrize("case", basic_test_cases)
|
||||
def test_response_streaming_basic(compat_client, text_model_id, case):
|
||||
response = compat_client.responses.create(
|
||||
model=text_model_id,
|
||||
input=case.input,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
# Track events and timing to verify proper streaming
|
||||
events = []
|
||||
event_times = []
|
||||
response_id = ""
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
for chunk in response:
|
||||
current_time = time.time()
|
||||
event_times.append(current_time - start_time)
|
||||
events.append(chunk)
|
||||
|
||||
if chunk.type == "response.created":
|
||||
# Verify response.created is emitted first and immediately
|
||||
assert len(events) == 1, "response.created should be the first event"
|
||||
assert event_times[0] < 0.1, "response.created should be emitted immediately"
|
||||
assert chunk.response.status == "in_progress"
|
||||
response_id = chunk.response.id
|
||||
|
||||
elif chunk.type == "response.completed":
|
||||
# Verify response.completed comes after response.created
|
||||
assert len(events) >= 2, "response.completed should come after response.created"
|
||||
assert chunk.response.status == "completed"
|
||||
assert chunk.response.id == response_id, "Response ID should be consistent"
|
||||
|
||||
# Verify content quality
|
||||
output_text = chunk.response.output_text.lower().strip()
|
||||
assert len(output_text) > 0, "Response should have content"
|
||||
assert case.expected.lower() in output_text, f"Expected '{case.expected}' in response"
|
||||
|
||||
# Use validator for common checks
|
||||
validator = StreamingValidator(events)
|
||||
validator.assert_basic_event_sequence()
|
||||
validator.assert_response_consistency()
|
||||
|
||||
# Verify stored response matches streamed response
|
||||
retrieved_response = compat_client.responses.retrieve(response_id=response_id)
|
||||
final_event = events[-1]
|
||||
assert retrieved_response.output_text == final_event.response.output_text
|
||||
|
||||
|
||||
@pytest.mark.parametrize("case", basic_test_cases)
|
||||
def test_response_streaming_incremental_content(compat_client, text_model_id, case):
|
||||
"""Test that streaming actually delivers content incrementally, not just at the end."""
|
||||
response = compat_client.responses.create(
|
||||
model=text_model_id,
|
||||
input=case.input,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
# Track all events and their content to verify incremental streaming
|
||||
events = []
|
||||
content_snapshots = []
|
||||
event_times = []
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
for chunk in response:
|
||||
current_time = time.time()
|
||||
event_times.append(current_time - start_time)
|
||||
events.append(chunk)
|
||||
|
||||
# Track content at each event based on event type
|
||||
if chunk.type == "response.output_text.delta":
|
||||
# For delta events, track the delta content
|
||||
content_snapshots.append(chunk.delta)
|
||||
elif hasattr(chunk, "response") and hasattr(chunk.response, "output_text"):
|
||||
# For response.created/completed events, track the full output_text
|
||||
content_snapshots.append(chunk.response.output_text)
|
||||
else:
|
||||
content_snapshots.append("")
|
||||
|
||||
validator = StreamingValidator(events)
|
||||
validator.assert_basic_event_sequence()
|
||||
|
||||
# Check if we have incremental content updates
|
||||
event_types = [event.type for event in events]
|
||||
created_index = event_types.index("response.created")
|
||||
completed_index = event_types.index("response.completed")
|
||||
|
||||
# The key test: verify content progression
|
||||
created_content = content_snapshots[created_index]
|
||||
completed_content = content_snapshots[completed_index]
|
||||
|
||||
# Verify that response.created has empty or minimal content
|
||||
assert len(created_content) == 0, f"response.created should have empty content, got: {repr(created_content[:100])}"
|
||||
|
||||
# Verify that response.completed has the full content
|
||||
assert len(completed_content) > 0, "response.completed should have content"
|
||||
assert case.expected.lower() in completed_content.lower(), f"Expected '{case.expected}' in final content"
|
||||
|
||||
# Use validator for incremental content checks
|
||||
delta_content_total = validator.assert_has_incremental_content()
|
||||
|
||||
# Verify that the accumulated delta content matches the final content
|
||||
assert delta_content_total.strip() == completed_content.strip(), (
|
||||
f"Delta content '{delta_content_total}' should match final content '{completed_content}'"
|
||||
)
|
||||
|
||||
# Verify timing: delta events should come between created and completed
|
||||
delta_events = [i for i, event_type in enumerate(event_types) if event_type == "response.output_text.delta"]
|
||||
for delta_idx in delta_events:
|
||||
assert created_index < delta_idx < completed_index, (
|
||||
f"Delta event at index {delta_idx} should be between created ({created_index}) and completed ({completed_index})"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("case", multi_turn_test_cases)
|
||||
def test_response_non_streaming_multi_turn(compat_client, text_model_id, case):
|
||||
previous_response_id = None
|
||||
for turn_input, turn_expected in case.turns:
|
||||
response = compat_client.responses.create(
|
||||
model=text_model_id,
|
||||
input=turn_input,
|
||||
previous_response_id=previous_response_id,
|
||||
)
|
||||
previous_response_id = response.id
|
||||
output_text = response.output_text.lower()
|
||||
assert turn_expected.lower() in output_text
|
||||
|
||||
|
||||
@pytest.mark.parametrize("case", image_test_cases)
|
||||
def test_response_non_streaming_image(compat_client, text_model_id, case):
|
||||
response = compat_client.responses.create(
|
||||
model=text_model_id,
|
||||
input=case.input,
|
||||
stream=False,
|
||||
)
|
||||
output_text = response.output_text.lower()
|
||||
assert case.expected.lower() in output_text
|
||||
|
||||
|
||||
@pytest.mark.parametrize("case", multi_turn_image_test_cases)
|
||||
def test_response_non_streaming_multi_turn_image(compat_client, text_model_id, case):
|
||||
previous_response_id = None
|
||||
for turn_input, turn_expected in case.turns:
|
||||
response = compat_client.responses.create(
|
||||
model=text_model_id,
|
||||
input=turn_input,
|
||||
previous_response_id=previous_response_id,
|
||||
)
|
||||
previous_response_id = response.id
|
||||
output_text = response.output_text.lower()
|
||||
assert turn_expected.lower() in output_text
|
318
tests/integration/non_ci/responses/test_file_search.py
Normal file
318
tests/integration/non_ci/responses/test_file_search.py
Normal file
|
@ -0,0 +1,318 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack import LlamaStackAsLibraryClient
|
||||
|
||||
from .helpers import new_vector_store, upload_file
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"text_format",
|
||||
# Not testing json_object because most providers don't actually support it.
|
||||
[
|
||||
{"type": "text"},
|
||||
{
|
||||
"type": "json_schema",
|
||||
"name": "capitals",
|
||||
"description": "A schema for the capital of each country",
|
||||
"schema": {"type": "object", "properties": {"capital": {"type": "string"}}},
|
||||
"strict": True,
|
||||
},
|
||||
],
|
||||
)
|
||||
def test_response_text_format(compat_client, text_model_id, text_format):
|
||||
if isinstance(compat_client, LlamaStackAsLibraryClient):
|
||||
pytest.skip("Responses API text format is not yet supported in library client.")
|
||||
|
||||
stream = False
|
||||
response = compat_client.responses.create(
|
||||
model=text_model_id,
|
||||
input="What is the capital of France?",
|
||||
stream=stream,
|
||||
text={"format": text_format},
|
||||
)
|
||||
# by_alias=True is needed because otherwise Pydantic renames our "schema" field
|
||||
assert response.text.format.model_dump(exclude_none=True, by_alias=True) == text_format
|
||||
assert "paris" in response.output_text.lower()
|
||||
if text_format["type"] == "json_schema":
|
||||
assert "paris" in json.loads(response.output_text)["capital"].lower()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vector_store_with_filtered_files(compat_client, text_model_id, tmp_path_factory):
|
||||
"""Create a vector store with multiple files that have different attributes for filtering tests."""
|
||||
if isinstance(compat_client, LlamaStackAsLibraryClient):
|
||||
pytest.skip("Responses API file search is not yet supported in library client.")
|
||||
|
||||
vector_store = new_vector_store(compat_client, "test_vector_store_with_filters")
|
||||
tmp_path = tmp_path_factory.mktemp("filter_test_files")
|
||||
|
||||
# Create multiple files with different attributes
|
||||
files_data = [
|
||||
{
|
||||
"name": "us_marketing_q1.txt",
|
||||
"content": "US promotional campaigns for Q1 2023. Revenue increased by 15% in the US region.",
|
||||
"attributes": {
|
||||
"region": "us",
|
||||
"category": "marketing",
|
||||
"date": 1672531200, # Jan 1, 2023
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "us_engineering_q2.txt",
|
||||
"content": "US technical updates for Q2 2023. New features deployed in the US region.",
|
||||
"attributes": {
|
||||
"region": "us",
|
||||
"category": "engineering",
|
||||
"date": 1680307200, # Apr 1, 2023
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "eu_marketing_q1.txt",
|
||||
"content": "European advertising campaign results for Q1 2023. Strong growth in EU markets.",
|
||||
"attributes": {
|
||||
"region": "eu",
|
||||
"category": "marketing",
|
||||
"date": 1672531200, # Jan 1, 2023
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "asia_sales_q3.txt",
|
||||
"content": "Asia Pacific revenue figures for Q3 2023. Record breaking quarter in Asia.",
|
||||
"attributes": {
|
||||
"region": "asia",
|
||||
"category": "sales",
|
||||
"date": 1688169600, # Jul 1, 2023
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
file_ids = []
|
||||
for file_data in files_data:
|
||||
# Create file
|
||||
file_path = tmp_path / file_data["name"]
|
||||
file_path.write_text(file_data["content"])
|
||||
|
||||
# Upload file
|
||||
file_response = upload_file(compat_client, file_data["name"], str(file_path))
|
||||
file_ids.append(file_response.id)
|
||||
|
||||
# Attach file to vector store with attributes
|
||||
file_attach_response = compat_client.vector_stores.files.create(
|
||||
vector_store_id=vector_store.id,
|
||||
file_id=file_response.id,
|
||||
attributes=file_data["attributes"],
|
||||
)
|
||||
|
||||
# Wait for attachment
|
||||
while file_attach_response.status == "in_progress":
|
||||
time.sleep(0.1)
|
||||
file_attach_response = compat_client.vector_stores.files.retrieve(
|
||||
vector_store_id=vector_store.id,
|
||||
file_id=file_response.id,
|
||||
)
|
||||
assert file_attach_response.status == "completed"
|
||||
|
||||
yield vector_store
|
||||
|
||||
# Cleanup: delete vector store and files
|
||||
try:
|
||||
compat_client.vector_stores.delete(vector_store_id=vector_store.id)
|
||||
for file_id in file_ids:
|
||||
try:
|
||||
compat_client.files.delete(file_id=file_id)
|
||||
except Exception:
|
||||
pass # File might already be deleted
|
||||
except Exception:
|
||||
pass # Best effort cleanup
|
||||
|
||||
|
||||
def test_response_file_search_filter_by_region(compat_client, text_model_id, vector_store_with_filtered_files):
|
||||
"""Test file search with region equality filter."""
|
||||
tools = [
|
||||
{
|
||||
"type": "file_search",
|
||||
"vector_store_ids": [vector_store_with_filtered_files.id],
|
||||
"filters": {"type": "eq", "key": "region", "value": "us"},
|
||||
}
|
||||
]
|
||||
|
||||
response = compat_client.responses.create(
|
||||
model=text_model_id,
|
||||
input="What are the updates from the US region?",
|
||||
tools=tools,
|
||||
stream=False,
|
||||
include=["file_search_call.results"],
|
||||
)
|
||||
|
||||
# Verify file search was called with US filter
|
||||
assert len(response.output) > 1
|
||||
assert response.output[0].type == "file_search_call"
|
||||
assert response.output[0].status == "completed"
|
||||
assert response.output[0].results
|
||||
# Should only return US files (not EU or Asia files)
|
||||
for result in response.output[0].results:
|
||||
assert "us" in result.text.lower() or "US" in result.text
|
||||
# Ensure non-US regions are NOT returned
|
||||
assert "european" not in result.text.lower()
|
||||
assert "asia" not in result.text.lower()
|
||||
|
||||
|
||||
def test_response_file_search_filter_by_category(compat_client, text_model_id, vector_store_with_filtered_files):
|
||||
"""Test file search with category equality filter."""
|
||||
tools = [
|
||||
{
|
||||
"type": "file_search",
|
||||
"vector_store_ids": [vector_store_with_filtered_files.id],
|
||||
"filters": {"type": "eq", "key": "category", "value": "marketing"},
|
||||
}
|
||||
]
|
||||
|
||||
response = compat_client.responses.create(
|
||||
model=text_model_id,
|
||||
input="Show me all marketing reports",
|
||||
tools=tools,
|
||||
stream=False,
|
||||
include=["file_search_call.results"],
|
||||
)
|
||||
|
||||
assert response.output[0].type == "file_search_call"
|
||||
assert response.output[0].status == "completed"
|
||||
assert response.output[0].results
|
||||
# Should only return marketing files (not engineering or sales)
|
||||
for result in response.output[0].results:
|
||||
# Marketing files should have promotional/advertising content
|
||||
assert "promotional" in result.text.lower() or "advertising" in result.text.lower()
|
||||
# Ensure non-marketing categories are NOT returned
|
||||
assert "technical" not in result.text.lower()
|
||||
assert "revenue figures" not in result.text.lower()
|
||||
|
||||
|
||||
def test_response_file_search_filter_by_date_range(compat_client, text_model_id, vector_store_with_filtered_files):
|
||||
"""Test file search with date range filter using compound AND."""
|
||||
tools = [
|
||||
{
|
||||
"type": "file_search",
|
||||
"vector_store_ids": [vector_store_with_filtered_files.id],
|
||||
"filters": {
|
||||
"type": "and",
|
||||
"filters": [
|
||||
{
|
||||
"type": "gte",
|
||||
"key": "date",
|
||||
"value": 1672531200, # Jan 1, 2023
|
||||
},
|
||||
{
|
||||
"type": "lt",
|
||||
"key": "date",
|
||||
"value": 1680307200, # Apr 1, 2023
|
||||
},
|
||||
],
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
response = compat_client.responses.create(
|
||||
model=text_model_id,
|
||||
input="What happened in Q1 2023?",
|
||||
tools=tools,
|
||||
stream=False,
|
||||
include=["file_search_call.results"],
|
||||
)
|
||||
|
||||
assert response.output[0].type == "file_search_call"
|
||||
assert response.output[0].status == "completed"
|
||||
assert response.output[0].results
|
||||
# Should only return Q1 files (not Q2 or Q3)
|
||||
for result in response.output[0].results:
|
||||
assert "q1" in result.text.lower()
|
||||
# Ensure non-Q1 quarters are NOT returned
|
||||
assert "q2" not in result.text.lower()
|
||||
assert "q3" not in result.text.lower()
|
||||
|
||||
|
||||
def test_response_file_search_filter_compound_and(compat_client, text_model_id, vector_store_with_filtered_files):
|
||||
"""Test file search with compound AND filter (region AND category)."""
|
||||
tools = [
|
||||
{
|
||||
"type": "file_search",
|
||||
"vector_store_ids": [vector_store_with_filtered_files.id],
|
||||
"filters": {
|
||||
"type": "and",
|
||||
"filters": [
|
||||
{"type": "eq", "key": "region", "value": "us"},
|
||||
{"type": "eq", "key": "category", "value": "engineering"},
|
||||
],
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
response = compat_client.responses.create(
|
||||
model=text_model_id,
|
||||
input="What are the engineering updates from the US?",
|
||||
tools=tools,
|
||||
stream=False,
|
||||
include=["file_search_call.results"],
|
||||
)
|
||||
|
||||
assert response.output[0].type == "file_search_call"
|
||||
assert response.output[0].status == "completed"
|
||||
assert response.output[0].results
|
||||
# Should only return US engineering files
|
||||
assert len(response.output[0].results) >= 1
|
||||
for result in response.output[0].results:
|
||||
assert "us" in result.text.lower() and "technical" in result.text.lower()
|
||||
# Ensure it's not from other regions or categories
|
||||
assert "european" not in result.text.lower() and "asia" not in result.text.lower()
|
||||
assert "promotional" not in result.text.lower() and "revenue" not in result.text.lower()
|
||||
|
||||
|
||||
def test_response_file_search_filter_compound_or(compat_client, text_model_id, vector_store_with_filtered_files):
|
||||
"""Test file search with compound OR filter (marketing OR sales)."""
|
||||
tools = [
|
||||
{
|
||||
"type": "file_search",
|
||||
"vector_store_ids": [vector_store_with_filtered_files.id],
|
||||
"filters": {
|
||||
"type": "or",
|
||||
"filters": [
|
||||
{"type": "eq", "key": "category", "value": "marketing"},
|
||||
{"type": "eq", "key": "category", "value": "sales"},
|
||||
],
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
response = compat_client.responses.create(
|
||||
model=text_model_id,
|
||||
input="Show me marketing and sales documents",
|
||||
tools=tools,
|
||||
stream=False,
|
||||
include=["file_search_call.results"],
|
||||
)
|
||||
|
||||
assert response.output[0].type == "file_search_call"
|
||||
assert response.output[0].status == "completed"
|
||||
assert response.output[0].results
|
||||
# Should return marketing and sales files, but NOT engineering
|
||||
categories_found = set()
|
||||
for result in response.output[0].results:
|
||||
text_lower = result.text.lower()
|
||||
if "promotional" in text_lower or "advertising" in text_lower:
|
||||
categories_found.add("marketing")
|
||||
if "revenue figures" in text_lower:
|
||||
categories_found.add("sales")
|
||||
# Ensure engineering files are NOT returned
|
||||
assert "technical" not in text_lower, f"Engineering file should not be returned, but got: {result.text}"
|
||||
|
||||
# Verify we got at least one of the expected categories
|
||||
assert len(categories_found) > 0, "Should have found at least one marketing or sales file"
|
||||
assert categories_found.issubset({"marketing", "sales"}), f"Found unexpected categories: {categories_found}"
|
File diff suppressed because it is too large
Load diff
474
tests/integration/non_ci/responses/test_tool_responses.py
Normal file
474
tests/integration/non_ci/responses/test_tool_responses.py
Normal file
|
@ -0,0 +1,474 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
import httpx
|
||||
import openai
|
||||
import pytest
|
||||
|
||||
from llama_stack import LlamaStackAsLibraryClient
|
||||
from llama_stack.core.datatypes import AuthenticationRequiredError
|
||||
from tests.common.mcp import dependency_tools, make_mcp_server
|
||||
|
||||
from .fixtures.test_cases import (
|
||||
custom_tool_test_cases,
|
||||
file_search_test_cases,
|
||||
mcp_tool_test_cases,
|
||||
multi_turn_tool_execution_streaming_test_cases,
|
||||
multi_turn_tool_execution_test_cases,
|
||||
web_search_test_cases,
|
||||
)
|
||||
from .helpers import new_vector_store, setup_mcp_tools, upload_file, wait_for_file_attachment
|
||||
from .streaming_assertions import StreamingValidator
|
||||
|
||||
|
||||
@pytest.mark.parametrize("case", web_search_test_cases)
|
||||
def test_response_non_streaming_web_search(compat_client, text_model_id, case):
|
||||
response = compat_client.responses.create(
|
||||
model=text_model_id,
|
||||
input=case.input,
|
||||
tools=case.tools,
|
||||
stream=False,
|
||||
)
|
||||
assert len(response.output) > 1
|
||||
assert response.output[0].type == "web_search_call"
|
||||
assert response.output[0].status == "completed"
|
||||
assert response.output[1].type == "message"
|
||||
assert response.output[1].status == "completed"
|
||||
assert response.output[1].role == "assistant"
|
||||
assert len(response.output[1].content) > 0
|
||||
assert case.expected.lower() in response.output_text.lower().strip()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("case", file_search_test_cases)
|
||||
def test_response_non_streaming_file_search(compat_client, text_model_id, tmp_path, case):
|
||||
if isinstance(compat_client, LlamaStackAsLibraryClient):
|
||||
pytest.skip("Responses API file search is not yet supported in library client.")
|
||||
|
||||
vector_store = new_vector_store(compat_client, "test_vector_store")
|
||||
|
||||
if case.file_content:
|
||||
file_name = "test_response_non_streaming_file_search.txt"
|
||||
file_path = tmp_path / file_name
|
||||
file_path.write_text(case.file_content)
|
||||
elif case.file_path:
|
||||
file_path = os.path.join(os.path.dirname(__file__), "fixtures", case.file_path)
|
||||
file_name = os.path.basename(file_path)
|
||||
else:
|
||||
raise ValueError("No file content or path provided for case")
|
||||
|
||||
file_response = upload_file(compat_client, file_name, file_path)
|
||||
|
||||
# Attach our file to the vector store
|
||||
compat_client.vector_stores.files.create(
|
||||
vector_store_id=vector_store.id,
|
||||
file_id=file_response.id,
|
||||
)
|
||||
|
||||
# Wait for the file to be attached
|
||||
wait_for_file_attachment(compat_client, vector_store.id, file_response.id)
|
||||
|
||||
# Update our tools with the right vector store id
|
||||
tools = case.tools
|
||||
for tool in tools:
|
||||
if tool["type"] == "file_search":
|
||||
tool["vector_store_ids"] = [vector_store.id]
|
||||
|
||||
# Create the response request, which should query our vector store
|
||||
response = compat_client.responses.create(
|
||||
model=text_model_id,
|
||||
input=case.input,
|
||||
tools=tools,
|
||||
stream=False,
|
||||
include=["file_search_call.results"],
|
||||
)
|
||||
|
||||
# Verify the file_search_tool was called
|
||||
assert len(response.output) > 1
|
||||
assert response.output[0].type == "file_search_call"
|
||||
assert response.output[0].status == "completed"
|
||||
assert response.output[0].queries # ensure it's some non-empty list
|
||||
assert response.output[0].results
|
||||
assert case.expected.lower() in response.output[0].results[0].text.lower()
|
||||
assert response.output[0].results[0].score > 0
|
||||
|
||||
# Verify the output_text generated by the response
|
||||
assert case.expected.lower() in response.output_text.lower().strip()
|
||||
|
||||
|
||||
def test_response_non_streaming_file_search_empty_vector_store(compat_client, text_model_id):
|
||||
if isinstance(compat_client, LlamaStackAsLibraryClient):
|
||||
pytest.skip("Responses API file search is not yet supported in library client.")
|
||||
|
||||
vector_store = new_vector_store(compat_client, "test_vector_store")
|
||||
|
||||
# Create the response request, which should query our vector store
|
||||
response = compat_client.responses.create(
|
||||
model=text_model_id,
|
||||
input="How many experts does the Llama 4 Maverick model have?",
|
||||
tools=[{"type": "file_search", "vector_store_ids": [vector_store.id]}],
|
||||
stream=False,
|
||||
include=["file_search_call.results"],
|
||||
)
|
||||
|
||||
# Verify the file_search_tool was called
|
||||
assert len(response.output) > 1
|
||||
assert response.output[0].type == "file_search_call"
|
||||
assert response.output[0].status == "completed"
|
||||
assert response.output[0].queries # ensure it's some non-empty list
|
||||
assert not response.output[0].results # ensure we don't get any results
|
||||
|
||||
# Verify some output_text was generated by the response
|
||||
assert response.output_text
|
||||
|
||||
|
||||
@pytest.mark.parametrize("case", mcp_tool_test_cases)
|
||||
def test_response_non_streaming_mcp_tool(compat_client, text_model_id, case):
|
||||
if not isinstance(compat_client, LlamaStackAsLibraryClient):
|
||||
pytest.skip("in-process MCP server is only supported in library client")
|
||||
|
||||
with make_mcp_server() as mcp_server_info:
|
||||
tools = setup_mcp_tools(case.tools, mcp_server_info)
|
||||
|
||||
response = compat_client.responses.create(
|
||||
model=text_model_id,
|
||||
input=case.input,
|
||||
tools=tools,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert len(response.output) >= 3
|
||||
list_tools = response.output[0]
|
||||
assert list_tools.type == "mcp_list_tools"
|
||||
assert list_tools.server_label == "localmcp"
|
||||
assert len(list_tools.tools) == 2
|
||||
assert {t.name for t in list_tools.tools} == {
|
||||
"get_boiling_point",
|
||||
"greet_everyone",
|
||||
}
|
||||
|
||||
call = response.output[1]
|
||||
assert call.type == "mcp_call"
|
||||
assert call.name == "get_boiling_point"
|
||||
assert json.loads(call.arguments) == {
|
||||
"liquid_name": "myawesomeliquid",
|
||||
"celsius": True,
|
||||
}
|
||||
assert call.error is None
|
||||
assert "-100" in call.output
|
||||
|
||||
# sometimes the model will call the tool again, so we need to get the last message
|
||||
message = response.output[-1]
|
||||
text_content = message.content[0].text
|
||||
assert "boiling point" in text_content.lower()
|
||||
|
||||
with make_mcp_server(required_auth_token="test-token") as mcp_server_info:
|
||||
tools = setup_mcp_tools(case.tools, mcp_server_info)
|
||||
|
||||
exc_type = (
|
||||
AuthenticationRequiredError
|
||||
if isinstance(compat_client, LlamaStackAsLibraryClient)
|
||||
else (httpx.HTTPStatusError, openai.AuthenticationError)
|
||||
)
|
||||
with pytest.raises(exc_type):
|
||||
compat_client.responses.create(
|
||||
model=text_model_id,
|
||||
input=case.input,
|
||||
tools=tools,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
for tool in tools:
|
||||
if tool["type"] == "mcp":
|
||||
tool["headers"] = {"Authorization": "Bearer test-token"}
|
||||
|
||||
response = compat_client.responses.create(
|
||||
model=text_model_id,
|
||||
input=case.input,
|
||||
tools=tools,
|
||||
stream=False,
|
||||
)
|
||||
assert len(response.output) >= 3
|
||||
|
||||
|
||||
@pytest.mark.parametrize("case", mcp_tool_test_cases)
|
||||
def test_response_sequential_mcp_tool(compat_client, text_model_id, case):
|
||||
if not isinstance(compat_client, LlamaStackAsLibraryClient):
|
||||
pytest.skip("in-process MCP server is only supported in library client")
|
||||
|
||||
with make_mcp_server() as mcp_server_info:
|
||||
tools = setup_mcp_tools(case.tools, mcp_server_info)
|
||||
|
||||
response = compat_client.responses.create(
|
||||
model=text_model_id,
|
||||
input=case.input,
|
||||
tools=tools,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert len(response.output) >= 3
|
||||
list_tools = response.output[0]
|
||||
assert list_tools.type == "mcp_list_tools"
|
||||
assert list_tools.server_label == "localmcp"
|
||||
assert len(list_tools.tools) == 2
|
||||
assert {t.name for t in list_tools.tools} == {
|
||||
"get_boiling_point",
|
||||
"greet_everyone",
|
||||
}
|
||||
|
||||
call = response.output[1]
|
||||
assert call.type == "mcp_call"
|
||||
assert call.name == "get_boiling_point"
|
||||
assert json.loads(call.arguments) == {
|
||||
"liquid_name": "myawesomeliquid",
|
||||
"celsius": True,
|
||||
}
|
||||
assert call.error is None
|
||||
assert "-100" in call.output
|
||||
|
||||
# sometimes the model will call the tool again, so we need to get the last message
|
||||
message = response.output[-1]
|
||||
text_content = message.content[0].text
|
||||
assert "boiling point" in text_content.lower()
|
||||
|
||||
response2 = compat_client.responses.create(
|
||||
model=text_model_id, input=case.input, tools=tools, stream=False, previous_response_id=response.id
|
||||
)
|
||||
|
||||
assert len(response2.output) >= 1
|
||||
message = response2.output[-1]
|
||||
text_content = message.content[0].text
|
||||
assert "boiling point" in text_content.lower()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("case", custom_tool_test_cases)
|
||||
def test_response_non_streaming_custom_tool(compat_client, text_model_id, case):
|
||||
response = compat_client.responses.create(
|
||||
model=text_model_id,
|
||||
input=case.input,
|
||||
tools=case.tools,
|
||||
stream=False,
|
||||
)
|
||||
assert len(response.output) == 1
|
||||
assert response.output[0].type == "function_call"
|
||||
assert response.output[0].status == "completed"
|
||||
assert response.output[0].name == "get_weather"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("case", custom_tool_test_cases)
|
||||
def test_response_function_call_ordering_1(compat_client, text_model_id, case):
|
||||
response = compat_client.responses.create(
|
||||
model=text_model_id,
|
||||
input=case.input,
|
||||
tools=case.tools,
|
||||
stream=False,
|
||||
)
|
||||
assert len(response.output) == 1
|
||||
assert response.output[0].type == "function_call"
|
||||
assert response.output[0].status == "completed"
|
||||
assert response.output[0].name == "get_weather"
|
||||
inputs = []
|
||||
inputs.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": case.input,
|
||||
}
|
||||
)
|
||||
inputs.append(
|
||||
{
|
||||
"type": "function_call_output",
|
||||
"output": "It is raining.",
|
||||
"call_id": response.output[0].call_id,
|
||||
}
|
||||
)
|
||||
response = compat_client.responses.create(
|
||||
model=text_model_id, input=inputs, tools=case.tools, stream=False, previous_response_id=response.id
|
||||
)
|
||||
assert len(response.output) == 1
|
||||
|
||||
|
||||
def test_response_function_call_ordering_2(compat_client, text_model_id):
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"name": "get_weather",
|
||||
"description": "Get current temperature for a given location.",
|
||||
"parameters": {
|
||||
"additionalProperties": False,
|
||||
"properties": {
|
||||
"location": {
|
||||
"description": "City and country e.g. Bogotá, Colombia",
|
||||
"type": "string",
|
||||
}
|
||||
},
|
||||
"required": ["location"],
|
||||
"type": "object",
|
||||
},
|
||||
}
|
||||
]
|
||||
inputs = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Is the weather better in San Francisco or Los Angeles?",
|
||||
}
|
||||
]
|
||||
response = compat_client.responses.create(
|
||||
model=text_model_id,
|
||||
input=inputs,
|
||||
tools=tools,
|
||||
stream=False,
|
||||
)
|
||||
for output in response.output:
|
||||
if output.type == "function_call" and output.status == "completed" and output.name == "get_weather":
|
||||
inputs.append(output)
|
||||
for output in response.output:
|
||||
if output.type == "function_call" and output.status == "completed" and output.name == "get_weather":
|
||||
weather = "It is raining."
|
||||
if "Los Angeles" in output.arguments:
|
||||
weather = "It is cloudy."
|
||||
inputs.append(
|
||||
{
|
||||
"type": "function_call_output",
|
||||
"output": weather,
|
||||
"call_id": output.call_id,
|
||||
}
|
||||
)
|
||||
response = compat_client.responses.create(
|
||||
model=text_model_id,
|
||||
input=inputs,
|
||||
tools=tools,
|
||||
stream=False,
|
||||
)
|
||||
assert len(response.output) == 1
|
||||
assert "Los Angeles" in response.output_text
|
||||
|
||||
|
||||
@pytest.mark.parametrize("case", multi_turn_tool_execution_test_cases)
|
||||
def test_response_non_streaming_multi_turn_tool_execution(compat_client, text_model_id, case):
|
||||
"""Test multi-turn tool execution where multiple MCP tool calls are performed in sequence."""
|
||||
if not isinstance(compat_client, LlamaStackAsLibraryClient):
|
||||
pytest.skip("in-process MCP server is only supported in library client")
|
||||
|
||||
with make_mcp_server(tools=dependency_tools()) as mcp_server_info:
|
||||
tools = setup_mcp_tools(case.tools, mcp_server_info)
|
||||
|
||||
response = compat_client.responses.create(
|
||||
input=case.input,
|
||||
model=text_model_id,
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
# Verify we have MCP tool calls in the output
|
||||
mcp_list_tools = [output for output in response.output if output.type == "mcp_list_tools"]
|
||||
mcp_calls = [output for output in response.output if output.type == "mcp_call"]
|
||||
message_outputs = [output for output in response.output if output.type == "message"]
|
||||
|
||||
# Should have exactly 1 MCP list tools message (at the beginning)
|
||||
assert len(mcp_list_tools) == 1, f"Expected exactly 1 mcp_list_tools, got {len(mcp_list_tools)}"
|
||||
assert mcp_list_tools[0].server_label == "localmcp"
|
||||
assert len(mcp_list_tools[0].tools) == 5 # Updated for dependency tools
|
||||
expected_tool_names = {
|
||||
"get_user_id",
|
||||
"get_user_permissions",
|
||||
"check_file_access",
|
||||
"get_experiment_id",
|
||||
"get_experiment_results",
|
||||
}
|
||||
assert {t.name for t in mcp_list_tools[0].tools} == expected_tool_names
|
||||
|
||||
assert len(mcp_calls) >= 1, f"Expected at least 1 mcp_call, got {len(mcp_calls)}"
|
||||
for mcp_call in mcp_calls:
|
||||
assert mcp_call.error is None, f"MCP call should not have errors, got: {mcp_call.error}"
|
||||
|
||||
assert len(message_outputs) >= 1, f"Expected at least 1 message output, got {len(message_outputs)}"
|
||||
|
||||
final_message = message_outputs[-1]
|
||||
assert final_message.role == "assistant", f"Final message should be from assistant, got {final_message.role}"
|
||||
assert final_message.status == "completed", f"Final message should be completed, got {final_message.status}"
|
||||
assert len(final_message.content) > 0, "Final message should have content"
|
||||
|
||||
expected_output = case.expected
|
||||
assert expected_output.lower() in response.output_text.lower(), (
|
||||
f"Expected '{expected_output}' to appear in response: {response.output_text}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("case", multi_turn_tool_execution_streaming_test_cases)
|
||||
def test_response_streaming_multi_turn_tool_execution(compat_client, text_model_id, case):
|
||||
"""Test streaming multi-turn tool execution where multiple MCP tool calls are performed in sequence."""
|
||||
if not isinstance(compat_client, LlamaStackAsLibraryClient):
|
||||
pytest.skip("in-process MCP server is only supported in library client")
|
||||
|
||||
with make_mcp_server(tools=dependency_tools()) as mcp_server_info:
|
||||
tools = setup_mcp_tools(case.tools, mcp_server_info)
|
||||
|
||||
stream = compat_client.responses.create(
|
||||
input=case.input,
|
||||
model=text_model_id,
|
||||
tools=tools,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
chunks = []
|
||||
for chunk in stream:
|
||||
chunks.append(chunk)
|
||||
|
||||
# Use validator for common streaming checks
|
||||
validator = StreamingValidator(chunks)
|
||||
validator.assert_basic_event_sequence()
|
||||
validator.assert_response_consistency()
|
||||
validator.assert_has_tool_calls()
|
||||
validator.assert_has_mcp_events()
|
||||
validator.assert_rich_streaming()
|
||||
|
||||
# Get the final response from the last chunk
|
||||
final_chunk = chunks[-1]
|
||||
if hasattr(final_chunk, "response"):
|
||||
final_response = final_chunk.response
|
||||
|
||||
# Verify multi-turn MCP tool execution results
|
||||
mcp_list_tools = [output for output in final_response.output if output.type == "mcp_list_tools"]
|
||||
mcp_calls = [output for output in final_response.output if output.type == "mcp_call"]
|
||||
message_outputs = [output for output in final_response.output if output.type == "message"]
|
||||
|
||||
# Should have exactly 1 MCP list tools message (at the beginning)
|
||||
assert len(mcp_list_tools) == 1, f"Expected exactly 1 mcp_list_tools, got {len(mcp_list_tools)}"
|
||||
assert mcp_list_tools[0].server_label == "localmcp"
|
||||
assert len(mcp_list_tools[0].tools) == 5 # Updated for dependency tools
|
||||
expected_tool_names = {
|
||||
"get_user_id",
|
||||
"get_user_permissions",
|
||||
"check_file_access",
|
||||
"get_experiment_id",
|
||||
"get_experiment_results",
|
||||
}
|
||||
assert {t.name for t in mcp_list_tools[0].tools} == expected_tool_names
|
||||
|
||||
# Should have at least 1 MCP call (the model should call at least one tool)
|
||||
assert len(mcp_calls) >= 1, f"Expected at least 1 mcp_call, got {len(mcp_calls)}"
|
||||
|
||||
# All MCP calls should be completed (verifies our tool execution works)
|
||||
for mcp_call in mcp_calls:
|
||||
assert mcp_call.error is None, f"MCP call should not have errors, got: {mcp_call.error}"
|
||||
|
||||
# Should have at least one final message response
|
||||
assert len(message_outputs) >= 1, f"Expected at least 1 message output, got {len(message_outputs)}"
|
||||
|
||||
# Final message should be from assistant and completed
|
||||
final_message = message_outputs[-1]
|
||||
assert final_message.role == "assistant", (
|
||||
f"Final message should be from assistant, got {final_message.role}"
|
||||
)
|
||||
assert final_message.status == "completed", f"Final message should be completed, got {final_message.status}"
|
||||
assert len(final_message.content) > 0, "Final message should have content"
|
||||
|
||||
# Check that the expected output appears in the response
|
||||
expected_output = case.expected
|
||||
assert expected_output.lower() in final_response.output_text.lower(), (
|
||||
f"Expected '{expected_output}' to appear in response: {final_response.output_text}"
|
||||
)
|
|
@ -4,7 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
import uuid
|
||||
|
@ -19,10 +18,10 @@ from llama_stack.apis.post_training import (
|
|||
LoraFinetuningConfig,
|
||||
TrainingConfig,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s", force=True)
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = get_logger(name=__name__, category="post_training")
|
||||
|
||||
|
||||
skip_because_resource_intensive = pytest.mark.skip(
|
||||
|
|
Binary file not shown.
39
tests/integration/recordings/responses/390f0c7dac96.json
Normal file
39
tests/integration/recordings/responses/390f0c7dac96.json
Normal file
|
@ -0,0 +1,39 @@
|
|||
{
|
||||
"request": {
|
||||
"method": "POST",
|
||||
"url": "http://localhost:11434/api/generate",
|
||||
"headers": {},
|
||||
"body": {
|
||||
"model": "llama3.2:3b-instruct-fp16",
|
||||
"raw": true,
|
||||
"prompt": "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nTest metrics generation 1<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
|
||||
"options": {
|
||||
"temperature": 0.0
|
||||
},
|
||||
"stream": false
|
||||
},
|
||||
"endpoint": "/api/generate",
|
||||
"model": "llama3.2:3b-instruct-fp16"
|
||||
},
|
||||
"response": {
|
||||
"body": {
|
||||
"__type__": "ollama._types.GenerateResponse",
|
||||
"__data__": {
|
||||
"model": "llama3.2:3b-instruct-fp16",
|
||||
"created_at": "2025-08-11T15:51:18.170868Z",
|
||||
"done": true,
|
||||
"done_reason": "stop",
|
||||
"total_duration": 5240614083,
|
||||
"load_duration": 9823416,
|
||||
"prompt_eval_count": 21,
|
||||
"prompt_eval_duration": 21000000,
|
||||
"eval_count": 310,
|
||||
"eval_duration": 5209000000,
|
||||
"response": "This is the start of a test. I'll provide some sample data and you can try to generate metrics based on it.\n\n**Data:**\n\nLet's say we have a dataset of user interactions with an e-commerce website. The data includes:\n\n| User ID | Product Name | Purchase Date | Quantity | Price |\n| --- | --- | --- | --- | --- |\n| 1 | iPhone 13 | 2022-01-01 | 2 | 999.99 |\n| 1 | MacBook Air | 2022-01-05 | 1 | 1299.99 |\n| 2 | Samsung TV | 2022-01-10 | 3 | 899.99 |\n| 3 | iPhone 13 | 2022-01-15 | 1 | 999.99 |\n| 4 | MacBook Pro | 2022-01-20 | 2 | 1799.99 |\n\n**Task:**\n\nYour task is to generate the following metrics based on this data:\n\n1. Average order value (AOV)\n2. Conversion rate\n3. Average revenue per user (ARPU)\n4. Customer lifetime value (CLV)\n\nPlease provide your answers in a format like this:\n\n| Metric | Value |\n| --- | --- |\n| AOV | 1234.56 |\n| Conversion Rate | 0.25 |\n| ARPU | 1000.00 |\n| CLV | 5000.00 |\n\nGo ahead and generate the metrics!",
|
||||
"thinking": null,
|
||||
"context": null
|
||||
}
|
||||
},
|
||||
"is_streaming": false
|
||||
}
|
||||
}
|
56
tests/integration/recordings/responses/3c0bf9ba81b2.json
Normal file
56
tests/integration/recordings/responses/3c0bf9ba81b2.json
Normal file
|
@ -0,0 +1,56 @@
|
|||
{
|
||||
"request": {
|
||||
"method": "POST",
|
||||
"url": "http://0.0.0.0:11434/v1/v1/chat/completions",
|
||||
"headers": {},
|
||||
"body": {
|
||||
"model": "llama3.2:3b-instruct-fp16",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Quick test"
|
||||
}
|
||||
],
|
||||
"max_tokens": 5
|
||||
},
|
||||
"endpoint": "/v1/chat/completions",
|
||||
"model": "llama3.2:3b-instruct-fp16"
|
||||
},
|
||||
"response": {
|
||||
"body": {
|
||||
"__type__": "openai.types.chat.chat_completion.ChatCompletion",
|
||||
"__data__": {
|
||||
"id": "chatcmpl-651",
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "length",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"message": {
|
||||
"content": "I'm ready to help",
|
||||
"refusal": null,
|
||||
"role": "assistant",
|
||||
"annotations": null,
|
||||
"audio": null,
|
||||
"function_call": null,
|
||||
"tool_calls": null
|
||||
}
|
||||
}
|
||||
],
|
||||
"created": 1755294941,
|
||||
"model": "llama3.2:3b-instruct-fp16",
|
||||
"object": "chat.completion",
|
||||
"service_tier": null,
|
||||
"system_fingerprint": "fp_ollama",
|
||||
"usage": {
|
||||
"completion_tokens": 5,
|
||||
"prompt_tokens": 27,
|
||||
"total_tokens": 32,
|
||||
"completion_tokens_details": null,
|
||||
"prompt_tokens_details": null
|
||||
}
|
||||
}
|
||||
},
|
||||
"is_streaming": false
|
||||
}
|
||||
}
|
56
tests/integration/recordings/responses/44a1d9de0602.json
Normal file
56
tests/integration/recordings/responses/44a1d9de0602.json
Normal file
|
@ -0,0 +1,56 @@
|
|||
{
|
||||
"request": {
|
||||
"method": "POST",
|
||||
"url": "http://0.0.0.0:11434/v1/v1/chat/completions",
|
||||
"headers": {},
|
||||
"body": {
|
||||
"model": "llama3.2:3b-instruct-fp16",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Say hello"
|
||||
}
|
||||
],
|
||||
"max_tokens": 20
|
||||
},
|
||||
"endpoint": "/v1/chat/completions",
|
||||
"model": "llama3.2:3b-instruct-fp16"
|
||||
},
|
||||
"response": {
|
||||
"body": {
|
||||
"__type__": "openai.types.chat.chat_completion.ChatCompletion",
|
||||
"__data__": {
|
||||
"id": "chatcmpl-987",
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "length",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"message": {
|
||||
"content": "Hello! It's nice to meet you. Is there something I can help you with or would you",
|
||||
"refusal": null,
|
||||
"role": "assistant",
|
||||
"annotations": null,
|
||||
"audio": null,
|
||||
"function_call": null,
|
||||
"tool_calls": null
|
||||
}
|
||||
}
|
||||
],
|
||||
"created": 1755294921,
|
||||
"model": "llama3.2:3b-instruct-fp16",
|
||||
"object": "chat.completion",
|
||||
"service_tier": null,
|
||||
"system_fingerprint": "fp_ollama",
|
||||
"usage": {
|
||||
"completion_tokens": 20,
|
||||
"prompt_tokens": 27,
|
||||
"total_tokens": 47,
|
||||
"completion_tokens_details": null,
|
||||
"prompt_tokens_details": null
|
||||
}
|
||||
}
|
||||
},
|
||||
"is_streaming": false
|
||||
}
|
||||
}
|
|
@ -14,7 +14,7 @@
|
|||
"models": [
|
||||
{
|
||||
"model": "nomic-embed-text:latest",
|
||||
"modified_at": "2025-08-05T14:04:07.946926-07:00",
|
||||
"modified_at": "2025-08-18T12:47:56.732989-07:00",
|
||||
"digest": "0a109f422b47e3a30ba2b10eca18548e944e8a23073ee3f3e947efcf3c45e59f",
|
||||
"size": 274302450,
|
||||
"details": {
|
||||
|
|
56
tests/integration/recordings/responses/4de6877d86fa.json
Normal file
56
tests/integration/recordings/responses/4de6877d86fa.json
Normal file
|
@ -0,0 +1,56 @@
|
|||
{
|
||||
"request": {
|
||||
"method": "POST",
|
||||
"url": "http://localhost:11434/v1/v1/chat/completions",
|
||||
"headers": {},
|
||||
"body": {
|
||||
"model": "llama3.2:3b",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "OpenAI test 0"
|
||||
}
|
||||
],
|
||||
"stream": false
|
||||
},
|
||||
"endpoint": "/v1/chat/completions",
|
||||
"model": "llama3.2:3b"
|
||||
},
|
||||
"response": {
|
||||
"body": {
|
||||
"__type__": "openai.types.chat.chat_completion.ChatCompletion",
|
||||
"__data__": {
|
||||
"id": "chatcmpl-843",
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "stop",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"message": {
|
||||
"content": "I don't have any information about an \"OpenAI test 0\". It's possible that you may be referring to a specific experiment or task being performed by OpenAI, but without more context, I can only speculate.\n\nHowever, I can tell you that OpenAI is a research organization that has been involved in various projects and tests related to artificial intelligence. If you could provide more context or clarify what you're referring to, I may be able to help further.\n\nIf you're looking for general information about OpenAI, I can try to provide some background on the organization:\n\nOpenAI is a non-profit research organization that was founded in 2015 with the goal of developing and applying advanced artificial intelligence to benefit humanity. The organization has made significant contributions to the field of AI, including the development of the popular language model, ChatGPT.\n\nIf you could provide more context or clarify what you're looking for, I'll do my best to assist you.",
|
||||
"refusal": null,
|
||||
"role": "assistant",
|
||||
"annotations": null,
|
||||
"audio": null,
|
||||
"function_call": null,
|
||||
"tool_calls": null
|
||||
}
|
||||
}
|
||||
],
|
||||
"created": 1755891518,
|
||||
"model": "llama3.2:3b",
|
||||
"object": "chat.completion",
|
||||
"service_tier": null,
|
||||
"system_fingerprint": "fp_ollama",
|
||||
"usage": {
|
||||
"completion_tokens": 194,
|
||||
"prompt_tokens": 30,
|
||||
"total_tokens": 224,
|
||||
"completion_tokens_details": null,
|
||||
"prompt_tokens_details": null
|
||||
}
|
||||
}
|
||||
},
|
||||
"is_streaming": false
|
||||
}
|
||||
}
|
|
@ -21,7 +21,7 @@
|
|||
"__type__": "ollama._types.GenerateResponse",
|
||||
"__data__": {
|
||||
"model": "llama3.2:3b-instruct-fp16",
|
||||
"created_at": "2025-08-04T22:55:14.141947Z",
|
||||
"created_at": "2025-08-15T20:24:49.18651486Z",
|
||||
"done": false,
|
||||
"done_reason": null,
|
||||
"total_duration": null,
|
||||
|
@ -39,7 +39,7 @@
|
|||
"__type__": "ollama._types.GenerateResponse",
|
||||
"__data__": {
|
||||
"model": "llama3.2:3b-instruct-fp16",
|
||||
"created_at": "2025-08-04T22:55:14.194979Z",
|
||||
"created_at": "2025-08-15T20:24:49.370611348Z",
|
||||
"done": false,
|
||||
"done_reason": null,
|
||||
"total_duration": null,
|
||||
|
@ -57,7 +57,7 @@
|
|||
"__type__": "ollama._types.GenerateResponse",
|
||||
"__data__": {
|
||||
"model": "llama3.2:3b-instruct-fp16",
|
||||
"created_at": "2025-08-04T22:55:14.248312Z",
|
||||
"created_at": "2025-08-15T20:24:49.557000029Z",
|
||||
"done": false,
|
||||
"done_reason": null,
|
||||
"total_duration": null,
|
||||
|
@ -75,7 +75,7 @@
|
|||
"__type__": "ollama._types.GenerateResponse",
|
||||
"__data__": {
|
||||
"model": "llama3.2:3b-instruct-fp16",
|
||||
"created_at": "2025-08-04T22:55:14.301911Z",
|
||||
"created_at": "2025-08-15T20:24:49.746777116Z",
|
||||
"done": false,
|
||||
"done_reason": null,
|
||||
"total_duration": null,
|
||||
|
@ -93,7 +93,7 @@
|
|||
"__type__": "ollama._types.GenerateResponse",
|
||||
"__data__": {
|
||||
"model": "llama3.2:3b-instruct-fp16",
|
||||
"created_at": "2025-08-04T22:55:14.354437Z",
|
||||
"created_at": "2025-08-15T20:24:49.942233333Z",
|
||||
"done": false,
|
||||
"done_reason": null,
|
||||
"total_duration": null,
|
||||
|
@ -111,7 +111,7 @@
|
|||
"__type__": "ollama._types.GenerateResponse",
|
||||
"__data__": {
|
||||
"model": "llama3.2:3b-instruct-fp16",
|
||||
"created_at": "2025-08-04T22:55:14.406821Z",
|
||||
"created_at": "2025-08-15T20:24:50.126788846Z",
|
||||
"done": false,
|
||||
"done_reason": null,
|
||||
"total_duration": null,
|
||||
|
@ -129,7 +129,7 @@
|
|||
"__type__": "ollama._types.GenerateResponse",
|
||||
"__data__": {
|
||||
"model": "llama3.2:3b-instruct-fp16",
|
||||
"created_at": "2025-08-04T22:55:14.457633Z",
|
||||
"created_at": "2025-08-15T20:24:50.311346131Z",
|
||||
"done": false,
|
||||
"done_reason": null,
|
||||
"total_duration": null,
|
||||
|
@ -147,7 +147,7 @@
|
|||
"__type__": "ollama._types.GenerateResponse",
|
||||
"__data__": {
|
||||
"model": "llama3.2:3b-instruct-fp16",
|
||||
"created_at": "2025-08-04T22:55:14.507857Z",
|
||||
"created_at": "2025-08-15T20:24:50.501507173Z",
|
||||
"done": false,
|
||||
"done_reason": null,
|
||||
"total_duration": null,
|
||||
|
@ -165,7 +165,7 @@
|
|||
"__type__": "ollama._types.GenerateResponse",
|
||||
"__data__": {
|
||||
"model": "llama3.2:3b-instruct-fp16",
|
||||
"created_at": "2025-08-04T22:55:14.558847Z",
|
||||
"created_at": "2025-08-15T20:24:50.692296777Z",
|
||||
"done": false,
|
||||
"done_reason": null,
|
||||
"total_duration": null,
|
||||
|
@ -183,7 +183,7 @@
|
|||
"__type__": "ollama._types.GenerateResponse",
|
||||
"__data__": {
|
||||
"model": "llama3.2:3b-instruct-fp16",
|
||||
"created_at": "2025-08-04T22:55:14.609969Z",
|
||||
"created_at": "2025-08-15T20:24:50.878846539Z",
|
||||
"done": false,
|
||||
"done_reason": null,
|
||||
"total_duration": null,
|
||||
|
@ -201,15 +201,15 @@
|
|||
"__type__": "ollama._types.GenerateResponse",
|
||||
"__data__": {
|
||||
"model": "llama3.2:3b-instruct-fp16",
|
||||
"created_at": "2025-08-04T22:55:14.660997Z",
|
||||
"created_at": "2025-08-15T20:24:51.063200561Z",
|
||||
"done": true,
|
||||
"done_reason": "stop",
|
||||
"total_duration": 715356542,
|
||||
"load_duration": 59747500,
|
||||
"total_duration": 33982453650,
|
||||
"load_duration": 2909001805,
|
||||
"prompt_eval_count": 341,
|
||||
"prompt_eval_duration": 128000000,
|
||||
"prompt_eval_duration": 29194357307,
|
||||
"eval_count": 11,
|
||||
"eval_duration": 526000000,
|
||||
"eval_duration": 1878247732,
|
||||
"response": "",
|
||||
"thinking": null,
|
||||
"context": null
|
||||
|
|
56
tests/integration/recordings/responses/5db0c44c83a4.json
Normal file
56
tests/integration/recordings/responses/5db0c44c83a4.json
Normal file
|
@ -0,0 +1,56 @@
|
|||
{
|
||||
"request": {
|
||||
"method": "POST",
|
||||
"url": "http://localhost:11434/v1/v1/chat/completions",
|
||||
"headers": {},
|
||||
"body": {
|
||||
"model": "llama3.2:3b",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "OpenAI test 1"
|
||||
}
|
||||
],
|
||||
"stream": false
|
||||
},
|
||||
"endpoint": "/v1/chat/completions",
|
||||
"model": "llama3.2:3b"
|
||||
},
|
||||
"response": {
|
||||
"body": {
|
||||
"__type__": "openai.types.chat.chat_completion.ChatCompletion",
|
||||
"__data__": {
|
||||
"id": "chatcmpl-726",
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "stop",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"message": {
|
||||
"content": "I'm ready to help with the test. What language would you like to use? Would you like to have a conversation, ask questions, or take a specific type of task?",
|
||||
"refusal": null,
|
||||
"role": "assistant",
|
||||
"annotations": null,
|
||||
"audio": null,
|
||||
"function_call": null,
|
||||
"tool_calls": null
|
||||
}
|
||||
}
|
||||
],
|
||||
"created": 1755891519,
|
||||
"model": "llama3.2:3b",
|
||||
"object": "chat.completion",
|
||||
"service_tier": null,
|
||||
"system_fingerprint": "fp_ollama",
|
||||
"usage": {
|
||||
"completion_tokens": 37,
|
||||
"prompt_tokens": 30,
|
||||
"total_tokens": 67,
|
||||
"completion_tokens_details": null,
|
||||
"prompt_tokens_details": null
|
||||
}
|
||||
}
|
||||
},
|
||||
"is_streaming": false
|
||||
}
|
||||
}
|
56
tests/integration/recordings/responses/6cb0285a7638.json
Normal file
56
tests/integration/recordings/responses/6cb0285a7638.json
Normal file
|
@ -0,0 +1,56 @@
|
|||
{
|
||||
"request": {
|
||||
"method": "POST",
|
||||
"url": "http://localhost:11434/v1/v1/chat/completions",
|
||||
"headers": {},
|
||||
"body": {
|
||||
"model": "llama3.2:3b",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "OpenAI test 4"
|
||||
}
|
||||
],
|
||||
"stream": false
|
||||
},
|
||||
"endpoint": "/v1/chat/completions",
|
||||
"model": "llama3.2:3b"
|
||||
},
|
||||
"response": {
|
||||
"body": {
|
||||
"__type__": "openai.types.chat.chat_completion.ChatCompletion",
|
||||
"__data__": {
|
||||
"id": "chatcmpl-581",
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "stop",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"message": {
|
||||
"content": "I'm ready to help. What would you like to test? We could try a variety of things, such as:\n\n1. Conversational dialogue\n2. Language understanding\n3. Common sense reasoning\n4. Joke or pun generation\n5. Trivia or knowledge-based questions\n6. Creative writing or storytelling\n7. Summarization or paraphrasing\n\nLet me know which area you'd like to test, or suggest something else that's on your mind!",
|
||||
"refusal": null,
|
||||
"role": "assistant",
|
||||
"annotations": null,
|
||||
"audio": null,
|
||||
"function_call": null,
|
||||
"tool_calls": null
|
||||
}
|
||||
}
|
||||
],
|
||||
"created": 1755891527,
|
||||
"model": "llama3.2:3b",
|
||||
"object": "chat.completion",
|
||||
"service_tier": null,
|
||||
"system_fingerprint": "fp_ollama",
|
||||
"usage": {
|
||||
"completion_tokens": 96,
|
||||
"prompt_tokens": 30,
|
||||
"total_tokens": 126,
|
||||
"completion_tokens_details": null,
|
||||
"prompt_tokens_details": null
|
||||
}
|
||||
}
|
||||
},
|
||||
"is_streaming": false
|
||||
}
|
||||
}
|
File diff suppressed because it is too large
Load diff
203
tests/integration/recordings/responses/731824c54461.json
Normal file
203
tests/integration/recordings/responses/731824c54461.json
Normal file
|
@ -0,0 +1,203 @@
|
|||
{
|
||||
"request": {
|
||||
"method": "POST",
|
||||
"url": "http://localhost:11434/api/generate",
|
||||
"headers": {},
|
||||
"body": {
|
||||
"model": "llama3.2:3b-instruct-fp16",
|
||||
"raw": true,
|
||||
"prompt": "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nYou are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nGive me a sentence that contains the word: hello<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
|
||||
"options": {
|
||||
"temperature": 0.0
|
||||
},
|
||||
"stream": true
|
||||
},
|
||||
"endpoint": "/api/generate",
|
||||
"model": "llama3.2:3b-instruct-fp16"
|
||||
},
|
||||
"response": {
|
||||
"body": [
|
||||
{
|
||||
"__type__": "ollama._types.GenerateResponse",
|
||||
"__data__": {
|
||||
"model": "llama3.2:3b-instruct-fp16",
|
||||
"created_at": "2025-08-18T19:47:58.267146Z",
|
||||
"done": false,
|
||||
"done_reason": null,
|
||||
"total_duration": null,
|
||||
"load_duration": null,
|
||||
"prompt_eval_count": null,
|
||||
"prompt_eval_duration": null,
|
||||
"eval_count": null,
|
||||
"eval_duration": null,
|
||||
"response": "Hello",
|
||||
"thinking": null,
|
||||
"context": null
|
||||
}
|
||||
},
|
||||
{
|
||||
"__type__": "ollama._types.GenerateResponse",
|
||||
"__data__": {
|
||||
"model": "llama3.2:3b-instruct-fp16",
|
||||
"created_at": "2025-08-18T19:47:58.309006Z",
|
||||
"done": false,
|
||||
"done_reason": null,
|
||||
"total_duration": null,
|
||||
"load_duration": null,
|
||||
"prompt_eval_count": null,
|
||||
"prompt_eval_duration": null,
|
||||
"eval_count": null,
|
||||
"eval_duration": null,
|
||||
"response": ",",
|
||||
"thinking": null,
|
||||
"context": null
|
||||
}
|
||||
},
|
||||
{
|
||||
"__type__": "ollama._types.GenerateResponse",
|
||||
"__data__": {
|
||||
"model": "llama3.2:3b-instruct-fp16",
|
||||
"created_at": "2025-08-18T19:47:58.351179Z",
|
||||
"done": false,
|
||||
"done_reason": null,
|
||||
"total_duration": null,
|
||||
"load_duration": null,
|
||||
"prompt_eval_count": null,
|
||||
"prompt_eval_duration": null,
|
||||
"eval_count": null,
|
||||
"eval_duration": null,
|
||||
"response": " how",
|
||||
"thinking": null,
|
||||
"context": null
|
||||
}
|
||||
},
|
||||
{
|
||||
"__type__": "ollama._types.GenerateResponse",
|
||||
"__data__": {
|
||||
"model": "llama3.2:3b-instruct-fp16",
|
||||
"created_at": "2025-08-18T19:47:58.393262Z",
|
||||
"done": false,
|
||||
"done_reason": null,
|
||||
"total_duration": null,
|
||||
"load_duration": null,
|
||||
"prompt_eval_count": null,
|
||||
"prompt_eval_duration": null,
|
||||
"eval_count": null,
|
||||
"eval_duration": null,
|
||||
"response": " can",
|
||||
"thinking": null,
|
||||
"context": null
|
||||
}
|
||||
},
|
||||
{
|
||||
"__type__": "ollama._types.GenerateResponse",
|
||||
"__data__": {
|
||||
"model": "llama3.2:3b-instruct-fp16",
|
||||
"created_at": "2025-08-18T19:47:58.436079Z",
|
||||
"done": false,
|
||||
"done_reason": null,
|
||||
"total_duration": null,
|
||||
"load_duration": null,
|
||||
"prompt_eval_count": null,
|
||||
"prompt_eval_duration": null,
|
||||
"eval_count": null,
|
||||
"eval_duration": null,
|
||||
"response": " I",
|
||||
"thinking": null,
|
||||
"context": null
|
||||
}
|
||||
},
|
||||
{
|
||||
"__type__": "ollama._types.GenerateResponse",
|
||||
"__data__": {
|
||||
"model": "llama3.2:3b-instruct-fp16",
|
||||
"created_at": "2025-08-18T19:47:58.478393Z",
|
||||
"done": false,
|
||||
"done_reason": null,
|
||||
"total_duration": null,
|
||||
"load_duration": null,
|
||||
"prompt_eval_count": null,
|
||||
"prompt_eval_duration": null,
|
||||
"eval_count": null,
|
||||
"eval_duration": null,
|
||||
"response": " assist",
|
||||
"thinking": null,
|
||||
"context": null
|
||||
}
|
||||
},
|
||||
{
|
||||
"__type__": "ollama._types.GenerateResponse",
|
||||
"__data__": {
|
||||
"model": "llama3.2:3b-instruct-fp16",
|
||||
"created_at": "2025-08-18T19:47:58.520608Z",
|
||||
"done": false,
|
||||
"done_reason": null,
|
||||
"total_duration": null,
|
||||
"load_duration": null,
|
||||
"prompt_eval_count": null,
|
||||
"prompt_eval_duration": null,
|
||||
"eval_count": null,
|
||||
"eval_duration": null,
|
||||
"response": " you",
|
||||
"thinking": null,
|
||||
"context": null
|
||||
}
|
||||
},
|
||||
{
|
||||
"__type__": "ollama._types.GenerateResponse",
|
||||
"__data__": {
|
||||
"model": "llama3.2:3b-instruct-fp16",
|
||||
"created_at": "2025-08-18T19:47:58.562885Z",
|
||||
"done": false,
|
||||
"done_reason": null,
|
||||
"total_duration": null,
|
||||
"load_duration": null,
|
||||
"prompt_eval_count": null,
|
||||
"prompt_eval_duration": null,
|
||||
"eval_count": null,
|
||||
"eval_duration": null,
|
||||
"response": " today",
|
||||
"thinking": null,
|
||||
"context": null
|
||||
}
|
||||
},
|
||||
{
|
||||
"__type__": "ollama._types.GenerateResponse",
|
||||
"__data__": {
|
||||
"model": "llama3.2:3b-instruct-fp16",
|
||||
"created_at": "2025-08-18T19:47:58.604683Z",
|
||||
"done": false,
|
||||
"done_reason": null,
|
||||
"total_duration": null,
|
||||
"load_duration": null,
|
||||
"prompt_eval_count": null,
|
||||
"prompt_eval_duration": null,
|
||||
"eval_count": null,
|
||||
"eval_duration": null,
|
||||
"response": "?",
|
||||
"thinking": null,
|
||||
"context": null
|
||||
}
|
||||
},
|
||||
{
|
||||
"__type__": "ollama._types.GenerateResponse",
|
||||
"__data__": {
|
||||
"model": "llama3.2:3b-instruct-fp16",
|
||||
"created_at": "2025-08-18T19:47:58.646586Z",
|
||||
"done": true,
|
||||
"done_reason": "stop",
|
||||
"total_duration": 1011323917,
|
||||
"load_duration": 76575458,
|
||||
"prompt_eval_count": 31,
|
||||
"prompt_eval_duration": 553259250,
|
||||
"eval_count": 10,
|
||||
"eval_duration": 380302792,
|
||||
"response": "",
|
||||
"thinking": null,
|
||||
"context": null
|
||||
}
|
||||
}
|
||||
],
|
||||
"is_streaming": true
|
||||
}
|
||||
}
|
39
tests/integration/recordings/responses/7bcb0f86c91b.json
Normal file
39
tests/integration/recordings/responses/7bcb0f86c91b.json
Normal file
|
@ -0,0 +1,39 @@
|
|||
{
|
||||
"request": {
|
||||
"method": "POST",
|
||||
"url": "http://localhost:11434/api/generate",
|
||||
"headers": {},
|
||||
"body": {
|
||||
"model": "llama3.2:3b-instruct-fp16",
|
||||
"raw": true,
|
||||
"prompt": "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nTest metrics generation 0<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
|
||||
"options": {
|
||||
"temperature": 0.0
|
||||
},
|
||||
"stream": false
|
||||
},
|
||||
"endpoint": "/api/generate",
|
||||
"model": "llama3.2:3b-instruct-fp16"
|
||||
},
|
||||
"response": {
|
||||
"body": {
|
||||
"__type__": "ollama._types.GenerateResponse",
|
||||
"__data__": {
|
||||
"model": "llama3.2:3b-instruct-fp16",
|
||||
"created_at": "2025-08-11T15:51:12.918723Z",
|
||||
"done": true,
|
||||
"done_reason": "stop",
|
||||
"total_duration": 8868987792,
|
||||
"load_duration": 2793275292,
|
||||
"prompt_eval_count": 21,
|
||||
"prompt_eval_duration": 250000000,
|
||||
"eval_count": 344,
|
||||
"eval_duration": 5823000000,
|
||||
"response": "Here are some common test metrics used to evaluate the performance of a system:\n\n1. **Accuracy**: The proportion of correct predictions or classifications out of total predictions made.\n2. **Precision**: The ratio of true positives (correctly predicted instances) to the sum of true positives and false positives (incorrectly predicted instances).\n3. **Recall**: The ratio of true positives to the sum of true positives and false negatives (missed instances).\n4. **F1-score**: The harmonic mean of precision and recall, providing a balanced measure of both.\n5. **Mean Squared Error (MSE)**: The average squared difference between predicted and actual values.\n6. **Mean Absolute Error (MAE)**: The average absolute difference between predicted and actual values.\n7. **Root Mean Squared Percentage Error (RMSPE)**: The square root of the mean of the squared percentage differences between predicted and actual values.\n8. **Coefficient of Determination (R-squared, R2)**: Measures how well a model fits the data, with higher values indicating better fit.\n9. **Mean Absolute Percentage Error (MAPE)**: The average absolute percentage difference between predicted and actual values.\n10. **Normalized Mean Squared Error (NMSE)**: Similar to MSE, but normalized by the mean of the actual values.\n\nThese metrics can be used for various types of data, including:\n\n* Regression problems (e.g., predicting continuous values)\n* Classification problems (e.g., predicting categorical labels)\n* Time series forecasting\n* Clustering and dimensionality reduction\n\nWhen choosing a metric, consider the specific problem you're trying to solve, the type of data, and the desired level of precision.",
|
||||
"thinking": null,
|
||||
"context": null
|
||||
}
|
||||
},
|
||||
"is_streaming": false
|
||||
}
|
||||
}
|
56
tests/integration/recordings/responses/bf79a89cc37f.json
Normal file
56
tests/integration/recordings/responses/bf79a89cc37f.json
Normal file
|
@ -0,0 +1,56 @@
|
|||
{
|
||||
"request": {
|
||||
"method": "POST",
|
||||
"url": "http://localhost:11434/v1/v1/chat/completions",
|
||||
"headers": {},
|
||||
"body": {
|
||||
"model": "llama3.2:3b",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "OpenAI test 3"
|
||||
}
|
||||
],
|
||||
"stream": false
|
||||
},
|
||||
"endpoint": "/v1/chat/completions",
|
||||
"model": "llama3.2:3b"
|
||||
},
|
||||
"response": {
|
||||
"body": {
|
||||
"__type__": "openai.types.chat.chat_completion.ChatCompletion",
|
||||
"__data__": {
|
||||
"id": "chatcmpl-48",
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "stop",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"message": {
|
||||
"content": "I'm happy to help, but it seems you want me to engage in a basic conversation as OpenAI's new chat model, right? I can do that!\n\nHere's my response:\n\nHello! How are you today? Is there something specific on your mind that you'd like to talk about or any particular topic you'd like to explore together?\n\nWhat is it that you're curious about?",
|
||||
"refusal": null,
|
||||
"role": "assistant",
|
||||
"annotations": null,
|
||||
"audio": null,
|
||||
"function_call": null,
|
||||
"tool_calls": null
|
||||
}
|
||||
}
|
||||
],
|
||||
"created": 1755891524,
|
||||
"model": "llama3.2:3b",
|
||||
"object": "chat.completion",
|
||||
"service_tier": null,
|
||||
"system_fingerprint": "fp_ollama",
|
||||
"usage": {
|
||||
"completion_tokens": 80,
|
||||
"prompt_tokens": 30,
|
||||
"total_tokens": 110,
|
||||
"completion_tokens_details": null,
|
||||
"prompt_tokens_details": null
|
||||
}
|
||||
}
|
||||
},
|
||||
"is_streaming": false
|
||||
}
|
||||
}
|
39
tests/integration/recordings/responses/c31a86ea6c58.json
Normal file
39
tests/integration/recordings/responses/c31a86ea6c58.json
Normal file
|
@ -0,0 +1,39 @@
|
|||
{
|
||||
"request": {
|
||||
"method": "POST",
|
||||
"url": "http://localhost:11434/api/generate",
|
||||
"headers": {},
|
||||
"body": {
|
||||
"model": "llama3.2:3b",
|
||||
"raw": true,
|
||||
"prompt": "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nTest metrics generation 0<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
|
||||
"options": {
|
||||
"temperature": 0.0
|
||||
},
|
||||
"stream": false
|
||||
},
|
||||
"endpoint": "/api/generate",
|
||||
"model": "llama3.2:3b"
|
||||
},
|
||||
"response": {
|
||||
"body": {
|
||||
"__type__": "ollama._types.GenerateResponse",
|
||||
"__data__": {
|
||||
"model": "llama3.2:3b",
|
||||
"created_at": "2025-08-11T15:56:06.703788Z",
|
||||
"done": true,
|
||||
"done_reason": "stop",
|
||||
"total_duration": 2722294000,
|
||||
"load_duration": 9736083,
|
||||
"prompt_eval_count": 21,
|
||||
"prompt_eval_duration": 113000000,
|
||||
"eval_count": 324,
|
||||
"eval_duration": 2598000000,
|
||||
"response": "Here are some test metrics that can be used to evaluate the performance of a system:\n\n1. **Accuracy**: The proportion of correct predictions made by the model.\n2. **Precision**: The ratio of true positives (correctly predicted instances) to total positive predictions.\n3. **Recall**: The ratio of true positives to the sum of true positives and false negatives (missed instances).\n4. **F1-score**: The harmonic mean of precision and recall, providing a balanced measure of both.\n5. **Mean Squared Error (MSE)**: The average squared difference between predicted and actual values.\n6. **Mean Absolute Error (MAE)**: The average absolute difference between predicted and actual values.\n7. **Root Mean Squared Percentage Error (RMSPE)**: A variation of MSE that expresses the error as a percentage.\n8. **Coefficient of Determination (R-squared, R2)**: Measures how well the model explains the variance in the data.\n9. **Mean Absolute Percentage Error (MAPE)**: The average absolute percentage difference between predicted and actual values.\n10. **Mean Squared Logarithmic Error (MSLE)**: A variation of MSE that is more suitable for skewed distributions.\n\nThese metrics can be used to evaluate different aspects of a system's performance, such as:\n\n* Classification models: accuracy, precision, recall, F1-score\n* Regression models: MSE, MAE, RMSPE, R2, MSLE\n* Time series forecasting: MAPE, RMSPE\n\nNote that the choice of metric depends on the specific problem and data.",
|
||||
"thinking": null,
|
||||
"context": null
|
||||
}
|
||||
},
|
||||
"is_streaming": false
|
||||
}
|
||||
}
|
|
@ -13,12 +13,12 @@
|
|||
"__data__": {
|
||||
"models": [
|
||||
{
|
||||
"model": "llama3.2:3b",
|
||||
"name": "llama3.2:3b",
|
||||
"digest": "a80c4f17acd55265feec403c7aef86be0c25983ab279d83f3bcd3abbcb5b8b72",
|
||||
"expires_at": "2025-08-06T15:57:21.573326-04:00",
|
||||
"size": 4030033920,
|
||||
"size_vram": 4030033920,
|
||||
"model": "llama3.2:3b-instruct-fp16",
|
||||
"name": "llama3.2:3b-instruct-fp16",
|
||||
"digest": "195a8c01d91ec3cb1e0aad4624a51f2602c51fa7d96110f8ab5a20c84081804d",
|
||||
"expires_at": "2025-08-18T13:47:44.262256-07:00",
|
||||
"size": 7919570944,
|
||||
"size_vram": 7919570944,
|
||||
"details": {
|
||||
"parent_model": "",
|
||||
"format": "gguf",
|
||||
|
@ -27,7 +27,7 @@
|
|||
"llama"
|
||||
],
|
||||
"parameter_size": "3.2B",
|
||||
"quantization_level": "Q4_K_M"
|
||||
"quantization_level": "F16"
|
||||
}
|
||||
}
|
||||
]
|
||||
|
|
56
tests/integration/recordings/responses/dc8120cf0774.json
Normal file
56
tests/integration/recordings/responses/dc8120cf0774.json
Normal file
|
@ -0,0 +1,56 @@
|
|||
{
|
||||
"request": {
|
||||
"method": "POST",
|
||||
"url": "http://localhost:11434/v1/v1/chat/completions",
|
||||
"headers": {},
|
||||
"body": {
|
||||
"model": "llama3.2:3b",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "OpenAI test 2"
|
||||
}
|
||||
],
|
||||
"stream": false
|
||||
},
|
||||
"endpoint": "/v1/chat/completions",
|
||||
"model": "llama3.2:3b"
|
||||
},
|
||||
"response": {
|
||||
"body": {
|
||||
"__type__": "openai.types.chat.chat_completion.ChatCompletion",
|
||||
"__data__": {
|
||||
"id": "chatcmpl-516",
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "stop",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"message": {
|
||||
"content": "I'm happy to help with your question or task. Please go ahead and ask me anything, and I'll do my best to assist you.\n\nNote: I'll be using the latest version of my knowledge cutoff, which is December 2023.\n\nAlso, please keep in mind that I'm a large language model, I can provide information on a broad range of topics, including science, history, technology, culture, and more. However, my ability to understand and respond to specific questions or requests may be limited by the data I've been trained on.",
|
||||
"refusal": null,
|
||||
"role": "assistant",
|
||||
"annotations": null,
|
||||
"audio": null,
|
||||
"function_call": null,
|
||||
"tool_calls": null
|
||||
}
|
||||
}
|
||||
],
|
||||
"created": 1755891522,
|
||||
"model": "llama3.2:3b",
|
||||
"object": "chat.completion",
|
||||
"service_tier": null,
|
||||
"system_fingerprint": "fp_ollama",
|
||||
"usage": {
|
||||
"completion_tokens": 113,
|
||||
"prompt_tokens": 30,
|
||||
"total_tokens": 143,
|
||||
"completion_tokens_details": null,
|
||||
"prompt_tokens_details": null
|
||||
}
|
||||
}
|
||||
},
|
||||
"is_streaming": false
|
||||
}
|
||||
}
|
109
tests/integration/recordings/responses/decfd950646c.json
Normal file
109
tests/integration/recordings/responses/decfd950646c.json
Normal file
|
@ -0,0 +1,109 @@
|
|||
{
|
||||
"request": {
|
||||
"method": "POST",
|
||||
"url": "http://localhost:11434/v1/v1/chat/completions",
|
||||
"headers": {},
|
||||
"body": {
|
||||
"model": "llama3.2:3b-instruct-fp16",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What's the weather in Tokyo? YOU MUST USE THE get_weather function to get the weather."
|
||||
}
|
||||
],
|
||||
"response_format": {
|
||||
"type": "text"
|
||||
},
|
||||
"stream": true,
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"type": "function",
|
||||
"name": "get_weather",
|
||||
"description": "Get the weather in a given city",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "The city to get the weather for"
|
||||
}
|
||||
}
|
||||
},
|
||||
"strict": null
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
"endpoint": "/v1/chat/completions",
|
||||
"model": "llama3.2:3b-instruct-fp16"
|
||||
},
|
||||
"response": {
|
||||
"body": [
|
||||
{
|
||||
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
|
||||
"__data__": {
|
||||
"id": "chatcmpl-620",
|
||||
"choices": [
|
||||
{
|
||||
"delta": {
|
||||
"content": "",
|
||||
"function_call": null,
|
||||
"refusal": null,
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"index": 0,
|
||||
"id": "call_490d5ur7",
|
||||
"function": {
|
||||
"arguments": "{\"city\":\"Tokyo\"}",
|
||||
"name": "get_weather"
|
||||
},
|
||||
"type": "function"
|
||||
}
|
||||
]
|
||||
},
|
||||
"finish_reason": null,
|
||||
"index": 0,
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1755228972,
|
||||
"model": "llama3.2:3b-instruct-fp16",
|
||||
"object": "chat.completion.chunk",
|
||||
"service_tier": null,
|
||||
"system_fingerprint": "fp_ollama",
|
||||
"usage": null
|
||||
}
|
||||
},
|
||||
{
|
||||
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
|
||||
"__data__": {
|
||||
"id": "chatcmpl-620",
|
||||
"choices": [
|
||||
{
|
||||
"delta": {
|
||||
"content": "",
|
||||
"function_call": null,
|
||||
"refusal": null,
|
||||
"role": "assistant",
|
||||
"tool_calls": null
|
||||
},
|
||||
"finish_reason": "tool_calls",
|
||||
"index": 0,
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1755228972,
|
||||
"model": "llama3.2:3b-instruct-fp16",
|
||||
"object": "chat.completion.chunk",
|
||||
"service_tier": null,
|
||||
"system_fingerprint": "fp_ollama",
|
||||
"usage": null
|
||||
}
|
||||
}
|
||||
],
|
||||
"is_streaming": true
|
||||
}
|
||||
}
|
39
tests/integration/recordings/responses/f6857bcea729.json
Normal file
39
tests/integration/recordings/responses/f6857bcea729.json
Normal file
|
@ -0,0 +1,39 @@
|
|||
{
|
||||
"request": {
|
||||
"method": "POST",
|
||||
"url": "http://localhost:11434/api/generate",
|
||||
"headers": {},
|
||||
"body": {
|
||||
"model": "llama3.2:3b",
|
||||
"raw": true,
|
||||
"prompt": "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nTest metrics generation 2<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
|
||||
"options": {
|
||||
"temperature": 0.0
|
||||
},
|
||||
"stream": false
|
||||
},
|
||||
"endpoint": "/api/generate",
|
||||
"model": "llama3.2:3b"
|
||||
},
|
||||
"response": {
|
||||
"body": {
|
||||
"__type__": "ollama._types.GenerateResponse",
|
||||
"__data__": {
|
||||
"model": "llama3.2:3b",
|
||||
"created_at": "2025-08-11T15:56:13.082679Z",
|
||||
"done": true,
|
||||
"done_reason": "stop",
|
||||
"total_duration": 2606245291,
|
||||
"load_duration": 9979708,
|
||||
"prompt_eval_count": 21,
|
||||
"prompt_eval_duration": 23000000,
|
||||
"eval_count": 321,
|
||||
"eval_duration": 2572000000,
|
||||
"response": "Here are some test metrics that can be used to evaluate the performance of a system:\n\n1. **Accuracy**: Measures how close the predicted values are to the actual values.\n2. **Precision**: Measures the proportion of true positives among all positive predictions made by the model.\n3. **Recall**: Measures the proportion of true positives among all actual positive instances.\n4. **F1-score**: The harmonic mean of precision and recall, providing a balanced measure of both.\n5. **Mean Squared Error (MSE)**: Measures the average squared difference between predicted and actual values.\n6. **Mean Absolute Error (MAE)**: Measures the average absolute difference between predicted and actual values.\n7. **Root Mean Squared Percentage Error (RMSPE)**: A variation of MSE that expresses errors as a percentage of the actual value.\n8. **Coefficient of Determination (R-squared, R2)**: Measures how well the model explains the variance in the data.\n9. **Mean Absolute Percentage Error (MAPE)**: Measures the average absolute percentage difference between predicted and actual values.\n10. **Mean Squared Logarithmic Error (MSLE)**: A variation of MSE that is more suitable for skewed distributions.\n\nThese metrics can be used to evaluate different aspects of a system's performance, such as:\n\n* Classification models: accuracy, precision, recall, F1-score\n* Regression models: MSE, MAE, RMSPE, R2\n* Time series forecasting: MAPE, MSLE\n\nNote that the choice of metric depends on the specific problem and data.",
|
||||
"thinking": null,
|
||||
"context": null
|
||||
}
|
||||
},
|
||||
"is_streaming": false
|
||||
}
|
||||
}
|
39
tests/integration/recordings/responses/f80b99430f7e.json
Normal file
39
tests/integration/recordings/responses/f80b99430f7e.json
Normal file
|
@ -0,0 +1,39 @@
|
|||
{
|
||||
"request": {
|
||||
"method": "POST",
|
||||
"url": "http://localhost:11434/api/generate",
|
||||
"headers": {},
|
||||
"body": {
|
||||
"model": "llama3.2:3b",
|
||||
"raw": true,
|
||||
"prompt": "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nTest metrics generation 1<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
|
||||
"options": {
|
||||
"temperature": 0.0
|
||||
},
|
||||
"stream": false
|
||||
},
|
||||
"endpoint": "/api/generate",
|
||||
"model": "llama3.2:3b"
|
||||
},
|
||||
"response": {
|
||||
"body": {
|
||||
"__type__": "ollama._types.GenerateResponse",
|
||||
"__data__": {
|
||||
"model": "llama3.2:3b",
|
||||
"created_at": "2025-08-11T15:56:10.465932Z",
|
||||
"done": true,
|
||||
"done_reason": "stop",
|
||||
"total_duration": 3745686709,
|
||||
"load_duration": 9734584,
|
||||
"prompt_eval_count": 21,
|
||||
"prompt_eval_duration": 23000000,
|
||||
"eval_count": 457,
|
||||
"eval_duration": 3712000000,
|
||||
"response": "Here are some test metrics that can be used to evaluate the performance of a system:\n\n**Primary Metrics**\n\n1. **Response Time**: The time it takes for the system to respond to a request.\n2. **Throughput**: The number of requests processed by the system per unit time (e.g., requests per second).\n3. **Error Rate**: The percentage of requests that result in an error.\n\n**Secondary Metrics**\n\n1. **Average Response Time**: The average response time for all requests.\n2. **Median Response Time**: The middle value of the response times, used to detect outliers.\n3. **99th Percentile Response Time**: The response time at which 99% of requests are completed within this time.\n4. **Request Latency**: The difference between the request arrival time and the response time.\n\n**User Experience Metrics**\n\n1. **User Satisfaction (USAT)**: Measured through surveys or feedback forms to gauge user satisfaction with the system's performance.\n2. **First Response Time**: The time it takes for a user to receive their first response from the system.\n3. **Time Spent in System**: The total amount of time a user spends interacting with the system.\n\n**System Resource Metrics**\n\n1. **CPU Utilization**: The percentage of CPU resources being used by the system.\n2. **Memory Usage**: The amount of memory being used by the system.\n3. **Disk I/O Wait Time**: The average time spent waiting for disk I/O operations to complete.\n\n**Security Metrics**\n\n1. **Authentication Success Rate**: The percentage of successful authentication attempts.\n2. **Authorization Success Rate**: The percentage of successful authorization attempts.\n3. **Error Rate (Security)**: The percentage of security-related errors.\n\n**Other Metrics**\n\n1. **Page Load Time**: The time it takes for a page to load.\n2. **Click-Through Rate (CTR)**: The percentage of users who click on a link or button after seeing an ad or notification.\n3. **Conversion Rate**: The percentage of users who complete a desired action (e.g., fill out a form, make a purchase).\n\nThese metrics can be used to evaluate the performance and effectiveness of various aspects of your system, from user experience to security and resource utilization.",
|
||||
"thinking": null,
|
||||
"context": null
|
||||
}
|
||||
},
|
||||
"is_streaming": false
|
||||
}
|
||||
}
|
|
@ -27,10 +27,10 @@ def data_url_from_image(file_path):
|
|||
|
||||
@pytest.fixture(scope="session")
|
||||
def code_scanner_shield_id(available_shields):
|
||||
if "CodeScanner" in available_shields:
|
||||
return "CodeScanner"
|
||||
if "code-scanner" in available_shields:
|
||||
return "code-scanner"
|
||||
|
||||
pytest.skip("CodeScanner shield is not available. Skipping.")
|
||||
pytest.skip("code-scanner shield is not available. Skipping.")
|
||||
|
||||
|
||||
def test_unsafe_examples(client_with_models, shield_id):
|
||||
|
@ -163,6 +163,60 @@ def test_safety_with_code_scanner(client_with_models, code_scanner_shield_id, mo
|
|||
assert response.violation.user_message == "Sorry, I found security concerns in the code."
|
||||
|
||||
|
||||
def test_unsafe_input_with_code_scanner_run_moderation(client_with_models, code_scanner_shield_id, model_providers):
|
||||
code_safety_enabled = len(CODE_SCANNER_ENABLED_PROVIDERS & model_providers) > 0
|
||||
if not code_safety_enabled:
|
||||
pytest.skip(f"Testing code scanner shields is not supported for model_providers {model_providers}")
|
||||
|
||||
insecure_code = [
|
||||
"""
|
||||
Here is an example of code that hashes a password using MD5:
|
||||
|
||||
```
|
||||
import hashlib
|
||||
password = 'securepassword123'.encode()
|
||||
hashed_password = hashlib.md5(password).hexdigest()
|
||||
print("Insecure hashed password:", hashed_password)
|
||||
```
|
||||
"""
|
||||
]
|
||||
moderation_object = client_with_models.moderations.create(
|
||||
input=insecure_code,
|
||||
model=code_scanner_shield_id,
|
||||
)
|
||||
assert moderation_object.results[0].flagged is True, f"Code scanner should have flagged {insecure_code} as insecure"
|
||||
assert all(value is True for value in moderation_object.results[0].categories.values()), (
|
||||
"Code scanner shield should have detected code insecure category"
|
||||
)
|
||||
|
||||
|
||||
def test_safe_input_with_code_scanner_run_moderation(client_with_models, code_scanner_shield_id, model_providers):
|
||||
code_safety_enabled = len(CODE_SCANNER_ENABLED_PROVIDERS & model_providers) > 0
|
||||
if not code_safety_enabled:
|
||||
pytest.skip(f"Testing code scanner shields is not supported for model_providers {model_providers}")
|
||||
|
||||
secure_code = [
|
||||
"""
|
||||
Extract the first 5 characters from a string:
|
||||
```
|
||||
text = "Hello World"
|
||||
first_five = text[:5]
|
||||
print(first_five) # Output: "Hello"
|
||||
|
||||
# Safe handling for strings shorter than 5 characters
|
||||
def get_first_five(text):
|
||||
return text[:5] if text else ""
|
||||
```
|
||||
"""
|
||||
]
|
||||
moderation_object = client_with_models.moderations.create(
|
||||
input=secure_code,
|
||||
model=code_scanner_shield_id,
|
||||
)
|
||||
|
||||
assert moderation_object.results[0].flagged is False, "Code scanner should not have flagged the code as insecure"
|
||||
|
||||
|
||||
# We can use an instance of the LlamaGuard shield to detect attempts to misuse
|
||||
# the interpreter as this is one of the existing categories it checks for
|
||||
def test_safety_with_code_interpreter_abuse(client_with_models, shield_id):
|
||||
|
|
209
tests/integration/telemetry/test_telemetry_metrics.py
Normal file
209
tests/integration/telemetry/test_telemetry_metrics.py
Normal file
|
@ -0,0 +1,209 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import time
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
def setup_telemetry_metrics_data(openai_client, client_with_models, text_model_id):
|
||||
"""Setup fixture that creates telemetry metrics data before tests run."""
|
||||
|
||||
# Skip OpenAI tests if running in library mode
|
||||
if not hasattr(client_with_models, "base_url"):
|
||||
pytest.skip("OpenAI client tests not supported with library client")
|
||||
|
||||
prompt_tokens = []
|
||||
completion_tokens = []
|
||||
total_tokens = []
|
||||
|
||||
# Create OpenAI completions to generate metrics using the proper OpenAI client
|
||||
for i in range(5):
|
||||
response = openai_client.chat.completions.create(
|
||||
model=text_model_id,
|
||||
messages=[{"role": "user", "content": f"OpenAI test {i}"}],
|
||||
stream=False,
|
||||
)
|
||||
prompt_tokens.append(response.usage.prompt_tokens)
|
||||
completion_tokens.append(response.usage.completion_tokens)
|
||||
total_tokens.append(response.usage.total_tokens)
|
||||
|
||||
# Wait for metrics to be logged
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < 30:
|
||||
try:
|
||||
# Try to query metrics to see if they're available
|
||||
metrics_response = client_with_models.telemetry.query_metrics(
|
||||
metric_name="completion_tokens",
|
||||
start_time=int((datetime.now(UTC) - timedelta(minutes=5)).timestamp()),
|
||||
)
|
||||
if len(metrics_response[0].values) > 0:
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
time.sleep(1)
|
||||
|
||||
# Wait additional time to ensure all metrics are processed
|
||||
time.sleep(5)
|
||||
|
||||
# Return the token lists for use in tests
|
||||
return {"prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens, "total_tokens": total_tokens}
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Skipping this test until client is regenerated")
|
||||
def test_query_metrics_prompt_tokens(client_with_models, text_model_id, setup_telemetry_metrics_data):
|
||||
"""Test that prompt_tokens metrics are queryable."""
|
||||
start_time = int((datetime.now(UTC) - timedelta(minutes=10)).timestamp())
|
||||
|
||||
response = client_with_models.telemetry.query_metrics(
|
||||
metric_name="prompt_tokens",
|
||||
start_time=start_time,
|
||||
)
|
||||
|
||||
assert isinstance(response, list)
|
||||
|
||||
assert isinstance(response[0].values, list), "Should return a list of metric series"
|
||||
|
||||
assert response[0].metric == "prompt_tokens"
|
||||
|
||||
# Use the actual values from setup instead of hardcoded values
|
||||
expected_values = setup_telemetry_metrics_data["prompt_tokens"]
|
||||
assert response[0].values[-1].value in expected_values, (
|
||||
f"Expected one of {expected_values}, got {response[0].values[-1].value}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Skipping this test until client is regenerated")
|
||||
def test_query_metrics_completion_tokens(client_with_models, text_model_id, setup_telemetry_metrics_data):
|
||||
"""Test that completion_tokens metrics are queryable."""
|
||||
start_time = int((datetime.now(UTC) - timedelta(minutes=10)).timestamp())
|
||||
|
||||
response = client_with_models.telemetry.query_metrics(
|
||||
metric_name="completion_tokens",
|
||||
start_time=start_time,
|
||||
)
|
||||
|
||||
assert isinstance(response, list)
|
||||
|
||||
assert isinstance(response[0].values, list), "Should return a list of metric series"
|
||||
|
||||
assert response[0].metric == "completion_tokens"
|
||||
|
||||
# Use the actual values from setup instead of hardcoded values
|
||||
expected_values = setup_telemetry_metrics_data["completion_tokens"]
|
||||
assert response[0].values[-1].value in expected_values, (
|
||||
f"Expected one of {expected_values}, got {response[0].values[-1].value}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Skipping this test until client is regenerated")
|
||||
def test_query_metrics_total_tokens(client_with_models, text_model_id, setup_telemetry_metrics_data):
|
||||
"""Test that total_tokens metrics are queryable."""
|
||||
start_time = int((datetime.now(UTC) - timedelta(minutes=10)).timestamp())
|
||||
|
||||
response = client_with_models.telemetry.query_metrics(
|
||||
metric_name="total_tokens",
|
||||
start_time=start_time,
|
||||
)
|
||||
|
||||
assert isinstance(response, list)
|
||||
|
||||
assert isinstance(response[0].values, list), "Should return a list of metric series"
|
||||
|
||||
assert response[0].metric == "total_tokens"
|
||||
|
||||
# Use the actual values from setup instead of hardcoded values
|
||||
expected_values = setup_telemetry_metrics_data["total_tokens"]
|
||||
assert response[0].values[-1].value in expected_values, (
|
||||
f"Expected one of {expected_values}, got {response[0].values[-1].value}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Skipping this test until client is regenerated")
|
||||
def test_query_metrics_with_time_range(llama_stack_client, text_model_id):
|
||||
"""Test that metrics are queryable with time range."""
|
||||
end_time = int(datetime.now(UTC).timestamp())
|
||||
start_time = end_time - 600 # 10 minutes ago
|
||||
|
||||
response = llama_stack_client.telemetry.query_metrics(
|
||||
metric_name="prompt_tokens",
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
|
||||
assert isinstance(response, list)
|
||||
|
||||
assert isinstance(response[0].values, list), "Should return a list of metric series"
|
||||
|
||||
assert response[0].metric == "prompt_tokens"
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Skipping this test until client is regenerated")
|
||||
def test_query_metrics_with_label_matchers(llama_stack_client, text_model_id):
|
||||
"""Test that metrics are queryable with label matchers."""
|
||||
start_time = int((datetime.now(UTC) - timedelta(minutes=10)).timestamp())
|
||||
|
||||
response = llama_stack_client.telemetry.query_metrics(
|
||||
metric_name="prompt_tokens",
|
||||
start_time=start_time,
|
||||
label_matchers=[{"name": "model_id", "value": text_model_id, "operator": "="}],
|
||||
)
|
||||
|
||||
assert isinstance(response[0].values, list), "Should return a list of metric series"
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Skipping this test until client is regenerated")
|
||||
def test_query_metrics_nonexistent_metric(llama_stack_client):
|
||||
"""Test that querying a nonexistent metric returns empty data."""
|
||||
start_time = int((datetime.now(UTC) - timedelta(minutes=10)).timestamp())
|
||||
|
||||
response = llama_stack_client.telemetry.query_metrics(
|
||||
metric_name="nonexistent_metric",
|
||||
start_time=start_time,
|
||||
)
|
||||
|
||||
assert isinstance(response, list), "Should return an empty list for nonexistent metric"
|
||||
assert len(response) == 0
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Skipping this test until client is regenerated")
|
||||
def test_query_metrics_with_granularity(llama_stack_client, text_model_id):
|
||||
"""Test that metrics are queryable with different granularity levels."""
|
||||
start_time = int((datetime.now(UTC) - timedelta(minutes=10)).timestamp())
|
||||
|
||||
# Test hourly granularity
|
||||
hourly_response = llama_stack_client.telemetry.query_metrics(
|
||||
metric_name="total_tokens",
|
||||
start_time=start_time,
|
||||
granularity="1h",
|
||||
)
|
||||
|
||||
# Test daily granularity
|
||||
daily_response = llama_stack_client.telemetry.query_metrics(
|
||||
metric_name="total_tokens",
|
||||
start_time=start_time,
|
||||
granularity="1d",
|
||||
)
|
||||
|
||||
# Test no granularity (raw data points)
|
||||
raw_response = llama_stack_client.telemetry.query_metrics(
|
||||
metric_name="total_tokens",
|
||||
start_time=start_time,
|
||||
granularity=None,
|
||||
)
|
||||
|
||||
# All should return valid data
|
||||
assert isinstance(hourly_response[0].values, list), "Hourly granularity should return data"
|
||||
assert isinstance(daily_response[0].values, list), "Daily granularity should return data"
|
||||
assert isinstance(raw_response[0].values, list), "No granularity should return data"
|
||||
|
||||
# Verify that different granularities produce different aggregation levels
|
||||
# (The exact number depends on data distribution, but they should be queryable)
|
||||
assert len(hourly_response[0].values) >= 0, "Hourly granularity should be queryable"
|
||||
assert len(daily_response[0].values) >= 0, "Daily granularity should be queryable"
|
||||
assert len(raw_response[0].values) >= 0, "No granularity should be queryable"
|
|
@ -4,7 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from io import BytesIO
|
||||
|
@ -15,8 +14,9 @@ from openai import BadRequestError as OpenAIBadRequestError
|
|||
|
||||
from llama_stack.apis.vector_io import Chunk
|
||||
from llama_stack.core.library_client import LlamaStackAsLibraryClient
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = get_logger(name=__name__, category="vector_io")
|
||||
|
||||
|
||||
def skip_if_provider_doesnt_support_openai_vector_stores(client_with_models):
|
||||
|
@ -57,6 +57,7 @@ def skip_if_provider_doesnt_support_openai_vector_stores_search(client_with_mode
|
|||
"keyword": [
|
||||
"inline::sqlite-vec",
|
||||
"remote::milvus",
|
||||
"inline::milvus",
|
||||
],
|
||||
"hybrid": [
|
||||
"inline::sqlite-vec",
|
||||
|
|
|
@ -5,86 +5,121 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Unit tests for LlamaStackAsLibraryClient initialization error handling.
|
||||
Unit tests for LlamaStackAsLibraryClient automatic initialization.
|
||||
|
||||
These tests ensure that users get proper error messages when they forget to call
|
||||
initialize() on the library client, preventing AttributeError regressions.
|
||||
These tests ensure that the library client is automatically initialized
|
||||
and ready to use immediately after construction.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.core.library_client import (
|
||||
AsyncLlamaStackAsLibraryClient,
|
||||
LlamaStackAsLibraryClient,
|
||||
)
|
||||
from llama_stack.core.server.routes import RouteImpls
|
||||
|
||||
|
||||
class TestLlamaStackAsLibraryClientInitialization:
|
||||
"""Test proper error handling for uninitialized library clients."""
|
||||
class TestLlamaStackAsLibraryClientAutoInitialization:
|
||||
"""Test automatic initialization of library clients."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"api_call",
|
||||
[
|
||||
lambda client: client.models.list(),
|
||||
lambda client: client.chat.completions.create(model="test", messages=[{"role": "user", "content": "test"}]),
|
||||
lambda client: next(
|
||||
client.chat.completions.create(
|
||||
model="test", messages=[{"role": "user", "content": "test"}], stream=True
|
||||
)
|
||||
),
|
||||
],
|
||||
ids=["models.list", "chat.completions.create", "chat.completions.create_stream"],
|
||||
)
|
||||
def test_sync_client_proper_error_without_initialization(self, api_call):
|
||||
"""Test that sync client raises ValueError with helpful message when not initialized."""
|
||||
client = LlamaStackAsLibraryClient("nvidia")
|
||||
def test_sync_client_auto_initialization(self, monkeypatch):
|
||||
"""Test that sync client is automatically initialized after construction."""
|
||||
# Mock the stack construction to avoid dependency issues
|
||||
mock_impls = {}
|
||||
mock_route_impls = RouteImpls({})
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
api_call(client)
|
||||
async def mock_construct_stack(config, custom_provider_registry):
|
||||
return mock_impls
|
||||
|
||||
error_msg = str(exc_info.value)
|
||||
assert "Client not initialized" in error_msg
|
||||
assert "Please call initialize() first" in error_msg
|
||||
def mock_initialize_route_impls(impls):
|
||||
return mock_route_impls
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"api_call",
|
||||
[
|
||||
lambda client: client.models.list(),
|
||||
lambda client: client.chat.completions.create(model="test", messages=[{"role": "user", "content": "test"}]),
|
||||
],
|
||||
ids=["models.list", "chat.completions.create"],
|
||||
)
|
||||
async def test_async_client_proper_error_without_initialization(self, api_call):
|
||||
"""Test that async client raises ValueError with helpful message when not initialized."""
|
||||
client = AsyncLlamaStackAsLibraryClient("nvidia")
|
||||
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack)
|
||||
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
await api_call(client)
|
||||
client = LlamaStackAsLibraryClient("ci-tests")
|
||||
|
||||
error_msg = str(exc_info.value)
|
||||
assert "Client not initialized" in error_msg
|
||||
assert "Please call initialize() first" in error_msg
|
||||
assert client.async_client.route_impls is not None
|
||||
|
||||
async def test_async_client_streaming_error_without_initialization(self):
|
||||
"""Test that async client streaming raises ValueError with helpful message when not initialized."""
|
||||
client = AsyncLlamaStackAsLibraryClient("nvidia")
|
||||
async def test_async_client_auto_initialization(self, monkeypatch):
|
||||
"""Test that async client can be initialized and works properly."""
|
||||
# Mock the stack construction to avoid dependency issues
|
||||
mock_impls = {}
|
||||
mock_route_impls = RouteImpls({})
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
stream = await client.chat.completions.create(
|
||||
model="test", messages=[{"role": "user", "content": "test"}], stream=True
|
||||
)
|
||||
await anext(stream)
|
||||
async def mock_construct_stack(config, custom_provider_registry):
|
||||
return mock_impls
|
||||
|
||||
error_msg = str(exc_info.value)
|
||||
assert "Client not initialized" in error_msg
|
||||
assert "Please call initialize() first" in error_msg
|
||||
def mock_initialize_route_impls(impls):
|
||||
return mock_route_impls
|
||||
|
||||
def test_route_impls_initialized_to_none(self):
|
||||
"""Test that route_impls is initialized to None to prevent AttributeError."""
|
||||
# Test sync client
|
||||
sync_client = LlamaStackAsLibraryClient("nvidia")
|
||||
assert sync_client.async_client.route_impls is None
|
||||
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack)
|
||||
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
|
||||
|
||||
# Test async client directly
|
||||
async_client = AsyncLlamaStackAsLibraryClient("nvidia")
|
||||
assert async_client.route_impls is None
|
||||
client = AsyncLlamaStackAsLibraryClient("ci-tests")
|
||||
|
||||
# Initialize the client
|
||||
result = await client.initialize()
|
||||
assert result is True
|
||||
assert client.route_impls is not None
|
||||
|
||||
def test_initialize_method_backward_compatibility(self, monkeypatch):
|
||||
"""Test that initialize() method still works for backward compatibility."""
|
||||
# Mock the stack construction to avoid dependency issues
|
||||
mock_impls = {}
|
||||
mock_route_impls = RouteImpls({})
|
||||
|
||||
async def mock_construct_stack(config, custom_provider_registry):
|
||||
return mock_impls
|
||||
|
||||
def mock_initialize_route_impls(impls):
|
||||
return mock_route_impls
|
||||
|
||||
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack)
|
||||
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
|
||||
|
||||
client = LlamaStackAsLibraryClient("ci-tests")
|
||||
|
||||
result = client.initialize()
|
||||
assert result is None
|
||||
|
||||
result2 = client.initialize()
|
||||
assert result2 is None
|
||||
|
||||
async def test_async_initialize_method_idempotent(self, monkeypatch):
|
||||
"""Test that async initialize() method can be called multiple times safely."""
|
||||
mock_impls = {}
|
||||
mock_route_impls = RouteImpls({})
|
||||
|
||||
async def mock_construct_stack(config, custom_provider_registry):
|
||||
return mock_impls
|
||||
|
||||
def mock_initialize_route_impls(impls):
|
||||
return mock_route_impls
|
||||
|
||||
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack)
|
||||
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
|
||||
|
||||
client = AsyncLlamaStackAsLibraryClient("ci-tests")
|
||||
|
||||
result1 = await client.initialize()
|
||||
assert result1 is True
|
||||
|
||||
result2 = await client.initialize()
|
||||
assert result2 is True
|
||||
|
||||
def test_route_impls_automatically_set(self, monkeypatch):
|
||||
"""Test that route_impls is automatically set during construction."""
|
||||
mock_impls = {}
|
||||
mock_route_impls = RouteImpls({})
|
||||
|
||||
async def mock_construct_stack(config, custom_provider_registry):
|
||||
return mock_impls
|
||||
|
||||
def mock_initialize_route_impls(impls):
|
||||
return mock_route_impls
|
||||
|
||||
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack)
|
||||
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
|
||||
|
||||
sync_client = LlamaStackAsLibraryClient("ci-tests")
|
||||
assert sync_client.async_client.route_impls is not None
|
||||
|
|
|
@ -7,6 +7,7 @@
|
|||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.common.errors import ResourceNotFoundError
|
||||
from llama_stack.apis.common.responses import Order
|
||||
from llama_stack.apis.files import OpenAIFilePurpose
|
||||
from llama_stack.core.access_control.access_control import default_policy
|
||||
|
@ -190,7 +191,7 @@ class TestOpenAIFilesAPI:
|
|||
|
||||
async def test_retrieve_file_not_found(self, files_provider):
|
||||
"""Test retrieving a non-existent file."""
|
||||
with pytest.raises(ValueError, match="File with id file-nonexistent not found"):
|
||||
with pytest.raises(ResourceNotFoundError, match="not found"):
|
||||
await files_provider.openai_retrieve_file("file-nonexistent")
|
||||
|
||||
async def test_retrieve_file_content_success(self, files_provider, sample_text_file):
|
||||
|
@ -208,7 +209,7 @@ class TestOpenAIFilesAPI:
|
|||
|
||||
async def test_retrieve_file_content_not_found(self, files_provider):
|
||||
"""Test retrieving content of a non-existent file."""
|
||||
with pytest.raises(ValueError, match="File with id file-nonexistent not found"):
|
||||
with pytest.raises(ResourceNotFoundError, match="not found"):
|
||||
await files_provider.openai_retrieve_file_content("file-nonexistent")
|
||||
|
||||
async def test_delete_file_success(self, files_provider, sample_text_file):
|
||||
|
@ -229,12 +230,12 @@ class TestOpenAIFilesAPI:
|
|||
assert delete_response.deleted is True
|
||||
|
||||
# Verify file no longer exists
|
||||
with pytest.raises(ValueError, match=f"File with id {uploaded_file.id} not found"):
|
||||
with pytest.raises(ResourceNotFoundError, match="not found"):
|
||||
await files_provider.openai_retrieve_file(uploaded_file.id)
|
||||
|
||||
async def test_delete_file_not_found(self, files_provider):
|
||||
"""Test deleting a non-existent file."""
|
||||
with pytest.raises(ValueError, match="File with id file-nonexistent not found"):
|
||||
with pytest.raises(ResourceNotFoundError, match="not found"):
|
||||
await files_provider.openai_delete_file("file-nonexistent")
|
||||
|
||||
async def test_file_persistence_across_operations(self, files_provider, sample_text_file):
|
||||
|
|
|
@ -24,6 +24,7 @@ from llama_stack.apis.agents.openai_responses import (
|
|||
OpenAIResponseMessage,
|
||||
OpenAIResponseObjectWithInput,
|
||||
OpenAIResponseOutputMessageContentOutputText,
|
||||
OpenAIResponseOutputMessageMCPCall,
|
||||
OpenAIResponseOutputMessageWebSearchToolCall,
|
||||
OpenAIResponseText,
|
||||
OpenAIResponseTextFormat,
|
||||
|
@ -41,7 +42,7 @@ from llama_stack.apis.inference import (
|
|||
)
|
||||
from llama_stack.apis.tools.tools import Tool, ToolGroups, ToolInvocationResult, ToolParameter, ToolRuntime
|
||||
from llama_stack.core.access_control.access_control import default_policy
|
||||
from llama_stack.providers.inline.agents.meta_reference.openai_responses import (
|
||||
from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import (
|
||||
OpenAIResponsesImpl,
|
||||
)
|
||||
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
|
||||
|
@ -136,9 +137,12 @@ async def test_create_openai_response_with_string_input(openai_responses_impl, m
|
|||
input=input_text,
|
||||
model=model,
|
||||
temperature=0.1,
|
||||
stream=True, # Enable streaming to test content part events
|
||||
)
|
||||
|
||||
# Verify
|
||||
# For streaming response, collect all chunks
|
||||
chunks = [chunk async for chunk in result]
|
||||
|
||||
mock_inference_api.openai_chat_completion.assert_called_once_with(
|
||||
model=model,
|
||||
messages=[OpenAIUserMessageParam(role="user", content="What is the capital of Ireland?", name=None)],
|
||||
|
@ -147,11 +151,32 @@ async def test_create_openai_response_with_string_input(openai_responses_impl, m
|
|||
stream=True,
|
||||
temperature=0.1,
|
||||
)
|
||||
|
||||
# Should have content part events for text streaming
|
||||
# Expected: response.created, content_part.added, output_text.delta, content_part.done, response.completed
|
||||
assert len(chunks) >= 4
|
||||
assert chunks[0].type == "response.created"
|
||||
|
||||
# Check for content part events
|
||||
content_part_added_events = [c for c in chunks if c.type == "response.content_part.added"]
|
||||
content_part_done_events = [c for c in chunks if c.type == "response.content_part.done"]
|
||||
text_delta_events = [c for c in chunks if c.type == "response.output_text.delta"]
|
||||
|
||||
assert len(content_part_added_events) >= 1, "Should have content_part.added event for text"
|
||||
assert len(content_part_done_events) >= 1, "Should have content_part.done event for text"
|
||||
assert len(text_delta_events) >= 1, "Should have text delta events"
|
||||
|
||||
# Verify final event is completion
|
||||
assert chunks[-1].type == "response.completed"
|
||||
|
||||
# When streaming, the final response is in the last chunk
|
||||
final_response = chunks[-1].response
|
||||
assert final_response.model == model
|
||||
assert len(final_response.output) == 1
|
||||
assert isinstance(final_response.output[0], OpenAIResponseMessage)
|
||||
|
||||
openai_responses_impl.responses_store.store_response_object.assert_called_once()
|
||||
assert result.model == model
|
||||
assert len(result.output) == 1
|
||||
assert isinstance(result.output[0], OpenAIResponseMessage)
|
||||
assert result.output[0].content[0].text == "Dublin"
|
||||
assert final_response.output[0].content[0].text == "Dublin"
|
||||
|
||||
|
||||
async def test_create_openai_response_with_string_input_with_tools(openai_responses_impl, mock_inference_api):
|
||||
|
@ -272,6 +297,8 @@ async def test_create_openai_response_with_tool_call_type_none(openai_responses_
|
|||
|
||||
# Check that we got the content from our mocked tool execution result
|
||||
chunks = [chunk async for chunk in result]
|
||||
|
||||
# Verify event types
|
||||
# Should have: response.created, output_item.added, function_call_arguments.delta,
|
||||
# function_call_arguments.done, output_item.done, response.completed
|
||||
assert len(chunks) == 6
|
||||
|
@ -435,6 +462,53 @@ async def test_prepend_previous_response_web_search(openai_responses_impl, mock_
|
|||
assert input[3].content == "fake_input"
|
||||
|
||||
|
||||
async def test_prepend_previous_response_mcp_tool_call(openai_responses_impl, mock_responses_store):
|
||||
"""Test prepending a previous response which included an mcp tool call to a new response."""
|
||||
input_item_message = OpenAIResponseMessage(
|
||||
id="123",
|
||||
content=[OpenAIResponseInputMessageContentText(text="fake_previous_input")],
|
||||
role="user",
|
||||
)
|
||||
output_tool_call = OpenAIResponseOutputMessageMCPCall(
|
||||
id="ws_123",
|
||||
name="fake-tool",
|
||||
arguments="fake-arguments",
|
||||
server_label="fake-label",
|
||||
)
|
||||
output_message = OpenAIResponseMessage(
|
||||
id="123",
|
||||
content=[OpenAIResponseOutputMessageContentOutputText(text="fake_tool_call_response")],
|
||||
status="completed",
|
||||
role="assistant",
|
||||
)
|
||||
response = OpenAIResponseObjectWithInput(
|
||||
created_at=1,
|
||||
id="resp_123",
|
||||
model="fake_model",
|
||||
output=[output_tool_call, output_message],
|
||||
status="completed",
|
||||
text=OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")),
|
||||
input=[input_item_message],
|
||||
)
|
||||
mock_responses_store.get_response_object.return_value = response
|
||||
|
||||
input_messages = [OpenAIResponseMessage(content="fake_input", role="user")]
|
||||
input = await openai_responses_impl._prepend_previous_response(input_messages, "resp_123")
|
||||
|
||||
assert len(input) == 4
|
||||
# Check for previous input
|
||||
assert isinstance(input[0], OpenAIResponseMessage)
|
||||
assert input[0].content[0].text == "fake_previous_input"
|
||||
# Check for previous output MCP tool call
|
||||
assert isinstance(input[1], OpenAIResponseOutputMessageMCPCall)
|
||||
# Check for previous output web search response
|
||||
assert isinstance(input[2], OpenAIResponseMessage)
|
||||
assert input[2].content[0].text == "fake_tool_call_response"
|
||||
# Check for new input
|
||||
assert isinstance(input[3], OpenAIResponseMessage)
|
||||
assert input[3].content == "fake_input"
|
||||
|
||||
|
||||
async def test_create_openai_response_with_instructions(openai_responses_impl, mock_inference_api):
|
||||
# Setup
|
||||
input_text = "What is the capital of Ireland?"
|
||||
|
|
|
@ -0,0 +1,342 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
OpenAIResponseInputFunctionToolCallOutput,
|
||||
OpenAIResponseInputMessageContentImage,
|
||||
OpenAIResponseInputMessageContentText,
|
||||
OpenAIResponseInputToolFunction,
|
||||
OpenAIResponseInputToolWebSearch,
|
||||
OpenAIResponseMessage,
|
||||
OpenAIResponseOutputMessageContentOutputText,
|
||||
OpenAIResponseOutputMessageFunctionToolCall,
|
||||
OpenAIResponseText,
|
||||
OpenAIResponseTextFormat,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIChatCompletionContentPartImageParam,
|
||||
OpenAIChatCompletionContentPartTextParam,
|
||||
OpenAIChatCompletionToolCall,
|
||||
OpenAIChatCompletionToolCallFunction,
|
||||
OpenAIChoice,
|
||||
OpenAIDeveloperMessageParam,
|
||||
OpenAIResponseFormatJSONObject,
|
||||
OpenAIResponseFormatJSONSchema,
|
||||
OpenAIResponseFormatText,
|
||||
OpenAISystemMessageParam,
|
||||
OpenAIToolMessageParam,
|
||||
OpenAIUserMessageParam,
|
||||
)
|
||||
from llama_stack.providers.inline.agents.meta_reference.responses.utils import (
|
||||
convert_chat_choice_to_response_message,
|
||||
convert_response_content_to_chat_content,
|
||||
convert_response_input_to_chat_messages,
|
||||
convert_response_text_to_chat_response_format,
|
||||
get_message_type_by_role,
|
||||
is_function_tool_call,
|
||||
)
|
||||
|
||||
|
||||
class TestConvertChatChoiceToResponseMessage:
|
||||
async def test_convert_string_content(self):
|
||||
choice = OpenAIChoice(
|
||||
message=OpenAIAssistantMessageParam(content="Test message"),
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
)
|
||||
|
||||
result = await convert_chat_choice_to_response_message(choice)
|
||||
|
||||
assert result.role == "assistant"
|
||||
assert result.status == "completed"
|
||||
assert len(result.content) == 1
|
||||
assert isinstance(result.content[0], OpenAIResponseOutputMessageContentOutputText)
|
||||
assert result.content[0].text == "Test message"
|
||||
|
||||
async def test_convert_text_param_content(self):
|
||||
choice = OpenAIChoice(
|
||||
message=OpenAIAssistantMessageParam(
|
||||
content=[OpenAIChatCompletionContentPartTextParam(text="Test text param")]
|
||||
),
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
await convert_chat_choice_to_response_message(choice)
|
||||
|
||||
assert "does not yet support output content type" in str(exc_info.value)
|
||||
|
||||
|
||||
class TestConvertResponseContentToChatContent:
|
||||
async def test_convert_string_content(self):
|
||||
result = await convert_response_content_to_chat_content("Simple string")
|
||||
assert result == "Simple string"
|
||||
|
||||
async def test_convert_text_content_parts(self):
|
||||
content = [
|
||||
OpenAIResponseInputMessageContentText(text="First part"),
|
||||
OpenAIResponseOutputMessageContentOutputText(text="Second part"),
|
||||
]
|
||||
|
||||
result = await convert_response_content_to_chat_content(content)
|
||||
|
||||
assert len(result) == 2
|
||||
assert isinstance(result[0], OpenAIChatCompletionContentPartTextParam)
|
||||
assert result[0].text == "First part"
|
||||
assert isinstance(result[1], OpenAIChatCompletionContentPartTextParam)
|
||||
assert result[1].text == "Second part"
|
||||
|
||||
async def test_convert_image_content(self):
|
||||
content = [OpenAIResponseInputMessageContentImage(image_url="https://example.com/image.jpg", detail="high")]
|
||||
|
||||
result = await convert_response_content_to_chat_content(content)
|
||||
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], OpenAIChatCompletionContentPartImageParam)
|
||||
assert result[0].image_url.url == "https://example.com/image.jpg"
|
||||
assert result[0].image_url.detail == "high"
|
||||
|
||||
|
||||
class TestConvertResponseInputToChatMessages:
|
||||
async def test_convert_string_input(self):
|
||||
result = await convert_response_input_to_chat_messages("User message")
|
||||
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], OpenAIUserMessageParam)
|
||||
assert result[0].content == "User message"
|
||||
|
||||
async def test_convert_function_tool_call_output(self):
|
||||
input_items = [
|
||||
OpenAIResponseOutputMessageFunctionToolCall(
|
||||
call_id="call_123",
|
||||
name="test_function",
|
||||
arguments='{"param": "value"}',
|
||||
),
|
||||
OpenAIResponseInputFunctionToolCallOutput(
|
||||
output="Tool output",
|
||||
call_id="call_123",
|
||||
),
|
||||
]
|
||||
|
||||
result = await convert_response_input_to_chat_messages(input_items)
|
||||
|
||||
assert len(result) == 2
|
||||
assert isinstance(result[0], OpenAIAssistantMessageParam)
|
||||
assert result[0].tool_calls[0].id == "call_123"
|
||||
assert result[0].tool_calls[0].function.name == "test_function"
|
||||
assert result[0].tool_calls[0].function.arguments == '{"param": "value"}'
|
||||
assert isinstance(result[1], OpenAIToolMessageParam)
|
||||
assert result[1].content == "Tool output"
|
||||
assert result[1].tool_call_id == "call_123"
|
||||
|
||||
async def test_convert_function_tool_call(self):
|
||||
input_items = [
|
||||
OpenAIResponseOutputMessageFunctionToolCall(
|
||||
call_id="call_456",
|
||||
name="test_function",
|
||||
arguments='{"param": "value"}',
|
||||
)
|
||||
]
|
||||
|
||||
result = await convert_response_input_to_chat_messages(input_items)
|
||||
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], OpenAIAssistantMessageParam)
|
||||
assert len(result[0].tool_calls) == 1
|
||||
assert result[0].tool_calls[0].id == "call_456"
|
||||
assert result[0].tool_calls[0].function.name == "test_function"
|
||||
assert result[0].tool_calls[0].function.arguments == '{"param": "value"}'
|
||||
|
||||
async def test_convert_function_call_ordering(self):
|
||||
input_items = [
|
||||
OpenAIResponseOutputMessageFunctionToolCall(
|
||||
call_id="call_123",
|
||||
name="test_function_a",
|
||||
arguments='{"param": "value"}',
|
||||
),
|
||||
OpenAIResponseOutputMessageFunctionToolCall(
|
||||
call_id="call_456",
|
||||
name="test_function_b",
|
||||
arguments='{"param": "value"}',
|
||||
),
|
||||
OpenAIResponseInputFunctionToolCallOutput(
|
||||
output="AAA",
|
||||
call_id="call_123",
|
||||
),
|
||||
OpenAIResponseInputFunctionToolCallOutput(
|
||||
output="BBB",
|
||||
call_id="call_456",
|
||||
),
|
||||
]
|
||||
|
||||
result = await convert_response_input_to_chat_messages(input_items)
|
||||
assert len(result) == 4
|
||||
assert isinstance(result[0], OpenAIAssistantMessageParam)
|
||||
assert len(result[0].tool_calls) == 1
|
||||
assert result[0].tool_calls[0].id == "call_123"
|
||||
assert result[0].tool_calls[0].function.name == "test_function_a"
|
||||
assert result[0].tool_calls[0].function.arguments == '{"param": "value"}'
|
||||
assert isinstance(result[1], OpenAIToolMessageParam)
|
||||
assert result[1].content == "AAA"
|
||||
assert result[1].tool_call_id == "call_123"
|
||||
assert isinstance(result[2], OpenAIAssistantMessageParam)
|
||||
assert len(result[2].tool_calls) == 1
|
||||
assert result[2].tool_calls[0].id == "call_456"
|
||||
assert result[2].tool_calls[0].function.name == "test_function_b"
|
||||
assert result[2].tool_calls[0].function.arguments == '{"param": "value"}'
|
||||
assert isinstance(result[3], OpenAIToolMessageParam)
|
||||
assert result[3].content == "BBB"
|
||||
assert result[3].tool_call_id == "call_456"
|
||||
|
||||
async def test_convert_response_message(self):
|
||||
input_items = [
|
||||
OpenAIResponseMessage(
|
||||
role="user",
|
||||
content=[OpenAIResponseInputMessageContentText(text="User text")],
|
||||
)
|
||||
]
|
||||
|
||||
result = await convert_response_input_to_chat_messages(input_items)
|
||||
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], OpenAIUserMessageParam)
|
||||
# Content should be converted to chat content format
|
||||
assert len(result[0].content) == 1
|
||||
assert result[0].content[0].text == "User text"
|
||||
|
||||
|
||||
class TestConvertResponseTextToChatResponseFormat:
|
||||
async def test_convert_text_format(self):
|
||||
text = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text"))
|
||||
result = await convert_response_text_to_chat_response_format(text)
|
||||
|
||||
assert isinstance(result, OpenAIResponseFormatText)
|
||||
assert result.type == "text"
|
||||
|
||||
async def test_convert_json_object_format(self):
|
||||
text = OpenAIResponseText(format={"type": "json_object"})
|
||||
result = await convert_response_text_to_chat_response_format(text)
|
||||
|
||||
assert isinstance(result, OpenAIResponseFormatJSONObject)
|
||||
|
||||
async def test_convert_json_schema_format(self):
|
||||
schema_def = {"type": "object", "properties": {"test": {"type": "string"}}}
|
||||
text = OpenAIResponseText(
|
||||
format={
|
||||
"type": "json_schema",
|
||||
"name": "test_schema",
|
||||
"schema": schema_def,
|
||||
}
|
||||
)
|
||||
result = await convert_response_text_to_chat_response_format(text)
|
||||
|
||||
assert isinstance(result, OpenAIResponseFormatJSONSchema)
|
||||
assert result.json_schema["name"] == "test_schema"
|
||||
assert result.json_schema["schema"] == schema_def
|
||||
|
||||
async def test_default_text_format(self):
|
||||
text = OpenAIResponseText()
|
||||
result = await convert_response_text_to_chat_response_format(text)
|
||||
|
||||
assert isinstance(result, OpenAIResponseFormatText)
|
||||
assert result.type == "text"
|
||||
|
||||
|
||||
class TestGetMessageTypeByRole:
|
||||
async def test_user_role(self):
|
||||
result = await get_message_type_by_role("user")
|
||||
assert result == OpenAIUserMessageParam
|
||||
|
||||
async def test_system_role(self):
|
||||
result = await get_message_type_by_role("system")
|
||||
assert result == OpenAISystemMessageParam
|
||||
|
||||
async def test_assistant_role(self):
|
||||
result = await get_message_type_by_role("assistant")
|
||||
assert result == OpenAIAssistantMessageParam
|
||||
|
||||
async def test_developer_role(self):
|
||||
result = await get_message_type_by_role("developer")
|
||||
assert result == OpenAIDeveloperMessageParam
|
||||
|
||||
async def test_unknown_role(self):
|
||||
result = await get_message_type_by_role("unknown")
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestIsFunctionToolCall:
|
||||
def test_is_function_tool_call_true(self):
|
||||
tool_call = OpenAIChatCompletionToolCall(
|
||||
index=0,
|
||||
id="call_123",
|
||||
function=OpenAIChatCompletionToolCallFunction(
|
||||
name="test_function",
|
||||
arguments="{}",
|
||||
),
|
||||
)
|
||||
tools = [
|
||||
OpenAIResponseInputToolFunction(
|
||||
type="function", name="test_function", parameters={"type": "object", "properties": {}}
|
||||
),
|
||||
OpenAIResponseInputToolWebSearch(type="web_search"),
|
||||
]
|
||||
|
||||
result = is_function_tool_call(tool_call, tools)
|
||||
assert result is True
|
||||
|
||||
def test_is_function_tool_call_false_different_name(self):
|
||||
tool_call = OpenAIChatCompletionToolCall(
|
||||
index=0,
|
||||
id="call_123",
|
||||
function=OpenAIChatCompletionToolCallFunction(
|
||||
name="other_function",
|
||||
arguments="{}",
|
||||
),
|
||||
)
|
||||
tools = [
|
||||
OpenAIResponseInputToolFunction(
|
||||
type="function", name="test_function", parameters={"type": "object", "properties": {}}
|
||||
),
|
||||
]
|
||||
|
||||
result = is_function_tool_call(tool_call, tools)
|
||||
assert result is False
|
||||
|
||||
def test_is_function_tool_call_false_no_function(self):
|
||||
tool_call = OpenAIChatCompletionToolCall(
|
||||
index=0,
|
||||
id="call_123",
|
||||
function=None,
|
||||
)
|
||||
tools = [
|
||||
OpenAIResponseInputToolFunction(
|
||||
type="function", name="test_function", parameters={"type": "object", "properties": {}}
|
||||
),
|
||||
]
|
||||
|
||||
result = is_function_tool_call(tool_call, tools)
|
||||
assert result is False
|
||||
|
||||
def test_is_function_tool_call_false_wrong_type(self):
|
||||
tool_call = OpenAIChatCompletionToolCall(
|
||||
index=0,
|
||||
id="call_123",
|
||||
function=OpenAIChatCompletionToolCallFunction(
|
||||
name="web_search",
|
||||
arguments="{}",
|
||||
),
|
||||
)
|
||||
tools = [
|
||||
OpenAIResponseInputToolWebSearch(type="web_search"),
|
||||
]
|
||||
|
||||
result = is_function_tool_call(tool_call, tools)
|
||||
assert result is False
|
54
tests/unit/providers/batches/conftest.py
Normal file
54
tests/unit/providers/batches/conftest.py
Normal file
|
@ -0,0 +1,54 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""Shared fixtures for batches provider unit tests."""
|
||||
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.providers.inline.batches.reference.batches import ReferenceBatchesImpl
|
||||
from llama_stack.providers.inline.batches.reference.config import ReferenceBatchesImplConfig
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def provider():
|
||||
"""Create a test provider instance with temporary database."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
db_path = Path(tmpdir) / "test_batches.db"
|
||||
kvstore_config = SqliteKVStoreConfig(db_path=str(db_path))
|
||||
config = ReferenceBatchesImplConfig(kvstore=kvstore_config)
|
||||
|
||||
# Create kvstore and mock APIs
|
||||
kvstore = await kvstore_impl(config.kvstore)
|
||||
mock_inference = AsyncMock()
|
||||
mock_files = AsyncMock()
|
||||
mock_models = AsyncMock()
|
||||
|
||||
provider = ReferenceBatchesImpl(config, mock_inference, mock_files, mock_models, kvstore)
|
||||
await provider.initialize()
|
||||
|
||||
# unit tests should not require background processing
|
||||
provider.process_batches = False
|
||||
|
||||
yield provider
|
||||
|
||||
await provider.shutdown()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_batch_data():
|
||||
"""Sample batch data for testing."""
|
||||
return {
|
||||
"input_file_id": "file_abc123",
|
||||
"endpoint": "/v1/chat/completions",
|
||||
"completion_window": "24h",
|
||||
"metadata": {"test": "true", "priority": "high"},
|
||||
}
|
710
tests/unit/providers/batches/test_reference.py
Normal file
710
tests/unit/providers/batches/test_reference.py
Normal file
|
@ -0,0 +1,710 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Test suite for the reference implementation of the Batches API.
|
||||
|
||||
The tests are categorized and outlined below, keep this updated:
|
||||
|
||||
- Batch creation with various parameters and validation:
|
||||
* test_create_and_retrieve_batch_success (positive)
|
||||
* test_create_batch_without_metadata (positive)
|
||||
* test_create_batch_completion_window (negative)
|
||||
* test_create_batch_invalid_endpoints (negative)
|
||||
* test_create_batch_invalid_metadata (negative)
|
||||
|
||||
- Batch retrieval and error handling for non-existent batches:
|
||||
* test_retrieve_batch_not_found (negative)
|
||||
|
||||
- Batch cancellation with proper status transitions:
|
||||
* test_cancel_batch_success (positive)
|
||||
* test_cancel_batch_invalid_statuses (negative)
|
||||
* test_cancel_batch_not_found (negative)
|
||||
|
||||
- Batch listing with pagination and filtering:
|
||||
* test_list_batches_empty (positive)
|
||||
* test_list_batches_single_batch (positive)
|
||||
* test_list_batches_multiple_batches (positive)
|
||||
* test_list_batches_with_limit (positive)
|
||||
* test_list_batches_with_pagination (positive)
|
||||
* test_list_batches_invalid_after (negative)
|
||||
|
||||
- Data persistence in the underlying key-value store:
|
||||
* test_kvstore_persistence (positive)
|
||||
|
||||
- Batch processing concurrency control:
|
||||
* test_max_concurrent_batches (positive)
|
||||
|
||||
- Input validation testing (direct _validate_input method tests):
|
||||
* test_validate_input_file_not_found (negative)
|
||||
* test_validate_input_file_exists_empty_content (positive)
|
||||
* test_validate_input_file_mixed_valid_invalid_json (mixed)
|
||||
* test_validate_input_invalid_model (negative)
|
||||
* test_validate_input_url_mismatch (negative)
|
||||
* test_validate_input_multiple_errors_per_request (negative)
|
||||
* test_validate_input_invalid_request_format (negative)
|
||||
* test_validate_input_missing_parameters (parametrized negative - custom_id, method, url, body, model, messages missing validation)
|
||||
* test_validate_input_invalid_parameter_types (parametrized negative - custom_id, url, method, body, model, messages type validation)
|
||||
|
||||
The tests use temporary SQLite databases for isolation and mock external
|
||||
dependencies like inference, files, and models APIs.
|
||||
"""
|
||||
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.batches import BatchObject
|
||||
from llama_stack.apis.common.errors import ConflictError, ResourceNotFoundError
|
||||
|
||||
|
||||
class TestReferenceBatchesImpl:
|
||||
"""Test the reference implementation of the Batches API."""
|
||||
|
||||
def _validate_batch_type(self, batch, expected_metadata=None):
|
||||
"""
|
||||
Helper function to validate batch object structure and field types.
|
||||
|
||||
Note: This validates the direct BatchObject from the provider, not the
|
||||
client library response which has a different structure.
|
||||
|
||||
Args:
|
||||
batch: The BatchObject instance to validate.
|
||||
expected_metadata: Optional expected metadata dictionary to validate against.
|
||||
"""
|
||||
assert isinstance(batch.id, str)
|
||||
assert isinstance(batch.completion_window, str)
|
||||
assert isinstance(batch.created_at, int)
|
||||
assert isinstance(batch.endpoint, str)
|
||||
assert isinstance(batch.input_file_id, str)
|
||||
assert batch.object == "batch"
|
||||
assert batch.status in [
|
||||
"validating",
|
||||
"failed",
|
||||
"in_progress",
|
||||
"finalizing",
|
||||
"completed",
|
||||
"expired",
|
||||
"cancelling",
|
||||
"cancelled",
|
||||
]
|
||||
|
||||
if expected_metadata is not None:
|
||||
assert batch.metadata == expected_metadata
|
||||
|
||||
timestamp_fields = [
|
||||
"cancelled_at",
|
||||
"cancelling_at",
|
||||
"completed_at",
|
||||
"expired_at",
|
||||
"expires_at",
|
||||
"failed_at",
|
||||
"finalizing_at",
|
||||
"in_progress_at",
|
||||
]
|
||||
for field in timestamp_fields:
|
||||
field_value = getattr(batch, field, None)
|
||||
if field_value is not None:
|
||||
assert isinstance(field_value, int), f"{field} should be int or None, got {type(field_value)}"
|
||||
|
||||
file_id_fields = ["error_file_id", "output_file_id"]
|
||||
for field in file_id_fields:
|
||||
field_value = getattr(batch, field, None)
|
||||
if field_value is not None:
|
||||
assert isinstance(field_value, str), f"{field} should be str or None, got {type(field_value)}"
|
||||
|
||||
if hasattr(batch, "request_counts") and batch.request_counts is not None:
|
||||
assert isinstance(batch.request_counts.completed, int), (
|
||||
f"request_counts.completed should be int, got {type(batch.request_counts.completed)}"
|
||||
)
|
||||
assert isinstance(batch.request_counts.failed, int), (
|
||||
f"request_counts.failed should be int, got {type(batch.request_counts.failed)}"
|
||||
)
|
||||
assert isinstance(batch.request_counts.total, int), (
|
||||
f"request_counts.total should be int, got {type(batch.request_counts.total)}"
|
||||
)
|
||||
|
||||
if hasattr(batch, "errors") and batch.errors is not None:
|
||||
assert isinstance(batch.errors, dict), f"errors should be object or dict, got {type(batch.errors)}"
|
||||
|
||||
if hasattr(batch.errors, "data") and batch.errors.data is not None:
|
||||
assert isinstance(batch.errors.data, list), (
|
||||
f"errors.data should be list or None, got {type(batch.errors.data)}"
|
||||
)
|
||||
|
||||
for i, error_item in enumerate(batch.errors.data):
|
||||
assert isinstance(error_item, dict), (
|
||||
f"errors.data[{i}] should be object or dict, got {type(error_item)}"
|
||||
)
|
||||
|
||||
if hasattr(error_item, "code") and error_item.code is not None:
|
||||
assert isinstance(error_item.code, str), (
|
||||
f"errors.data[{i}].code should be str or None, got {type(error_item.code)}"
|
||||
)
|
||||
|
||||
if hasattr(error_item, "line") and error_item.line is not None:
|
||||
assert isinstance(error_item.line, int), (
|
||||
f"errors.data[{i}].line should be int or None, got {type(error_item.line)}"
|
||||
)
|
||||
|
||||
if hasattr(error_item, "message") and error_item.message is not None:
|
||||
assert isinstance(error_item.message, str), (
|
||||
f"errors.data[{i}].message should be str or None, got {type(error_item.message)}"
|
||||
)
|
||||
|
||||
if hasattr(error_item, "param") and error_item.param is not None:
|
||||
assert isinstance(error_item.param, str), (
|
||||
f"errors.data[{i}].param should be str or None, got {type(error_item.param)}"
|
||||
)
|
||||
|
||||
if hasattr(batch.errors, "object") and batch.errors.object is not None:
|
||||
assert isinstance(batch.errors.object, str), (
|
||||
f"errors.object should be str or None, got {type(batch.errors.object)}"
|
||||
)
|
||||
assert batch.errors.object == "list", f"errors.object should be 'list', got {batch.errors.object}"
|
||||
|
||||
async def test_create_and_retrieve_batch_success(self, provider, sample_batch_data):
|
||||
"""Test successful batch creation and retrieval."""
|
||||
created_batch = await provider.create_batch(**sample_batch_data)
|
||||
|
||||
self._validate_batch_type(created_batch, expected_metadata=sample_batch_data["metadata"])
|
||||
|
||||
assert created_batch.id.startswith("batch_")
|
||||
assert len(created_batch.id) > 13
|
||||
assert created_batch.object == "batch"
|
||||
assert created_batch.endpoint == sample_batch_data["endpoint"]
|
||||
assert created_batch.input_file_id == sample_batch_data["input_file_id"]
|
||||
assert created_batch.completion_window == sample_batch_data["completion_window"]
|
||||
assert created_batch.status == "validating"
|
||||
assert created_batch.metadata == sample_batch_data["metadata"]
|
||||
assert isinstance(created_batch.created_at, int)
|
||||
assert created_batch.created_at > 0
|
||||
|
||||
retrieved_batch = await provider.retrieve_batch(created_batch.id)
|
||||
|
||||
self._validate_batch_type(retrieved_batch, expected_metadata=sample_batch_data["metadata"])
|
||||
|
||||
assert retrieved_batch.id == created_batch.id
|
||||
assert retrieved_batch.input_file_id == created_batch.input_file_id
|
||||
assert retrieved_batch.endpoint == created_batch.endpoint
|
||||
assert retrieved_batch.status == created_batch.status
|
||||
assert retrieved_batch.metadata == created_batch.metadata
|
||||
|
||||
async def test_create_batch_without_metadata(self, provider):
|
||||
"""Test batch creation without optional metadata."""
|
||||
batch = await provider.create_batch(
|
||||
input_file_id="file_123", endpoint="/v1/chat/completions", completion_window="24h"
|
||||
)
|
||||
|
||||
assert batch.metadata is None
|
||||
|
||||
async def test_create_batch_completion_window(self, provider):
|
||||
"""Test batch creation with invalid completion window."""
|
||||
with pytest.raises(ValueError, match="Invalid completion_window"):
|
||||
await provider.create_batch(
|
||||
input_file_id="file_123", endpoint="/v1/chat/completions", completion_window="now"
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"endpoint",
|
||||
[
|
||||
"/v1/embeddings",
|
||||
"/v1/completions",
|
||||
"/v1/invalid/endpoint",
|
||||
"",
|
||||
],
|
||||
)
|
||||
async def test_create_batch_invalid_endpoints(self, provider, endpoint):
|
||||
"""Test batch creation with various invalid endpoints."""
|
||||
with pytest.raises(ValueError, match="Invalid endpoint"):
|
||||
await provider.create_batch(input_file_id="file_123", endpoint=endpoint, completion_window="24h")
|
||||
|
||||
async def test_create_batch_invalid_metadata(self, provider):
|
||||
"""Test that batch creation fails with invalid metadata."""
|
||||
with pytest.raises(ValueError, match="should be a valid string"):
|
||||
await provider.create_batch(
|
||||
input_file_id="file_123",
|
||||
endpoint="/v1/chat/completions",
|
||||
completion_window="24h",
|
||||
metadata={123: "invalid_key"}, # Non-string key
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="should be a valid string"):
|
||||
await provider.create_batch(
|
||||
input_file_id="file_123",
|
||||
endpoint="/v1/chat/completions",
|
||||
completion_window="24h",
|
||||
metadata={"valid_key": 456}, # Non-string value
|
||||
)
|
||||
|
||||
async def test_retrieve_batch_not_found(self, provider):
|
||||
"""Test error when retrieving non-existent batch."""
|
||||
with pytest.raises(ResourceNotFoundError, match=r"Batch 'nonexistent_batch' not found"):
|
||||
await provider.retrieve_batch("nonexistent_batch")
|
||||
|
||||
async def test_cancel_batch_success(self, provider, sample_batch_data):
|
||||
"""Test successful batch cancellation."""
|
||||
created_batch = await provider.create_batch(**sample_batch_data)
|
||||
assert created_batch.status == "validating"
|
||||
|
||||
cancelled_batch = await provider.cancel_batch(created_batch.id)
|
||||
|
||||
assert cancelled_batch.id == created_batch.id
|
||||
assert cancelled_batch.status in ["cancelling", "cancelled"]
|
||||
assert isinstance(cancelled_batch.cancelling_at, int)
|
||||
assert cancelled_batch.cancelling_at >= created_batch.created_at
|
||||
|
||||
@pytest.mark.parametrize("status", ["failed", "expired", "completed"])
|
||||
async def test_cancel_batch_invalid_statuses(self, provider, sample_batch_data, status):
|
||||
"""Test error when cancelling batch in final states."""
|
||||
provider.process_batches = False
|
||||
created_batch = await provider.create_batch(**sample_batch_data)
|
||||
|
||||
# directly update status in kvstore
|
||||
await provider._update_batch(created_batch.id, status=status)
|
||||
|
||||
with pytest.raises(ConflictError, match=f"Cannot cancel batch '{created_batch.id}' with status '{status}'"):
|
||||
await provider.cancel_batch(created_batch.id)
|
||||
|
||||
async def test_cancel_batch_not_found(self, provider):
|
||||
"""Test error when cancelling non-existent batch."""
|
||||
with pytest.raises(ResourceNotFoundError, match=r"Batch 'nonexistent_batch' not found"):
|
||||
await provider.cancel_batch("nonexistent_batch")
|
||||
|
||||
async def test_list_batches_empty(self, provider):
|
||||
"""Test listing batches when none exist."""
|
||||
response = await provider.list_batches()
|
||||
|
||||
assert response.object == "list"
|
||||
assert response.data == []
|
||||
assert response.first_id is None
|
||||
assert response.last_id is None
|
||||
assert response.has_more is False
|
||||
|
||||
async def test_list_batches_single_batch(self, provider, sample_batch_data):
|
||||
"""Test listing batches with single batch."""
|
||||
created_batch = await provider.create_batch(**sample_batch_data)
|
||||
|
||||
response = await provider.list_batches()
|
||||
|
||||
assert len(response.data) == 1
|
||||
self._validate_batch_type(response.data[0], expected_metadata=sample_batch_data["metadata"])
|
||||
assert response.data[0].id == created_batch.id
|
||||
assert response.first_id == created_batch.id
|
||||
assert response.last_id == created_batch.id
|
||||
assert response.has_more is False
|
||||
|
||||
async def test_list_batches_multiple_batches(self, provider):
|
||||
"""Test listing multiple batches."""
|
||||
batches = [
|
||||
await provider.create_batch(
|
||||
input_file_id=f"file_{i}", endpoint="/v1/chat/completions", completion_window="24h"
|
||||
)
|
||||
for i in range(3)
|
||||
]
|
||||
|
||||
response = await provider.list_batches()
|
||||
|
||||
assert len(response.data) == 3
|
||||
|
||||
batch_ids = {batch.id for batch in response.data}
|
||||
expected_ids = {batch.id for batch in batches}
|
||||
assert batch_ids == expected_ids
|
||||
assert response.has_more is False
|
||||
|
||||
assert response.first_id in expected_ids
|
||||
assert response.last_id in expected_ids
|
||||
|
||||
async def test_list_batches_with_limit(self, provider):
|
||||
"""Test listing batches with limit parameter."""
|
||||
batches = [
|
||||
await provider.create_batch(
|
||||
input_file_id=f"file_{i}", endpoint="/v1/chat/completions", completion_window="24h"
|
||||
)
|
||||
for i in range(3)
|
||||
]
|
||||
|
||||
response = await provider.list_batches(limit=2)
|
||||
|
||||
assert len(response.data) == 2
|
||||
assert response.has_more is True
|
||||
assert response.first_id == response.data[0].id
|
||||
assert response.last_id == response.data[1].id
|
||||
batch_ids = {batch.id for batch in response.data}
|
||||
expected_ids = {batch.id for batch in batches}
|
||||
assert batch_ids.issubset(expected_ids)
|
||||
|
||||
async def test_list_batches_with_pagination(self, provider):
|
||||
"""Test listing batches with pagination using 'after' parameter."""
|
||||
for i in range(3):
|
||||
await provider.create_batch(
|
||||
input_file_id=f"file_{i}", endpoint="/v1/chat/completions", completion_window="24h"
|
||||
)
|
||||
|
||||
# Get first page
|
||||
first_page = await provider.list_batches(limit=1)
|
||||
assert len(first_page.data) == 1
|
||||
assert first_page.has_more is True
|
||||
|
||||
# Get second page using 'after'
|
||||
second_page = await provider.list_batches(limit=1, after=first_page.data[0].id)
|
||||
assert len(second_page.data) == 1
|
||||
assert second_page.data[0].id != first_page.data[0].id
|
||||
|
||||
# Verify we got the next batch in order
|
||||
all_batches = await provider.list_batches()
|
||||
expected_second_batch_id = all_batches.data[1].id
|
||||
assert second_page.data[0].id == expected_second_batch_id
|
||||
|
||||
async def test_list_batches_invalid_after(self, provider, sample_batch_data):
|
||||
"""Test listing batches with invalid 'after' parameter."""
|
||||
await provider.create_batch(**sample_batch_data)
|
||||
|
||||
response = await provider.list_batches(after="nonexistent_batch")
|
||||
|
||||
# Should return all batches (no filtering when 'after' batch not found)
|
||||
assert len(response.data) == 1
|
||||
|
||||
async def test_kvstore_persistence(self, provider, sample_batch_data):
|
||||
"""Test that batches are properly persisted in kvstore."""
|
||||
batch = await provider.create_batch(**sample_batch_data)
|
||||
|
||||
stored_data = await provider.kvstore.get(f"batch:{batch.id}")
|
||||
assert stored_data is not None
|
||||
|
||||
stored_batch_dict = json.loads(stored_data)
|
||||
assert stored_batch_dict["id"] == batch.id
|
||||
assert stored_batch_dict["input_file_id"] == sample_batch_data["input_file_id"]
|
||||
|
||||
async def test_validate_input_file_not_found(self, provider):
|
||||
"""Test _validate_input when input file does not exist."""
|
||||
provider.files_api.openai_retrieve_file = AsyncMock(side_effect=Exception("File not found"))
|
||||
|
||||
batch = BatchObject(
|
||||
id="batch_test",
|
||||
object="batch",
|
||||
endpoint="/v1/chat/completions",
|
||||
input_file_id="nonexistent_file",
|
||||
completion_window="24h",
|
||||
status="validating",
|
||||
created_at=1234567890,
|
||||
)
|
||||
|
||||
errors, requests = await provider._validate_input(batch)
|
||||
|
||||
assert len(errors) == 1
|
||||
assert len(requests) == 0
|
||||
assert errors[0].code == "invalid_request"
|
||||
assert errors[0].message == "Cannot find file nonexistent_file."
|
||||
assert errors[0].param == "input_file_id"
|
||||
assert errors[0].line is None
|
||||
|
||||
async def test_validate_input_file_exists_empty_content(self, provider):
|
||||
"""Test _validate_input when file exists but is empty."""
|
||||
provider.files_api.openai_retrieve_file = AsyncMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.body = b""
|
||||
provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
|
||||
|
||||
batch = BatchObject(
|
||||
id="batch_test",
|
||||
object="batch",
|
||||
endpoint="/v1/chat/completions",
|
||||
input_file_id="empty_file",
|
||||
completion_window="24h",
|
||||
status="validating",
|
||||
created_at=1234567890,
|
||||
)
|
||||
|
||||
errors, requests = await provider._validate_input(batch)
|
||||
|
||||
assert len(errors) == 0
|
||||
assert len(requests) == 0
|
||||
|
||||
async def test_validate_input_file_mixed_valid_invalid_json(self, provider):
|
||||
"""Test _validate_input when file contains valid and invalid JSON lines."""
|
||||
provider.files_api.openai_retrieve_file = AsyncMock()
|
||||
mock_response = MagicMock()
|
||||
# Line 1: valid JSON with proper body args, Line 2: invalid JSON
|
||||
mock_response.body = b'{"custom_id": "req-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "test-model", "messages": [{"role": "user", "content": "Hello"}]}}\n{invalid json'
|
||||
provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
|
||||
|
||||
batch = BatchObject(
|
||||
id="batch_test",
|
||||
object="batch",
|
||||
endpoint="/v1/chat/completions",
|
||||
input_file_id="mixed_file",
|
||||
completion_window="24h",
|
||||
status="validating",
|
||||
created_at=1234567890,
|
||||
)
|
||||
|
||||
errors, requests = await provider._validate_input(batch)
|
||||
|
||||
# Should have 1 JSON parsing error from line 2, and 1 valid request from line 1
|
||||
assert len(errors) == 1
|
||||
assert len(requests) == 1
|
||||
|
||||
assert errors[0].code == "invalid_json_line"
|
||||
assert errors[0].line == 2
|
||||
assert errors[0].message == "This line is not parseable as valid JSON."
|
||||
|
||||
assert requests[0].custom_id == "req-1"
|
||||
assert requests[0].method == "POST"
|
||||
assert requests[0].url == "/v1/chat/completions"
|
||||
assert requests[0].body["model"] == "test-model"
|
||||
assert requests[0].body["messages"] == [{"role": "user", "content": "Hello"}]
|
||||
|
||||
async def test_validate_input_invalid_model(self, provider):
|
||||
"""Test _validate_input when file contains request with non-existent model."""
|
||||
provider.files_api.openai_retrieve_file = AsyncMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.body = b'{"custom_id": "req-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "nonexistent-model", "messages": [{"role": "user", "content": "Hello"}]}}'
|
||||
provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
|
||||
|
||||
provider.models_api.get_model = AsyncMock(side_effect=Exception("Model not found"))
|
||||
|
||||
batch = BatchObject(
|
||||
id="batch_test",
|
||||
object="batch",
|
||||
endpoint="/v1/chat/completions",
|
||||
input_file_id="invalid_model_file",
|
||||
completion_window="24h",
|
||||
status="validating",
|
||||
created_at=1234567890,
|
||||
)
|
||||
|
||||
errors, requests = await provider._validate_input(batch)
|
||||
|
||||
assert len(errors) == 1
|
||||
assert len(requests) == 0
|
||||
|
||||
assert errors[0].code == "model_not_found"
|
||||
assert errors[0].line == 1
|
||||
assert errors[0].message == "Model 'nonexistent-model' does not exist or is not supported"
|
||||
assert errors[0].param == "body.model"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"param_name,param_path,error_code,error_message",
|
||||
[
|
||||
("custom_id", "custom_id", "missing_required_parameter", "Missing required parameter: custom_id"),
|
||||
("method", "method", "missing_required_parameter", "Missing required parameter: method"),
|
||||
("url", "url", "missing_required_parameter", "Missing required parameter: url"),
|
||||
("body", "body", "missing_required_parameter", "Missing required parameter: body"),
|
||||
("model", "body.model", "invalid_request", "Model parameter is required"),
|
||||
("messages", "body.messages", "invalid_request", "Messages parameter is required"),
|
||||
],
|
||||
)
|
||||
async def test_validate_input_missing_parameters(self, provider, param_name, param_path, error_code, error_message):
|
||||
"""Test _validate_input when file contains request with missing required parameters."""
|
||||
provider.files_api.openai_retrieve_file = AsyncMock()
|
||||
mock_response = MagicMock()
|
||||
|
||||
base_request = {
|
||||
"custom_id": "req-1",
|
||||
"method": "POST",
|
||||
"url": "/v1/chat/completions",
|
||||
"body": {"model": "test-model", "messages": [{"role": "user", "content": "Hello"}]},
|
||||
}
|
||||
|
||||
# Remove the specific parameter being tested
|
||||
if "." in param_path:
|
||||
top_level, nested_param = param_path.split(".", 1)
|
||||
del base_request[top_level][nested_param]
|
||||
else:
|
||||
del base_request[param_name]
|
||||
|
||||
mock_response.body = json.dumps(base_request).encode()
|
||||
provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
|
||||
|
||||
batch = BatchObject(
|
||||
id="batch_test",
|
||||
object="batch",
|
||||
endpoint="/v1/chat/completions",
|
||||
input_file_id=f"missing_{param_name}_file",
|
||||
completion_window="24h",
|
||||
status="validating",
|
||||
created_at=1234567890,
|
||||
)
|
||||
|
||||
errors, requests = await provider._validate_input(batch)
|
||||
|
||||
assert len(errors) == 1
|
||||
assert len(requests) == 0
|
||||
|
||||
assert errors[0].code == error_code
|
||||
assert errors[0].line == 1
|
||||
assert errors[0].message == error_message
|
||||
assert errors[0].param == param_path
|
||||
|
||||
async def test_validate_input_url_mismatch(self, provider):
|
||||
"""Test _validate_input when file contains request with URL that doesn't match batch endpoint."""
|
||||
provider.files_api.openai_retrieve_file = AsyncMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.body = b'{"custom_id": "req-1", "method": "POST", "url": "/v1/embeddings", "body": {"model": "test-model", "messages": [{"role": "user", "content": "Hello"}]}}'
|
||||
provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
|
||||
|
||||
batch = BatchObject(
|
||||
id="batch_test",
|
||||
object="batch",
|
||||
endpoint="/v1/chat/completions", # This doesn't match the URL in the request
|
||||
input_file_id="url_mismatch_file",
|
||||
completion_window="24h",
|
||||
status="validating",
|
||||
created_at=1234567890,
|
||||
)
|
||||
|
||||
errors, requests = await provider._validate_input(batch)
|
||||
|
||||
assert len(errors) == 1
|
||||
assert len(requests) == 0
|
||||
|
||||
assert errors[0].code == "invalid_url"
|
||||
assert errors[0].line == 1
|
||||
assert errors[0].message == "URL provided for this request does not match the batch endpoint"
|
||||
assert errors[0].param == "url"
|
||||
|
||||
async def test_validate_input_multiple_errors_per_request(self, provider):
|
||||
"""Test _validate_input when a single request has multiple validation errors."""
|
||||
provider.files_api.openai_retrieve_file = AsyncMock()
|
||||
mock_response = MagicMock()
|
||||
# Request missing custom_id, has invalid URL, and missing model in body
|
||||
mock_response.body = (
|
||||
b'{"method": "POST", "url": "/v1/embeddings", "body": {"messages": [{"role": "user", "content": "Hello"}]}}'
|
||||
)
|
||||
provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
|
||||
|
||||
batch = BatchObject(
|
||||
id="batch_test",
|
||||
object="batch",
|
||||
endpoint="/v1/chat/completions", # Doesn't match /v1/embeddings in request
|
||||
input_file_id="multiple_errors_file",
|
||||
completion_window="24h",
|
||||
status="validating",
|
||||
created_at=1234567890,
|
||||
)
|
||||
|
||||
errors, requests = await provider._validate_input(batch)
|
||||
|
||||
assert len(errors) >= 2 # At least missing custom_id and URL mismatch
|
||||
assert len(requests) == 0
|
||||
|
||||
for error in errors:
|
||||
assert error.line == 1
|
||||
|
||||
error_codes = {error.code for error in errors}
|
||||
assert "missing_required_parameter" in error_codes # missing custom_id
|
||||
assert "invalid_url" in error_codes # URL mismatch
|
||||
|
||||
async def test_validate_input_invalid_request_format(self, provider):
|
||||
"""Test _validate_input when file contains non-object JSON (array, string, number)."""
|
||||
provider.files_api.openai_retrieve_file = AsyncMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.body = b'["not", "a", "request", "object"]'
|
||||
provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
|
||||
|
||||
batch = BatchObject(
|
||||
id="batch_test",
|
||||
object="batch",
|
||||
endpoint="/v1/chat/completions",
|
||||
input_file_id="invalid_format_file",
|
||||
completion_window="24h",
|
||||
status="validating",
|
||||
created_at=1234567890,
|
||||
)
|
||||
|
||||
errors, requests = await provider._validate_input(batch)
|
||||
|
||||
assert len(errors) == 1
|
||||
assert len(requests) == 0
|
||||
|
||||
assert errors[0].code == "invalid_request"
|
||||
assert errors[0].line == 1
|
||||
assert errors[0].message == "Each line must be a JSON dictionary object"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"param_name,param_path,invalid_value,error_message",
|
||||
[
|
||||
("custom_id", "custom_id", 12345, "Custom_id must be a string"),
|
||||
("url", "url", 123, "URL must be a string"),
|
||||
("method", "method", ["POST"], "Method must be a string"),
|
||||
("body", "body", ["not", "valid"], "Body must be a JSON dictionary object"),
|
||||
("model", "body.model", 123, "Model must be a string"),
|
||||
("messages", "body.messages", "invalid messages format", "Messages must be an array"),
|
||||
],
|
||||
)
|
||||
async def test_validate_input_invalid_parameter_types(
|
||||
self, provider, param_name, param_path, invalid_value, error_message
|
||||
):
|
||||
"""Test _validate_input when file contains request with parameters that have invalid types."""
|
||||
provider.files_api.openai_retrieve_file = AsyncMock()
|
||||
mock_response = MagicMock()
|
||||
|
||||
base_request = {
|
||||
"custom_id": "req-1",
|
||||
"method": "POST",
|
||||
"url": "/v1/chat/completions",
|
||||
"body": {"model": "test-model", "messages": [{"role": "user", "content": "Hello"}]},
|
||||
}
|
||||
|
||||
# Override the specific parameter with invalid value
|
||||
if "." in param_path:
|
||||
top_level, nested_param = param_path.split(".", 1)
|
||||
base_request[top_level][nested_param] = invalid_value
|
||||
else:
|
||||
base_request[param_name] = invalid_value
|
||||
|
||||
mock_response.body = json.dumps(base_request).encode()
|
||||
provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
|
||||
|
||||
batch = BatchObject(
|
||||
id="batch_test",
|
||||
object="batch",
|
||||
endpoint="/v1/chat/completions",
|
||||
input_file_id=f"invalid_{param_name}_type_file",
|
||||
completion_window="24h",
|
||||
status="validating",
|
||||
created_at=1234567890,
|
||||
)
|
||||
|
||||
errors, requests = await provider._validate_input(batch)
|
||||
|
||||
assert len(errors) == 1
|
||||
assert len(requests) == 0
|
||||
|
||||
assert errors[0].code == "invalid_request"
|
||||
assert errors[0].line == 1
|
||||
assert errors[0].message == error_message
|
||||
assert errors[0].param == param_path
|
||||
|
||||
async def test_max_concurrent_batches(self, provider):
|
||||
"""Test max_concurrent_batches configuration and concurrency control."""
|
||||
import asyncio
|
||||
|
||||
provider._batch_semaphore = asyncio.Semaphore(2)
|
||||
|
||||
provider.process_batches = True # enable because we're testing background processing
|
||||
|
||||
active_batches = 0
|
||||
|
||||
async def add_and_wait(batch_id: str):
|
||||
nonlocal active_batches
|
||||
active_batches += 1
|
||||
await asyncio.sleep(float("inf"))
|
||||
|
||||
# the first thing done in _process_batch is to acquire the semaphore, then call _process_batch_impl,
|
||||
# so we can replace _process_batch_impl with our mock to control concurrency
|
||||
provider._process_batch_impl = add_and_wait
|
||||
|
||||
for _ in range(3):
|
||||
await provider.create_batch(
|
||||
input_file_id="file_id", endpoint="/v1/chat/completions", completion_window="24h"
|
||||
)
|
||||
|
||||
await asyncio.sleep(0.042) # let tasks start
|
||||
|
||||
assert active_batches == 2, f"Expected 2 active batches, got {active_batches}"
|
128
tests/unit/providers/batches/test_reference_idempotency.py
Normal file
128
tests/unit/providers/batches/test_reference_idempotency.py
Normal file
|
@ -0,0 +1,128 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Tests for idempotency functionality in the reference batches provider.
|
||||
|
||||
This module tests the optional idempotency feature that allows clients to provide
|
||||
an idempotency key (idempotency_key) to ensure that repeated requests with the same key
|
||||
and parameters return the same batch, while requests with the same key but different
|
||||
parameters result in a conflict error.
|
||||
|
||||
Test Categories:
|
||||
1. Core Idempotency: Same parameters with same key return same batch
|
||||
2. Parameter Independence: Different parameters without keys create different batches
|
||||
3. Conflict Detection: Same key with different parameters raises ConflictError
|
||||
|
||||
Tests by Category:
|
||||
|
||||
1. Core Idempotency:
|
||||
- test_idempotent_batch_creation_same_params
|
||||
- test_idempotent_batch_creation_metadata_order_independence
|
||||
|
||||
2. Parameter Independence:
|
||||
- test_non_idempotent_behavior_without_key
|
||||
- test_different_idempotency_keys_create_different_batches
|
||||
|
||||
3. Conflict Detection:
|
||||
- test_same_idempotency_key_different_params_conflict (parametrized: input_file_id, metadata values, metadata None vs {})
|
||||
|
||||
Key Behaviors Tested:
|
||||
- Idempotent batch creation when idempotency_key provided with identical parameters
|
||||
- Metadata order independence for consistent batch ID generation
|
||||
- Non-idempotent behavior when no idempotency_key provided (random UUIDs)
|
||||
- Conflict detection for parameter mismatches with same idempotency key
|
||||
- Deterministic ID generation based solely on idempotency key
|
||||
- Proper error handling with detailed conflict messages including key and error codes
|
||||
- Protection against idempotency key reuse with different request parameters
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.common.errors import ConflictError
|
||||
|
||||
|
||||
class TestReferenceBatchesIdempotency:
|
||||
"""Test suite for idempotency functionality in the reference implementation."""
|
||||
|
||||
async def test_idempotent_batch_creation_same_params(self, provider, sample_batch_data):
|
||||
"""Test that creating batches with identical parameters returns the same batch when idempotency_key is provided."""
|
||||
|
||||
del sample_batch_data["metadata"]
|
||||
|
||||
batch1 = await provider.create_batch(
|
||||
**sample_batch_data,
|
||||
metadata={"test": "value1", "other": "value2"},
|
||||
idempotency_key="unique-token-1",
|
||||
)
|
||||
|
||||
# sleep for 1 second to allow created_at timestamps to be different
|
||||
await asyncio.sleep(1)
|
||||
|
||||
batch2 = await provider.create_batch(
|
||||
**sample_batch_data,
|
||||
metadata={"other": "value2", "test": "value1"}, # Different order
|
||||
idempotency_key="unique-token-1",
|
||||
)
|
||||
|
||||
assert batch1.id == batch2.id
|
||||
assert batch1.input_file_id == batch2.input_file_id
|
||||
assert batch1.metadata == batch2.metadata
|
||||
assert batch1.created_at == batch2.created_at
|
||||
|
||||
async def test_different_idempotency_keys_create_different_batches(self, provider, sample_batch_data):
|
||||
"""Test that different idempotency keys create different batches even with same params."""
|
||||
batch1 = await provider.create_batch(
|
||||
**sample_batch_data,
|
||||
idempotency_key="token-A",
|
||||
)
|
||||
|
||||
batch2 = await provider.create_batch(
|
||||
**sample_batch_data,
|
||||
idempotency_key="token-B",
|
||||
)
|
||||
|
||||
assert batch1.id != batch2.id
|
||||
|
||||
async def test_non_idempotent_behavior_without_key(self, provider, sample_batch_data):
|
||||
"""Test that batches without idempotency key create unique batches even with identical parameters."""
|
||||
batch1 = await provider.create_batch(**sample_batch_data)
|
||||
|
||||
batch2 = await provider.create_batch(**sample_batch_data)
|
||||
|
||||
assert batch1.id != batch2.id
|
||||
assert batch1.input_file_id == batch2.input_file_id
|
||||
assert batch1.endpoint == batch2.endpoint
|
||||
assert batch1.completion_window == batch2.completion_window
|
||||
assert batch1.metadata == batch2.metadata
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"param_name,first_value,second_value",
|
||||
[
|
||||
("input_file_id", "file_001", "file_002"),
|
||||
("metadata", {"test": "value1"}, {"test": "value2"}),
|
||||
("metadata", None, {}),
|
||||
],
|
||||
)
|
||||
async def test_same_idempotency_key_different_params_conflict(
|
||||
self, provider, sample_batch_data, param_name, first_value, second_value
|
||||
):
|
||||
"""Test that same idempotency_key with different parameters raises conflict error."""
|
||||
sample_batch_data["idempotency_key"] = "same-token"
|
||||
|
||||
sample_batch_data[param_name] = first_value
|
||||
|
||||
batch1 = await provider.create_batch(**sample_batch_data)
|
||||
|
||||
with pytest.raises(ConflictError, match="Idempotency key.*was previously used with different parameters"):
|
||||
sample_batch_data[param_name] = second_value
|
||||
await provider.create_batch(**sample_batch_data)
|
||||
|
||||
retrieved_batch = await provider.retrieve_batch(batch1.id)
|
||||
assert retrieved_batch.id == batch1.id
|
||||
assert getattr(retrieved_batch, param_name) == first_value
|
251
tests/unit/providers/files/test_s3_files.py
Normal file
251
tests/unit/providers/files/test_s3_files.py
Normal file
|
@ -0,0 +1,251 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import boto3
|
||||
import pytest
|
||||
from botocore.exceptions import ClientError
|
||||
from moto import mock_aws
|
||||
|
||||
from llama_stack.apis.common.errors import ResourceNotFoundError
|
||||
from llama_stack.apis.files import OpenAIFilePurpose
|
||||
from llama_stack.providers.remote.files.s3 import (
|
||||
S3FilesImplConfig,
|
||||
get_adapter_impl,
|
||||
)
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
|
||||
|
||||
|
||||
class MockUploadFile:
|
||||
def __init__(self, content: bytes, filename: str, content_type: str = "text/plain"):
|
||||
self.content = content
|
||||
self.filename = filename
|
||||
self.content_type = content_type
|
||||
|
||||
async def read(self):
|
||||
return self.content
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def s3_config(tmp_path):
|
||||
db_path = tmp_path / "s3_files_metadata.db"
|
||||
|
||||
return S3FilesImplConfig(
|
||||
bucket_name="test-bucket",
|
||||
region="not-a-region",
|
||||
auto_create_bucket=True,
|
||||
metadata_store=SqliteSqlStoreConfig(db_path=db_path.as_posix()),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def s3_client():
|
||||
"""Create a mocked S3 client for testing."""
|
||||
# we use `with mock_aws()` because @mock_aws decorator does not support being a generator
|
||||
with mock_aws():
|
||||
# must yield or the mock will be reset before it is used
|
||||
yield boto3.client("s3")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def s3_provider(s3_config, s3_client):
|
||||
"""Create an S3 files provider with mocked S3 for testing."""
|
||||
provider = await get_adapter_impl(s3_config, {})
|
||||
yield provider
|
||||
await provider.shutdown()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_text_file():
|
||||
content = b"Hello, this is a test file for the S3 Files API!"
|
||||
return MockUploadFile(content, "sample_text_file.txt")
|
||||
|
||||
|
||||
class TestS3FilesImpl:
|
||||
"""Test suite for S3 Files implementation."""
|
||||
|
||||
async def test_upload_file(self, s3_provider, sample_text_file, s3_client, s3_config):
|
||||
"""Test successful file upload."""
|
||||
sample_text_file.filename = "test_upload_file"
|
||||
result = await s3_provider.openai_upload_file(
|
||||
file=sample_text_file,
|
||||
purpose=OpenAIFilePurpose.ASSISTANTS,
|
||||
)
|
||||
|
||||
assert result.filename == sample_text_file.filename
|
||||
assert result.purpose == OpenAIFilePurpose.ASSISTANTS
|
||||
assert result.bytes == len(sample_text_file.content)
|
||||
assert result.id.startswith("file-")
|
||||
|
||||
# Verify file exists in S3 backend
|
||||
response = s3_client.head_object(Bucket=s3_config.bucket_name, Key=result.id)
|
||||
assert response["ResponseMetadata"]["HTTPStatusCode"] == 200
|
||||
|
||||
async def test_list_files_empty(self, s3_provider):
|
||||
"""Test listing files when no files exist."""
|
||||
result = await s3_provider.openai_list_files()
|
||||
|
||||
assert len(result.data) == 0
|
||||
assert not result.has_more
|
||||
assert result.first_id == ""
|
||||
assert result.last_id == ""
|
||||
|
||||
async def test_retrieve_file(self, s3_provider, sample_text_file):
|
||||
"""Test retrieving file metadata."""
|
||||
sample_text_file.filename = "test_retrieve_file"
|
||||
uploaded = await s3_provider.openai_upload_file(
|
||||
file=sample_text_file,
|
||||
purpose=OpenAIFilePurpose.ASSISTANTS,
|
||||
)
|
||||
|
||||
retrieved = await s3_provider.openai_retrieve_file(uploaded.id)
|
||||
|
||||
assert retrieved.id == uploaded.id
|
||||
assert retrieved.filename == uploaded.filename
|
||||
assert retrieved.purpose == uploaded.purpose
|
||||
assert retrieved.bytes == uploaded.bytes
|
||||
|
||||
async def test_retrieve_file_content(self, s3_provider, sample_text_file):
|
||||
"""Test retrieving file content."""
|
||||
sample_text_file.filename = "test_retrieve_file_content"
|
||||
uploaded = await s3_provider.openai_upload_file(
|
||||
file=sample_text_file,
|
||||
purpose=OpenAIFilePurpose.ASSISTANTS,
|
||||
)
|
||||
|
||||
response = await s3_provider.openai_retrieve_file_content(uploaded.id)
|
||||
|
||||
assert response.body == sample_text_file.content
|
||||
assert response.headers["Content-Disposition"] == f'attachment; filename="{sample_text_file.filename}"'
|
||||
|
||||
async def test_delete_file(self, s3_provider, sample_text_file, s3_config, s3_client):
|
||||
"""Test deleting a file."""
|
||||
sample_text_file.filename = "test_delete_file"
|
||||
uploaded = await s3_provider.openai_upload_file(
|
||||
file=sample_text_file,
|
||||
purpose=OpenAIFilePurpose.ASSISTANTS,
|
||||
)
|
||||
|
||||
delete_response = await s3_provider.openai_delete_file(uploaded.id)
|
||||
|
||||
assert delete_response.id == uploaded.id
|
||||
assert delete_response.deleted is True
|
||||
|
||||
with pytest.raises(ResourceNotFoundError, match="not found"):
|
||||
await s3_provider.openai_retrieve_file(uploaded.id)
|
||||
|
||||
# Verify file is gone from S3 backend
|
||||
with pytest.raises(ClientError) as exc_info:
|
||||
s3_client.head_object(Bucket=s3_config.bucket_name, Key=uploaded.id)
|
||||
assert exc_info.value.response["Error"]["Code"] == "404"
|
||||
|
||||
async def test_list_files(self, s3_provider, sample_text_file):
|
||||
"""Test listing files after uploading some."""
|
||||
sample_text_file.filename = "test_list_files_with_content_file1"
|
||||
file1 = await s3_provider.openai_upload_file(
|
||||
file=sample_text_file,
|
||||
purpose=OpenAIFilePurpose.ASSISTANTS,
|
||||
)
|
||||
|
||||
file2_content = MockUploadFile(b"Second file content", "test_list_files_with_content_file2")
|
||||
file2 = await s3_provider.openai_upload_file(
|
||||
file=file2_content,
|
||||
purpose=OpenAIFilePurpose.BATCH,
|
||||
)
|
||||
|
||||
result = await s3_provider.openai_list_files()
|
||||
|
||||
assert len(result.data) == 2
|
||||
file_ids = {f.id for f in result.data}
|
||||
assert file1.id in file_ids
|
||||
assert file2.id in file_ids
|
||||
|
||||
async def test_list_files_with_purpose_filter(self, s3_provider, sample_text_file):
|
||||
"""Test listing files with purpose filter."""
|
||||
sample_text_file.filename = "test_list_files_with_purpose_filter_file1"
|
||||
file1 = await s3_provider.openai_upload_file(
|
||||
file=sample_text_file,
|
||||
purpose=OpenAIFilePurpose.ASSISTANTS,
|
||||
)
|
||||
|
||||
file2_content = MockUploadFile(b"Batch file content", "test_list_files_with_purpose_filter_file2")
|
||||
await s3_provider.openai_upload_file(
|
||||
file=file2_content,
|
||||
purpose=OpenAIFilePurpose.BATCH,
|
||||
)
|
||||
|
||||
result = await s3_provider.openai_list_files(purpose=OpenAIFilePurpose.ASSISTANTS)
|
||||
|
||||
assert len(result.data) == 1
|
||||
assert result.data[0].id == file1.id
|
||||
assert result.data[0].purpose == OpenAIFilePurpose.ASSISTANTS
|
||||
|
||||
async def test_nonexistent_file_retrieval(self, s3_provider):
|
||||
"""Test retrieving a non-existent file raises error."""
|
||||
with pytest.raises(ResourceNotFoundError, match="not found"):
|
||||
await s3_provider.openai_retrieve_file("file-nonexistent")
|
||||
|
||||
async def test_nonexistent_file_content_retrieval(self, s3_provider):
|
||||
"""Test retrieving content of a non-existent file raises error."""
|
||||
with pytest.raises(ResourceNotFoundError, match="not found"):
|
||||
await s3_provider.openai_retrieve_file_content("file-nonexistent")
|
||||
|
||||
async def test_nonexistent_file_deletion(self, s3_provider):
|
||||
"""Test deleting a non-existent file raises error."""
|
||||
with pytest.raises(ResourceNotFoundError, match="not found"):
|
||||
await s3_provider.openai_delete_file("file-nonexistent")
|
||||
|
||||
async def test_upload_file_without_filename(self, s3_provider, sample_text_file):
|
||||
"""Test uploading a file without a filename uses the fallback."""
|
||||
del sample_text_file.filename
|
||||
result = await s3_provider.openai_upload_file(
|
||||
file=sample_text_file,
|
||||
purpose=OpenAIFilePurpose.ASSISTANTS,
|
||||
)
|
||||
|
||||
assert result.purpose == OpenAIFilePurpose.ASSISTANTS
|
||||
assert result.bytes == len(sample_text_file.content)
|
||||
|
||||
retrieved = await s3_provider.openai_retrieve_file(result.id)
|
||||
assert retrieved.filename == result.filename
|
||||
|
||||
async def test_file_operations_when_s3_object_deleted(self, s3_provider, sample_text_file, s3_config, s3_client):
|
||||
"""Test file operations when S3 object is deleted but metadata exists (negative test)."""
|
||||
sample_text_file.filename = "test_orphaned_metadata"
|
||||
uploaded = await s3_provider.openai_upload_file(
|
||||
file=sample_text_file,
|
||||
purpose=OpenAIFilePurpose.ASSISTANTS,
|
||||
)
|
||||
|
||||
# Directly delete the S3 object from the backend
|
||||
s3_client.delete_object(Bucket=s3_config.bucket_name, Key=uploaded.id)
|
||||
|
||||
with pytest.raises(ResourceNotFoundError, match="not found") as exc_info:
|
||||
await s3_provider.openai_retrieve_file_content(uploaded.id)
|
||||
assert uploaded.id in str(exc_info).lower()
|
||||
|
||||
listed_files = await s3_provider.openai_list_files()
|
||||
assert uploaded.id not in [file.id for file in listed_files.data]
|
||||
|
||||
async def test_upload_file_s3_put_object_failure(self, s3_provider, sample_text_file, s3_config, s3_client):
|
||||
"""Test that put_object failure results in exception and no orphaned metadata."""
|
||||
sample_text_file.filename = "test_s3_put_object_failure"
|
||||
|
||||
def failing_put_object(*args, **kwargs):
|
||||
raise ClientError(
|
||||
error_response={"Error": {"Code": "SolarRadiation", "Message": "Bloop"}}, operation_name="PutObject"
|
||||
)
|
||||
|
||||
with patch.object(s3_provider.client, "put_object", side_effect=failing_put_object):
|
||||
with pytest.raises(RuntimeError, match="Failed to upload file to S3"):
|
||||
await s3_provider.openai_upload_file(
|
||||
file=sample_text_file,
|
||||
purpose=OpenAIFilePurpose.ASSISTANTS,
|
||||
)
|
||||
|
||||
files_list = await s3_provider.openai_list_files()
|
||||
assert len(files_list.data) == 0, "No file metadata should remain after failed upload"
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import logging # allow-direct-logging
|
||||
import threading
|
||||
import time
|
||||
from http.server import BaseHTTPRequestHandler, HTTPServer
|
||||
|
|
|
@ -24,6 +24,7 @@ from llama_stack.apis.inference import (
|
|||
from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
convert_message_to_openai_dict,
|
||||
convert_message_to_openai_dict_new,
|
||||
openai_messages_to_messages,
|
||||
)
|
||||
|
||||
|
@ -182,3 +183,42 @@ def test_user_message_accepts_images():
|
|||
assert len(msg.content) == 2
|
||||
assert msg.content[0].text == "Describe this image:"
|
||||
assert msg.content[1].image_url.url == "http://example.com/image.jpg"
|
||||
|
||||
|
||||
async def test_convert_message_to_openai_dict_new_user_message():
|
||||
"""Test convert_message_to_openai_dict_new with UserMessage."""
|
||||
message = UserMessage(content="Hello, world!", role="user")
|
||||
result = await convert_message_to_openai_dict_new(message)
|
||||
|
||||
assert result["role"] == "user"
|
||||
assert result["content"] == "Hello, world!"
|
||||
|
||||
|
||||
async def test_convert_message_to_openai_dict_new_completion_message_with_tool_calls():
|
||||
"""Test convert_message_to_openai_dict_new with CompletionMessage containing tool calls."""
|
||||
message = CompletionMessage(
|
||||
content="I'll help you find the weather.",
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
call_id="call_123",
|
||||
tool_name="get_weather",
|
||||
arguments={"city": "Sligo"},
|
||||
arguments_json='{"city": "Sligo"}',
|
||||
)
|
||||
],
|
||||
stop_reason=StopReason.end_of_turn,
|
||||
)
|
||||
result = await convert_message_to_openai_dict_new(message)
|
||||
|
||||
# This would have failed with "Cannot instantiate typing.Union" before the fix
|
||||
assert result["role"] == "assistant"
|
||||
assert result["content"] == "I'll help you find the weather."
|
||||
assert "tool_calls" in result
|
||||
assert result["tool_calls"] is not None
|
||||
assert len(result["tool_calls"]) == 1
|
||||
|
||||
tool_call = result["tool_calls"][0]
|
||||
assert tool_call.id == "call_123"
|
||||
assert tool_call.type == "function"
|
||||
assert tool_call.function.name == "get_weather"
|
||||
assert tool_call.function.arguments == '{"city": "Sligo"}'
|
||||
|
|
105
tests/unit/server/test_cors.py
Normal file
105
tests/unit/server/test_cors.py
Normal file
|
@ -0,0 +1,105 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.core.datatypes import CORSConfig, process_cors_config
|
||||
|
||||
|
||||
def test_cors_config_defaults():
|
||||
config = CORSConfig()
|
||||
|
||||
assert config.allow_origins == []
|
||||
assert config.allow_origin_regex is None
|
||||
assert config.allow_methods == ["OPTIONS"]
|
||||
assert config.allow_headers == []
|
||||
assert config.allow_credentials is False
|
||||
assert config.expose_headers == []
|
||||
assert config.max_age == 600
|
||||
|
||||
|
||||
def test_cors_config_explicit_config():
|
||||
config = CORSConfig(
|
||||
allow_origins=["https://example.com"], allow_credentials=True, max_age=3600, allow_methods=["GET", "POST"]
|
||||
)
|
||||
|
||||
assert config.allow_origins == ["https://example.com"]
|
||||
assert config.allow_credentials is True
|
||||
assert config.max_age == 3600
|
||||
assert config.allow_methods == ["GET", "POST"]
|
||||
|
||||
|
||||
def test_cors_config_regex():
|
||||
config = CORSConfig(allow_origins=[], allow_origin_regex=r"https?://localhost:\d+")
|
||||
|
||||
assert config.allow_origins == []
|
||||
assert config.allow_origin_regex == r"https?://localhost:\d+"
|
||||
|
||||
|
||||
def test_cors_config_wildcard_credentials_error():
|
||||
with pytest.raises(ValueError, match="Cannot use wildcard origins with credentials enabled"):
|
||||
CORSConfig(allow_origins=["*"], allow_credentials=True)
|
||||
|
||||
with pytest.raises(ValueError, match="Cannot use wildcard origins with credentials enabled"):
|
||||
CORSConfig(allow_origins=["https://example.com", "*"], allow_credentials=True)
|
||||
|
||||
|
||||
def test_process_cors_config_false():
|
||||
result = process_cors_config(False)
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_process_cors_config_true():
|
||||
result = process_cors_config(True)
|
||||
|
||||
assert isinstance(result, CORSConfig)
|
||||
assert result.allow_origins == []
|
||||
assert result.allow_origin_regex == r"https?://localhost:\d+"
|
||||
assert result.allow_credentials is False
|
||||
expected_methods = ["GET", "POST", "PUT", "DELETE", "OPTIONS"]
|
||||
for method in expected_methods:
|
||||
assert method in result.allow_methods
|
||||
|
||||
|
||||
def test_process_cors_config_passthrough():
|
||||
original = CORSConfig(allow_origins=["https://example.com"], allow_methods=["GET"])
|
||||
result = process_cors_config(original)
|
||||
|
||||
assert result is original
|
||||
|
||||
|
||||
def test_process_cors_config_invalid_type():
|
||||
with pytest.raises(ValueError, match="Expected bool or CORSConfig, got str"):
|
||||
process_cors_config("invalid")
|
||||
|
||||
|
||||
def test_cors_config_model_dump():
|
||||
cors_config = CORSConfig(
|
||||
allow_origins=["https://example.com"],
|
||||
allow_methods=["GET", "POST"],
|
||||
allow_headers=["Content-Type"],
|
||||
allow_credentials=True,
|
||||
max_age=3600,
|
||||
)
|
||||
|
||||
config_dict = cors_config.model_dump()
|
||||
|
||||
assert config_dict["allow_origins"] == ["https://example.com"]
|
||||
assert config_dict["allow_methods"] == ["GET", "POST"]
|
||||
assert config_dict["allow_headers"] == ["Content-Type"]
|
||||
assert config_dict["allow_credentials"] is True
|
||||
assert config_dict["max_age"] == 3600
|
||||
|
||||
expected_keys = {
|
||||
"allow_origins",
|
||||
"allow_origin_regex",
|
||||
"allow_methods",
|
||||
"allow_headers",
|
||||
"allow_credentials",
|
||||
"expose_headers",
|
||||
"max_age",
|
||||
}
|
||||
assert set(config_dict.keys()) == expected_keys
|
Loading…
Add table
Add a link
Reference in a new issue