Improved report generation for providers (#844)

# What does this PR do?

Automates the model list check by querying the distro. 
Added support for both remote hosted and templates. 

## Test Plan
Run on a remote hosted distro via 
`LLAMA_STACK_BASE_URL="https://llamastack-preview.fireworks.ai" pytest
-s -v tests/client-sdk --report`
Run on a template via 
`LLAMA_STACK_CONFIG=fireworks pytest -s -v  tests/client-sdk --report`
This commit is contained in:
Hardik Shah 2025-01-22 15:27:09 -08:00 committed by GitHub
parent 8738c3e5a7
commit deab4f57dd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 142 additions and 94 deletions

View file

@ -0,0 +1,45 @@
# Report for fireworks distribution
## Supported Models:
| Model Descriptor | fireworks |
|:---|:---|
| meta-llama/Llama-3-8B-Instruct | ❌ |
| meta-llama/Llama-3-70B-Instruct | ❌ |
| meta-llama/Llama-3.1-8B-Instruct | ❌ |
| meta-llama/Llama-3.1-70B-Instruct | ❌ |
| meta-llama/Llama-3.1-405B-Instruct-FP8 | ❌ |
| meta-llama/Llama-3.2-1B-Instruct | ❌ |
| meta-llama/Llama-3.2-3B-Instruct | ❌ |
| meta-llama/Llama-3.2-11B-Vision-Instruct | ❌ |
| meta-llama/Llama-3.2-90B-Vision-Instruct | ❌ |
| meta-llama/Llama-3.3-70B-Instruct | ❌ |
| meta-llama/Llama-Guard-3-11B-Vision | ❌ |
| meta-llama/Llama-Guard-3-1B | ❌ |
| meta-llama/Llama-Guard-3-8B | ❌ |
| meta-llama/Llama-Guard-2-8B | ❌ |
## Inference:
| Model | API | Capability | Test | Status |
|:----- |:-----|:-----|:-----|:-----|
| Text | /chat_completion | streaming | test_text_chat_completion_streaming | ❌ |
| Vision | /chat_completion | streaming | test_image_chat_completion_streaming | ❌ |
| Vision | /chat_completion | non_streaming | test_image_chat_completion_non_streaming | ❌ |
| Text | /chat_completion | non_streaming | test_text_chat_completion_non_streaming | ❌ |
| Text | /chat_completion | tool_calling | test_text_chat_completion_with_tool_calling_and_streaming | ❌ |
| Text | /chat_completion | tool_calling | test_text_chat_completion_with_tool_calling_and_non_streaming | ❌ |
| Text | /completion | streaming | test_text_completion_streaming | ❌ |
| Text | /completion | non_streaming | test_text_completion_non_streaming | ❌ |
| Text | /completion | structured_output | test_text_completion_structured_output | ❌ |
## Memory:
| API | Capability | Test | Status |
|:-----|:-----|:-----|:-----|
| /insert, /query | inline | test_memory_bank_insert_inline_and_query | ❌ |
| /insert, /query | url | test_memory_bank_insert_from_url_and_query | ❌ |
## Agents:
| API | Capability | Test | Status |
|:-----|:-----|:-----|:-----|
| create_agent_turn | rag | test_rag_agent | ❌ |
| create_agent_turn | custom_tool | test_custom_tool | ❌ |
| create_agent_turn | code_execution | test_code_execution | ❌ |

View file

@ -3,20 +3,20 @@
## Supported Models:
| Model Descriptor | fireworks |
|:---|:---|
| Llama-3-8B-Instruct | ❌ |
| Llama-3-70B-Instruct | ❌ |
| Llama3.1-8B-Instruct | ✅ |
| Llama3.1-70B-Instruct | ✅ |
| Llama3.1-405B-Instruct | ✅ |
| Llama3.2-1B-Instruct | ✅ |
| Llama3.2-3B-Instruct | ✅ |
| Llama3.2-11B-Vision-Instruct | ✅ |
| Llama3.2-90B-Vision-Instruct | ✅ |
| Llama3.3-70B-Instruct | ✅ |
| Llama-Guard-3-11B-Vision | ✅ |
| Llama-Guard-3-1B | ❌ |
| Llama-Guard-3-8B | ✅ |
| Llama-Guard-2-8B | ❌ |
| meta-llama/Llama-3-8B-Instruct | ❌ |
| meta-llama/Llama-3-70B-Instruct | ❌ |
| meta-llama/Llama-3.1-8B-Instruct | ✅ |
| meta-llama/Llama-3.1-70B-Instruct | ✅ |
| meta-llama/Llama-3.1-405B-Instruct-FP8 | ✅ |
| meta-llama/Llama-3.2-1B-Instruct | ✅ |
| meta-llama/Llama-3.2-3B-Instruct | ✅ |
| meta-llama/Llama-3.2-11B-Vision-Instruct | ✅ |
| meta-llama/Llama-3.2-90B-Vision-Instruct | ✅ |
| meta-llama/Llama-3.3-70B-Instruct | ✅ |
| meta-llama/Llama-Guard-3-11B-Vision | ✅ |
| meta-llama/Llama-Guard-3-1B | ❌ |
| meta-llama/Llama-Guard-3-8B | ✅ |
| meta-llama/Llama-Guard-2-8B | ❌ |
## Inference:
| Model | API | Capability | Test | Status |
@ -34,12 +34,12 @@
## Memory:
| API | Capability | Test | Status |
|:-----|:-----|:-----|:-----|
| /insert, /query | inline | test_memory_bank_insert_inline_and_query | |
| /insert, /query | url | test_memory_bank_insert_from_url_and_query | |
| /insert, /query | inline | test_memory_bank_insert_inline_and_query | |
| /insert, /query | url | test_memory_bank_insert_from_url_and_query | |
## Agents:
| API | Capability | Test | Status |
|:-----|:-----|:-----|:-----|
| create_agent_turn | rag | test_rag_agent | |
| create_agent_turn | rag | test_rag_agent | |
| create_agent_turn | custom_tool | test_custom_tool | ✅ |
| create_agent_turn | code_execution | test_code_execution | ❌ |

View file

@ -5,88 +5,87 @@
# the root directory of this source tree.
import importlib
import os
from collections import defaultdict
from pathlib import Path
from urllib.parse import urlparse
import pytest
from llama_models.datatypes import CoreModelId
from llama_models.sku_list import all_registered_models
from llama_models.sku_list import (
llama3_1_instruct_models,
llama3_2_instruct_models,
llama3_3_instruct_models,
llama3_instruct_models,
safety_models,
)
from llama_stack.distribution.library_client import LlamaStackAsLibraryClient
from llama_stack.providers.tests.env import get_env_or_fail
from llama_stack_client import LlamaStackClient
from metadata import API_MAPS
from pytest import CollectReport
from termcolor import cprint
SUPPORTED_MODELS = {
"ollama": set(
[
CoreModelId.llama3_1_8b_instruct.value,
CoreModelId.llama3_1_8b_instruct.value,
CoreModelId.llama3_1_70b_instruct.value,
CoreModelId.llama3_1_70b_instruct.value,
CoreModelId.llama3_1_405b_instruct.value,
CoreModelId.llama3_1_405b_instruct.value,
CoreModelId.llama3_2_1b_instruct.value,
CoreModelId.llama3_2_1b_instruct.value,
CoreModelId.llama3_2_3b_instruct.value,
CoreModelId.llama3_2_3b_instruct.value,
CoreModelId.llama3_2_11b_vision_instruct.value,
CoreModelId.llama3_2_11b_vision_instruct.value,
CoreModelId.llama3_2_90b_vision_instruct.value,
CoreModelId.llama3_2_90b_vision_instruct.value,
CoreModelId.llama3_3_70b_instruct.value,
CoreModelId.llama_guard_3_8b.value,
CoreModelId.llama_guard_3_1b.value,
def featured_models_repo_names():
models = [
*llama3_instruct_models(),
*llama3_1_instruct_models(),
*llama3_2_instruct_models(),
*llama3_3_instruct_models(),
*safety_models(),
]
),
"fireworks": set(
[
CoreModelId.llama3_1_8b_instruct.value,
CoreModelId.llama3_1_70b_instruct.value,
CoreModelId.llama3_1_405b_instruct.value,
CoreModelId.llama3_2_1b_instruct.value,
CoreModelId.llama3_2_3b_instruct.value,
CoreModelId.llama3_2_11b_vision_instruct.value,
CoreModelId.llama3_2_90b_vision_instruct.value,
CoreModelId.llama3_3_70b_instruct.value,
CoreModelId.llama_guard_3_8b.value,
CoreModelId.llama_guard_3_11b_vision.value,
]
),
"together": set(
[
CoreModelId.llama3_1_8b_instruct.value,
CoreModelId.llama3_1_70b_instruct.value,
CoreModelId.llama3_1_405b_instruct.value,
CoreModelId.llama3_2_3b_instruct.value,
CoreModelId.llama3_2_11b_vision_instruct.value,
CoreModelId.llama3_2_90b_vision_instruct.value,
CoreModelId.llama3_3_70b_instruct.value,
CoreModelId.llama_guard_3_8b.value,
CoreModelId.llama_guard_3_11b_vision.value,
]
),
}
return [model.huggingface_repo for model in models if not model.variant]
class Report:
def __init__(self):
config_file = os.environ.get("LLAMA_STACK_CONFIG")
if not config_file:
raise ValueError(
"Currently we only support generating report for LlamaStackClientLibrary distributions"
if os.environ.get("LLAMA_STACK_CONFIG"):
config_path_or_template_name = get_env_or_fail("LLAMA_STACK_CONFIG")
if config_path_or_template_name.endswith(".yaml"):
config_path = Path(config_path_or_template_name)
else:
config_path = Path(
importlib.resources.files("llama_stack")
/ f"templates/{config_path_or_template_name}/run.yaml"
)
config_path = Path(config_file)
if not config_path.exists():
raise ValueError(f"Config file {config_path} does not exist")
self.output_path = Path(config_path.parent / "report.md")
self.client = LlamaStackAsLibraryClient(
config_file,
config_path_or_template_name,
provider_data=None,
skip_logger_removal=True,
)
self.client.initialize()
self.image_name = self.client.async_client.config.image_name
elif os.environ.get("LLAMA_STACK_BASE_URL"):
url = get_env_or_fail("LLAMA_STACK_BASE_URL")
hostname = urlparse(url).netloc
domain = hostname.split(".")[-2]
self.image_name = domain
self.client = LlamaStackClient(
base_url=url,
provider_data=None,
)
# We assume that the domain maps to a template
# i.e. https://llamastack-preview.fireworks.ai --> "fireworks" template
# and add report in that directory
output_dir = Path(
importlib.resources.files("llama_stack") / f"templates/{domain}/"
)
if not output_dir.exists():
raise ValueError(f"Output dir {output_dir} does not exist")
self.output_path = Path(output_dir / "remote-hosted-report.md")
else:
raise ValueError("LLAMA_STACK_CONFIG or LLAMA_STACK_BASE_URL must be set")
self.report_data = defaultdict(dict)
# test function -> test nodeid
self.test_data = dict()
@ -105,7 +104,7 @@ class Report:
def pytest_sessionfinish(self, session):
report = []
report.append(f"# Report for {self.image_name} distribution")
report.append("\n## Supported Models: ")
report.append("\n## Supported Models:")
header = f"| Model Descriptor | {self.image_name} |"
dividor = "|:---|:---|"
@ -114,21 +113,23 @@ class Report:
report.append(dividor)
rows = []
for model in all_registered_models():
if (
"Instruct" not in model.core_model_id.value
and "Guard" not in model.core_model_id.value
) or (model.variant):
continue
row = f"| {model.core_model_id.value} |"
if model.core_model_id.value in SUPPORTED_MODELS[self.image_name]:
try:
supported_models = {m.identifier for m in self.client.models.list()}
except Exception as e:
cprint(f"Error getting models: {e}", "red")
supported_models = set()
for m_name in featured_models_repo_names():
row = f"| {m_name} |"
if m_name in supported_models:
row += " ✅ |"
else:
row += " ❌ |"
rows.append(row)
report.extend(rows)
report.append("\n## Inference: ")
report.append("\n## Inference:")
test_table = [
"| Model | API | Capability | Test | Status |",
"|:----- |:-----|:-----|:-----|:-----|",
@ -150,7 +151,7 @@ class Report:
for api_group in ["memory", "agents"]:
api_capitalized = api_group.capitalize()
report.append(f"\n## {api_capitalized}: ")
report.append(f"\n## {api_capitalized}:")
test_table = [
"| API | Capability | Test | Status |",
"|:-----|:-----|:-----|:-----|",
@ -164,9 +165,11 @@ class Report:
f"| {api} | {capa} | {test_name} | {self._print_result_icon(self.test_data[test_nodeids[0]])} |"
)
report.extend(test_table)
output_file = self.output_path
output_file.write_text("\n".join(report))
print(f"\nReport generated: {output_file.absolute()}")
text = "\n".join(report) + "\n"
output_file.write_text(text)
cprint(f"\nReport generated: {output_file.absolute()}", "green")
def pytest_runtest_makereport(self, item, call):
func_name = getattr(item, "originalname", item.name)