add pytest option to generate a functional report for distribution (#833)

# What does this PR do?

add pytest option (`--report`) to support generating a functional report
for llama stack distribution

## Test Plan
```
export LLAMA_STACK_CONFIG=./llama_stack/templates/fireworks/run.yaml
/opt/miniconda3/envs/stack/bin/pytest -s -v tests/client-sdk/  --report
```

See a report file was generated under
`./llama_stack/templates/fireworks/report.md`


## Sources

Please link relevant resources if necessary.


## Before submitting

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Ran pre-commit to handle lint / formatting issues.
- [ ] Read the [contributor
guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md),
      Pull Request section?
- [ ] Updated relevant documentation.
- [ ] Wrote necessary unit or integration tests.
This commit is contained in:
Sixian Yi 2025-01-21 21:18:23 -08:00 committed by GitHub
parent e41873f268
commit edf56884a7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 324 additions and 6 deletions

View file

@ -0,0 +1,45 @@
# Report for fireworks distribution
## Supported Models:
| Model Descriptor | fireworks |
|:---|:---|
| Llama-3-8B-Instruct | ❌ |
| Llama-3-70B-Instruct | ❌ |
| Llama3.1-8B-Instruct | ✅ |
| Llama3.1-70B-Instruct | ✅ |
| Llama3.1-405B-Instruct | ✅ |
| Llama3.2-1B-Instruct | ✅ |
| Llama3.2-3B-Instruct | ✅ |
| Llama3.2-11B-Vision-Instruct | ✅ |
| Llama3.2-90B-Vision-Instruct | ✅ |
| Llama3.3-70B-Instruct | ✅ |
| Llama-Guard-3-11B-Vision | ✅ |
| Llama-Guard-3-1B | ❌ |
| Llama-Guard-3-8B | ✅ |
| Llama-Guard-2-8B | ❌ |
## Inference:
| Model | API | Capability | Test | Status |
|:----- |:-----|:-----|:-----|:-----|
| Text | /chat_completion | streaming | test_text_chat_completion_streaming | ✅ |
| Vision | /chat_completion | streaming | test_image_chat_completion_streaming | Passed |
| Text | /chat_completion | non_streaming | test_text_chat_completion_non_streaming | ✅ |
| Vision | /chat_completion | non_streaming | test_image_chat_completion_non_streaming | Passed |
| Text | /chat_completion | tool_calling | test_text_chat_completion_with_tool_calling_and_streaming | ✅ |
| Text | /chat_completion | tool_calling | test_text_chat_completion_with_tool_calling_and_non_streaming | ✅ |
| Text | /completion | streaming | test_text_completion_streaming | ✅ |
| Text | /completion | non_streaming | test_text_completion_non_streaming | ✅ |
| Text | /completion | structured_output | test_text_completion_structured_output | ✅ |
## Memory:
| API | Capability | Test | Status |
|:-----|:-----|:-----|:-----|
| /insert, /query | inline | test_memory_bank_insert_inline_and_query | ❌ |
| /insert, /query | url | test_memory_bank_insert_from_url_and_query | ❌ |
## Agents:
| API | Capability | Test | Status |
|:-----|:-----|:-----|:-----|
| create_agent_turn | rag | test_rag_agent | ❌ |
| create_agent_turn | custom_tool | test_custom_tool | ✅ |
| create_agent_turn | code_execution | test_code_execution | ❌ |

View file

@ -80,7 +80,7 @@ class TestClientTool(ClientTool):
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def model_id(llama_stack_client): def text_model_id(llama_stack_client):
available_models = [ available_models = [
model.identifier model.identifier
for model in llama_stack_client.models.list() for model in llama_stack_client.models.list()
@ -92,14 +92,14 @@ def model_id(llama_stack_client):
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def agent_config(llama_stack_client, model_id): def agent_config(llama_stack_client, text_model_id):
available_shields = [ available_shields = [
shield.identifier for shield in llama_stack_client.shields.list() shield.identifier for shield in llama_stack_client.shields.list()
] ]
available_shields = available_shields[:1] available_shields = available_shields[:1]
print(f"Using shield: {available_shields}") print(f"Using shield: {available_shields}")
agent_config = AgentConfig( agent_config = AgentConfig(
model=model_id, model=text_model_id,
instructions="You are a helpful assistant", instructions="You are a helpful assistant",
sampling_params={ sampling_params={
"strategy": { "strategy": {

View file

@ -10,11 +10,27 @@ import pytest
from llama_stack import LlamaStackAsLibraryClient from llama_stack import LlamaStackAsLibraryClient
from llama_stack.providers.tests.env import get_env_or_fail from llama_stack.providers.tests.env import get_env_or_fail
from llama_stack_client import LlamaStackClient from llama_stack_client import LlamaStackClient
from report import Report
def pytest_configure(config): def pytest_configure(config):
config.option.tbstyle = "short" config.option.tbstyle = "short"
config.option.disable_warnings = True config.option.disable_warnings = True
if config.getoption("--report"):
config.pluginmanager.register(Report())
def pytest_addoption(parser):
parser.addoption(
"--report",
default=False,
action="store_true",
help="Knob to determine if we should generate report, e.g. --output=True",
)
TEXT_MODEL = "meta-llama/Llama-3.1-8B-Instruct"
INFERENCE_MODEL = "meta-llama/Llama-3.2-11B-Vision-Instruct"
@pytest.fixture(scope="session") @pytest.fixture(scope="session")

View file

@ -82,7 +82,7 @@ def base64_image_url():
return base64_url return base64_url
def test_completion_non_streaming(llama_stack_client, text_model_id): def test_text_completion_non_streaming(llama_stack_client, text_model_id):
response = llama_stack_client.inference.completion( response = llama_stack_client.inference.completion(
content="Complete the sentence using one word: Roses are red, violets are ", content="Complete the sentence using one word: Roses are red, violets are ",
stream=False, stream=False,
@ -94,7 +94,7 @@ def test_completion_non_streaming(llama_stack_client, text_model_id):
assert "blue" in response.content.lower().strip() assert "blue" in response.content.lower().strip()
def test_completion_streaming(llama_stack_client, text_model_id): def test_text_completion_streaming(llama_stack_client, text_model_id):
response = llama_stack_client.inference.completion( response = llama_stack_client.inference.completion(
content="Complete the sentence using one word: Roses are red, violets are ", content="Complete the sentence using one word: Roses are red, violets are ",
stream=True, stream=True,
@ -147,7 +147,7 @@ def test_completion_log_probs_streaming(llama_stack_client, text_model_id):
assert not chunk.logprobs, "Logprobs should be empty" assert not chunk.logprobs, "Logprobs should be empty"
def test_completion_structured_output( def test_text_completion_structured_output(
llama_stack_client, text_model_id, inference_provider_type llama_stack_client, text_model_id, inference_provider_type
): ):
user_input = """ user_input = """

View file

@ -0,0 +1,50 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
INFERENCE_API_CAPA_TEST_MAP = {
"chat_completion": {
"streaming": [
"test_text_chat_completion_streaming",
"test_image_chat_completion_streaming",
],
"non_streaming": [
"test_image_chat_completion_non_streaming",
"test_text_chat_completion_non_streaming",
],
"tool_calling": [
"test_text_chat_completion_with_tool_calling_and_streaming",
"test_text_chat_completion_with_tool_calling_and_non_streaming",
],
},
"completion": {
"streaming": ["test_text_completion_streaming"],
"non_streaming": ["test_text_completion_non_streaming"],
"structured_output": ["test_text_completion_structured_output"],
},
}
MEMORY_API_TEST_MAP = {
"/insert, /query": {
"inline": ["test_memory_bank_insert_inline_and_query"],
"url": ["test_memory_bank_insert_from_url_and_query"],
}
}
AGENTS_API_TEST_MAP = {
"create_agent_turn": {
"rag": ["test_rag_agent"],
"custom_tool": ["test_custom_tool"],
"code_execution": ["test_code_execution"],
}
}
API_MAPS = {
"inference": INFERENCE_API_CAPA_TEST_MAP,
"memory": MEMORY_API_TEST_MAP,
"agents": AGENTS_API_TEST_MAP,
}

207
tests/client-sdk/report.py Normal file
View file

@ -0,0 +1,207 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
from collections import defaultdict
from pathlib import Path
import pytest
from llama_models.datatypes import CoreModelId
from llama_models.sku_list import all_registered_models
from llama_stack.distribution.library_client import LlamaStackAsLibraryClient
from metadata import API_MAPS
from pytest import CollectReport
SUPPORTED_MODELS = {
"ollama": set(
[
CoreModelId.llama3_1_8b_instruct.value,
CoreModelId.llama3_1_8b_instruct.value,
CoreModelId.llama3_1_70b_instruct.value,
CoreModelId.llama3_1_70b_instruct.value,
CoreModelId.llama3_1_405b_instruct.value,
CoreModelId.llama3_1_405b_instruct.value,
CoreModelId.llama3_2_1b_instruct.value,
CoreModelId.llama3_2_1b_instruct.value,
CoreModelId.llama3_2_3b_instruct.value,
CoreModelId.llama3_2_3b_instruct.value,
CoreModelId.llama3_2_11b_vision_instruct.value,
CoreModelId.llama3_2_11b_vision_instruct.value,
CoreModelId.llama3_2_90b_vision_instruct.value,
CoreModelId.llama3_2_90b_vision_instruct.value,
CoreModelId.llama3_3_70b_instruct.value,
CoreModelId.llama_guard_3_8b.value,
CoreModelId.llama_guard_3_1b.value,
]
),
"fireworks": set(
[
CoreModelId.llama3_1_8b_instruct.value,
CoreModelId.llama3_1_70b_instruct.value,
CoreModelId.llama3_1_405b_instruct.value,
CoreModelId.llama3_2_1b_instruct.value,
CoreModelId.llama3_2_3b_instruct.value,
CoreModelId.llama3_2_11b_vision_instruct.value,
CoreModelId.llama3_2_90b_vision_instruct.value,
CoreModelId.llama3_3_70b_instruct.value,
CoreModelId.llama_guard_3_8b.value,
CoreModelId.llama_guard_3_11b_vision.value,
]
),
"together": set(
[
CoreModelId.llama3_1_8b_instruct.value,
CoreModelId.llama3_1_70b_instruct.value,
CoreModelId.llama3_1_405b_instruct.value,
CoreModelId.llama3_2_3b_instruct.value,
CoreModelId.llama3_2_11b_vision_instruct.value,
CoreModelId.llama3_2_90b_vision_instruct.value,
CoreModelId.llama3_3_70b_instruct.value,
CoreModelId.llama_guard_3_8b.value,
CoreModelId.llama_guard_3_11b_vision.value,
]
),
}
class Report:
def __init__(self):
config_file = os.environ.get("LLAMA_STACK_CONFIG")
if not config_file:
raise ValueError(
"Currently we only support generating report for LlamaStackClientLibrary distributions"
)
config_path = Path(config_file)
self.output_path = Path(config_path.parent / "report.md")
self.client = LlamaStackAsLibraryClient(
config_file,
provider_data=None,
skip_logger_removal=True,
)
self.image_name = self.client.async_client.config.image_name
self.report_data = defaultdict(dict)
# test function -> test nodeid
self.test_data = dict()
self.test_name_to_nodeid = defaultdict(list)
@pytest.hookimpl(tryfirst=True)
def pytest_runtest_logreport(self, report):
# This hook is called in several phases, including setup, call and teardown
# The test is considered failed / error if any of the outcomes is not "Passed"
outcome = self._process_outcome(report)
if report.nodeid not in self.test_data:
self.test_data[report.nodeid] = outcome
elif self.test_data[report.nodeid] != outcome and outcome != "Passed":
self.test_data[report.nodeid] = outcome
def pytest_sessionfinish(self, session):
report = []
report.append(f"# Report for {self.image_name} distribution")
report.append("\n## Supported Models: ")
header = f"| Model Descriptor | {self.image_name} |"
dividor = "|:---|:---|"
report.append(header)
report.append(dividor)
rows = []
for model in all_registered_models():
if (
"Instruct" not in model.core_model_id.value
and "Guard" not in model.core_model_id.value
) or (model.variant):
continue
row = f"| {model.core_model_id.value} |"
if model.core_model_id.value in SUPPORTED_MODELS[self.image_name]:
row += " ✅ |"
else:
row += " ❌ |"
rows.append(row)
report.extend(rows)
report.append("\n## Inference: ")
test_table = [
"| Model | API | Capability | Test | Status |",
"|:----- |:-----|:-----|:-----|:-----|",
]
for api, capa_map in API_MAPS["inference"].items():
for capa, tests in capa_map.items():
vision_tests = filter(lambda test_name: "image" in test_name, tests)
text_tests = filter(lambda test_name: "text" in test_name, tests)
for test_name in text_tests:
test_nodeids = self.test_name_to_nodeid[test_name]
assert len(test_nodeids) > 0
# There might be more than one parametrizations for the same test function. We take
# the result of the first one for now. Ideally we should mark the test as failed if
# any of the parametrizations failed.
test_table.append(
f"| Text | /{api} | {capa} | {test_name} | {self._print_result_icon(self.test_data[test_nodeids[0]])} |"
)
for test_name in vision_tests:
test_nodeids = self.test_name_to_nodeid[test_name]
assert len(test_nodeids) > 0
test_table.append(
f"| Vision | /{api} | {capa} | {test_name} | {self.test_data[test_nodeids[0]]} |"
)
report.extend(test_table)
for api_group in ["memory", "agents"]:
api_capitalized = api_group.capitalize()
report.append(f"\n## {api_capitalized}: ")
test_table = [
"| API | Capability | Test | Status |",
"|:-----|:-----|:-----|:-----|",
]
for api, capa_map in API_MAPS[api_group].items():
for capa, tests in capa_map.items():
for test_name in tests:
test_nodeids = self.test_name_to_nodeid[test_name]
assert len(test_nodeids) > 0
test_table.append(
f"| {api} | {capa} | {test_name} | {self._print_result_icon(self.test_data[test_nodeids[0]])} |"
)
report.extend(test_table)
output_file = self.output_path
output_file.write_text("\n".join(report))
print(f"\nReport generated: {output_file.absolute()}")
def pytest_runtest_makereport(self, item, call):
func_name = getattr(item, "originalname", item.name)
self.test_name_to_nodeid[func_name].append(item.nodeid)
def _print_result_icon(self, result):
if result == "Passed":
return ""
elif result == "Failed" or result == "Error":
return ""
else:
# result == "Skipped":
return "⏭️"
def _process_outcome(self, report: CollectReport):
if self._is_error(report):
return "Error"
if hasattr(report, "wasxfail"):
if report.outcome in ["passed", "failed"]:
return "XPassed"
if report.outcome == "skipped":
return "XFailed"
return report.outcome.capitalize()
def _is_error(self, report: CollectReport):
return (
report.when in ["setup", "teardown", "collect"]
and report.outcome == "failed"
)