forked from phoenix-oss/llama-stack-mirror
# What does this PR do? Generate a test report in MD that contains two main infos: 1) custom report on inference provider -> API / functionalities 2) [TO BE ADDED] test log for easy debugging ## Test Plan For local testing, run test script in command line. See a test report being generated at tests/report.html `pytest /Users/sxyi/llama-stack/llama_stack/providers/tests/. --config=ci_test_config.yaml` See [gist](https://gist.github.com/sixianyi0721/a421fd3bc450b74354a1c2c7da483fa5) for output MD file ## Sources Please link relevant resources if necessary. ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [x] Ran pre-commit to handle lint / formatting issues. - [x] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [ ] Updated relevant documentation. - [ ] Wrote necessary unit or integration tests.
200 lines
7.5 KiB
Python
200 lines
7.5 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
|
|
from collections import defaultdict
|
|
from pathlib import Path
|
|
|
|
import pytest
|
|
from llama_models.datatypes import CoreModelId
|
|
from llama_models.sku_list import all_registered_models
|
|
from pytest import ExitCode
|
|
|
|
from pytest_html.basereport import _process_outcome
|
|
|
|
|
|
INFERENCE_APIS = ["chat_completion"]
|
|
FUNCTIONALITIES = ["streaming", "structured_output", "tool_calling"]
|
|
SUPPORTED_MODELS = {
|
|
"ollama": set(
|
|
[
|
|
CoreModelId.llama3_1_8b_instruct.value,
|
|
CoreModelId.llama3_1_8b_instruct.value,
|
|
CoreModelId.llama3_1_70b_instruct.value,
|
|
CoreModelId.llama3_1_70b_instruct.value,
|
|
CoreModelId.llama3_1_405b_instruct.value,
|
|
CoreModelId.llama3_1_405b_instruct.value,
|
|
CoreModelId.llama3_2_1b_instruct.value,
|
|
CoreModelId.llama3_2_1b_instruct.value,
|
|
CoreModelId.llama3_2_3b_instruct.value,
|
|
CoreModelId.llama3_2_3b_instruct.value,
|
|
CoreModelId.llama3_2_11b_vision_instruct.value,
|
|
CoreModelId.llama3_2_11b_vision_instruct.value,
|
|
CoreModelId.llama3_2_90b_vision_instruct.value,
|
|
CoreModelId.llama3_2_90b_vision_instruct.value,
|
|
CoreModelId.llama3_3_70b_instruct.value,
|
|
CoreModelId.llama_guard_3_8b.value,
|
|
CoreModelId.llama_guard_3_1b.value,
|
|
]
|
|
),
|
|
"fireworks": set(
|
|
[
|
|
CoreModelId.llama3_1_8b_instruct.value,
|
|
CoreModelId.llama3_1_70b_instruct.value,
|
|
CoreModelId.llama3_1_405b_instruct.value,
|
|
CoreModelId.llama3_2_1b_instruct.value,
|
|
CoreModelId.llama3_2_3b_instruct.value,
|
|
CoreModelId.llama3_2_11b_vision_instruct.value,
|
|
CoreModelId.llama3_2_90b_vision_instruct.value,
|
|
CoreModelId.llama3_3_70b_instruct.value,
|
|
CoreModelId.llama_guard_3_8b.value,
|
|
CoreModelId.llama_guard_3_11b_vision.value,
|
|
]
|
|
),
|
|
"together": set(
|
|
[
|
|
CoreModelId.llama3_1_8b_instruct.value,
|
|
CoreModelId.llama3_1_70b_instruct.value,
|
|
CoreModelId.llama3_1_405b_instruct.value,
|
|
CoreModelId.llama3_2_3b_instruct.value,
|
|
CoreModelId.llama3_2_11b_vision_instruct.value,
|
|
CoreModelId.llama3_2_90b_vision_instruct.value,
|
|
CoreModelId.llama3_3_70b_instruct.value,
|
|
CoreModelId.llama_guard_3_8b.value,
|
|
CoreModelId.llama_guard_3_11b_vision.value,
|
|
]
|
|
),
|
|
}
|
|
|
|
|
|
class Report:
|
|
|
|
def __init__(self, output_path):
|
|
|
|
valid_file_format = (
|
|
output_path.split(".")[1] in ["md", "markdown"]
|
|
if len(output_path.split(".")) == 2
|
|
else False
|
|
)
|
|
if not valid_file_format:
|
|
raise ValueError(
|
|
f"Invalid output file {output_path}. Markdown file is required"
|
|
)
|
|
self.output_path = output_path
|
|
self.test_data = defaultdict(dict)
|
|
self.inference_tests = defaultdict(dict)
|
|
|
|
@pytest.hookimpl
|
|
def pytest_runtest_logreport(self, report):
|
|
# This hook is called in several phases, including setup, call and teardown
|
|
# The test is considered failed / error if any of the outcomes is not "Passed"
|
|
outcome = _process_outcome(report)
|
|
data = {
|
|
"outcome": report.outcome,
|
|
"longrepr": report.longrepr,
|
|
"name": report.nodeid,
|
|
}
|
|
if report.nodeid not in self.test_data:
|
|
self.test_data[report.nodeid] = data
|
|
elif self.test_data[report.nodeid] != outcome and outcome != "Passed":
|
|
self.test_data[report.nodeid] = data
|
|
|
|
@pytest.hookimpl
|
|
def pytest_sessionfinish(self, session, exitstatus):
|
|
if exitstatus <= ExitCode.INTERRUPTED:
|
|
return
|
|
report = []
|
|
report.append("# Llama Stack Integration Test Results Report")
|
|
report.append("\n## Summary")
|
|
report.append("\n## Supported Models: ")
|
|
|
|
header = "| Model Descriptor |"
|
|
dividor = "|:---|"
|
|
for k in SUPPORTED_MODELS.keys():
|
|
header += f"{k} |"
|
|
dividor += ":---:|"
|
|
|
|
report.append(header)
|
|
report.append(dividor)
|
|
|
|
rows = []
|
|
for model in all_registered_models():
|
|
if (
|
|
"Instruct" not in model.core_model_id.value
|
|
and "Guard" not in model.core_model_id.value
|
|
):
|
|
continue
|
|
row = f"| {model.core_model_id.value} |"
|
|
for k in SUPPORTED_MODELS.keys():
|
|
if model.core_model_id.value in SUPPORTED_MODELS[k]:
|
|
row += " ✅ |"
|
|
else:
|
|
row += " ❌ |"
|
|
rows.append(row)
|
|
report.extend(rows)
|
|
|
|
report.append("\n### Tests:")
|
|
|
|
for provider in SUPPORTED_MODELS.keys():
|
|
if provider not in self.inference_tests:
|
|
continue
|
|
report.append(f"\n #### {provider}")
|
|
test_table = [
|
|
"| Area | Model | API | Functionality Test | Status |",
|
|
"|:-----|:-----|:-----|:-----|:-----|",
|
|
]
|
|
for api in INFERENCE_APIS:
|
|
tests = self.inference_tests[provider][api]
|
|
for test_nodeid in tests:
|
|
row = "|{area} | {model} | {api} | {test} | {result} ".format(
|
|
area="Text" if "text" in test_nodeid else "Vision",
|
|
model=(
|
|
"Llama-3.1-8B-Instruct"
|
|
if "text" in test_nodeid
|
|
else "Llama3.2-11B-Vision-Instruct"
|
|
),
|
|
api=f"/{api}",
|
|
test=self.get_simple_function_name(test_nodeid),
|
|
result=(
|
|
"✅"
|
|
if self.test_data[test_nodeid]["outcome"] == "passed"
|
|
else "❌"
|
|
),
|
|
)
|
|
test_table += [row]
|
|
report.extend(test_table)
|
|
report.append("\n")
|
|
|
|
output_file = Path(self.output_path)
|
|
output_file.write_text("\n".join(report))
|
|
print(f"\n Report generated: {output_file.absolute()}")
|
|
|
|
@pytest.hookimpl(trylast=True)
|
|
def pytest_collection_modifyitems(self, session, config, items):
|
|
for item in items:
|
|
inference = item.callspec.params.get("inference_stack")
|
|
if "inference" in item.nodeid:
|
|
func_name = getattr(item, "originalname", item.name)
|
|
for api in INFERENCE_APIS:
|
|
if api in func_name:
|
|
api_tests = self.inference_tests[inference].get(api, set())
|
|
api_tests.add(item.nodeid)
|
|
self.inference_tests[inference][api] = api_tests
|
|
|
|
def get_simple_function_name(self, nodeid):
|
|
"""Extract function name from nodeid.
|
|
|
|
Examples:
|
|
- 'tests/test_math.py::test_addition' -> 'test_addition'
|
|
- 'tests/test_math.py::TestClass::test_method' -> test_method'
|
|
"""
|
|
parts = nodeid.split("::")
|
|
func_name = nodeid # Fallback to full nodeid if pattern doesn't match
|
|
if len(parts) == 2: # Simple function
|
|
func_name = parts[1]
|
|
elif len(parts) == 3: # Class method
|
|
func_name = parts[2]
|
|
return func_name.split("[")[0]
|