forked from phoenix-oss/llama-stack-mirror
feat: adds test suite to verify provider's OAI compat endpoints (#1901)
# What does this PR do? ## Test Plan pytest verifications/openai/test_chat_completion.py --provider together
This commit is contained in:
parent
7d9adf22ad
commit
bcbc56baa2
14 changed files with 9404 additions and 0 deletions
65
tests/verifications/README.md
Normal file
65
tests/verifications/README.md
Normal file
|
@ -0,0 +1,65 @@
|
|||
# Llama Stack Verifications
|
||||
|
||||
Llama Stack Verifications provide standardized test suites to ensure API compatibility and behavior consistency across different LLM providers. These tests help verify that different models and providers implement the expected interfaces and behaviors correctly.
|
||||
|
||||
## Overview
|
||||
|
||||
This framework allows you to run the same set of verification tests against different LLM providers' OpenAI-compatible endpoints (Fireworks, Together, Groq, Cerebras, etc., and OpenAI itself) to ensure they meet the expected behavior and interface standards.
|
||||
|
||||
## Features
|
||||
|
||||
The verification suite currently tests:
|
||||
|
||||
- Basic chat completions (streaming and non-streaming)
|
||||
- Image input capabilities
|
||||
- Structured JSON output formatting
|
||||
- Tool calling functionality
|
||||
|
||||
## Running Tests
|
||||
|
||||
To run the verification tests, use pytest with the following parameters:
|
||||
|
||||
```bash
|
||||
cd llama-stack
|
||||
pytest tests/verifications/openai --provider=<provider-name>
|
||||
```
|
||||
|
||||
Example:
|
||||
```bash
|
||||
# Run all tests
|
||||
pytest tests/verifications/openai --provider=together
|
||||
|
||||
# Only run tests with Llama 4 models
|
||||
pytest tests/verifications/openai --provider=together -k 'Llama-4'
|
||||
```
|
||||
|
||||
### Parameters
|
||||
|
||||
- `--provider`: The provider name (openai, fireworks, together, groq, cerebras, etc.)
|
||||
- `--base-url`: The base URL for the provider's API (optional - defaults to the standard URL for the specified provider)
|
||||
- `--api-key`: Your API key for the provider (optional - defaults to the standard API_KEY name for the specified provider)
|
||||
|
||||
## Supported Providers
|
||||
|
||||
The verification suite currently supports:
|
||||
- OpenAI
|
||||
- Fireworks
|
||||
- Together
|
||||
- Groq
|
||||
- Cerebras
|
||||
|
||||
## Adding New Test Cases
|
||||
|
||||
To add new test cases, create appropriate JSON files in the `openai/fixtures/test_cases/` directory following the existing patterns.
|
||||
|
||||
|
||||
## Structure
|
||||
|
||||
- `__init__.py` - Marks the directory as a Python package
|
||||
- `conftest.py` - Global pytest configuration and fixtures
|
||||
- `openai/` - Tests specific to OpenAI-compatible APIs
|
||||
- `fixtures/` - Test fixtures and utilities
|
||||
- `fixtures.py` - Provider-specific fixtures
|
||||
- `load.py` - Utilities for loading test cases
|
||||
- `test_cases/` - JSON test case definitions
|
||||
- `test_chat_completion.py` - Tests for chat completion APIs
|
88
tests/verifications/REPORT.md
Normal file
88
tests/verifications/REPORT.md
Normal file
|
@ -0,0 +1,88 @@
|
|||
# Test Results Report
|
||||
|
||||
*Generated on: 2025-04-08 21:14:02*
|
||||
|
||||
*This report was generated by running `python tests/verifications/generate_report.py`*
|
||||
|
||||
## Legend
|
||||
|
||||
- ✅ - Test passed
|
||||
- ❌ - Test failed
|
||||
- ⚪ - Test not applicable or not run for this model
|
||||
|
||||
|
||||
## Summary
|
||||
|
||||
| Provider | Pass Rate | Tests Passed | Total Tests |
|
||||
| --- | --- | --- | --- |
|
||||
| Together | 67.7% | 21 | 31 |
|
||||
| Fireworks | 90.3% | 28 | 31 |
|
||||
| Openai | 100.0% | 22 | 22 |
|
||||
|
||||
|
||||
|
||||
## Together
|
||||
|
||||
*Tests run on: 2025-04-08 16:19:59*
|
||||
|
||||
```bash
|
||||
pytest tests/verifications/openai/test_chat_completion.py --provider=together -v
|
||||
```
|
||||
|
||||
| Test | Llama-3.3-70B-Instruct | Llama-4-Maverick-17B-128E-Instruct | Llama-4-Scout-17B-16E-Instruct |
|
||||
| --- | --- | --- | --- |
|
||||
| test_chat_non_streaming_basic (case 0) | ✅ | ✅ | ✅ |
|
||||
| test_chat_non_streaming_basic (case 1) | ✅ | ✅ | ✅ |
|
||||
| test_chat_non_streaming_image (case 0) | ⚪ | ✅ | ✅ |
|
||||
| test_chat_non_streaming_structured_output (case 0) | ✅ | ✅ | ✅ |
|
||||
| test_chat_non_streaming_structured_output (case 1) | ✅ | ✅ | ✅ |
|
||||
| test_chat_non_streaming_tool_calling (case 0) | ✅ | ✅ | ✅ |
|
||||
| test_chat_streaming_basic (case 0) | ✅ | ❌ | ❌ |
|
||||
| test_chat_streaming_basic (case 1) | ✅ | ❌ | ❌ |
|
||||
| test_chat_streaming_image (case 0) | ⚪ | ❌ | ❌ |
|
||||
| test_chat_streaming_structured_output (case 0) | ✅ | ❌ | ❌ |
|
||||
| test_chat_streaming_structured_output (case 1) | ✅ | ❌ | ❌ |
|
||||
|
||||
## Fireworks
|
||||
|
||||
*Tests run on: 2025-04-08 16:18:28*
|
||||
|
||||
```bash
|
||||
pytest tests/verifications/openai/test_chat_completion.py --provider=fireworks -v
|
||||
```
|
||||
|
||||
| Test | Llama-3.3-70B-Instruct | Llama-4-Maverick-17B-128E-Instruct | Llama-4-Scout-17B-16E-Instruct |
|
||||
| --- | --- | --- | --- |
|
||||
| test_chat_non_streaming_basic (case 0) | ✅ | ✅ | ✅ |
|
||||
| test_chat_non_streaming_basic (case 1) | ✅ | ✅ | ✅ |
|
||||
| test_chat_non_streaming_image (case 0) | ⚪ | ✅ | ✅ |
|
||||
| test_chat_non_streaming_structured_output (case 0) | ✅ | ✅ | ✅ |
|
||||
| test_chat_non_streaming_structured_output (case 1) | ✅ | ✅ | ✅ |
|
||||
| test_chat_non_streaming_tool_calling (case 0) | ✅ | ❌ | ❌ |
|
||||
| test_chat_streaming_basic (case 0) | ✅ | ✅ | ✅ |
|
||||
| test_chat_streaming_basic (case 1) | ✅ | ✅ | ✅ |
|
||||
| test_chat_streaming_image (case 0) | ⚪ | ✅ | ✅ |
|
||||
| test_chat_streaming_structured_output (case 0) | ✅ | ✅ | ✅ |
|
||||
| test_chat_streaming_structured_output (case 1) | ❌ | ✅ | ✅ |
|
||||
|
||||
## Openai
|
||||
|
||||
*Tests run on: 2025-04-08 16:22:02*
|
||||
|
||||
```bash
|
||||
pytest tests/verifications/openai/test_chat_completion.py --provider=openai -v
|
||||
```
|
||||
|
||||
| Test | gpt-4o | gpt-4o-mini |
|
||||
| --- | --- | --- |
|
||||
| test_chat_non_streaming_basic (case 0) | ✅ | ✅ |
|
||||
| test_chat_non_streaming_basic (case 1) | ✅ | ✅ |
|
||||
| test_chat_non_streaming_image (case 0) | ✅ | ✅ |
|
||||
| test_chat_non_streaming_structured_output (case 0) | ✅ | ✅ |
|
||||
| test_chat_non_streaming_structured_output (case 1) | ✅ | ✅ |
|
||||
| test_chat_non_streaming_tool_calling (case 0) | ✅ | ✅ |
|
||||
| test_chat_streaming_basic (case 0) | ✅ | ✅ |
|
||||
| test_chat_streaming_basic (case 1) | ✅ | ✅ |
|
||||
| test_chat_streaming_image (case 0) | ✅ | ✅ |
|
||||
| test_chat_streaming_structured_output (case 0) | ✅ | ✅ |
|
||||
| test_chat_streaming_structured_output (case 1) | ✅ | ✅ |
|
5
tests/verifications/__init__.py
Normal file
5
tests/verifications/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
28
tests/verifications/conftest.py
Normal file
28
tests/verifications/conftest.py
Normal file
|
@ -0,0 +1,28 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
parser.addoption(
|
||||
"--base-url",
|
||||
action="store",
|
||||
help="Base URL for OpenAI compatible API",
|
||||
)
|
||||
parser.addoption(
|
||||
"--api-key",
|
||||
action="store",
|
||||
help="API key",
|
||||
)
|
||||
parser.addoption(
|
||||
"--provider",
|
||||
action="store",
|
||||
help="Provider to use for testing",
|
||||
)
|
||||
|
||||
|
||||
pytest_plugins = [
|
||||
"tests.verifications.openai.fixtures.fixtures",
|
||||
]
|
485
tests/verifications/generate_report.py
Executable file
485
tests/verifications/generate_report.py
Executable file
|
@ -0,0 +1,485 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Test Report Generator
|
||||
|
||||
Requirements:
|
||||
pip install pytest-json-report
|
||||
|
||||
Usage:
|
||||
# Generate a report using existing test results
|
||||
python tests/verifications/generate_report.py
|
||||
|
||||
# Run tests and generate a report
|
||||
python tests/verifications/generate_report.py --run-tests
|
||||
|
||||
# Run tests for specific providers
|
||||
python tests/verifications/generate_report.py --run-tests --providers fireworks openai
|
||||
|
||||
# Save the report to a custom location
|
||||
python tests/verifications/generate_report.py --output custom_report.md
|
||||
|
||||
# Clean up old test result files
|
||||
python tests/verifications/generate_report.py --cleanup
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
# Define the root directory for test results
|
||||
RESULTS_DIR = Path(__file__).parent / "test_results"
|
||||
RESULTS_DIR.mkdir(exist_ok=True)
|
||||
|
||||
# Maximum number of test result files to keep per provider
|
||||
MAX_RESULTS_PER_PROVIDER = 1
|
||||
|
||||
# Custom order of providers
|
||||
PROVIDER_ORDER = ["together", "fireworks", "groq", "cerebras", "openai"]
|
||||
|
||||
# Dictionary to store providers and their models (will be populated dynamically)
|
||||
PROVIDERS = defaultdict(set)
|
||||
|
||||
# Tests will be dynamically extracted from results
|
||||
ALL_TESTS = set()
|
||||
|
||||
|
||||
def run_tests(provider):
|
||||
"""Run pytest for a specific provider and save results"""
|
||||
print(f"Running tests for provider: {provider}")
|
||||
|
||||
timestamp = int(time.time())
|
||||
result_file = RESULTS_DIR / f"{provider}_{timestamp}.json"
|
||||
temp_json_file = RESULTS_DIR / f"temp_{provider}_{timestamp}.json"
|
||||
|
||||
# Run pytest with JSON output
|
||||
cmd = [
|
||||
"python",
|
||||
"-m",
|
||||
"pytest",
|
||||
"tests/verifications/openai/test_chat_completion.py",
|
||||
f"--provider={provider}",
|
||||
"-v",
|
||||
"--json-report",
|
||||
f"--json-report-file={temp_json_file}",
|
||||
]
|
||||
|
||||
try:
|
||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||
print(f"Pytest exit code: {result.returncode}")
|
||||
|
||||
# Check if the JSON file was created
|
||||
if temp_json_file.exists():
|
||||
# Read the JSON file and save it to our results format
|
||||
with open(temp_json_file, "r") as f:
|
||||
test_results = json.load(f)
|
||||
|
||||
# Save results to our own format with a trailing newline
|
||||
with open(result_file, "w") as f:
|
||||
json.dump(test_results, f, indent=2)
|
||||
f.write("\n") # Add a trailing newline for precommit
|
||||
|
||||
# Clean up temp file
|
||||
temp_json_file.unlink()
|
||||
|
||||
print(f"Test results saved to {result_file}")
|
||||
return result_file
|
||||
else:
|
||||
print(f"Error: JSON report file not created for {provider}")
|
||||
print(f"Command stdout: {result.stdout}")
|
||||
print(f"Command stderr: {result.stderr}")
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"Error running tests for {provider}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def parse_results(result_file):
|
||||
"""Parse the test results file and extract pass/fail by model and test"""
|
||||
if not os.path.exists(result_file):
|
||||
print(f"Results file does not exist: {result_file}")
|
||||
return {}
|
||||
|
||||
with open(result_file, "r") as f:
|
||||
results = json.load(f)
|
||||
|
||||
# Initialize results dictionary
|
||||
parsed_results = defaultdict(lambda: defaultdict(dict))
|
||||
provider = os.path.basename(result_file).split("_")[0]
|
||||
|
||||
# Debug: Print summary of test results
|
||||
print(f"Test results summary for {provider}:")
|
||||
print(f"Total tests: {results.get('summary', {}).get('total', 0)}")
|
||||
print(f"Passed: {results.get('summary', {}).get('passed', 0)}")
|
||||
print(f"Failed: {results.get('summary', {}).get('failed', 0)}")
|
||||
print(f"Error: {results.get('summary', {}).get('error', 0)}")
|
||||
print(f"Skipped: {results.get('summary', {}).get('skipped', 0)}")
|
||||
|
||||
# Extract test results
|
||||
if "tests" not in results or not results["tests"]:
|
||||
print(f"No test results found in {result_file}")
|
||||
return parsed_results
|
||||
|
||||
# Map for normalizing model names
|
||||
model_name_map = {
|
||||
"Llama-3.3-8B-Instruct": "Llama-3.3-8B-Instruct",
|
||||
"Llama-3.3-70B-Instruct": "Llama-3.3-70B-Instruct",
|
||||
"Llama-3.2-11B-Vision-Instruct": "Llama-3.2-11B-Vision-Instruct",
|
||||
"Llama-4-Scout-17B-16E": "Llama-4-Scout-17B-16E-Instruct",
|
||||
"Llama-4-Scout-17B-16E-Instruct": "Llama-4-Scout-17B-16E-Instruct",
|
||||
"Llama-4-Maverick-17B-128E": "Llama-4-Maverick-17B-128E-Instruct",
|
||||
"Llama-4-Maverick-17B-128E-Instruct": "Llama-4-Maverick-17B-128E-Instruct",
|
||||
"gpt-4o": "gpt-4o",
|
||||
"gpt-4o-mini": "gpt-4o-mini",
|
||||
}
|
||||
|
||||
# Keep track of all models found for this provider
|
||||
provider_models = set()
|
||||
|
||||
# Track all unique test cases for each base test
|
||||
test_case_counts = defaultdict(int)
|
||||
|
||||
# First pass: count the number of cases for each test
|
||||
for test in results["tests"]:
|
||||
test_id = test.get("nodeid", "")
|
||||
|
||||
if "call" in test:
|
||||
test_name = test_id.split("::")[1].split("[")[0]
|
||||
input_output_match = re.search(r"\[input_output(\d+)-", test_id)
|
||||
if input_output_match:
|
||||
test_case_counts[test_name] += 1
|
||||
|
||||
# Second pass: process the tests with case numbers only for tests with multiple cases
|
||||
for test in results["tests"]:
|
||||
test_id = test.get("nodeid", "")
|
||||
outcome = test.get("outcome", "")
|
||||
|
||||
# Only process tests that have been executed (not setup errors)
|
||||
if "call" in test:
|
||||
# Regular test that actually ran
|
||||
test_name = test_id.split("::")[1].split("[")[0]
|
||||
|
||||
# Extract input_output parameter to differentiate between test cases
|
||||
input_output_match = re.search(r"\[input_output(\d+)-", test_id)
|
||||
input_output_index = input_output_match.group(1) if input_output_match else ""
|
||||
|
||||
# Create a more detailed test name with case number only if there are multiple cases
|
||||
detailed_test_name = test_name
|
||||
if input_output_index and test_case_counts[test_name] > 1:
|
||||
detailed_test_name = f"{test_name} (case {input_output_index})"
|
||||
|
||||
# Track all unique test names
|
||||
ALL_TESTS.add(detailed_test_name)
|
||||
|
||||
# Extract model name from test_id using a more robust pattern
|
||||
model_match = re.search(r"\[input_output\d+-([^\]]+)\]", test_id)
|
||||
if model_match:
|
||||
raw_model = model_match.group(1)
|
||||
model = model_name_map.get(raw_model, raw_model)
|
||||
|
||||
# Add to set of known models for this provider
|
||||
provider_models.add(model)
|
||||
|
||||
# Also update the global PROVIDERS dictionary
|
||||
PROVIDERS[provider].add(model)
|
||||
|
||||
# Store the result
|
||||
if outcome == "passed":
|
||||
parsed_results[provider][model][detailed_test_name] = True
|
||||
else:
|
||||
parsed_results[provider][model][detailed_test_name] = False
|
||||
|
||||
print(f"Parsed test result: {detailed_test_name} for model {model}: {outcome}")
|
||||
elif outcome == "error" and "setup" in test and test.get("setup", {}).get("outcome") == "failed":
|
||||
# This is a setup failure, which likely means a configuration issue
|
||||
# Extract the base test name and model name
|
||||
parts = test_id.split("::")
|
||||
if len(parts) > 1:
|
||||
test_name = parts[1].split("[")[0]
|
||||
|
||||
# Extract input_output parameter to differentiate between test cases
|
||||
input_output_match = re.search(r"\[input_output(\d+)-", test_id)
|
||||
input_output_index = input_output_match.group(1) if input_output_match else ""
|
||||
|
||||
# Create a more detailed test name with case number only if there are multiple cases
|
||||
detailed_test_name = test_name
|
||||
if input_output_index and test_case_counts[test_name] > 1:
|
||||
detailed_test_name = f"{test_name} (case {input_output_index})"
|
||||
|
||||
if detailed_test_name in ALL_TESTS:
|
||||
# Use a more robust pattern for model extraction
|
||||
model_match = re.search(r"\[input_output\d+-([^\]]+)\]", test_id)
|
||||
if model_match:
|
||||
raw_model = model_match.group(1)
|
||||
model = model_name_map.get(raw_model, raw_model)
|
||||
|
||||
# Add to set of known models for this provider
|
||||
provider_models.add(model)
|
||||
|
||||
# Also update the global PROVIDERS dictionary
|
||||
PROVIDERS[provider].add(model)
|
||||
|
||||
# Mark setup failures as false (failed)
|
||||
parsed_results[provider][model][detailed_test_name] = False
|
||||
print(f"Parsed setup failure: {detailed_test_name} for model {model}")
|
||||
|
||||
# Debug: Print parsed results
|
||||
if not parsed_results[provider]:
|
||||
print(f"Warning: No test results parsed for provider {provider}")
|
||||
else:
|
||||
for model, tests in parsed_results[provider].items():
|
||||
print(f"Model {model}: {len(tests)} test results")
|
||||
|
||||
return parsed_results
|
||||
|
||||
|
||||
def cleanup_old_results():
|
||||
"""Clean up old test result files, keeping only the newest N per provider"""
|
||||
for provider in PROVIDERS.keys():
|
||||
# Get all result files for this provider
|
||||
provider_files = list(RESULTS_DIR.glob(f"{provider}_*.json"))
|
||||
|
||||
# Sort by timestamp (newest first)
|
||||
provider_files.sort(key=lambda x: int(x.stem.split("_")[1]), reverse=True)
|
||||
|
||||
# Remove old files beyond the max to keep
|
||||
if len(provider_files) > MAX_RESULTS_PER_PROVIDER:
|
||||
for old_file in provider_files[MAX_RESULTS_PER_PROVIDER:]:
|
||||
try:
|
||||
old_file.unlink()
|
||||
print(f"Removed old result file: {old_file}")
|
||||
except Exception as e:
|
||||
print(f"Error removing file {old_file}: {e}")
|
||||
|
||||
|
||||
def get_latest_results_by_provider():
|
||||
"""Get the latest test result file for each provider"""
|
||||
provider_results = {}
|
||||
|
||||
# Get all result files
|
||||
result_files = list(RESULTS_DIR.glob("*.json"))
|
||||
|
||||
# Extract all provider names from filenames
|
||||
all_providers = set()
|
||||
for file in result_files:
|
||||
# File format is provider_timestamp.json
|
||||
parts = file.stem.split("_")
|
||||
if len(parts) >= 2:
|
||||
all_providers.add(parts[0])
|
||||
|
||||
# Group by provider
|
||||
for provider in all_providers:
|
||||
provider_files = [f for f in result_files if f.name.startswith(f"{provider}_")]
|
||||
|
||||
# Sort by timestamp (newest first)
|
||||
provider_files.sort(key=lambda x: int(x.stem.split("_")[1]), reverse=True)
|
||||
|
||||
if provider_files:
|
||||
provider_results[provider] = provider_files[0]
|
||||
|
||||
return provider_results
|
||||
|
||||
|
||||
def generate_report(results_dict, output_file=None):
|
||||
"""Generate the markdown report"""
|
||||
if output_file is None:
|
||||
# Default to creating the report in the same directory as this script
|
||||
output_file = Path(__file__).parent / "REPORT.md"
|
||||
else:
|
||||
output_file = Path(output_file)
|
||||
|
||||
# Get the timestamp from result files
|
||||
provider_timestamps = {}
|
||||
provider_results = get_latest_results_by_provider()
|
||||
for provider, result_file in provider_results.items():
|
||||
# Extract timestamp from filename (format: provider_timestamp.json)
|
||||
try:
|
||||
timestamp_str = result_file.stem.split("_")[1]
|
||||
timestamp = int(timestamp_str)
|
||||
formatted_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(timestamp))
|
||||
provider_timestamps[provider] = formatted_time
|
||||
except (IndexError, ValueError):
|
||||
provider_timestamps[provider] = "Unknown"
|
||||
|
||||
# Convert provider model sets to sorted lists
|
||||
for provider in PROVIDERS:
|
||||
PROVIDERS[provider] = sorted(PROVIDERS[provider])
|
||||
|
||||
# Sort tests alphabetically
|
||||
sorted_tests = sorted(ALL_TESTS)
|
||||
|
||||
report = ["# Test Results Report\n"]
|
||||
report.append(f"*Generated on: {time.strftime('%Y-%m-%d %H:%M:%S')}*\n")
|
||||
report.append("*This report was generated by running `python tests/verifications/generate_report.py`*\n")
|
||||
|
||||
# Icons for pass/fail
|
||||
pass_icon = "✅"
|
||||
fail_icon = "❌"
|
||||
na_icon = "⚪"
|
||||
|
||||
# Add emoji legend
|
||||
report.append("## Legend\n")
|
||||
report.append(f"- {pass_icon} - Test passed")
|
||||
report.append(f"- {fail_icon} - Test failed")
|
||||
report.append(f"- {na_icon} - Test not applicable or not run for this model")
|
||||
report.append("\n")
|
||||
|
||||
# Add a summary section
|
||||
report.append("## Summary\n")
|
||||
|
||||
# Count total tests and passes
|
||||
total_tests = 0
|
||||
passed_tests = 0
|
||||
provider_totals = {}
|
||||
|
||||
# Prepare summary data
|
||||
for provider in PROVIDERS.keys():
|
||||
provider_passed = 0
|
||||
provider_total = 0
|
||||
|
||||
if provider in results_dict:
|
||||
provider_models = PROVIDERS[provider]
|
||||
for model in provider_models:
|
||||
if model in results_dict[provider]:
|
||||
model_results = results_dict[provider][model]
|
||||
for test in sorted_tests:
|
||||
if test in model_results:
|
||||
provider_total += 1
|
||||
total_tests += 1
|
||||
if model_results[test]:
|
||||
provider_passed += 1
|
||||
passed_tests += 1
|
||||
|
||||
provider_totals[provider] = (provider_passed, provider_total)
|
||||
|
||||
# Add summary table
|
||||
report.append("| Provider | Pass Rate | Tests Passed | Total Tests |")
|
||||
report.append("| --- | --- | --- | --- |")
|
||||
|
||||
# Use the custom order for summary table
|
||||
for provider in [p for p in PROVIDER_ORDER if p in PROVIDERS]:
|
||||
passed, total = provider_totals.get(provider, (0, 0))
|
||||
pass_rate = f"{(passed / total * 100):.1f}%" if total > 0 else "N/A"
|
||||
report.append(f"| {provider.capitalize()} | {pass_rate} | {passed} | {total} |")
|
||||
|
||||
# Add providers not in the custom order
|
||||
for provider in [p for p in PROVIDERS if p not in PROVIDER_ORDER]:
|
||||
passed, total = provider_totals.get(provider, (0, 0))
|
||||
pass_rate = f"{(passed / total * 100):.1f}%" if total > 0 else "N/A"
|
||||
report.append(f"| {provider.capitalize()} | {pass_rate} | {passed} | {total} |")
|
||||
|
||||
report.append("\n")
|
||||
|
||||
# Process each provider in the custom order, then any additional providers
|
||||
for provider in sorted(
|
||||
PROVIDERS.keys(), key=lambda p: (PROVIDER_ORDER.index(p) if p in PROVIDER_ORDER else float("inf"), p)
|
||||
):
|
||||
if not PROVIDERS[provider]:
|
||||
# Skip providers with no models
|
||||
continue
|
||||
|
||||
report.append(f"\n## {provider.capitalize()}\n")
|
||||
|
||||
# Add timestamp when test was run
|
||||
if provider in provider_timestamps:
|
||||
report.append(f"*Tests run on: {provider_timestamps[provider]}*\n")
|
||||
|
||||
# Add test command for reproducing results
|
||||
test_cmd = f"pytest tests/verifications/openai/test_chat_completion.py --provider={provider} -v"
|
||||
report.append(f"```bash\n{test_cmd}\n```\n")
|
||||
|
||||
# Get the relevant models for this provider
|
||||
provider_models = PROVIDERS[provider]
|
||||
|
||||
# Create table header with models as columns
|
||||
header = "| Test | " + " | ".join(provider_models) + " |"
|
||||
separator = "| --- | " + " | ".join(["---"] * len(provider_models)) + " |"
|
||||
|
||||
report.append(header)
|
||||
report.append(separator)
|
||||
|
||||
# Get results for this provider
|
||||
provider_results = results_dict.get(provider, {})
|
||||
|
||||
# Add rows for each test
|
||||
for test in sorted_tests:
|
||||
row = f"| {test} |"
|
||||
|
||||
# Add results for each model in this test
|
||||
for model in provider_models:
|
||||
if model in provider_results and test in provider_results[model]:
|
||||
result = pass_icon if provider_results[model][test] else fail_icon
|
||||
else:
|
||||
result = na_icon
|
||||
row += f" {result} |"
|
||||
|
||||
report.append(row)
|
||||
|
||||
# Write to file
|
||||
with open(output_file, "w") as f:
|
||||
f.write("\n".join(report))
|
||||
f.write("\n")
|
||||
|
||||
print(f"Report generated: {output_file}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Generate test report")
|
||||
parser.add_argument("--run-tests", action="store_true", help="Run tests before generating report")
|
||||
parser.add_argument(
|
||||
"--providers",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="Specify providers to test (comma-separated or space-separated, default: all)",
|
||||
)
|
||||
parser.add_argument("--output", type=str, help="Output file location (default: tests/verifications/REPORT.md)")
|
||||
args = parser.parse_args()
|
||||
|
||||
all_results = {}
|
||||
|
||||
if args.run_tests:
|
||||
# Get list of available providers from command line or use detected providers
|
||||
if args.providers:
|
||||
# Handle both comma-separated and space-separated lists
|
||||
test_providers = []
|
||||
for provider_arg in args.providers:
|
||||
# Split by comma if commas are present
|
||||
if "," in provider_arg:
|
||||
test_providers.extend(provider_arg.split(","))
|
||||
else:
|
||||
test_providers.append(provider_arg)
|
||||
else:
|
||||
# Default providers to test
|
||||
test_providers = PROVIDER_ORDER
|
||||
|
||||
for provider in test_providers:
|
||||
provider = provider.strip() # Remove any whitespace
|
||||
result_file = run_tests(provider)
|
||||
if result_file:
|
||||
provider_results = parse_results(result_file)
|
||||
all_results.update(provider_results)
|
||||
else:
|
||||
# Use existing results
|
||||
provider_result_files = get_latest_results_by_provider()
|
||||
|
||||
for result_file in provider_result_files.values():
|
||||
provider_results = parse_results(result_file)
|
||||
all_results.update(provider_results)
|
||||
|
||||
# Generate the report
|
||||
generate_report(all_results, args.output)
|
||||
|
||||
cleanup_old_results()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
5
tests/verifications/openai/__init__.py
Normal file
5
tests/verifications/openai/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
5
tests/verifications/openai/fixtures/__init__.py
Normal file
5
tests/verifications/openai/fixtures/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
97
tests/verifications/openai/fixtures/fixtures.py
Normal file
97
tests/verifications/openai/fixtures/fixtures.py
Normal file
|
@ -0,0 +1,97 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from openai import OpenAI
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def providers_model_mapping():
|
||||
"""
|
||||
Mapping from model names used in test cases to provider's model names.
|
||||
"""
|
||||
return {
|
||||
"fireworks": {
|
||||
"Llama-3.3-70B-Instruct": "accounts/fireworks/models/llama-v3p1-70b-instruct",
|
||||
"Llama-3.2-11B-Vision-Instruct": "accounts/fireworks/models/llama-v3p2-11b-vision-instruct",
|
||||
"Llama-4-Scout-17B-16E-Instruct": "accounts/fireworks/models/llama4-scout-instruct-basic",
|
||||
"Llama-4-Maverick-17B-128E-Instruct": "accounts/fireworks/models/llama4-maverick-instruct-basic",
|
||||
},
|
||||
"together": {
|
||||
"Llama-3.3-70B-Instruct": "meta-llama/Llama-3.3-70B-Instruct-Turbo",
|
||||
"Llama-3.2-11B-Vision-Instruct": "meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo",
|
||||
"Llama-4-Scout-17B-16E-Instruct": "meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
||||
"Llama-4-Maverick-17B-128E-Instruct": "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
|
||||
},
|
||||
"groq": {
|
||||
"Llama-3.3-70B-Instruct": "llama-3.3-70b-versatile",
|
||||
"Llama-3.2-11B-Vision-Instruct": "llama-3.2-11b-vision-preview",
|
||||
"Llama-4-Scout-17B-16E-Instruct": "llama-4-scout-17b-16e-instruct",
|
||||
"Llama-4-Maverick-17B-128E-Instruct": "llama-4-maverick-17b-128e-instruct",
|
||||
},
|
||||
"cerebras": {
|
||||
"Llama-3.3-70B-Instruct": "llama-3.3-70b",
|
||||
},
|
||||
"openai": {
|
||||
"gpt-4o": "gpt-4o",
|
||||
"gpt-4o-mini": "gpt-4o-mini",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def provider_metadata():
|
||||
return {
|
||||
"fireworks": ("https://api.fireworks.ai/inference/v1", "FIREWORKS_API_KEY"),
|
||||
"together": ("https://api.together.xyz/v1", "TOGETHER_API_KEY"),
|
||||
"groq": ("https://api.groq.com/openai/v1", "GROQ_API_KEY"),
|
||||
"cerebras": ("https://api.cerebras.ai/v1", "CEREBRAS_API_KEY"),
|
||||
"openai": ("https://api.openai.com/v1", "OPENAI_API_KEY"),
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def provider(request, provider_metadata):
|
||||
provider = request.config.getoption("--provider")
|
||||
base_url = request.config.getoption("--base-url")
|
||||
|
||||
if provider and base_url and provider_metadata[provider][0] != base_url:
|
||||
raise ValueError(f"Provider {provider} is not supported for base URL {base_url}")
|
||||
|
||||
if not provider:
|
||||
if not base_url:
|
||||
raise ValueError("Provider and base URL are not provided")
|
||||
for provider, metadata in provider_metadata.items():
|
||||
if metadata[0] == base_url:
|
||||
provider = provider
|
||||
break
|
||||
|
||||
return provider
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def base_url(request, provider, provider_metadata):
|
||||
return request.config.getoption("--base-url") or provider_metadata[provider][0]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def api_key(request, provider, provider_metadata):
|
||||
return request.config.getoption("--api-key") or os.getenv(provider_metadata[provider][1])
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model_mapping(provider, providers_model_mapping):
|
||||
return providers_model_mapping[provider]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def openai_client(base_url, api_key):
|
||||
return OpenAI(
|
||||
base_url=base_url,
|
||||
api_key=api_key,
|
||||
)
|
16
tests/verifications/openai/fixtures/load.py
Normal file
16
tests/verifications/openai/fixtures/load.py
Normal file
|
@ -0,0 +1,16 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
|
||||
|
||||
def load_test_cases(name: str):
|
||||
fixture_dir = Path(__file__).parent / "test_cases"
|
||||
yaml_path = fixture_dir / f"{name}.yaml"
|
||||
with open(yaml_path, "r") as f:
|
||||
return yaml.safe_load(f)
|
|
@ -0,0 +1,162 @@
|
|||
test_chat_basic:
|
||||
test_name: test_chat_basic
|
||||
test_params:
|
||||
input_output:
|
||||
- input:
|
||||
messages:
|
||||
- content: Which planet do humans live on?
|
||||
role: user
|
||||
output: Earth
|
||||
- input:
|
||||
messages:
|
||||
- content: Which planet has rings around it with a name starting with letter
|
||||
S?
|
||||
role: user
|
||||
output: Saturn
|
||||
model:
|
||||
- Llama-3.3-8B-Instruct
|
||||
- Llama-3.3-70B-Instruct
|
||||
- Llama-4-Scout-17B-16E
|
||||
- Llama-4-Scout-17B-16E-Instruct
|
||||
- Llama-4-Maverick-17B-128E
|
||||
- Llama-4-Maverick-17B-128E-Instruct
|
||||
- gpt-4o
|
||||
- gpt-4o-mini
|
||||
test_chat_image:
|
||||
test_name: test_chat_image
|
||||
test_params:
|
||||
input_output:
|
||||
- input:
|
||||
messages:
|
||||
- content:
|
||||
- text: What is in this image?
|
||||
type: text
|
||||
- image_url:
|
||||
url: https://upload.wikimedia.org/wikipedia/commons/f/f7/Llamas%2C_Vernagt-Stausee%2C_Italy.jpg
|
||||
type: image_url
|
||||
role: user
|
||||
output: llama
|
||||
model:
|
||||
- Llama-4-Scout-17B-16E
|
||||
- Llama-4-Scout-17B-16E-Instruct
|
||||
- Llama-4-Maverick-17B-128E
|
||||
- Llama-4-Maverick-17B-128E-Instruct
|
||||
- gpt-4o
|
||||
- gpt-4o-mini
|
||||
test_chat_structured_output:
|
||||
test_name: test_chat_structured_output
|
||||
test_params:
|
||||
input_output:
|
||||
- input:
|
||||
messages:
|
||||
- content: Extract the event information.
|
||||
role: system
|
||||
- content: Alice and Bob are going to a science fair on Friday.
|
||||
role: user
|
||||
response_format:
|
||||
json_schema:
|
||||
name: calendar_event
|
||||
schema:
|
||||
properties:
|
||||
date:
|
||||
title: Date
|
||||
type: string
|
||||
name:
|
||||
title: Name
|
||||
type: string
|
||||
participants:
|
||||
items:
|
||||
type: string
|
||||
title: Participants
|
||||
type: array
|
||||
required:
|
||||
- name
|
||||
- date
|
||||
- participants
|
||||
title: CalendarEvent
|
||||
type: object
|
||||
type: json_schema
|
||||
output: valid_calendar_event
|
||||
- input:
|
||||
messages:
|
||||
- content: You are a helpful math tutor. Guide the user through the solution
|
||||
step by step.
|
||||
role: system
|
||||
- content: how can I solve 8x + 7 = -23
|
||||
role: user
|
||||
response_format:
|
||||
json_schema:
|
||||
name: math_reasoning
|
||||
schema:
|
||||
$defs:
|
||||
Step:
|
||||
properties:
|
||||
explanation:
|
||||
title: Explanation
|
||||
type: string
|
||||
output:
|
||||
title: Output
|
||||
type: string
|
||||
required:
|
||||
- explanation
|
||||
- output
|
||||
title: Step
|
||||
type: object
|
||||
properties:
|
||||
final_answer:
|
||||
title: Final Answer
|
||||
type: string
|
||||
steps:
|
||||
items:
|
||||
$ref: '#/$defs/Step'
|
||||
title: Steps
|
||||
type: array
|
||||
required:
|
||||
- steps
|
||||
- final_answer
|
||||
title: MathReasoning
|
||||
type: object
|
||||
type: json_schema
|
||||
output: valid_math_reasoning
|
||||
model:
|
||||
- Llama-3.3-8B-Instruct
|
||||
- Llama-3.3-70B-Instruct
|
||||
- Llama-4-Scout-17B-16E
|
||||
- Llama-4-Scout-17B-16E-Instruct
|
||||
- Llama-4-Maverick-17B-128E
|
||||
- Llama-4-Maverick-17B-128E-Instruct
|
||||
- gpt-4o
|
||||
- gpt-4o-mini
|
||||
test_tool_calling:
|
||||
test_name: test_tool_calling
|
||||
test_params:
|
||||
input_output:
|
||||
- input:
|
||||
messages:
|
||||
- content: You are a helpful assistant that can use tools to get information.
|
||||
role: system
|
||||
- content: What's the weather like in San Francisco?
|
||||
role: user
|
||||
tools:
|
||||
- function:
|
||||
description: Get current temperature for a given location.
|
||||
name: get_weather
|
||||
parameters:
|
||||
additionalProperties: false
|
||||
properties:
|
||||
location:
|
||||
description: "City and country e.g. Bogot\xE1, Colombia"
|
||||
type: string
|
||||
required:
|
||||
- location
|
||||
type: object
|
||||
type: function
|
||||
output: get_weather_tool_call
|
||||
model:
|
||||
- Llama-3.3-70B-Instruct
|
||||
- Llama-4-Scout-17B-16E
|
||||
- Llama-4-Scout-17B-16E-Instruct
|
||||
- Llama-4-Maverick-17B-128E
|
||||
- Llama-4-Maverick-17B-128E-Instruct
|
||||
- gpt-4o
|
||||
- gpt-4o-mini
|
202
tests/verifications/openai/test_chat_completion.py
Normal file
202
tests/verifications/openai/test_chat_completion.py
Normal file
|
@ -0,0 +1,202 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from tests.verifications.openai.fixtures.load import load_test_cases
|
||||
|
||||
chat_completion_test_cases = load_test_cases("chat_completion")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def correct_model_name(model, provider, providers_model_mapping):
|
||||
"""Return the provider-specific model name based on the generic model name."""
|
||||
mapping = providers_model_mapping[provider]
|
||||
if model not in mapping:
|
||||
pytest.skip(f"Provider {provider} does not support model {model}")
|
||||
return mapping[model]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", chat_completion_test_cases["test_chat_basic"]["test_params"]["model"])
|
||||
@pytest.mark.parametrize(
|
||||
"input_output",
|
||||
chat_completion_test_cases["test_chat_basic"]["test_params"]["input_output"],
|
||||
)
|
||||
def test_chat_non_streaming_basic(openai_client, input_output, correct_model_name):
|
||||
response = openai_client.chat.completions.create(
|
||||
model=correct_model_name,
|
||||
messages=input_output["input"]["messages"],
|
||||
stream=False,
|
||||
)
|
||||
assert response.choices[0].message.role == "assistant"
|
||||
assert input_output["output"].lower() in response.choices[0].message.content.lower()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", chat_completion_test_cases["test_chat_basic"]["test_params"]["model"])
|
||||
@pytest.mark.parametrize(
|
||||
"input_output",
|
||||
chat_completion_test_cases["test_chat_basic"]["test_params"]["input_output"],
|
||||
)
|
||||
def test_chat_streaming_basic(openai_client, input_output, correct_model_name):
|
||||
response = openai_client.chat.completions.create(
|
||||
model=correct_model_name,
|
||||
messages=input_output["input"]["messages"],
|
||||
stream=True,
|
||||
)
|
||||
content = ""
|
||||
for chunk in response:
|
||||
content += chunk.choices[0].delta.content or ""
|
||||
|
||||
# TODO: add detailed type validation
|
||||
|
||||
assert input_output["output"].lower() in content.lower()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", chat_completion_test_cases["test_chat_image"]["test_params"]["model"])
|
||||
@pytest.mark.parametrize(
|
||||
"input_output",
|
||||
chat_completion_test_cases["test_chat_image"]["test_params"]["input_output"],
|
||||
)
|
||||
def test_chat_non_streaming_image(openai_client, input_output, correct_model_name):
|
||||
response = openai_client.chat.completions.create(
|
||||
model=correct_model_name,
|
||||
messages=input_output["input"]["messages"],
|
||||
stream=False,
|
||||
)
|
||||
assert response.choices[0].message.role == "assistant"
|
||||
assert input_output["output"].lower() in response.choices[0].message.content.lower()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", chat_completion_test_cases["test_chat_image"]["test_params"]["model"])
|
||||
@pytest.mark.parametrize(
|
||||
"input_output",
|
||||
chat_completion_test_cases["test_chat_image"]["test_params"]["input_output"],
|
||||
)
|
||||
def test_chat_streaming_image(openai_client, input_output, correct_model_name):
|
||||
response = openai_client.chat.completions.create(
|
||||
model=correct_model_name,
|
||||
messages=input_output["input"]["messages"],
|
||||
stream=True,
|
||||
)
|
||||
content = ""
|
||||
for chunk in response:
|
||||
content += chunk.choices[0].delta.content or ""
|
||||
|
||||
# TODO: add detailed type validation
|
||||
|
||||
assert input_output["output"].lower() in content.lower()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
chat_completion_test_cases["test_chat_structured_output"]["test_params"]["model"],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"input_output",
|
||||
chat_completion_test_cases["test_chat_structured_output"]["test_params"]["input_output"],
|
||||
)
|
||||
def test_chat_non_streaming_structured_output(openai_client, input_output, correct_model_name):
|
||||
response = openai_client.chat.completions.create(
|
||||
model=correct_model_name,
|
||||
messages=input_output["input"]["messages"],
|
||||
response_format=input_output["input"]["response_format"],
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert response.choices[0].message.role == "assistant"
|
||||
maybe_json_content = response.choices[0].message.content
|
||||
|
||||
validate_structured_output(maybe_json_content, input_output["output"])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
chat_completion_test_cases["test_chat_structured_output"]["test_params"]["model"],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"input_output",
|
||||
chat_completion_test_cases["test_chat_structured_output"]["test_params"]["input_output"],
|
||||
)
|
||||
def test_chat_streaming_structured_output(openai_client, input_output, correct_model_name):
|
||||
response = openai_client.chat.completions.create(
|
||||
model=correct_model_name,
|
||||
messages=input_output["input"]["messages"],
|
||||
response_format=input_output["input"]["response_format"],
|
||||
stream=True,
|
||||
)
|
||||
maybe_json_content = ""
|
||||
for chunk in response:
|
||||
maybe_json_content += chunk.choices[0].delta.content or ""
|
||||
validate_structured_output(maybe_json_content, input_output["output"])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
chat_completion_test_cases["test_tool_calling"]["test_params"]["model"],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"input_output",
|
||||
chat_completion_test_cases["test_tool_calling"]["test_params"]["input_output"],
|
||||
)
|
||||
def test_chat_non_streaming_tool_calling(openai_client, input_output, correct_model_name):
|
||||
response = openai_client.chat.completions.create(
|
||||
model=correct_model_name,
|
||||
messages=input_output["input"]["messages"],
|
||||
tools=input_output["input"]["tools"],
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert response.choices[0].message.role == "assistant"
|
||||
assert len(response.choices[0].message.tool_calls) > 0
|
||||
assert input_output["output"] == "get_weather_tool_call"
|
||||
assert response.choices[0].message.tool_calls[0].function.name == "get_weather"
|
||||
# TODO: add detailed type validation
|
||||
|
||||
|
||||
def get_structured_output(maybe_json_content: str, schema_name: str) -> Any | None:
|
||||
if schema_name == "valid_calendar_event":
|
||||
|
||||
class CalendarEvent(BaseModel):
|
||||
name: str
|
||||
date: str
|
||||
participants: list[str]
|
||||
|
||||
try:
|
||||
calendar_event = CalendarEvent.model_validate_json(maybe_json_content)
|
||||
return calendar_event
|
||||
except Exception:
|
||||
return None
|
||||
elif schema_name == "valid_math_reasoning":
|
||||
|
||||
class Step(BaseModel):
|
||||
explanation: str
|
||||
output: str
|
||||
|
||||
class MathReasoning(BaseModel):
|
||||
steps: list[Step]
|
||||
final_answer: str
|
||||
|
||||
try:
|
||||
math_reasoning = MathReasoning.model_validate_json(maybe_json_content)
|
||||
return math_reasoning
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def validate_structured_output(maybe_json_content: str, schema_name: str) -> None:
|
||||
structured_output = get_structured_output(maybe_json_content, schema_name)
|
||||
assert structured_output is not None
|
||||
if schema_name == "valid_calendar_event":
|
||||
assert structured_output.name is not None
|
||||
assert structured_output.date is not None
|
||||
assert len(structured_output.participants) == 2
|
||||
elif schema_name == "valid_math_reasoning":
|
||||
assert len(structured_output.final_answer) > 0
|
2744
tests/verifications/test_results/fireworks_1744154308.json
Normal file
2744
tests/verifications/test_results/fireworks_1744154308.json
Normal file
File diff suppressed because it is too large
Load diff
2672
tests/verifications/test_results/openai_1744154522.json
Normal file
2672
tests/verifications/test_results/openai_1744154522.json
Normal file
File diff suppressed because it is too large
Load diff
2830
tests/verifications/test_results/together_1744154399.json
Normal file
2830
tests/verifications/test_results/together_1744154399.json
Normal file
File diff suppressed because it is too large
Load diff
Loading…
Add table
Add a link
Reference in a new issue