llama-stack-mirror/llama_stack/providers/tests/report.py
2025-01-16 14:23:06 -08:00

200 lines
7.5 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections import defaultdict
from pathlib import Path
import pytest
from llama_models.datatypes import CoreModelId
from llama_models.sku_list import all_registered_models
from pytest_html.basereport import _process_outcome
INFERNECE_APIS = ["chat_completion"]
FUNCTIONALITIES = ["streaming", "structured_output", "tool_calling"]
SUPPORTED_MODELS = {
"ollama": set(
[
CoreModelId.llama3_1_8b_instruct.value,
CoreModelId.llama3_1_8b_instruct.value,
CoreModelId.llama3_1_70b_instruct.value,
CoreModelId.llama3_1_70b_instruct.value,
CoreModelId.llama3_1_405b_instruct.value,
CoreModelId.llama3_1_405b_instruct.value,
CoreModelId.llama3_2_1b_instruct.value,
CoreModelId.llama3_2_1b_instruct.value,
CoreModelId.llama3_2_3b_instruct.value,
CoreModelId.llama3_2_3b_instruct.value,
CoreModelId.llama3_2_11b_vision_instruct.value,
CoreModelId.llama3_2_11b_vision_instruct.value,
CoreModelId.llama3_2_90b_vision_instruct.value,
CoreModelId.llama3_2_90b_vision_instruct.value,
CoreModelId.llama3_3_70b_instruct.value,
CoreModelId.llama_guard_3_8b.value,
CoreModelId.llama_guard_3_1b.value,
]
),
"fireworks": set(
[
CoreModelId.llama3_1_8b_instruct.value,
CoreModelId.llama3_1_70b_instruct.value,
CoreModelId.llama3_1_405b_instruct.value,
CoreModelId.llama3_2_1b_instruct.value,
CoreModelId.llama3_2_3b_instruct.value,
CoreModelId.llama3_2_11b_vision_instruct.value,
CoreModelId.llama3_2_90b_vision_instruct.value,
CoreModelId.llama3_3_70b_instruct.value,
CoreModelId.llama_guard_3_8b.value,
CoreModelId.llama_guard_3_11b_vision.value,
]
),
"together": set(
[
CoreModelId.llama3_1_8b_instruct.value,
CoreModelId.llama3_1_70b_instruct.value,
CoreModelId.llama3_1_405b_instruct.value,
CoreModelId.llama3_2_3b_instruct.value,
CoreModelId.llama3_2_11b_vision_instruct.value,
CoreModelId.llama3_2_90b_vision_instruct.value,
CoreModelId.llama3_3_70b_instruct.value,
CoreModelId.llama_guard_3_8b.value,
CoreModelId.llama_guard_3_11b_vision.value,
]
),
}
class Report:
def __init__(self, _config):
self.test_data = defaultdict(dict)
self.inference_tests = defaultdict(dict)
@pytest.hookimpl(tryfirst=True)
def pytest_runtest_logreport(self, report):
# This hook is called in several phases, including setup, call and teardown
# The test is considered failed / error if any of the outcomes is not "Passed"
outcome = _process_outcome(report)
data = {
"outcome": report.outcome,
"longrepr": report.longrepr,
"name": report.nodeid,
}
if report.nodeid not in self.test_data:
self.test_data[report.nodeid] = data
elif self.test_data[report.nodeid] != outcome and outcome != "Passed":
self.test_data[report.nodeid] = data
@pytest.hookimpl
def pytest_sessionfinish(self, session):
report = []
report.append("# Llama Stack Integration Test Results Report")
report.append("\n## Summary")
report.append("\n## Supported Models: ")
header = "| Model Descriptor |"
dividor = "|:---|"
for k in SUPPORTED_MODELS.keys():
header += f"{k} |"
dividor += ":---:|"
report.append(header)
report.append(dividor)
rows = []
for model in all_registered_models():
row = f"| {model.core_model_id.value} |"
for k in SUPPORTED_MODELS.keys():
if model.core_model_id.value in SUPPORTED_MODELS[k]:
row += " ✅ |"
else:
row += " ❌ |"
rows.append(row)
report.extend(rows)
report.append("\n### Tests:")
for provider in SUPPORTED_MODELS.items():
if provider not in self.inference_tests:
continue
test_table = [
"| Area | Model | API / Functionality | Test name | Test Result |",
"|:-----|:-----|:-----|:-----|:-----|",
]
for api in INFERNECE_APIS:
tests = self.inference_tests[provider][api]
# report.append("\n - **APIs:**")
# for api in INFERNECE_APIS:
# test_nodeids = self.inference_tests[provider][api]
# report.append(f"\n - /{api}:")
# report.extend(self._generate_test_result_short(test_nodeids))
# report.append("\n - **Functionality:**")
# for functionality in FUNCTIONALITIES:
# test_nodeids = self.inference_tests[provider][functionality]
# report.append(f"\n - {functionality}:")
# report.extend(self._generate_test_result_short(test_nodeids))
output_file = Path("pytest_report.md")
output_file.write_text("\n".join(report))
print(f"\n Report generated: {output_file.absolute()}")
@pytest.hookimpl(trylast=True)
def pytest_collection_modifyitems(self, session, config, items):
for item in items:
inference = item.callspec.params.get("inference_stack")
if "inference" in item.nodeid:
api, functionality = self._process_function_name(item.nodeid)
api_tests = self.inference_tests[inference].get(api, set())
# functionality_tests = self.inference_tests[inference].get(
# functionality, set()
# )
api_tests.add(item.nodeid)
# functionality_tests.add(item.nodeid)
self.inference_tests[inference][api] = api_tests
# self.inference_tests[inference][functionality] = functionality_tests
def _process_function_name(self, function_name):
api, functionality = None, None
for val in INFERNECE_APIS:
if val in function_name:
api = val
for val in FUNCTIONALITIES:
if val in function_name:
functionality = val
return api, functionality
def _print_result_icon(self, result):
if result == "Passed":
return "✅"
else:
# result == "Failed" or result == "Error":
return "❌"
def get_function_name(self, nodeid):
"""Extract function name from nodeid.
Examples:
- 'tests/test_math.py::test_addition' -> 'test_addition'
- 'tests/test_math.py::TestClass::test_method' -> 'TestClass.test_method'
"""
parts = nodeid.split("::")
if len(parts) == 2: # Simple function
return parts[1]
elif len(parts) == 3: # Class method
return f"{parts[1]}.{parts[2]}"
return nodeid # Fallback to full nodeid if pattern doesn't match
def _generate_test_result_short(self, test_nodeids):
report = []
for nodeid in test_nodeids:
name = self.get_function_name(self.test_data[nodeid]["name"])
result = self.test_data[nodeid]["outcome"]
report.append(f" - {name}. Result: {result}")
return report