This commit is contained in:
Sixian Yi 2025-01-14 16:33:29 -08:00
parent b22cc3e1fa
commit 9b709005fa
5 changed files with 46 additions and 92 deletions

View file

@ -9,7 +9,9 @@ import pytest
from llama_stack.apis.agents import AgentConfig, Turn from llama_stack.apis.agents import AgentConfig, Turn
from llama_stack.apis.inference import SamplingParams, UserMessage from llama_stack.apis.inference import SamplingParams, UserMessage
from llama_stack.providers.datatypes import Api from llama_stack.providers.datatypes import Api
from llama_stack.providers.utils.kvstore import kvstore_impl, SqliteKVStoreConfig from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
from .fixtures import pick_inference_model from .fixtures import pick_inference_model
from .utils import create_agent_session from .utils import create_agent_session

View file

@ -6,6 +6,7 @@
import os import os
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
@ -13,15 +14,13 @@ import pytest
import yaml import yaml
from dotenv import load_dotenv from dotenv import load_dotenv
from llama_stack.distribution.datatypes import Provider
from llama_stack.providers.datatypes import RemoteProviderConfig
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from termcolor import colored from termcolor import colored
from .env import get_env_or_fail from llama_stack.distribution.datatypes import Provider
from llama_stack.providers.datatypes import RemoteProviderConfig
from .test_config_helper import try_load_config_file_cached from .env import get_env_or_fail
from .report import Report from .report import Report
@ -142,8 +141,8 @@ def pytest_configure(config):
key, value = env_var.split("=", 1) key, value = env_var.split("=", 1)
os.environ[key] = value os.environ[key] = value
if config.getoption("--config") is not None: if config.getoption("--output") is not None:
config.pluginmanager.register(Report(config)) config.pluginmanager.register(Report(config.getoption("--output")))
def pytest_addoption(parser): def pytest_addoption(parser):
@ -160,6 +159,11 @@ def pytest_addoption(parser):
action="store", action="store",
help="Set test config file (supported format: YAML), e.g. --config=test_config.yml", help="Set test config file (supported format: YAML), e.g. --config=test_config.yml",
) )
parser.addoption(
"--output",
action="store",
help="Set output file for test report, e.g. --output=pytest_report.md",
)
"""Add custom command line options""" """Add custom command line options"""
parser.addoption( parser.addoption(
"--env", action="append", help="Set environment variables, e.g. --env KEY=value" "--env", action="append", help="Set environment variables, e.g. --env KEY=value"
@ -269,9 +273,14 @@ def pytest_collection_modifyitems(session, config, items):
return return
required_tests = defaultdict(set) required_tests = defaultdict(set)
test_configs = [test_config.inference, test_config.memory, test_config.agent] for api_test_config in [
for test_config in test_configs: test_config.inference,
for test in test_config.tests: test_config.memory,
test_config.agents,
]:
if api_test_config is None:
continue
for test in api_test_config.tests:
arr = test.split("::") arr = test.split("::")
if len(arr) != 2: if len(arr) != 2:
raise ValueError(f"Invalid format for test name {test}") raise ValueError(f"Invalid format for test name {test}")

View file

@ -5,7 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
import pytest import pytest
from llama_stack.apis.common.type_system import JobStatus from llama_stack.apis.common.job_types import JobStatus
from llama_stack.apis.post_training import ( from llama_stack.apis.post_training import (
Checkpoint, Checkpoint,
DataConfig, DataConfig,

View file

@ -11,6 +11,7 @@ from pathlib import Path
import pytest import pytest
from llama_models.datatypes import CoreModelId from llama_models.datatypes import CoreModelId
from llama_models.sku_list import all_registered_models from llama_models.sku_list import all_registered_models
from pytest import ExitCode
from pytest_html.basereport import _process_outcome from pytest_html.basereport import _process_outcome
@ -71,11 +72,22 @@ SUPPORTED_MODELS = {
class Report: class Report:
def __init__(self, _config): def __init__(self, output_path):
valid_file_format = (
output_path.split(".")[1] in ["md", "markdown"]
if len(output_path.split(".")) == 2
else False
)
if not valid_file_format:
raise ValueError(
f"Invalid output file {output_path}. Markdown file is required"
)
self.output_path = output_path
self.test_data = defaultdict(dict) self.test_data = defaultdict(dict)
self.inference_tests = defaultdict(dict) self.inference_tests = defaultdict(dict)
@pytest.hookimpl(tryfirst=True) @pytest.hookimpl
def pytest_runtest_logreport(self, report): def pytest_runtest_logreport(self, report):
# This hook is called in several phases, including setup, call and teardown # This hook is called in several phases, including setup, call and teardown
# The test is considered failed / error if any of the outcomes is not "Passed" # The test is considered failed / error if any of the outcomes is not "Passed"
@ -91,7 +103,9 @@ class Report:
self.test_data[report.nodeid] = data self.test_data[report.nodeid] = data
@pytest.hookimpl @pytest.hookimpl
def pytest_sessionfinish(self, session): def pytest_sessionfinish(self, session, exitstatus):
if exitstatus <= ExitCode.INTERRUPTED:
return
report = [] report = []
report.append("# Llama Stack Integration Test Results Report") report.append("# Llama Stack Integration Test Results Report")
report.append("\n## Summary") report.append("\n## Summary")
@ -108,6 +122,11 @@ class Report:
rows = [] rows = []
for model in all_registered_models(): for model in all_registered_models():
if (
"Instruct" not in model.core_model_id.value
and "Guard" not in model.core_model_id.value
):
continue
row = f"| {model.core_model_id.value} |" row = f"| {model.core_model_id.value} |"
for k in SUPPORTED_MODELS.keys(): for k in SUPPORTED_MODELS.keys():
if model.core_model_id.value in SUPPORTED_MODELS[k]: if model.core_model_id.value in SUPPORTED_MODELS[k]:
@ -149,7 +168,7 @@ class Report:
report.extend(test_table) report.extend(test_table)
report.append("\n") report.append("\n")
output_file = Path("pytest_report.md") output_file = Path(self.output_path)
output_file.write_text("\n".join(report)) output_file.write_text("\n".join(report))
print(f"\n Report generated: {output_file.absolute()}") print(f"\n Report generated: {output_file.absolute()}")

View file

@ -1,76 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional
import pytest
import yaml
from pydantic import BaseModel, Field
@dataclass
class APITestConfig(BaseModel):
class Fixtures(BaseModel):
# provider fixtures can be either a mark or a dictionary of api -> providers
provider_fixtures: List[Dict[str, str]] = Field(default_factory=list)
inference_models: List[str] = Field(default_factory=list)
safety_shield: Optional[str] = Field(default_factory=None)
embedding_model: Optional[str] = Field(default_factory=None)
fixtures: Fixtures
tests: List[str] = Field(default_factory=list)
# test name format should be <relative_path.py>::<test_name>
class TestConfig(BaseModel):
inference: APITestConfig
agent: Optional[APITestConfig] = Field(default=None)
memory: Optional[APITestConfig] = Field(default=None)
CONFIG_CACHE = None
def try_load_config_file_cached(config_file):
if config_file is None:
return None
if CONFIG_CACHE is not None:
return CONFIG_CACHE
config_file_path = Path(__file__).parent / config_file
if not config_file_path.exists():
raise ValueError(
f"Test config {config_file} was specified but not found. Please make sure it exists in the llama_stack/providers/tests directory."
)
with open(config_file_path, "r") as config_file:
config = yaml.safe_load(config_file)
return TestConfig(**config)
def get_provider_fixtures_from_config(
provider_fixtures_config, default_fixture_combination
):
custom_fixtures = []
selected_default_param_id = set()
for fixture_config in provider_fixtures_config:
if "default_fixture_param_id" in fixture_config:
selected_default_param_id.add(fixture_config["default_fixture_param_id"])
else:
custom_fixtures.append(
pytest.param(fixture_config, id=fixture_config.get("inference") or "")
)
if len(selected_default_param_id) > 0:
for default_fixture in default_fixture_combination:
if default_fixture.id in selected_default_param_id:
custom_fixtures.append(default_fixture)
return custom_fixtures