From d19526ecd75a13c76756deb8fd1962f1d3044642 Mon Sep 17 00:00:00 2001 From: Sixian Yi Date: Wed, 15 Jan 2025 23:50:42 -0800 Subject: [PATCH] address comments --- .../providers/tests/agents/conftest.py | 33 +++----- .../tests/agents/test_persistence.py | 3 +- .../providers/tests/ci_test_config.yaml | 54 ++++++------ llama_stack/providers/tests/conftest.py | 83 ++++++++++++------- .../providers/tests/inference/conftest.py | 50 +++++------ .../providers/tests/memory/conftest.py | 26 ++---- .../tests/post_training/test_post_training.py | 2 +- 7 files changed, 127 insertions(+), 124 deletions(-) diff --git a/llama_stack/providers/tests/agents/conftest.py b/llama_stack/providers/tests/agents/conftest.py index 06e98bd5e..4efdfe8b7 100644 --- a/llama_stack/providers/tests/agents/conftest.py +++ b/llama_stack/providers/tests/agents/conftest.py @@ -8,8 +8,8 @@ import pytest from ..conftest import ( get_provider_fixture_overrides, - get_provider_fixtures_from_config, - try_load_config_file_cached, + get_provider_fixture_overrides_from_test_config, + get_test_config_for_api, ) from ..inference.fixtures import INFERENCE_FIXTURES from ..memory.fixtures import MEMORY_FIXTURES @@ -87,25 +87,14 @@ def pytest_configure(config): def pytest_generate_tests(metafunc): - test_config = try_load_config_file_cached(metafunc.config) - ( - config_override_inference_models, - config_override_safety_shield, - custom_provider_fixtures, - ) = (None, None, None) - if test_config is not None and test_config.agent is not None: - config_override_inference_models = test_config.agent.fixtures.inference_models - config_override_safety_shield = test_config.agent.fixtures.safety_shield - custom_provider_fixtures = get_provider_fixtures_from_config( - test_config.agent.fixtures.provider_fixtures, DEFAULT_PROVIDER_COMBINATIONS - ) - - shield_id = config_override_safety_shield or metafunc.config.getoption( - "--safety-shield" - ) - inference_model = config_override_inference_models or [ + test_config = get_test_config_for_api(metafunc.config, "agents") + shield_id = getattr( + test_config, "safety_shield", None + ) or metafunc.config.getoption("--safety-shield") + inference_models = getattr(test_config, "inference_models", None) or [ metafunc.config.getoption("--inference-model") ] + if "safety_shield" in metafunc.fixturenames: metafunc.parametrize( "safety_shield", @@ -113,7 +102,7 @@ def pytest_generate_tests(metafunc): indirect=True, ) if "inference_model" in metafunc.fixturenames: - models = set(inference_model) + models = set(inference_models) if safety_model := safety_model_from_shield(shield_id): models.add(safety_model) @@ -131,7 +120,9 @@ def pytest_generate_tests(metafunc): "tool_runtime": TOOL_RUNTIME_FIXTURES, } combinations = ( - custom_provider_fixtures + get_provider_fixture_overrides_from_test_config( + metafunc.config, "agents", DEFAULT_PROVIDER_COMBINATIONS + ) or get_provider_fixture_overrides(metafunc.config, available_fixtures) or DEFAULT_PROVIDER_COMBINATIONS ) diff --git a/llama_stack/providers/tests/agents/test_persistence.py b/llama_stack/providers/tests/agents/test_persistence.py index bb6e280cb..e6b1470ef 100644 --- a/llama_stack/providers/tests/agents/test_persistence.py +++ b/llama_stack/providers/tests/agents/test_persistence.py @@ -9,7 +9,8 @@ import pytest from llama_stack.apis.agents import AgentConfig, Turn from llama_stack.apis.inference import SamplingParams, UserMessage from llama_stack.providers.datatypes import Api -from llama_stack.providers.utils.kvstore import kvstore_impl, SqliteKVStoreConfig +from llama_stack.providers.utils.kvstore import kvstore_impl +from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig from .fixtures import pick_inference_model diff --git a/llama_stack/providers/tests/ci_test_config.yaml b/llama_stack/providers/tests/ci_test_config.yaml index 22700da08..3edcd38bf 100644 --- a/llama_stack/providers/tests/ci_test_config.yaml +++ b/llama_stack/providers/tests/ci_test_config.yaml @@ -8,52 +8,48 @@ inference: - inference/test_text_inference.py::test_chat_completion_with_tool_calling - inference/test_text_inference.py::test_chat_completion_with_tool_calling_streaming - fixtures: - provider_fixtures: - - inference: ollama - - default_fixture_param_id: fireworks - - inference: together + scenarios: + - provider_fixtures: + inference: ollama + - fixture_combo_id: fireworks + - provider_fixtures: + inference: together # - inference: tgi # - inference: vllm_remote - inference_models: - - meta-llama/Llama-3.1-8B-Instruct - - meta-llama/Llama-3.2-11B-Vision-Instruct - safety_shield: ~ - embedding_model: ~ + inference_models: + - meta-llama/Llama-3.1-8B-Instruct + - meta-llama/Llama-3.2-11B-Vision-Instruct -agent: +agents: tests: - agents/test_agents.py::test_agent_turns_with_safety - agents/test_agents.py::test_rag_agent - fixtures: - provider_fixtures: - - default_fixture_param_id: ollama - - default_fixture_param_id: together - - default_fixture_param_id: fireworks + scenarios: + - fixture_combo_id: ollama + - fixture_combo_id: together + - fixture_combo_id: fireworks - safety_shield: ~ - embedding_model: ~ + inference_models: + - meta-llama/Llama-3.2-1B-Instruct - inference_models: - - meta-llama/Llama-3.2-1B-Instruct + safety_shield: meta-llama/Llama-Guard-3-1B memory: tests: - memory/test_memory.py::test_query_documents - fixtures: - provider_fixtures: - - default_fixture_param_id: ollama - - inference: sentence_transformers + scenarios: + - fixture_combo_id: ollama + - provider_fixtures: + inference: sentence_transformers memory: faiss - - default_fixture_param_id: chroma + - fixture_combo_id: chroma - inference_models: - - meta-llama/Llama-3.2-1B-Instruct + inference_models: + - meta-llama/Llama-3.2-1B-Instruct - safety_shield: ~ - embedding_model: ~ + embedding_model: all-MiniLM-L6-v2 diff --git a/llama_stack/providers/tests/conftest.py b/llama_stack/providers/tests/conftest.py index 6cc8a6772..281a688cc 100644 --- a/llama_stack/providers/tests/conftest.py +++ b/llama_stack/providers/tests/conftest.py @@ -28,28 +28,35 @@ class ProviderFixture(BaseModel): provider_data: Optional[Dict[str, Any]] = None -class Fixtures(BaseModel): +class TestScenario(BaseModel): # provider fixtures can be either a mark or a dictionary of api -> providers - provider_fixtures: List[Dict[str, str]] = Field(default_factory=list) - inference_models: List[str] = Field(default_factory=list) - safety_shield: Optional[str] = Field(default_factory=None) - embedding_model: Optional[str] = Field(default_factory=None) + provider_fixtures: Dict[str, str] = Field(default_factory=dict) + fixture_combo_id: Optional[str] = None class APITestConfig(BaseModel): - fixtures: Fixtures + scenarios: List[TestScenario] = Field(default_factory=list) + inference_models: List[str] = Field(default_factory=list) # test name format should be :: tests: List[str] = Field(default_factory=list) +class MemoryApiTestConfig(APITestConfig): + embedding_model: Optional[str] = Field(default_factory=None) + + +class AgentsApiTestConfig(APITestConfig): + safety_shield: Optional[str] = Field(default_factory=None) + + class TestConfig(BaseModel): - inference: APITestConfig - agent: Optional[APITestConfig] = Field(default=None) - memory: Optional[APITestConfig] = Field(default=None) + inference: APITestConfig = Field(default=None) + agents: AgentsApiTestConfig = Field(default=None) + memory: MemoryApiTestConfig = Field(default=None) -def try_load_config_file_cached(config): +def get_test_config_from_config_file(config): config_file = config.getoption("--config") if config_file is None: return None @@ -64,25 +71,38 @@ def try_load_config_file_cached(config): return TestConfig(**config) -def get_provider_fixtures_from_config( - provider_fixtures_config, default_fixture_combination +def get_test_config_for_api(config, api): + test_config = get_test_config_from_config_file(config) + if test_config is None: + return None + return getattr(test_config, api) + + +def get_provider_fixture_overrides_from_test_config( + config, api, default_provider_fixture_combination ): - custom_fixtures = [] - selected_default_param_id = set() - for fixture_config in provider_fixtures_config: - if "default_fixture_param_id" in fixture_config: - selected_default_param_id.add(fixture_config["default_fixture_param_id"]) + api_config = get_test_config_for_api(config, api) + if api_config is None: + return None + + fixture_combo_ids = set() + custom_provider_fixture_combos = [] + for scenario in api_config.scenarios: + if scenario.fixture_combo_id: + fixture_combo_ids.add(scenario.fixture_combo_id) else: - custom_fixtures.append( - pytest.param(fixture_config, id=fixture_config.get("inference") or "") + custom_provider_fixture_combos.append( + pytest.param( + scenario.provider_fixtures, + id=scenario.provider_fixtures.get("inference") or "", + ) ) - if len(selected_default_param_id) > 0: - for default_fixture in default_fixture_combination: - if default_fixture.id in selected_default_param_id: - custom_fixtures.append(default_fixture) - - return custom_fixtures + if len(fixture_combo_ids) > 0: + for default_fixture in default_provider_fixture_combination: + if default_fixture.id in fixture_combo_ids: + custom_provider_fixture_combos.append(default_fixture) + return custom_provider_fixture_combos def remote_stack_fixture() -> ProviderFixture: @@ -239,16 +259,19 @@ def pytest_itemcollected(item): def pytest_collection_modifyitems(session, config, items): - test_config = try_load_config_file_cached(config) + test_config = get_test_config_from_config_file(config) if test_config is None: return required_tests = defaultdict(set) - test_configs = [test_config.inference, test_config.memory, test_config.agent] - for test_config in test_configs: - if test_config is None: + for api_test_config in [ + test_config.inference, + test_config.memory, + test_config.agents, + ]: + if api_test_config is None: continue - for test in test_config.tests: + for test in api_test_config.tests: arr = test.split("::") if len(arr) != 2: raise ValueError(f"Invalid format for test name {test}") diff --git a/llama_stack/providers/tests/inference/conftest.py b/llama_stack/providers/tests/inference/conftest.py index 1343459e9..1303a1b35 100644 --- a/llama_stack/providers/tests/inference/conftest.py +++ b/llama_stack/providers/tests/inference/conftest.py @@ -6,7 +6,7 @@ import pytest -from ..conftest import get_provider_fixture_overrides, try_load_config_file_cached +from ..conftest import get_provider_fixture_overrides, get_test_config_for_api from .fixtures import INFERENCE_FIXTURES @@ -42,43 +42,43 @@ VISION_MODEL_PARAMS = [ def pytest_generate_tests(metafunc): - test_config = try_load_config_file_cached(metafunc.config) + test_config = get_test_config_for_api(metafunc.config, "inference") + if "inference_model" in metafunc.fixturenames: cls_name = metafunc.cls.__name__ - if test_config is not None: - params = [] - for model in test_config.inference.fixtures.inference_models: - if ("Vision" in cls_name and "Vision" in model) or ( - "Vision" not in cls_name and "Vision" not in model - ): - params.append(pytest.param(model, id=model)) - else: + params = [] + inference_models = getattr(test_config, "inference_models", []) + for model in inference_models: + if ("Vision" in cls_name and "Vision" in model) or ( + "Vision" not in cls_name and "Vision" not in model + ): + params.append(pytest.param(model, id=model)) + + if not params: model = metafunc.config.getoption("--inference-model") - if model: - params = [pytest.param(model, id="")] - else: - if "Vision" in cls_name: - params = VISION_MODEL_PARAMS - else: - params = MODEL_PARAMS + params = [pytest.param(model, id="")] + metafunc.parametrize( "inference_model", params, indirect=True, ) if "inference_stack" in metafunc.fixturenames: - if test_config is not None: - fixtures = [ - (f.get("inference") or f.get("default_fixture_param_id")) - for f in test_config.inference.fixtures.provider_fixtures - ] - elif filtered_stacks := get_provider_fixture_overrides( + fixtures = INFERENCE_FIXTURES + if filtered_stacks := get_provider_fixture_overrides( metafunc.config, { "inference": INFERENCE_FIXTURES, }, ): fixtures = [stack.values[0]["inference"] for stack in filtered_stacks] - else: - fixtures = INFERENCE_FIXTURES + if test_config: + if custom_fixtures := [ + ( + scenario.fixture_combo_id + or scenario.provider_fixtures.get("inference") + ) + for scenario in test_config.scenarios + ]: + fixtures = custom_fixtures metafunc.parametrize("inference_stack", fixtures, indirect=True) diff --git a/llama_stack/providers/tests/memory/conftest.py b/llama_stack/providers/tests/memory/conftest.py index 99fdb7715..87dec4beb 100644 --- a/llama_stack/providers/tests/memory/conftest.py +++ b/llama_stack/providers/tests/memory/conftest.py @@ -8,8 +8,8 @@ import pytest from ..conftest import ( get_provider_fixture_overrides, - get_provider_fixtures_from_config, - try_load_config_file_cached, + get_provider_fixture_overrides_from_test_config, + get_test_config_for_api, ) from ..inference.fixtures import INFERENCE_FIXTURES @@ -69,21 +69,11 @@ def pytest_configure(config): def pytest_generate_tests(metafunc): - test_config = try_load_config_file_cached(metafunc.config) - provider_fixtures_config = ( - test_config.memory.fixtures.provider_fixtures - if test_config is not None and test_config.memory is not None - else None - ) - custom_fixtures = ( - get_provider_fixtures_from_config( - provider_fixtures_config, DEFAULT_PROVIDER_COMBINATIONS - ) - if provider_fixtures_config is not None - else None - ) + test_config = get_test_config_for_api(metafunc.config, "memory") if "embedding_model" in metafunc.fixturenames: - model = metafunc.config.getoption("--embedding-model") + model = getattr(test_config, "embedding_model", None) + # Fall back to the default if not specified by the config file + model = model or metafunc.config.getoption("--embedding-model") if model: params = [pytest.param(model, id="")] else: @@ -97,7 +87,9 @@ def pytest_generate_tests(metafunc): "memory": MEMORY_FIXTURES, } combinations = ( - custom_fixtures + get_provider_fixture_overrides_from_test_config( + metafunc.config, "memory", DEFAULT_PROVIDER_COMBINATIONS + ) or get_provider_fixture_overrides(metafunc.config, available_fixtures) or DEFAULT_PROVIDER_COMBINATIONS ) diff --git a/llama_stack/providers/tests/post_training/test_post_training.py b/llama_stack/providers/tests/post_training/test_post_training.py index 0645cd555..0c58c1fa0 100644 --- a/llama_stack/providers/tests/post_training/test_post_training.py +++ b/llama_stack/providers/tests/post_training/test_post_training.py @@ -5,7 +5,7 @@ # the root directory of this source tree. import pytest -from llama_stack.apis.common.type_system import JobStatus +from llama_stack.apis.common.job_types import JobStatus from llama_stack.apis.post_training import ( Checkpoint, DataConfig,