forked from phoenix-oss/llama-stack-mirror
feat(verification): various improvements (#1921)
# What does this PR do? - provider and their models now live in config.yaml - better distinguish different cases within a test - add model key to surface provider's model_id - include example command to rerun single test case ## Test Plan <img width="1173" alt="image" src="https://github.com/user-attachments/assets/b414baf0-c768-451f-8c3b-c2905cf36fac" />
This commit is contained in:
parent
09a83b1ec1
commit
14146e4b3f
22 changed files with 4449 additions and 8810 deletions
5
tests/verifications/openai_api/__init__.py
Normal file
5
tests/verifications/openai_api/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
5
tests/verifications/openai_api/fixtures/__init__.py
Normal file
5
tests/verifications/openai_api/fixtures/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
105
tests/verifications/openai_api/fixtures/fixtures.py
Normal file
105
tests/verifications/openai_api/fixtures/fixtures.py
Normal file
|
@ -0,0 +1,105 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
from openai import OpenAI
|
||||
|
||||
|
||||
# --- Helper Function to Load Config ---
|
||||
def _load_all_verification_configs():
|
||||
"""Load and aggregate verification configs from the conf/ directory."""
|
||||
# Note: Path is relative to *this* file (fixtures.py)
|
||||
conf_dir = Path(__file__).parent.parent.parent / "conf"
|
||||
if not conf_dir.is_dir():
|
||||
# Use pytest.fail if called during test collection, otherwise raise error
|
||||
# For simplicity here, we'll raise an error, assuming direct calls
|
||||
# are less likely or can handle it.
|
||||
raise FileNotFoundError(f"Verification config directory not found at {conf_dir}")
|
||||
|
||||
all_provider_configs = {}
|
||||
yaml_files = list(conf_dir.glob("*.yaml"))
|
||||
if not yaml_files:
|
||||
raise FileNotFoundError(f"No YAML configuration files found in {conf_dir}")
|
||||
|
||||
for config_path in yaml_files:
|
||||
provider_name = config_path.stem
|
||||
try:
|
||||
with open(config_path, "r") as f:
|
||||
provider_config = yaml.safe_load(f)
|
||||
if provider_config:
|
||||
all_provider_configs[provider_name] = provider_config
|
||||
else:
|
||||
# Log warning if possible, or just skip empty files silently
|
||||
print(f"Warning: Config file {config_path} is empty or invalid.")
|
||||
except Exception as e:
|
||||
raise IOError(f"Error loading config file {config_path}: {e}") from e
|
||||
|
||||
return {"providers": all_provider_configs}
|
||||
|
||||
|
||||
# --- End Helper Function ---
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def verification_config():
|
||||
"""Pytest fixture to provide the loaded verification config."""
|
||||
try:
|
||||
return _load_all_verification_configs()
|
||||
except (FileNotFoundError, IOError) as e:
|
||||
pytest.fail(str(e)) # Fail test collection if config loading fails
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def provider(request, verification_config):
|
||||
provider = request.config.getoption("--provider")
|
||||
base_url = request.config.getoption("--base-url")
|
||||
|
||||
if provider and base_url and verification_config["providers"][provider]["base_url"] != base_url:
|
||||
raise ValueError(f"Provider {provider} is not supported for base URL {base_url}")
|
||||
|
||||
if not provider:
|
||||
if not base_url:
|
||||
raise ValueError("Provider and base URL are not provided")
|
||||
for provider, metadata in verification_config["providers"].items():
|
||||
if metadata["base_url"] == base_url:
|
||||
provider = provider
|
||||
break
|
||||
|
||||
return provider
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def base_url(request, provider, verification_config):
|
||||
return request.config.getoption("--base-url") or verification_config["providers"][provider]["base_url"]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def api_key(request, provider, verification_config):
|
||||
provider_conf = verification_config.get("providers", {}).get(provider, {})
|
||||
api_key_env_var = provider_conf.get("api_key_var")
|
||||
|
||||
key_from_option = request.config.getoption("--api-key")
|
||||
key_from_env = os.getenv(api_key_env_var) if api_key_env_var else None
|
||||
|
||||
final_key = key_from_option or key_from_env
|
||||
return final_key
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model_mapping(provider, providers_model_mapping):
|
||||
return providers_model_mapping[provider]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def openai_client(base_url, api_key):
|
||||
return OpenAI(
|
||||
base_url=base_url,
|
||||
api_key=api_key,
|
||||
)
|
16
tests/verifications/openai_api/fixtures/load.py
Normal file
16
tests/verifications/openai_api/fixtures/load.py
Normal file
|
@ -0,0 +1,16 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
|
||||
|
||||
def load_test_cases(name: str):
|
||||
fixture_dir = Path(__file__).parent / "test_cases"
|
||||
yaml_path = fixture_dir / f"{name}.yaml"
|
||||
with open(yaml_path, "r") as f:
|
||||
return yaml.safe_load(f)
|
|
@ -0,0 +1,133 @@
|
|||
test_chat_basic:
|
||||
test_name: test_chat_basic
|
||||
test_params:
|
||||
case:
|
||||
- case_id: "earth"
|
||||
input:
|
||||
messages:
|
||||
- content: Which planet do humans live on?
|
||||
role: user
|
||||
output: Earth
|
||||
- case_id: "saturn"
|
||||
input:
|
||||
messages:
|
||||
- content: Which planet has rings around it with a name starting with letter
|
||||
S?
|
||||
role: user
|
||||
output: Saturn
|
||||
test_chat_image:
|
||||
test_name: test_chat_image
|
||||
test_params:
|
||||
case:
|
||||
- input:
|
||||
messages:
|
||||
- content:
|
||||
- text: What is in this image?
|
||||
type: text
|
||||
- image_url:
|
||||
url: https://upload.wikimedia.org/wikipedia/commons/f/f7/Llamas%2C_Vernagt-Stausee%2C_Italy.jpg
|
||||
type: image_url
|
||||
role: user
|
||||
output: llama
|
||||
test_chat_structured_output:
|
||||
test_name: test_chat_structured_output
|
||||
test_params:
|
||||
case:
|
||||
- case_id: "calendar"
|
||||
input:
|
||||
messages:
|
||||
- content: Extract the event information.
|
||||
role: system
|
||||
- content: Alice and Bob are going to a science fair on Friday.
|
||||
role: user
|
||||
response_format:
|
||||
json_schema:
|
||||
name: calendar_event
|
||||
schema:
|
||||
properties:
|
||||
date:
|
||||
title: Date
|
||||
type: string
|
||||
name:
|
||||
title: Name
|
||||
type: string
|
||||
participants:
|
||||
items:
|
||||
type: string
|
||||
title: Participants
|
||||
type: array
|
||||
required:
|
||||
- name
|
||||
- date
|
||||
- participants
|
||||
title: CalendarEvent
|
||||
type: object
|
||||
type: json_schema
|
||||
output: valid_calendar_event
|
||||
- case_id: "math"
|
||||
input:
|
||||
messages:
|
||||
- content: You are a helpful math tutor. Guide the user through the solution
|
||||
step by step.
|
||||
role: system
|
||||
- content: how can I solve 8x + 7 = -23
|
||||
role: user
|
||||
response_format:
|
||||
json_schema:
|
||||
name: math_reasoning
|
||||
schema:
|
||||
$defs:
|
||||
Step:
|
||||
properties:
|
||||
explanation:
|
||||
title: Explanation
|
||||
type: string
|
||||
output:
|
||||
title: Output
|
||||
type: string
|
||||
required:
|
||||
- explanation
|
||||
- output
|
||||
title: Step
|
||||
type: object
|
||||
properties:
|
||||
final_answer:
|
||||
title: Final Answer
|
||||
type: string
|
||||
steps:
|
||||
items:
|
||||
$ref: '#/$defs/Step'
|
||||
title: Steps
|
||||
type: array
|
||||
required:
|
||||
- steps
|
||||
- final_answer
|
||||
title: MathReasoning
|
||||
type: object
|
||||
type: json_schema
|
||||
output: valid_math_reasoning
|
||||
test_tool_calling:
|
||||
test_name: test_tool_calling
|
||||
test_params:
|
||||
case:
|
||||
- input:
|
||||
messages:
|
||||
- content: You are a helpful assistant that can use tools to get information.
|
||||
role: system
|
||||
- content: What's the weather like in San Francisco?
|
||||
role: user
|
||||
tools:
|
||||
- function:
|
||||
description: Get current temperature for a given location.
|
||||
name: get_weather
|
||||
parameters:
|
||||
additionalProperties: false
|
||||
properties:
|
||||
location:
|
||||
description: "City and country e.g. Bogot\xE1, Colombia"
|
||||
type: string
|
||||
required:
|
||||
- location
|
||||
type: object
|
||||
type: function
|
||||
output: get_weather_tool_call
|
271
tests/verifications/openai_api/test_chat_completion.py
Normal file
271
tests/verifications/openai_api/test_chat_completion.py
Normal file
|
@ -0,0 +1,271 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from tests.verifications.openai_api.fixtures.fixtures import _load_all_verification_configs
|
||||
from tests.verifications.openai_api.fixtures.load import load_test_cases
|
||||
|
||||
chat_completion_test_cases = load_test_cases("chat_completion")
|
||||
|
||||
|
||||
def case_id_generator(case):
|
||||
"""Generate a test ID from the case's 'case_id' field, or use a default."""
|
||||
case_id = case.get("case_id")
|
||||
if isinstance(case_id, (str, int)):
|
||||
return re.sub(r"\\W|^(?=\\d)", "_", str(case_id))
|
||||
return None
|
||||
|
||||
|
||||
def pytest_generate_tests(metafunc):
|
||||
"""Dynamically parametrize tests based on the selected provider and config."""
|
||||
if "model" in metafunc.fixturenames:
|
||||
provider = metafunc.config.getoption("provider")
|
||||
if not provider:
|
||||
print("Warning: --provider not specified. Skipping model parametrization.")
|
||||
metafunc.parametrize("model", [])
|
||||
return
|
||||
|
||||
try:
|
||||
config_data = _load_all_verification_configs()
|
||||
except (FileNotFoundError, IOError) as e:
|
||||
print(f"ERROR loading verification configs: {e}")
|
||||
config_data = {"providers": {}}
|
||||
|
||||
provider_config = config_data.get("providers", {}).get(provider)
|
||||
if provider_config:
|
||||
models = provider_config.get("models", [])
|
||||
if models:
|
||||
metafunc.parametrize("model", models)
|
||||
else:
|
||||
print(f"Warning: No models found for provider '{provider}' in config.")
|
||||
metafunc.parametrize("model", []) # Parametrize empty if no models found
|
||||
else:
|
||||
print(f"Warning: Provider '{provider}' not found in config. No models parametrized.")
|
||||
metafunc.parametrize("model", []) # Parametrize empty if provider not found
|
||||
|
||||
|
||||
def should_skip_test(verification_config, provider, model, test_name_base):
|
||||
"""Check if a test should be skipped based on config exclusions."""
|
||||
provider_config = verification_config.get("providers", {}).get(provider)
|
||||
if not provider_config:
|
||||
return False # No config for provider, don't skip
|
||||
|
||||
exclusions = provider_config.get("test_exclusions", {}).get(model, [])
|
||||
return test_name_base in exclusions
|
||||
|
||||
|
||||
# Helper to get the base test name from the request object
|
||||
def get_base_test_name(request):
|
||||
return request.node.originalname
|
||||
|
||||
|
||||
# --- Test Functions ---
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"case",
|
||||
chat_completion_test_cases["test_chat_basic"]["test_params"]["case"],
|
||||
ids=case_id_generator,
|
||||
)
|
||||
def test_chat_non_streaming_basic(request, openai_client, model, provider, verification_config, case):
|
||||
test_name_base = get_base_test_name(request)
|
||||
if should_skip_test(verification_config, provider, model, test_name_base):
|
||||
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
|
||||
|
||||
response = openai_client.chat.completions.create(
|
||||
model=model,
|
||||
messages=case["input"]["messages"],
|
||||
stream=False,
|
||||
)
|
||||
assert response.choices[0].message.role == "assistant"
|
||||
assert case["output"].lower() in response.choices[0].message.content.lower()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"case",
|
||||
chat_completion_test_cases["test_chat_basic"]["test_params"]["case"],
|
||||
ids=case_id_generator,
|
||||
)
|
||||
def test_chat_streaming_basic(request, openai_client, model, provider, verification_config, case):
|
||||
test_name_base = get_base_test_name(request)
|
||||
if should_skip_test(verification_config, provider, model, test_name_base):
|
||||
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
|
||||
|
||||
response = openai_client.chat.completions.create(
|
||||
model=model,
|
||||
messages=case["input"]["messages"],
|
||||
stream=True,
|
||||
)
|
||||
content = ""
|
||||
for chunk in response:
|
||||
content += chunk.choices[0].delta.content or ""
|
||||
|
||||
# TODO: add detailed type validation
|
||||
|
||||
assert case["output"].lower() in content.lower()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"case",
|
||||
chat_completion_test_cases["test_chat_image"]["test_params"]["case"],
|
||||
ids=case_id_generator,
|
||||
)
|
||||
def test_chat_non_streaming_image(request, openai_client, model, provider, verification_config, case):
|
||||
test_name_base = get_base_test_name(request)
|
||||
if should_skip_test(verification_config, provider, model, test_name_base):
|
||||
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
|
||||
|
||||
response = openai_client.chat.completions.create(
|
||||
model=model,
|
||||
messages=case["input"]["messages"],
|
||||
stream=False,
|
||||
)
|
||||
assert response.choices[0].message.role == "assistant"
|
||||
assert case["output"].lower() in response.choices[0].message.content.lower()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"case",
|
||||
chat_completion_test_cases["test_chat_image"]["test_params"]["case"],
|
||||
ids=case_id_generator,
|
||||
)
|
||||
def test_chat_streaming_image(request, openai_client, model, provider, verification_config, case):
|
||||
test_name_base = get_base_test_name(request)
|
||||
if should_skip_test(verification_config, provider, model, test_name_base):
|
||||
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
|
||||
|
||||
response = openai_client.chat.completions.create(
|
||||
model=model,
|
||||
messages=case["input"]["messages"],
|
||||
stream=True,
|
||||
)
|
||||
content = ""
|
||||
for chunk in response:
|
||||
content += chunk.choices[0].delta.content or ""
|
||||
|
||||
# TODO: add detailed type validation
|
||||
|
||||
assert case["output"].lower() in content.lower()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"case",
|
||||
chat_completion_test_cases["test_chat_structured_output"]["test_params"]["case"],
|
||||
ids=case_id_generator,
|
||||
)
|
||||
def test_chat_non_streaming_structured_output(request, openai_client, model, provider, verification_config, case):
|
||||
test_name_base = get_base_test_name(request)
|
||||
if should_skip_test(verification_config, provider, model, test_name_base):
|
||||
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
|
||||
|
||||
response = openai_client.chat.completions.create(
|
||||
model=model,
|
||||
messages=case["input"]["messages"],
|
||||
response_format=case["input"]["response_format"],
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert response.choices[0].message.role == "assistant"
|
||||
maybe_json_content = response.choices[0].message.content
|
||||
|
||||
validate_structured_output(maybe_json_content, case["output"])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"case",
|
||||
chat_completion_test_cases["test_chat_structured_output"]["test_params"]["case"],
|
||||
ids=case_id_generator,
|
||||
)
|
||||
def test_chat_streaming_structured_output(request, openai_client, model, provider, verification_config, case):
|
||||
test_name_base = get_base_test_name(request)
|
||||
if should_skip_test(verification_config, provider, model, test_name_base):
|
||||
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
|
||||
|
||||
response = openai_client.chat.completions.create(
|
||||
model=model,
|
||||
messages=case["input"]["messages"],
|
||||
response_format=case["input"]["response_format"],
|
||||
stream=True,
|
||||
)
|
||||
maybe_json_content = ""
|
||||
for chunk in response:
|
||||
maybe_json_content += chunk.choices[0].delta.content or ""
|
||||
validate_structured_output(maybe_json_content, case["output"])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"case",
|
||||
chat_completion_test_cases["test_tool_calling"]["test_params"]["case"],
|
||||
ids=case_id_generator,
|
||||
)
|
||||
def test_chat_non_streaming_tool_calling(request, openai_client, model, provider, verification_config, case):
|
||||
test_name_base = get_base_test_name(request)
|
||||
if should_skip_test(verification_config, provider, model, test_name_base):
|
||||
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
|
||||
|
||||
response = openai_client.chat.completions.create(
|
||||
model=model,
|
||||
messages=case["input"]["messages"],
|
||||
tools=case["input"]["tools"],
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert response.choices[0].message.role == "assistant"
|
||||
assert len(response.choices[0].message.tool_calls) > 0
|
||||
assert case["output"] == "get_weather_tool_call"
|
||||
assert response.choices[0].message.tool_calls[0].function.name == "get_weather"
|
||||
# TODO: add detailed type validation
|
||||
|
||||
|
||||
# --- Helper functions (structured output validation) ---
|
||||
|
||||
|
||||
def get_structured_output(maybe_json_content: str, schema_name: str) -> Any | None:
|
||||
if schema_name == "valid_calendar_event":
|
||||
|
||||
class CalendarEvent(BaseModel):
|
||||
name: str
|
||||
date: str
|
||||
participants: list[str]
|
||||
|
||||
try:
|
||||
calendar_event = CalendarEvent.model_validate_json(maybe_json_content)
|
||||
return calendar_event
|
||||
except Exception:
|
||||
return None
|
||||
elif schema_name == "valid_math_reasoning":
|
||||
|
||||
class Step(BaseModel):
|
||||
explanation: str
|
||||
output: str
|
||||
|
||||
class MathReasoning(BaseModel):
|
||||
steps: list[Step]
|
||||
final_answer: str
|
||||
|
||||
try:
|
||||
math_reasoning = MathReasoning.model_validate_json(maybe_json_content)
|
||||
return math_reasoning
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def validate_structured_output(maybe_json_content: str, schema_name: str) -> None:
|
||||
structured_output = get_structured_output(maybe_json_content, schema_name)
|
||||
assert structured_output is not None
|
||||
if schema_name == "valid_calendar_event":
|
||||
assert structured_output.name is not None
|
||||
assert structured_output.date is not None
|
||||
assert len(structured_output.participants) == 2
|
||||
elif schema_name == "valid_math_reasoning":
|
||||
assert len(structured_output.final_answer) > 0
|
Loading…
Add table
Add a link
Reference in a new issue