Merge branch 'main' of https://github.com/meta-llama/llama-stack into register_custom_model

This commit is contained in:
raspawar 2025-04-24 21:44:32 +05:30
commit 0990f60dad
74 changed files with 4854 additions and 1869 deletions

View file

@ -0,0 +1,9 @@
version: '2'
distribution_spec:
description: Custom distro for CI tests
providers:
inference:
- remote::custom_ollama
image_type: container
image_name: ci-test
external_providers_dir: /tmp/providers.d

View file

@ -1,6 +1,6 @@
adapter:
adapter_type: custom_ollama
pip_packages: ["ollama", "aiohttp"]
pip_packages: ["ollama", "aiohttp", "tests/external-provider/llama-stack-provider-ollama"]
config_class: llama_stack_provider_ollama.config.OllamaImplConfig
module: llama_stack_provider_ollama
api_dependencies: []

View file

@ -1,14 +1,10 @@
version: '2'
image_name: ollama
apis:
- agents
- datasetio
- eval
- inference
- safety
- scoring
- telemetry
- tool_runtime
- datasetio
- vector_io
providers:
inference:
@ -24,19 +20,6 @@ providers:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/faiss_store.db
safety:
- provider_id: llama-guard
provider_type: inline::llama-guard
config:
excluded_categories: []
agents:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
persistence_store:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/agents_store.db
telemetry:
- provider_id: meta-reference
provider_type: inline::meta-reference
@ -44,14 +27,6 @@ providers:
service_name: ${env.OTEL_SERVICE_NAME:llama-stack}
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/ollama/trace_store.db}
eval:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
kvstore:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/meta_reference_eval.db
datasetio:
- provider_id: huggingface
provider_type: remote::huggingface
@ -67,17 +42,6 @@ providers:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/localfs_datasetio.db
scoring:
- provider_id: basic
provider_type: inline::basic
config: {}
- provider_id: llm-as-judge
provider_type: inline::llm-as-judge
config: {}
- provider_id: braintrust
provider_type: inline::braintrust
config:
openai_api_key: ${env.OPENAI_API_KEY:}
tool_runtime:
- provider_id: brave-search
provider_type: remote::brave-search

View file

@ -115,6 +115,70 @@ def test_agent_simple(llama_stack_client_with_mocked_inference, agent_config):
assert "I can't" in logs_str
def test_agent_name(llama_stack_client, text_model_id):
agent_name = f"test-agent-{uuid4()}"
try:
agent = Agent(
llama_stack_client,
model=text_model_id,
instructions="You are a helpful assistant",
name=agent_name,
)
except TypeError:
agent = Agent(
llama_stack_client,
model=text_model_id,
instructions="You are a helpful assistant",
)
return
session_id = agent.create_session(f"test-session-{uuid4()}")
agent.create_turn(
messages=[
{
"role": "user",
"content": "Give me a sentence that contains the word: hello",
}
],
session_id=session_id,
stream=False,
)
all_spans = []
for span in llama_stack_client.telemetry.query_spans(
attribute_filters=[
{"key": "session_id", "op": "eq", "value": session_id},
],
attributes_to_return=["input", "output", "agent_name", "agent_id", "session_id"],
):
all_spans.append(span.attributes)
agent_name_spans = []
for span in llama_stack_client.telemetry.query_spans(
attribute_filters=[],
attributes_to_return=["agent_name"],
):
if "agent_name" in span.attributes:
agent_name_spans.append(span.attributes)
agent_logs = []
for span in llama_stack_client.telemetry.query_spans(
attribute_filters=[
{"key": "agent_name", "op": "eq", "value": agent_name},
],
attributes_to_return=["input", "output", "agent_name"],
):
if "output" in span.attributes and span.attributes["output"] != "no shields":
agent_logs.append(span.attributes)
assert len(agent_logs) == 1
assert agent_logs[0]["agent_name"] == agent_name
assert "Give me a sentence that contains the word: hello" in agent_logs[0]["input"]
assert "hello" in agent_logs[0]["output"].lower()
def test_tool_config(llama_stack_client_with_mocked_inference, agent_config):
common_params = dict(
model="meta-llama/Llama-3.2-3B-Instruct",

View file

@ -31,6 +31,7 @@ def data_url_from_file(file_path: str) -> str:
return data_url
@pytest.mark.skip(reason="flaky. Couldn't find 'llamastack/simpleqa' on the Hugging Face Hub")
@pytest.mark.parametrize(
"purpose, source, provider_id, limit",
[

View file

@ -0,0 +1,40 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pathlib import Path
from llama_stack.cli.stack._build import (
_run_stack_build_command_from_build_config,
)
from llama_stack.distribution.datatypes import BuildConfig, DistributionSpec
from llama_stack.distribution.utils.image_types import LlamaStackImageType
def test_container_build_passes_path(monkeypatch, tmp_path):
called_with = {}
def spy_build_image(cfg, build_file_path, image_name, template_or_config, run_config=None):
called_with["path"] = template_or_config
called_with["run_config"] = run_config
return 0
monkeypatch.setattr(
"llama_stack.cli.stack._build.build_image",
spy_build_image,
raising=True,
)
cfg = BuildConfig(
image_type=LlamaStackImageType.CONTAINER.value,
distribution_spec=DistributionSpec(providers={}, description=""),
)
_run_stack_build_command_from_build_config(cfg, image_name="dummy")
assert "path" in called_with
assert isinstance(called_with["path"], str)
assert Path(called_with["path"]).exists()
assert called_with["run_config"] is None

View file

@ -216,35 +216,48 @@ class TestNvidiaPostTraining(unittest.TestCase):
)
def test_get_training_job_status(self):
self.mock_make_request.return_value = {
"created_at": "2024-12-09T04:06:28.580220",
"updated_at": "2024-12-09T04:21:19.852832",
"status": "completed",
"steps_completed": 1210,
"epochs_completed": 2,
"percentage_done": 100.0,
"best_epoch": 2,
"train_loss": 1.718016266822815,
"val_loss": 1.8661999702453613,
}
customizer_status_to_job_status = [
("running", "in_progress"),
("completed", "completed"),
("failed", "failed"),
("cancelled", "cancelled"),
("pending", "scheduled"),
("unknown", "scheduled"),
]
job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2"
for customizer_status, expected_status in customizer_status_to_job_status:
with self.subTest(customizer_status=customizer_status, expected_status=expected_status):
self.mock_make_request.return_value = {
"created_at": "2024-12-09T04:06:28.580220",
"updated_at": "2024-12-09T04:21:19.852832",
"status": customizer_status,
"steps_completed": 1210,
"epochs_completed": 2,
"percentage_done": 100.0,
"best_epoch": 2,
"train_loss": 1.718016266822815,
"val_loss": 1.8661999702453613,
}
status = self.run_async(self.adapter.get_training_job_status(job_uuid=job_id))
job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2"
assert isinstance(status, NvidiaPostTrainingJobStatusResponse)
assert status.status.value == "completed"
assert status.steps_completed == 1210
assert status.epochs_completed == 2
assert status.percentage_done == 100.0
assert status.best_epoch == 2
assert status.train_loss == 1.718016266822815
assert status.val_loss == 1.8661999702453613
status = self.run_async(self.adapter.get_training_job_status(job_uuid=job_id))
self.mock_make_request.assert_called_once()
self._assert_request(
self.mock_make_request, "GET", f"/v1/customization/jobs/{job_id}/status", expected_params={"job_id": job_id}
)
assert isinstance(status, NvidiaPostTrainingJobStatusResponse)
assert status.status.value == expected_status
assert status.steps_completed == 1210
assert status.epochs_completed == 2
assert status.percentage_done == 100.0
assert status.best_epoch == 2
assert status.train_loss == 1.718016266822815
assert status.val_loss == 1.8661999702453613
self._assert_request(
self.mock_make_request,
"GET",
f"/v1/customization/jobs/{job_id}/status",
expected_params={"job_id": job_id},
)
def test_get_training_jobs(self):
job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2"

View file

@ -0,0 +1,55 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import pytest
from llama_stack.distribution.server.server import create_sse_event, sse_generator
@pytest.mark.asyncio
async def test_sse_generator_basic():
# An AsyncIterator wrapped in an Awaitable, just like our web methods
async def async_event_gen():
async def event_gen():
yield "Test event 1"
yield "Test event 2"
return event_gen()
sse_gen = sse_generator(async_event_gen())
assert sse_gen is not None
# Test that the events are streamed correctly
seen_events = []
async for event in sse_gen:
seen_events.append(event)
assert len(seen_events) == 2
assert seen_events[0] == create_sse_event("Test event 1")
assert seen_events[1] == create_sse_event("Test event 2")
@pytest.mark.asyncio
async def test_sse_generator_client_disconnected():
# An AsyncIterator wrapped in an Awaitable, just like our web methods
async def async_event_gen():
async def event_gen():
yield "Test event 1"
# Simulate a client disconnect before emitting event 2
raise asyncio.CancelledError()
return event_gen()
sse_gen = sse_generator(async_event_gen())
assert sse_gen is not None
# Start reading the events, ensuring this doesn't raise an exception
seen_events = []
async for event in sse_gen:
seen_events.append(event)
assert len(seen_events) == 1
assert seen_events[0] == create_sse_event("Test event 1")

View file

@ -8,29 +8,44 @@ This framework allows you to run the same set of verification tests against diff
## Features
The verification suite currently tests:
The verification suite currently tests the following in both streaming and non-streaming modes:
- Basic chat completions (streaming and non-streaming)
- Basic chat completions
- Image input capabilities
- Structured JSON output formatting
- Tool calling functionality
## Report
The lastest report can be found at [REPORT.md](REPORT.md).
To update the report, ensure you have the API keys set,
```bash
export OPENAI_API_KEY=<your_openai_api_key>
export FIREWORKS_API_KEY=<your_fireworks_api_key>
export TOGETHER_API_KEY=<your_together_api_key>
```
then run
```bash
uv run --with-editable ".[dev]" python tests/verifications/generate_report.py --run-tests
```
## Running Tests
To run the verification tests, use pytest with the following parameters:
```bash
cd llama-stack
pytest tests/verifications/openai --provider=<provider-name>
pytest tests/verifications/openai_api --provider=<provider-name>
```
Example:
```bash
# Run all tests
pytest tests/verifications/openai --provider=together
pytest tests/verifications/openai_api --provider=together
# Only run tests with Llama 4 models
pytest tests/verifications/openai --provider=together -k 'Llama-4'
pytest tests/verifications/openai_api --provider=together -k 'Llama-4'
```
### Parameters
@ -41,23 +56,22 @@ pytest tests/verifications/openai --provider=together -k 'Llama-4'
## Supported Providers
The verification suite currently supports:
- OpenAI
- Fireworks
- Together
- Groq
- Cerebras
The verification suite supports any provider with an OpenAI compatible endpoint.
See `tests/verifications/conf/` for the list of supported providers.
To run on a new provider, simply add a new yaml file to the `conf/` directory with the provider config. See `tests/verifications/conf/together.yaml` for an example.
## Adding New Test Cases
To add new test cases, create appropriate JSON files in the `openai/fixtures/test_cases/` directory following the existing patterns.
To add new test cases, create appropriate JSON files in the `openai_api/fixtures/test_cases/` directory following the existing patterns.
## Structure
- `__init__.py` - Marks the directory as a Python package
- `conftest.py` - Global pytest configuration and fixtures
- `openai/` - Tests specific to OpenAI-compatible APIs
- `conf/` - Provider-specific configuration files
- `openai_api/` - Tests specific to OpenAI-compatible APIs
- `fixtures/` - Test fixtures and utilities
- `fixtures.py` - Provider-specific fixtures
- `load.py` - Utilities for loading test cases

View file

@ -1,6 +1,6 @@
# Test Results Report
*Generated on: 2025-04-14 18:11:37*
*Generated on: 2025-04-17 12:42:33*
*This report was generated by running `python tests/verifications/generate_report.py`*
@ -15,22 +15,74 @@
| Provider | Pass Rate | Tests Passed | Total Tests |
| --- | --- | --- | --- |
| Together | 48.7% | 37 | 76 |
| Fireworks | 47.4% | 36 | 76 |
| Openai | 100.0% | 52 | 52 |
| Meta_reference | 100.0% | 28 | 28 |
| Together | 50.0% | 40 | 80 |
| Fireworks | 50.0% | 40 | 80 |
| Openai | 100.0% | 56 | 56 |
## Meta_reference
*Tests run on: 2025-04-17 12:37:11*
```bash
# Run all tests for this provider:
pytest tests/verifications/openai_api/test_chat_completion.py --provider=meta_reference -v
# Example: Run only the 'stream=False' case of test_chat_multi_turn_multiple_images:
pytest tests/verifications/openai_api/test_chat_completion.py --provider=meta_reference -k "test_chat_multi_turn_multiple_images and stream=False"
```
**Model Key (Meta_reference)**
| Display Name | Full Model ID |
| --- | --- |
| Llama-4-Scout-Instruct | `meta-llama/Llama-4-Scout-17B-16E-Instruct` |
| Test | Llama-4-Scout-Instruct |
| --- | --- |
| test_chat_multi_turn_multiple_images (stream=False) | ✅ |
| test_chat_multi_turn_multiple_images (stream=True) | ✅ |
| test_chat_non_streaming_basic (earth) | ✅ |
| test_chat_non_streaming_basic (saturn) | ✅ |
| test_chat_non_streaming_image | ✅ |
| test_chat_non_streaming_multi_turn_tool_calling (add_product_tool) | ✅ |
| test_chat_non_streaming_multi_turn_tool_calling (compare_monthly_expense_tool) | ✅ |
| test_chat_non_streaming_multi_turn_tool_calling (get_then_create_event_tool) | ✅ |
| test_chat_non_streaming_multi_turn_tool_calling (text_then_weather_tool) | ✅ |
| test_chat_non_streaming_multi_turn_tool_calling (weather_tool_then_text) | ✅ |
| test_chat_non_streaming_structured_output (calendar) | ✅ |
| test_chat_non_streaming_structured_output (math) | ✅ |
| test_chat_non_streaming_tool_calling | ✅ |
| test_chat_non_streaming_tool_choice_none | ✅ |
| test_chat_non_streaming_tool_choice_required | ✅ |
| test_chat_streaming_basic (earth) | ✅ |
| test_chat_streaming_basic (saturn) | ✅ |
| test_chat_streaming_image | ✅ |
| test_chat_streaming_multi_turn_tool_calling (add_product_tool) | ✅ |
| test_chat_streaming_multi_turn_tool_calling (compare_monthly_expense_tool) | ✅ |
| test_chat_streaming_multi_turn_tool_calling (get_then_create_event_tool) | ✅ |
| test_chat_streaming_multi_turn_tool_calling (text_then_weather_tool) | ✅ |
| test_chat_streaming_multi_turn_tool_calling (weather_tool_then_text) | ✅ |
| test_chat_streaming_structured_output (calendar) | ✅ |
| test_chat_streaming_structured_output (math) | ✅ |
| test_chat_streaming_tool_calling | ✅ |
| test_chat_streaming_tool_choice_none | ✅ |
| test_chat_streaming_tool_choice_required | ✅ |
## Together
*Tests run on: 2025-04-14 18:08:14*
*Tests run on: 2025-04-17 12:27:45*
```bash
# Run all tests for this provider:
pytest tests/verifications/openai_api/test_chat_completion.py --provider=together -v
# Example: Run only the 'earth' case of test_chat_non_streaming_basic:
pytest tests/verifications/openai_api/test_chat_completion.py --provider=together -k "test_chat_non_streaming_basic and earth"
# Example: Run only the 'stream=False' case of test_chat_multi_turn_multiple_images:
pytest tests/verifications/openai_api/test_chat_completion.py --provider=together -k "test_chat_multi_turn_multiple_images and stream=False"
```
@ -45,11 +97,13 @@ pytest tests/verifications/openai_api/test_chat_completion.py --provider=togethe
| Test | Llama-3.3-70B-Instruct | Llama-4-Maverick-Instruct | Llama-4-Scout-Instruct |
| --- | --- | --- | --- |
| test_chat_multi_turn_multiple_images (stream=False) | ⚪ | ✅ | ✅ |
| test_chat_multi_turn_multiple_images (stream=True) | ⚪ | ❌ | ❌ |
| test_chat_non_streaming_basic (earth) | ✅ | ✅ | ✅ |
| test_chat_non_streaming_basic (saturn) | ✅ | ✅ | ✅ |
| test_chat_non_streaming_image | ⚪ | ✅ | ✅ |
| test_chat_non_streaming_multi_turn_tool_calling (add_product_tool) | ✅ | ✅ | ✅ |
| test_chat_non_streaming_multi_turn_tool_calling (compare_monthly_expense_tool) | | ✅ | ✅ |
| test_chat_non_streaming_multi_turn_tool_calling (compare_monthly_expense_tool) | | ✅ | ✅ |
| test_chat_non_streaming_multi_turn_tool_calling (get_then_create_event_tool) | ✅ | ❌ | ✅ |
| test_chat_non_streaming_multi_turn_tool_calling (text_then_weather_tool) | ❌ | ❌ | ❌ |
| test_chat_non_streaming_multi_turn_tool_calling (weather_tool_then_text) | ✅ | ✅ | ✅ |
@ -74,14 +128,14 @@ pytest tests/verifications/openai_api/test_chat_completion.py --provider=togethe
## Fireworks
*Tests run on: 2025-04-14 18:04:06*
*Tests run on: 2025-04-17 12:29:53*
```bash
# Run all tests for this provider:
pytest tests/verifications/openai_api/test_chat_completion.py --provider=fireworks -v
# Example: Run only the 'earth' case of test_chat_non_streaming_basic:
pytest tests/verifications/openai_api/test_chat_completion.py --provider=fireworks -k "test_chat_non_streaming_basic and earth"
# Example: Run only the 'stream=False' case of test_chat_multi_turn_multiple_images:
pytest tests/verifications/openai_api/test_chat_completion.py --provider=fireworks -k "test_chat_multi_turn_multiple_images and stream=False"
```
@ -96,6 +150,8 @@ pytest tests/verifications/openai_api/test_chat_completion.py --provider=firewor
| Test | Llama-3.3-70B-Instruct | Llama-4-Maverick-Instruct | Llama-4-Scout-Instruct |
| --- | --- | --- | --- |
| test_chat_multi_turn_multiple_images (stream=False) | ⚪ | ✅ | ✅ |
| test_chat_multi_turn_multiple_images (stream=True) | ⚪ | ✅ | ✅ |
| test_chat_non_streaming_basic (earth) | ✅ | ✅ | ✅ |
| test_chat_non_streaming_basic (saturn) | ✅ | ✅ | ✅ |
| test_chat_non_streaming_image | ⚪ | ✅ | ✅ |
@ -125,14 +181,14 @@ pytest tests/verifications/openai_api/test_chat_completion.py --provider=firewor
## Openai
*Tests run on: 2025-04-14 18:09:51*
*Tests run on: 2025-04-17 12:34:08*
```bash
# Run all tests for this provider:
pytest tests/verifications/openai_api/test_chat_completion.py --provider=openai -v
# Example: Run only the 'earth' case of test_chat_non_streaming_basic:
pytest tests/verifications/openai_api/test_chat_completion.py --provider=openai -k "test_chat_non_streaming_basic and earth"
# Example: Run only the 'stream=False' case of test_chat_multi_turn_multiple_images:
pytest tests/verifications/openai_api/test_chat_completion.py --provider=openai -k "test_chat_multi_turn_multiple_images and stream=False"
```
@ -146,6 +202,8 @@ pytest tests/verifications/openai_api/test_chat_completion.py --provider=openai
| Test | gpt-4o | gpt-4o-mini |
| --- | --- | --- |
| test_chat_multi_turn_multiple_images (stream=False) | ✅ | ✅ |
| test_chat_multi_turn_multiple_images (stream=True) | ✅ | ✅ |
| test_chat_non_streaming_basic (earth) | ✅ | ✅ |
| test_chat_non_streaming_basic (saturn) | ✅ | ✅ |
| test_chat_non_streaming_image | ✅ | ✅ |

View file

@ -8,3 +8,4 @@ test_exclusions:
llama-3.3-70b:
- test_chat_non_streaming_image
- test_chat_streaming_image
- test_chat_multi_turn_multiple_images

View file

@ -12,3 +12,4 @@ test_exclusions:
fireworks/llama-v3p3-70b-instruct:
- test_chat_non_streaming_image
- test_chat_streaming_image
- test_chat_multi_turn_multiple_images

View file

@ -12,3 +12,4 @@ test_exclusions:
accounts/fireworks/models/llama-v3p3-70b-instruct:
- test_chat_non_streaming_image
- test_chat_streaming_image
- test_chat_multi_turn_multiple_images

View file

@ -12,3 +12,4 @@ test_exclusions:
groq/llama-3.3-70b-versatile:
- test_chat_non_streaming_image
- test_chat_streaming_image
- test_chat_multi_turn_multiple_images

View file

@ -12,3 +12,4 @@ test_exclusions:
llama-3.3-70b-versatile:
- test_chat_non_streaming_image
- test_chat_streaming_image
- test_chat_multi_turn_multiple_images

View file

@ -0,0 +1,8 @@
# LLAMA_STACK_PORT=5002 llama stack run meta-reference-gpu --env INFERENCE_MODEL=meta-llama/Llama-4-Scout-17B-16E-Instruct --env INFERENCE_CHECKPOINT_DIR=<path_to_ckpt>
base_url: http://localhost:5002/v1/openai/v1
api_key_var: foo
models:
- meta-llama/Llama-4-Scout-17B-16E-Instruct
model_display_names:
meta-llama/Llama-4-Scout-17B-16E-Instruct: Llama-4-Scout-Instruct
test_exclusions: {}

View file

@ -12,3 +12,4 @@ test_exclusions:
together/meta-llama/Llama-3.3-70B-Instruct-Turbo:
- test_chat_non_streaming_image
- test_chat_streaming_image
- test_chat_multi_turn_multiple_images

View file

@ -12,3 +12,4 @@ test_exclusions:
meta-llama/Llama-3.3-70B-Instruct-Turbo:
- test_chat_non_streaming_image
- test_chat_streaming_image
- test_chat_multi_turn_multiple_images

View file

@ -3,14 +3,6 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# /// script
# requires-python = ">=3.10"
# dependencies = [
# "pytest-json-report",
# "pyyaml",
# ]
# ///
"""
Test Report Generator
@ -67,16 +59,11 @@ RESULTS_DIR.mkdir(exist_ok=True)
# Maximum number of test result files to keep per provider
MAX_RESULTS_PER_PROVIDER = 1
PROVIDER_ORDER = [
DEFAULT_PROVIDERS = [
"meta_reference",
"together",
"fireworks",
"groq",
"cerebras",
"openai",
"together-llama-stack",
"fireworks-llama-stack",
"groq-llama-stack",
"openai-llama-stack",
]
VERIFICATION_CONFIG = _load_all_verification_configs()
@ -142,6 +129,14 @@ def run_tests(provider, keyword=None):
return None
def run_multiple_tests(providers_to_run: list[str], keyword: str | None):
"""Runs tests for a list of providers."""
print(f"Running tests for providers: {', '.join(providers_to_run)}")
for provider in providers_to_run:
run_tests(provider.strip(), keyword=keyword)
print("Finished running tests.")
def parse_results(
result_file,
) -> Tuple[DefaultDict[str, DefaultDict[str, Dict[str, bool]]], DefaultDict[str, Set[str]], Set[str], str]:
@ -250,20 +245,6 @@ def parse_results(
return parsed_results, providers_in_file, tests_in_file, run_timestamp_str
def get_all_result_files_by_provider():
"""Get all test result files, keyed by provider."""
provider_results = {}
result_files = list(RESULTS_DIR.glob("*.json"))
for file in result_files:
provider = file.stem
if provider:
provider_results[provider] = file
return provider_results
def generate_report(
results_dict: Dict[str, Any],
providers: Dict[str, Set[str]],
@ -276,6 +257,7 @@ def generate_report(
Args:
results_dict: Aggregated results [provider][model][test_name] -> status.
providers: Dict of all providers and their models {provider: {models}}.
The order of keys in this dict determines the report order.
all_tests: Set of all test names found.
provider_timestamps: Dict of provider to timestamp when tests were run
output_file: Optional path to save the report.
@ -353,22 +335,17 @@ def generate_report(
passed_tests += 1
provider_totals[provider] = (provider_passed, provider_total)
# Add summary table (use passed-in providers dict)
# Add summary table (use the order from the providers dict keys)
report.append("| Provider | Pass Rate | Tests Passed | Total Tests |")
report.append("| --- | --- | --- | --- |")
for provider in [p for p in PROVIDER_ORDER if p in providers]: # Check against keys of passed-in dict
passed, total = provider_totals.get(provider, (0, 0))
pass_rate = f"{(passed / total * 100):.1f}%" if total > 0 else "N/A"
report.append(f"| {provider.capitalize()} | {pass_rate} | {passed} | {total} |")
for provider in [p for p in providers if p not in PROVIDER_ORDER]: # Check against keys of passed-in dict
# Iterate through providers in the order they appear in the input dict
for provider in providers_sorted.keys():
passed, total = provider_totals.get(provider, (0, 0))
pass_rate = f"{(passed / total * 100):.1f}%" if total > 0 else "N/A"
report.append(f"| {provider.capitalize()} | {pass_rate} | {passed} | {total} |")
report.append("\n")
for provider in sorted(
providers_sorted.keys(), key=lambda p: (PROVIDER_ORDER.index(p) if p in PROVIDER_ORDER else float("inf"), p)
):
for provider in providers_sorted.keys():
provider_models = providers_sorted[provider] # Use sorted models
if not provider_models:
continue
@ -461,60 +438,62 @@ def main():
"--providers",
type=str,
nargs="+",
help="Specify providers to test (comma-separated or space-separated, default: all)",
help="Specify providers to include/test (comma-separated or space-separated, default: uses DEFAULT_PROVIDERS)",
)
parser.add_argument("--output", type=str, help="Output file location (default: tests/verifications/REPORT.md)")
parser.add_argument("--k", type=str, help="Keyword expression to filter tests (passed to pytest -k)")
args = parser.parse_args()
all_results = {}
# Initialize collections to aggregate results in main
aggregated_providers = defaultdict(set)
final_providers_order = {} # Dictionary to store results, preserving processing order
aggregated_tests = set()
provider_timestamps = {}
if args.run_tests:
# Get list of available providers from command line or use detected providers
if args.providers:
# Handle both comma-separated and space-separated lists
test_providers = []
for provider_arg in args.providers:
# Split by comma if commas are present
if "," in provider_arg:
test_providers.extend(provider_arg.split(","))
else:
test_providers.append(provider_arg)
else:
# Default providers to test
test_providers = PROVIDER_ORDER
for provider in test_providers:
provider = provider.strip() # Remove any whitespace
result_file = run_tests(provider, keyword=args.k)
if result_file:
# Parse and aggregate results
parsed_results, providers_in_file, tests_in_file, run_timestamp = parse_results(result_file)
all_results.update(parsed_results)
for prov, models in providers_in_file.items():
aggregated_providers[prov].update(models)
if run_timestamp:
provider_timestamps[prov] = run_timestamp
aggregated_tests.update(tests_in_file)
# 1. Determine the desired list and order of providers
if args.providers:
desired_providers = []
for provider_arg in args.providers:
desired_providers.extend([p.strip() for p in provider_arg.split(",")])
else:
# Use existing results
provider_result_files = get_all_result_files_by_provider()
desired_providers = DEFAULT_PROVIDERS # Use default order/list
for result_file in provider_result_files.values():
# Parse and aggregate results
parsed_results, providers_in_file, tests_in_file, run_timestamp = parse_results(result_file)
all_results.update(parsed_results)
for prov, models in providers_in_file.items():
aggregated_providers[prov].update(models)
if run_timestamp:
provider_timestamps[prov] = run_timestamp
aggregated_tests.update(tests_in_file)
# 2. Run tests if requested (using the desired provider list)
if args.run_tests:
run_multiple_tests(desired_providers, args.k)
generate_report(all_results, aggregated_providers, aggregated_tests, provider_timestamps, args.output)
for provider in desired_providers:
# Construct the expected result file path directly
result_file = RESULTS_DIR / f"{provider}.json"
if result_file.exists(): # Check if the specific file exists
print(f"Loading results for {provider} from {result_file}")
try:
parsed_data = parse_results(result_file)
parsed_results, providers_in_file, tests_in_file, run_timestamp = parsed_data
all_results.update(parsed_results)
aggregated_tests.update(tests_in_file)
# Add models for this provider, ensuring it's added in the correct report order
if provider in providers_in_file:
if provider not in final_providers_order:
final_providers_order[provider] = set()
final_providers_order[provider].update(providers_in_file[provider])
if run_timestamp != "Unknown":
provider_timestamps[provider] = run_timestamp
else:
print(
f"Warning: Provider '{provider}' found in desired list but not within its result file data ({result_file})."
)
except Exception as e:
print(f"Error parsing results for provider {provider} from {result_file}: {e}")
else:
# Only print warning if we expected results (i.e., provider was in the desired list)
print(f"Result file for desired provider '{provider}' not found at {result_file}. Skipping.")
# 5. Generate the report using the filtered & ordered results
print(f"Final Provider Order for Report: {list(final_providers_order.keys())}")
generate_report(all_results, final_providers_order, aggregated_tests, provider_timestamps, args.output)
if __name__ == "__main__":

Binary file not shown.

After

Width:  |  Height:  |  Size: 108 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 148 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 139 KiB

View file

@ -15,6 +15,52 @@ test_chat_basic:
S?
role: user
output: Saturn
test_chat_input_validation:
test_name: test_chat_input_validation
test_params:
case:
- case_id: "messages_missing"
input:
messages: []
output:
error:
status_code: 400
- case_id: "messages_role_invalid"
input:
messages:
- content: Which planet do humans live on?
role: fake_role
output:
error:
status_code: 400
- case_id: "tool_choice_invalid"
input:
messages:
- content: Which planet do humans live on?
role: user
tool_choice: invalid
output:
error:
status_code: 400
- case_id: "tool_choice_no_tools"
input:
messages:
- content: Which planet do humans live on?
role: user
tool_choice: required
output:
error:
status_code: 400
- case_id: "tools_type_invalid"
input:
messages:
- content: Which planet do humans live on?
role: user
tools:
- type: invalid
output:
error:
status_code: 400
test_chat_image:
test_name: test_chat_image
test_params:

View file

@ -4,19 +4,26 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import base64
import copy
import json
import re
from pathlib import Path
from typing import Any
import pytest
from openai import APIError
from pydantic import BaseModel
from tests.verifications.openai_api.fixtures.fixtures import _load_all_verification_configs
from tests.verifications.openai_api.fixtures.fixtures import (
_load_all_verification_configs,
)
from tests.verifications.openai_api.fixtures.load import load_test_cases
chat_completion_test_cases = load_test_cases("chat_completion")
THIS_DIR = Path(__file__).parent
def case_id_generator(case):
"""Generate a test ID from the case's 'case_id' field, or use a default."""
@ -69,6 +76,21 @@ def get_base_test_name(request):
return request.node.originalname
@pytest.fixture
def multi_image_data():
files = [
THIS_DIR / "fixtures/images/vision_test_1.jpg",
THIS_DIR / "fixtures/images/vision_test_2.jpg",
THIS_DIR / "fixtures/images/vision_test_3.jpg",
]
encoded_files = []
for file in files:
with open(file, "rb") as image_file:
base64_data = base64.b64encode(image_file.read()).decode("utf-8")
encoded_files.append(f"data:image/jpeg;base64,{base64_data}")
return encoded_files
# --- Test Functions ---
@ -115,6 +137,50 @@ def test_chat_streaming_basic(request, openai_client, model, provider, verificat
assert case["output"].lower() in content.lower()
@pytest.mark.parametrize(
"case",
chat_completion_test_cases["test_chat_input_validation"]["test_params"]["case"],
ids=case_id_generator,
)
def test_chat_non_streaming_error_handling(request, openai_client, model, provider, verification_config, case):
test_name_base = get_base_test_name(request)
if should_skip_test(verification_config, provider, model, test_name_base):
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
with pytest.raises(APIError) as e:
openai_client.chat.completions.create(
model=model,
messages=case["input"]["messages"],
stream=False,
tool_choice=case["input"]["tool_choice"] if "tool_choice" in case["input"] else None,
tools=case["input"]["tools"] if "tools" in case["input"] else None,
)
assert case["output"]["error"]["status_code"] == e.value.status_code
@pytest.mark.parametrize(
"case",
chat_completion_test_cases["test_chat_input_validation"]["test_params"]["case"],
ids=case_id_generator,
)
def test_chat_streaming_error_handling(request, openai_client, model, provider, verification_config, case):
test_name_base = get_base_test_name(request)
if should_skip_test(verification_config, provider, model, test_name_base):
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
with pytest.raises(APIError) as e:
response = openai_client.chat.completions.create(
model=model,
messages=case["input"]["messages"],
stream=True,
tool_choice=case["input"]["tool_choice"] if "tool_choice" in case["input"] else None,
tools=case["input"]["tools"] if "tools" in case["input"] else None,
)
for _chunk in response:
pass
assert str(case["output"]["error"]["status_code"]) in e.value.message
@pytest.mark.parametrize(
"case",
chat_completion_test_cases["test_chat_image"]["test_params"]["case"],
@ -272,7 +338,6 @@ def test_chat_non_streaming_tool_choice_required(request, openai_client, model,
tool_choice="required", # Force tool call
stream=False,
)
print(response)
assert response.choices[0].message.role == "assistant"
assert len(response.choices[0].message.tool_calls) > 0, "Expected tool call when tool_choice='required'"
@ -532,6 +597,86 @@ def test_chat_streaming_multi_turn_tool_calling(request, openai_client, model, p
)
@pytest.mark.parametrize("stream", [False, True], ids=["stream=False", "stream=True"])
def test_chat_multi_turn_multiple_images(
request, openai_client, model, provider, verification_config, multi_image_data, stream
):
test_name_base = get_base_test_name(request)
if should_skip_test(verification_config, provider, model, test_name_base):
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
messages_turn1 = [
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url": multi_image_data[0],
},
},
{
"type": "image_url",
"image_url": {
"url": multi_image_data[1],
},
},
{
"type": "text",
"text": "What furniture is in the first image that is not in the second image?",
},
],
},
]
# First API call
response1 = openai_client.chat.completions.create(
model=model,
messages=messages_turn1,
stream=stream,
)
if stream:
message_content1 = ""
for chunk in response1:
message_content1 += chunk.choices[0].delta.content or ""
else:
message_content1 = response1.choices[0].message.content
assert len(message_content1) > 0
assert any(expected in message_content1.lower().strip() for expected in {"chair", "table"}), message_content1
# Prepare messages for the second turn
messages_turn2 = messages_turn1 + [
{"role": "assistant", "content": message_content1},
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url": multi_image_data[2],
},
},
{"type": "text", "text": "What is in this image that is also in the first image?"},
],
},
]
# Second API call
response2 = openai_client.chat.completions.create(
model=model,
messages=messages_turn2,
stream=stream,
)
if stream:
message_content2 = ""
for chunk in response2:
message_content2 += chunk.choices[0].delta.content or ""
else:
message_content2 = response2.choices[0].message.content
assert len(message_content2) > 0
assert any(expected in message_content2.lower().strip() for expected in {"bed"}), message_content2
# --- Helper functions (structured output validation) ---

File diff suppressed because one or more lines are too long

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

File diff suppressed because one or more lines are too long