Improved report generation for providers (#844)

# What does this PR do?

Automates the model list check by querying the distro. 
Added support for both remote hosted and templates. 

## Test Plan
Run on a remote hosted distro via 
`LLAMA_STACK_BASE_URL="https://llamastack-preview.fireworks.ai" pytest
-s -v tests/client-sdk --report`
Run on a template via 
`LLAMA_STACK_CONFIG=fireworks pytest -s -v  tests/client-sdk --report`
This commit is contained in:
Hardik Shah 2025-01-22 15:27:09 -08:00 committed by GitHub
parent 8738c3e5a7
commit deab4f57dd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 142 additions and 94 deletions

View file

@ -0,0 +1,45 @@
# Report for fireworks distribution
## Supported Models:
| Model Descriptor | fireworks |
|:---|:---|
| meta-llama/Llama-3-8B-Instruct | ❌ |
| meta-llama/Llama-3-70B-Instruct | ❌ |
| meta-llama/Llama-3.1-8B-Instruct | ❌ |
| meta-llama/Llama-3.1-70B-Instruct | ❌ |
| meta-llama/Llama-3.1-405B-Instruct-FP8 | ❌ |
| meta-llama/Llama-3.2-1B-Instruct | ❌ |
| meta-llama/Llama-3.2-3B-Instruct | ❌ |
| meta-llama/Llama-3.2-11B-Vision-Instruct | ❌ |
| meta-llama/Llama-3.2-90B-Vision-Instruct | ❌ |
| meta-llama/Llama-3.3-70B-Instruct | ❌ |
| meta-llama/Llama-Guard-3-11B-Vision | ❌ |
| meta-llama/Llama-Guard-3-1B | ❌ |
| meta-llama/Llama-Guard-3-8B | ❌ |
| meta-llama/Llama-Guard-2-8B | ❌ |
## Inference:
| Model | API | Capability | Test | Status |
|:----- |:-----|:-----|:-----|:-----|
| Text | /chat_completion | streaming | test_text_chat_completion_streaming | ❌ |
| Vision | /chat_completion | streaming | test_image_chat_completion_streaming | ❌ |
| Vision | /chat_completion | non_streaming | test_image_chat_completion_non_streaming | ❌ |
| Text | /chat_completion | non_streaming | test_text_chat_completion_non_streaming | ❌ |
| Text | /chat_completion | tool_calling | test_text_chat_completion_with_tool_calling_and_streaming | ❌ |
| Text | /chat_completion | tool_calling | test_text_chat_completion_with_tool_calling_and_non_streaming | ❌ |
| Text | /completion | streaming | test_text_completion_streaming | ❌ |
| Text | /completion | non_streaming | test_text_completion_non_streaming | ❌ |
| Text | /completion | structured_output | test_text_completion_structured_output | ❌ |
## Memory:
| API | Capability | Test | Status |
|:-----|:-----|:-----|:-----|
| /insert, /query | inline | test_memory_bank_insert_inline_and_query | ❌ |
| /insert, /query | url | test_memory_bank_insert_from_url_and_query | ❌ |
## Agents:
| API | Capability | Test | Status |
|:-----|:-----|:-----|:-----|
| create_agent_turn | rag | test_rag_agent | ❌ |
| create_agent_turn | custom_tool | test_custom_tool | ❌ |
| create_agent_turn | code_execution | test_code_execution | ❌ |

View file

@ -3,20 +3,20 @@
## Supported Models: ## Supported Models:
| Model Descriptor | fireworks | | Model Descriptor | fireworks |
|:---|:---| |:---|:---|
| Llama-3-8B-Instruct | ❌ | | meta-llama/Llama-3-8B-Instruct | ❌ |
| Llama-3-70B-Instruct | ❌ | | meta-llama/Llama-3-70B-Instruct | ❌ |
| Llama3.1-8B-Instruct | ✅ | | meta-llama/Llama-3.1-8B-Instruct | ✅ |
| Llama3.1-70B-Instruct | ✅ | | meta-llama/Llama-3.1-70B-Instruct | ✅ |
| Llama3.1-405B-Instruct | ✅ | | meta-llama/Llama-3.1-405B-Instruct-FP8 | ✅ |
| Llama3.2-1B-Instruct | ✅ | | meta-llama/Llama-3.2-1B-Instruct | ✅ |
| Llama3.2-3B-Instruct | ✅ | | meta-llama/Llama-3.2-3B-Instruct | ✅ |
| Llama3.2-11B-Vision-Instruct | ✅ | | meta-llama/Llama-3.2-11B-Vision-Instruct | ✅ |
| Llama3.2-90B-Vision-Instruct | ✅ | | meta-llama/Llama-3.2-90B-Vision-Instruct | ✅ |
| Llama3.3-70B-Instruct | ✅ | | meta-llama/Llama-3.3-70B-Instruct | ✅ |
| Llama-Guard-3-11B-Vision | ✅ | | meta-llama/Llama-Guard-3-11B-Vision | ✅ |
| Llama-Guard-3-1B | ❌ | | meta-llama/Llama-Guard-3-1B | ❌ |
| Llama-Guard-3-8B | ✅ | | meta-llama/Llama-Guard-3-8B | ✅ |
| Llama-Guard-2-8B | ❌ | | meta-llama/Llama-Guard-2-8B | ❌ |
## Inference: ## Inference:
| Model | API | Capability | Test | Status | | Model | API | Capability | Test | Status |
@ -34,12 +34,12 @@
## Memory: ## Memory:
| API | Capability | Test | Status | | API | Capability | Test | Status |
|:-----|:-----|:-----|:-----| |:-----|:-----|:-----|:-----|
| /insert, /query | inline | test_memory_bank_insert_inline_and_query | | | /insert, /query | inline | test_memory_bank_insert_inline_and_query | |
| /insert, /query | url | test_memory_bank_insert_from_url_and_query | | | /insert, /query | url | test_memory_bank_insert_from_url_and_query | |
## Agents: ## Agents:
| API | Capability | Test | Status | | API | Capability | Test | Status |
|:-----|:-----|:-----|:-----| |:-----|:-----|:-----|:-----|
| create_agent_turn | rag | test_rag_agent | | | create_agent_turn | rag | test_rag_agent | |
| create_agent_turn | custom_tool | test_custom_tool | ✅ | | create_agent_turn | custom_tool | test_custom_tool | ✅ |
| create_agent_turn | code_execution | test_code_execution | ❌ | | create_agent_turn | code_execution | test_code_execution | ❌ |

View file

@ -5,88 +5,87 @@
# the root directory of this source tree. # the root directory of this source tree.
import importlib
import os import os
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
from urllib.parse import urlparse
import pytest import pytest
from llama_models.datatypes import CoreModelId
from llama_models.sku_list import all_registered_models from llama_models.sku_list import (
llama3_1_instruct_models,
llama3_2_instruct_models,
llama3_3_instruct_models,
llama3_instruct_models,
safety_models,
)
from llama_stack.distribution.library_client import LlamaStackAsLibraryClient from llama_stack.distribution.library_client import LlamaStackAsLibraryClient
from llama_stack.providers.tests.env import get_env_or_fail
from llama_stack_client import LlamaStackClient
from metadata import API_MAPS from metadata import API_MAPS
from pytest import CollectReport from pytest import CollectReport
from termcolor import cprint
SUPPORTED_MODELS = { def featured_models_repo_names():
"ollama": set( models = [
[ *llama3_instruct_models(),
CoreModelId.llama3_1_8b_instruct.value, *llama3_1_instruct_models(),
CoreModelId.llama3_1_8b_instruct.value, *llama3_2_instruct_models(),
CoreModelId.llama3_1_70b_instruct.value, *llama3_3_instruct_models(),
CoreModelId.llama3_1_70b_instruct.value, *safety_models(),
CoreModelId.llama3_1_405b_instruct.value, ]
CoreModelId.llama3_1_405b_instruct.value, return [model.huggingface_repo for model in models if not model.variant]
CoreModelId.llama3_2_1b_instruct.value,
CoreModelId.llama3_2_1b_instruct.value,
CoreModelId.llama3_2_3b_instruct.value,
CoreModelId.llama3_2_3b_instruct.value,
CoreModelId.llama3_2_11b_vision_instruct.value,
CoreModelId.llama3_2_11b_vision_instruct.value,
CoreModelId.llama3_2_90b_vision_instruct.value,
CoreModelId.llama3_2_90b_vision_instruct.value,
CoreModelId.llama3_3_70b_instruct.value,
CoreModelId.llama_guard_3_8b.value,
CoreModelId.llama_guard_3_1b.value,
]
),
"fireworks": set(
[
CoreModelId.llama3_1_8b_instruct.value,
CoreModelId.llama3_1_70b_instruct.value,
CoreModelId.llama3_1_405b_instruct.value,
CoreModelId.llama3_2_1b_instruct.value,
CoreModelId.llama3_2_3b_instruct.value,
CoreModelId.llama3_2_11b_vision_instruct.value,
CoreModelId.llama3_2_90b_vision_instruct.value,
CoreModelId.llama3_3_70b_instruct.value,
CoreModelId.llama_guard_3_8b.value,
CoreModelId.llama_guard_3_11b_vision.value,
]
),
"together": set(
[
CoreModelId.llama3_1_8b_instruct.value,
CoreModelId.llama3_1_70b_instruct.value,
CoreModelId.llama3_1_405b_instruct.value,
CoreModelId.llama3_2_3b_instruct.value,
CoreModelId.llama3_2_11b_vision_instruct.value,
CoreModelId.llama3_2_90b_vision_instruct.value,
CoreModelId.llama3_3_70b_instruct.value,
CoreModelId.llama_guard_3_8b.value,
CoreModelId.llama_guard_3_11b_vision.value,
]
),
}
class Report: class Report:
def __init__(self): def __init__(self):
config_file = os.environ.get("LLAMA_STACK_CONFIG") if os.environ.get("LLAMA_STACK_CONFIG"):
if not config_file: config_path_or_template_name = get_env_or_fail("LLAMA_STACK_CONFIG")
raise ValueError( if config_path_or_template_name.endswith(".yaml"):
"Currently we only support generating report for LlamaStackClientLibrary distributions" config_path = Path(config_path_or_template_name)
else:
config_path = Path(
importlib.resources.files("llama_stack")
/ f"templates/{config_path_or_template_name}/run.yaml"
)
if not config_path.exists():
raise ValueError(f"Config file {config_path} does not exist")
self.output_path = Path(config_path.parent / "report.md")
self.client = LlamaStackAsLibraryClient(
config_path_or_template_name,
provider_data=None,
skip_logger_removal=True,
) )
config_path = Path(config_file) self.client.initialize()
self.output_path = Path(config_path.parent / "report.md") self.image_name = self.client.async_client.config.image_name
self.client = LlamaStackAsLibraryClient( elif os.environ.get("LLAMA_STACK_BASE_URL"):
config_file, url = get_env_or_fail("LLAMA_STACK_BASE_URL")
provider_data=None, hostname = urlparse(url).netloc
skip_logger_removal=True, domain = hostname.split(".")[-2]
) self.image_name = domain
self.image_name = self.client.async_client.config.image_name
self.client = LlamaStackClient(
base_url=url,
provider_data=None,
)
# We assume that the domain maps to a template
# i.e. https://llamastack-preview.fireworks.ai --> "fireworks" template
# and add report in that directory
output_dir = Path(
importlib.resources.files("llama_stack") / f"templates/{domain}/"
)
if not output_dir.exists():
raise ValueError(f"Output dir {output_dir} does not exist")
self.output_path = Path(output_dir / "remote-hosted-report.md")
else:
raise ValueError("LLAMA_STACK_CONFIG or LLAMA_STACK_BASE_URL must be set")
self.report_data = defaultdict(dict) self.report_data = defaultdict(dict)
# test function -> test nodeid # test function -> test nodeid
self.test_data = dict() self.test_data = dict()
@ -105,7 +104,7 @@ class Report:
def pytest_sessionfinish(self, session): def pytest_sessionfinish(self, session):
report = [] report = []
report.append(f"# Report for {self.image_name} distribution") report.append(f"# Report for {self.image_name} distribution")
report.append("\n## Supported Models: ") report.append("\n## Supported Models:")
header = f"| Model Descriptor | {self.image_name} |" header = f"| Model Descriptor | {self.image_name} |"
dividor = "|:---|:---|" dividor = "|:---|:---|"
@ -114,21 +113,23 @@ class Report:
report.append(dividor) report.append(dividor)
rows = [] rows = []
for model in all_registered_models():
if ( try:
"Instruct" not in model.core_model_id.value supported_models = {m.identifier for m in self.client.models.list()}
and "Guard" not in model.core_model_id.value except Exception as e:
) or (model.variant): cprint(f"Error getting models: {e}", "red")
continue supported_models = set()
row = f"| {model.core_model_id.value} |"
if model.core_model_id.value in SUPPORTED_MODELS[self.image_name]: for m_name in featured_models_repo_names():
row = f"| {m_name} |"
if m_name in supported_models:
row += " ✅ |" row += " ✅ |"
else: else:
row += " ❌ |" row += " ❌ |"
rows.append(row) rows.append(row)
report.extend(rows) report.extend(rows)
report.append("\n## Inference: ") report.append("\n## Inference:")
test_table = [ test_table = [
"| Model | API | Capability | Test | Status |", "| Model | API | Capability | Test | Status |",
"|:----- |:-----|:-----|:-----|:-----|", "|:----- |:-----|:-----|:-----|:-----|",
@ -150,7 +151,7 @@ class Report:
for api_group in ["memory", "agents"]: for api_group in ["memory", "agents"]:
api_capitalized = api_group.capitalize() api_capitalized = api_group.capitalize()
report.append(f"\n## {api_capitalized}: ") report.append(f"\n## {api_capitalized}:")
test_table = [ test_table = [
"| API | Capability | Test | Status |", "| API | Capability | Test | Status |",
"|:-----|:-----|:-----|:-----|", "|:-----|:-----|:-----|:-----|",
@ -164,9 +165,11 @@ class Report:
f"| {api} | {capa} | {test_name} | {self._print_result_icon(self.test_data[test_nodeids[0]])} |" f"| {api} | {capa} | {test_name} | {self._print_result_icon(self.test_data[test_nodeids[0]])} |"
) )
report.extend(test_table) report.extend(test_table)
output_file = self.output_path output_file = self.output_path
output_file.write_text("\n".join(report)) text = "\n".join(report) + "\n"
print(f"\nReport generated: {output_file.absolute()}") output_file.write_text(text)
cprint(f"\nReport generated: {output_file.absolute()}", "green")
def pytest_runtest_makereport(self, item, call): def pytest_runtest_makereport(self, item, call):
func_name = getattr(item, "originalname", item.name) func_name = getattr(item, "originalname", item.name)