diff --git a/llama_stack/providers/tests/agents/test_persistence.py b/llama_stack/providers/tests/agents/test_persistence.py index 38eb7de55..e6b1470ef 100644 --- a/llama_stack/providers/tests/agents/test_persistence.py +++ b/llama_stack/providers/tests/agents/test_persistence.py @@ -9,7 +9,9 @@ import pytest from llama_stack.apis.agents import AgentConfig, Turn from llama_stack.apis.inference import SamplingParams, UserMessage from llama_stack.providers.datatypes import Api -from llama_stack.providers.utils.kvstore import kvstore_impl, SqliteKVStoreConfig +from llama_stack.providers.utils.kvstore import kvstore_impl +from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig + from .fixtures import pick_inference_model from .utils import create_agent_session diff --git a/llama_stack/providers/tests/conftest.py b/llama_stack/providers/tests/conftest.py index 011048af4..4aa53a687 100644 --- a/llama_stack/providers/tests/conftest.py +++ b/llama_stack/providers/tests/conftest.py @@ -6,6 +6,7 @@ import os from collections import defaultdict + from pathlib import Path from typing import Any, Dict, List, Optional @@ -13,15 +14,13 @@ import pytest import yaml from dotenv import load_dotenv - -from llama_stack.distribution.datatypes import Provider -from llama_stack.providers.datatypes import RemoteProviderConfig from pydantic import BaseModel, Field from termcolor import colored -from .env import get_env_or_fail +from llama_stack.distribution.datatypes import Provider +from llama_stack.providers.datatypes import RemoteProviderConfig -from .test_config_helper import try_load_config_file_cached +from .env import get_env_or_fail from .report import Report @@ -142,8 +141,8 @@ def pytest_configure(config): key, value = env_var.split("=", 1) os.environ[key] = value - if config.getoption("--config") is not None: - config.pluginmanager.register(Report(config)) + if config.getoption("--output") is not None: + config.pluginmanager.register(Report(config.getoption("--output"))) def pytest_addoption(parser): @@ -160,6 +159,11 @@ def pytest_addoption(parser): action="store", help="Set test config file (supported format: YAML), e.g. --config=test_config.yml", ) + parser.addoption( + "--output", + action="store", + help="Set output file for test report, e.g. --output=pytest_report.md", + ) """Add custom command line options""" parser.addoption( "--env", action="append", help="Set environment variables, e.g. --env KEY=value" @@ -269,9 +273,14 @@ def pytest_collection_modifyitems(session, config, items): return required_tests = defaultdict(set) - test_configs = [test_config.inference, test_config.memory, test_config.agent] - for test_config in test_configs: - for test in test_config.tests: + for api_test_config in [ + test_config.inference, + test_config.memory, + test_config.agents, + ]: + if api_test_config is None: + continue + for test in api_test_config.tests: arr = test.split("::") if len(arr) != 2: raise ValueError(f"Invalid format for test name {test}") diff --git a/llama_stack/providers/tests/post_training/test_post_training.py b/llama_stack/providers/tests/post_training/test_post_training.py index 0645cd555..0c58c1fa0 100644 --- a/llama_stack/providers/tests/post_training/test_post_training.py +++ b/llama_stack/providers/tests/post_training/test_post_training.py @@ -5,7 +5,7 @@ # the root directory of this source tree. import pytest -from llama_stack.apis.common.type_system import JobStatus +from llama_stack.apis.common.job_types import JobStatus from llama_stack.apis.post_training import ( Checkpoint, DataConfig, diff --git a/llama_stack/providers/tests/report.py b/llama_stack/providers/tests/report.py index b60fc099d..c07d7278a 100644 --- a/llama_stack/providers/tests/report.py +++ b/llama_stack/providers/tests/report.py @@ -11,6 +11,7 @@ from pathlib import Path import pytest from llama_models.datatypes import CoreModelId from llama_models.sku_list import all_registered_models +from pytest import ExitCode from pytest_html.basereport import _process_outcome @@ -71,11 +72,22 @@ SUPPORTED_MODELS = { class Report: - def __init__(self, _config): + def __init__(self, output_path): + + valid_file_format = ( + output_path.split(".")[1] in ["md", "markdown"] + if len(output_path.split(".")) == 2 + else False + ) + if not valid_file_format: + raise ValueError( + f"Invalid output file {output_path}. Markdown file is required" + ) + self.output_path = output_path self.test_data = defaultdict(dict) self.inference_tests = defaultdict(dict) - @pytest.hookimpl(tryfirst=True) + @pytest.hookimpl def pytest_runtest_logreport(self, report): # This hook is called in several phases, including setup, call and teardown # The test is considered failed / error if any of the outcomes is not "Passed" @@ -91,7 +103,9 @@ class Report: self.test_data[report.nodeid] = data @pytest.hookimpl - def pytest_sessionfinish(self, session): + def pytest_sessionfinish(self, session, exitstatus): + if exitstatus <= ExitCode.INTERRUPTED: + return report = [] report.append("# Llama Stack Integration Test Results Report") report.append("\n## Summary") @@ -108,6 +122,11 @@ class Report: rows = [] for model in all_registered_models(): + if ( + "Instruct" not in model.core_model_id.value + and "Guard" not in model.core_model_id.value + ): + continue row = f"| {model.core_model_id.value} |" for k in SUPPORTED_MODELS.keys(): if model.core_model_id.value in SUPPORTED_MODELS[k]: @@ -149,7 +168,7 @@ class Report: report.extend(test_table) report.append("\n") - output_file = Path("pytest_report.md") + output_file = Path(self.output_path) output_file.write_text("\n".join(report)) print(f"\n Report generated: {output_file.absolute()}") diff --git a/llama_stack/providers/tests/test_config_helper.py b/llama_stack/providers/tests/test_config_helper.py deleted file mode 100644 index c86822f0e..000000000 --- a/llama_stack/providers/tests/test_config_helper.py +++ /dev/null @@ -1,76 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from dataclasses import dataclass -from pathlib import Path -from typing import Dict, List, Optional - -import pytest -import yaml -from pydantic import BaseModel, Field - - -@dataclass -class APITestConfig(BaseModel): - - class Fixtures(BaseModel): - # provider fixtures can be either a mark or a dictionary of api -> providers - provider_fixtures: List[Dict[str, str]] = Field(default_factory=list) - inference_models: List[str] = Field(default_factory=list) - safety_shield: Optional[str] = Field(default_factory=None) - embedding_model: Optional[str] = Field(default_factory=None) - - fixtures: Fixtures - tests: List[str] = Field(default_factory=list) - - # test name format should be :: - - -class TestConfig(BaseModel): - - inference: APITestConfig - agent: Optional[APITestConfig] = Field(default=None) - memory: Optional[APITestConfig] = Field(default=None) - - -CONFIG_CACHE = None - - -def try_load_config_file_cached(config_file): - if config_file is None: - return None - if CONFIG_CACHE is not None: - return CONFIG_CACHE - - config_file_path = Path(__file__).parent / config_file - if not config_file_path.exists(): - raise ValueError( - f"Test config {config_file} was specified but not found. Please make sure it exists in the llama_stack/providers/tests directory." - ) - with open(config_file_path, "r") as config_file: - config = yaml.safe_load(config_file) - return TestConfig(**config) - - -def get_provider_fixtures_from_config( - provider_fixtures_config, default_fixture_combination -): - custom_fixtures = [] - selected_default_param_id = set() - for fixture_config in provider_fixtures_config: - if "default_fixture_param_id" in fixture_config: - selected_default_param_id.add(fixture_config["default_fixture_param_id"]) - else: - custom_fixtures.append( - pytest.param(fixture_config, id=fixture_config.get("inference") or "") - ) - - if len(selected_default_param_id) > 0: - for default_fixture in default_fixture_combination: - if default_fixture.id in selected_default_param_id: - custom_fixtures.append(default_fixture) - - return custom_fixtures