From 39072f8798b02ef4112b617a37713a1943196873 Mon Sep 17 00:00:00 2001 From: Sixian Yi Date: Wed, 15 Jan 2025 13:45:55 -0800 Subject: [PATCH] support non value for test config --- llama_stack/providers/tests/conftest.py | 15 ++-- .../providers/tests/test_config_helper.py | 76 +++++++++++++++++++ 2 files changed, 81 insertions(+), 10 deletions(-) create mode 100644 llama_stack/providers/tests/test_config_helper.py diff --git a/llama_stack/providers/tests/conftest.py b/llama_stack/providers/tests/conftest.py index 5a04d5eee..6ce5b5ecc 100644 --- a/llama_stack/providers/tests/conftest.py +++ b/llama_stack/providers/tests/conftest.py @@ -13,11 +13,11 @@ import pytest import yaml from dotenv import load_dotenv -from pydantic import BaseModel, Field -from termcolor import colored from llama_stack.distribution.datatypes import Provider from llama_stack.providers.datatypes import RemoteProviderConfig +from pydantic import BaseModel, Field +from termcolor import colored from .env import get_env_or_fail @@ -265,14 +265,9 @@ def pytest_collection_modifyitems(session, config, items): return required_tests = defaultdict(set) - for api_test_config in [ - test_config.inference, - test_config.memory, - test_config.agents, - ]: - if api_test_config is None: - continue - for test in api_test_config.tests: + test_configs = [test_config.inference, test_config.memory, test_config.agent] + for test_config in test_configs: + for test in test_config.tests: arr = test.split("::") if len(arr) != 2: raise ValueError(f"Invalid format for test name {test}") diff --git a/llama_stack/providers/tests/test_config_helper.py b/llama_stack/providers/tests/test_config_helper.py new file mode 100644 index 000000000..c86822f0e --- /dev/null +++ b/llama_stack/providers/tests/test_config_helper.py @@ -0,0 +1,76 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, List, Optional + +import pytest +import yaml +from pydantic import BaseModel, Field + + +@dataclass +class APITestConfig(BaseModel): + + class Fixtures(BaseModel): + # provider fixtures can be either a mark or a dictionary of api -> providers + provider_fixtures: List[Dict[str, str]] = Field(default_factory=list) + inference_models: List[str] = Field(default_factory=list) + safety_shield: Optional[str] = Field(default_factory=None) + embedding_model: Optional[str] = Field(default_factory=None) + + fixtures: Fixtures + tests: List[str] = Field(default_factory=list) + + # test name format should be :: + + +class TestConfig(BaseModel): + + inference: APITestConfig + agent: Optional[APITestConfig] = Field(default=None) + memory: Optional[APITestConfig] = Field(default=None) + + +CONFIG_CACHE = None + + +def try_load_config_file_cached(config_file): + if config_file is None: + return None + if CONFIG_CACHE is not None: + return CONFIG_CACHE + + config_file_path = Path(__file__).parent / config_file + if not config_file_path.exists(): + raise ValueError( + f"Test config {config_file} was specified but not found. Please make sure it exists in the llama_stack/providers/tests directory." + ) + with open(config_file_path, "r") as config_file: + config = yaml.safe_load(config_file) + return TestConfig(**config) + + +def get_provider_fixtures_from_config( + provider_fixtures_config, default_fixture_combination +): + custom_fixtures = [] + selected_default_param_id = set() + for fixture_config in provider_fixtures_config: + if "default_fixture_param_id" in fixture_config: + selected_default_param_id.add(fixture_config["default_fixture_param_id"]) + else: + custom_fixtures.append( + pytest.param(fixture_config, id=fixture_config.get("inference") or "") + ) + + if len(selected_default_param_id) > 0: + for default_fixture in default_fixture_combination: + if default_fixture.id in selected_default_param_id: + custom_fixtures.append(default_fixture) + + return custom_fixtures