diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index e02606936..889bd4624 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -246,7 +246,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): provider_id = list(self.impls_by_provider_id.keys())[0] else: raise ValueError( - "No provider specified and multiple providers available. Please specify a provider_id. Available providers: {self.impls_by_provider_id.keys()}" + f"No provider specified and multiple providers available. Please specify a provider_id. Available providers: {self.impls_by_provider_id.keys()}" ) if metadata is None: metadata = {} diff --git a/llama_stack/providers/tests/agents/conftest.py b/llama_stack/providers/tests/agents/conftest.py index da0b93557..347c317dc 100644 --- a/llama_stack/providers/tests/agents/conftest.py +++ b/llama_stack/providers/tests/agents/conftest.py @@ -10,6 +10,10 @@ from ..conftest import get_provider_fixture_overrides from ..inference.fixtures import INFERENCE_FIXTURES from ..memory.fixtures import MEMORY_FIXTURES from ..safety.fixtures import SAFETY_FIXTURES, safety_model_from_shield +from ..test_config_helper import ( + get_provider_fixtures_from_config, + try_load_config_file_cached, +) from ..tools.fixtures import TOOL_RUNTIME_FIXTURES from .fixtures import AGENTS_FIXTURES @@ -82,7 +86,25 @@ def pytest_configure(config): def pytest_generate_tests(metafunc): - shield_id = metafunc.config.getoption("--safety-shield") + test_config = try_load_config_file_cached(metafunc.config.getoption("config")) + ( + config_override_inference_models, + config_override_safety_shield, + custom_provider_fixtures, + ) = (None, None, None) + if test_config is not None: + config_override_inference_models = test_config.agent.fixtures.inference_models + config_override_safety_shield = test_config.agent.fixtures.safety_shield + custom_provider_fixtures = get_provider_fixtures_from_config( + test_config.agent.fixtures.provider_fixtures, DEFAULT_PROVIDER_COMBINATIONS + ) + + shield_id = config_override_safety_shield or metafunc.config.getoption( + "--safety-shield" + ) + inference_model = config_override_inference_models or [ + metafunc.config.getoption("--inference-model") + ] if "safety_shield" in metafunc.fixturenames: metafunc.parametrize( "safety_shield", @@ -90,8 +112,7 @@ def pytest_generate_tests(metafunc): indirect=True, ) if "inference_model" in metafunc.fixturenames: - inference_model = metafunc.config.getoption("--inference-model") - models = set({inference_model}) + models = set(inference_model) if safety_model := safety_model_from_shield(shield_id): models.add(safety_model) @@ -109,7 +130,8 @@ def pytest_generate_tests(metafunc): "tool_runtime": TOOL_RUNTIME_FIXTURES, } combinations = ( - get_provider_fixture_overrides(metafunc.config, available_fixtures) + custom_provider_fixtures + or get_provider_fixture_overrides(metafunc.config, available_fixtures) or DEFAULT_PROVIDER_COMBINATIONS ) metafunc.parametrize("agents_stack", combinations, indirect=True) diff --git a/llama_stack/providers/tests/ci_test_config.yaml b/llama_stack/providers/tests/ci_test_config.yaml index c4c6b2319..22700da08 100644 --- a/llama_stack/providers/tests/ci_test_config.yaml +++ b/llama_stack/providers/tests/ci_test_config.yaml @@ -1,24 +1,59 @@ -tests: -- path: inference/test_vision_inference.py - functions: - - test_vision_chat_completion_streaming - - test_vision_chat_completion_non_streaming +inference: + tests: + - inference/test_vision_inference.py::test_vision_chat_completion_streaming + - inference/test_vision_inference.py::test_vision_chat_completion_non_streaming + - inference/test_text_inference.py::test_structured_output + - inference/test_text_inference.py::test_chat_completion_streaming + - inference/test_text_inference.py::test_chat_completion_non_streaming + - inference/test_text_inference.py::test_chat_completion_with_tool_calling + - inference/test_text_inference.py::test_chat_completion_with_tool_calling_streaming -- path: inference/test_text_inference.py - functions: - - test_structured_output - - test_chat_completion_streaming - - test_chat_completion_non_streaming - - test_chat_completion_with_tool_calling - - test_chat_completion_with_tool_calling_streaming + fixtures: + provider_fixtures: + - inference: ollama + - default_fixture_param_id: fireworks + - inference: together + # - inference: tgi + # - inference: vllm_remote + inference_models: + - meta-llama/Llama-3.1-8B-Instruct + - meta-llama/Llama-3.2-11B-Vision-Instruct -inference_fixtures: - - ollama - - fireworks - - together - - tgi - - vllm_remote + safety_shield: ~ + embedding_model: ~ -test_models: - text: meta-llama/Llama-3.1-8B-Instruct - vision: meta-llama/Llama-3.2-11B-Vision-Instruct + +agent: + tests: + - agents/test_agents.py::test_agent_turns_with_safety + - agents/test_agents.py::test_rag_agent + + fixtures: + provider_fixtures: + - default_fixture_param_id: ollama + - default_fixture_param_id: together + - default_fixture_param_id: fireworks + + safety_shield: ~ + embedding_model: ~ + + inference_models: + - meta-llama/Llama-3.2-1B-Instruct + + +memory: + tests: + - memory/test_memory.py::test_query_documents + + fixtures: + provider_fixtures: + - default_fixture_param_id: ollama + - inference: sentence_transformers + memory: faiss + - default_fixture_param_id: chroma + + inference_models: + - meta-llama/Llama-3.2-1B-Instruct + + safety_shield: ~ + embedding_model: ~ diff --git a/llama_stack/providers/tests/conftest.py b/llama_stack/providers/tests/conftest.py index dbdbad033..14c34d14e 100644 --- a/llama_stack/providers/tests/conftest.py +++ b/llama_stack/providers/tests/conftest.py @@ -5,12 +5,12 @@ # the root directory of this source tree. import os +from collections import defaultdict from pathlib import Path from typing import Any, Dict, List, Optional import pytest -import yaml from dotenv import load_dotenv from pydantic import BaseModel from termcolor import colored @@ -20,6 +20,8 @@ from llama_stack.providers.datatypes import RemoteProviderConfig from .env import get_env_or_fail +from .test_config_helper import try_load_config_file_cached + class ProviderFixture(BaseModel): providers: List[Provider] @@ -180,34 +182,26 @@ def pytest_itemcollected(item): def pytest_collection_modifyitems(session, config, items): - if config.getoption("--config") is None: + test_config = try_load_config_file_cached(config.getoption("--config")) + if test_config is None: return - file_name = config.getoption("--config") - config_file_path = Path(__file__).parent / file_name - if not config_file_path.exists(): - raise ValueError( - f"Test config {file_name} was specified but not found. Please make sure it exists in the llama_stack/providers/tests directory." - ) - required_tests = dict() - inference_providers = set() - with open(config_file_path, "r") as config_file: - test_config = yaml.safe_load(config_file) - for test in test_config["tests"]: - required_tests[Path(__file__).parent / test["path"]] = set( - test["functions"] - ) - inference_providers = set(test_config["inference_fixtures"]) + required_tests = defaultdict(set) + test_configs = [test_config.inference, test_config.memory, test_config.agent] + for test_config in test_configs: + for test in test_config.tests: + arr = test.split("::") + if len(arr) != 2: + raise ValueError(f"Invalid format for test name {test}") + test_path, func_name = arr + required_tests[Path(__file__).parent / test_path].add(func_name) new_items, deselected_items = [], [] for item in items: - if item.fspath in required_tests: - func_name = getattr(item, "originalname", item.name) - if func_name in required_tests[item.fspath]: - inference = item.callspec.params.get("inference_stack") - if inference in inference_providers: - new_items.append(item) - continue + func_name = getattr(item, "originalname", item.name) + if func_name in required_tests[item.fspath]: + new_items.append(item) + continue deselected_items.append(item) items[:] = new_items diff --git a/llama_stack/providers/tests/inference/conftest.py b/llama_stack/providers/tests/inference/conftest.py index 88568dce9..fca4f7544 100644 --- a/llama_stack/providers/tests/inference/conftest.py +++ b/llama_stack/providers/tests/inference/conftest.py @@ -7,7 +7,7 @@ import pytest from ..conftest import get_provider_fixture_overrides - +from ..test_config_helper import try_load_config_file_cached from .fixtures import INFERENCE_FIXTURES @@ -43,29 +43,43 @@ VISION_MODEL_PARAMS = [ def pytest_generate_tests(metafunc): + test_config = try_load_config_file_cached(metafunc.config.getoption("config")) if "inference_model" in metafunc.fixturenames: - model = metafunc.config.getoption("--inference-model") - if model: - params = [pytest.param(model, id="")] + cls_name = metafunc.cls.__name__ + if test_config is not None: + params = [] + for model in test_config.inference.fixtures.inference_models: + if ("Vision" in cls_name and "Vision" in model) or ( + "Vision" not in cls_name and "Vision" not in model + ): + params.append(pytest.param(model, id=model)) else: - cls_name = metafunc.cls.__name__ - if "Vision" in cls_name: - params = VISION_MODEL_PARAMS + model = metafunc.config.getoption("--inference-model") + if model: + params = [pytest.param(model, id="")] else: - params = MODEL_PARAMS - + if "Vision" in cls_name: + params = VISION_MODEL_PARAMS + else: + params = MODEL_PARAMS metafunc.parametrize( "inference_model", params, indirect=True, ) if "inference_stack" in metafunc.fixturenames: - fixtures = INFERENCE_FIXTURES - if filtered_stacks := get_provider_fixture_overrides( + if test_config is not None: + fixtures = [ + (f.get("inference") or f.get("default_fixture_param_id")) + for f in test_config.inference.fixtures.provider_fixtures + ] + elif filtered_stacks := get_provider_fixture_overrides( metafunc.config, { "inference": INFERENCE_FIXTURES, }, ): fixtures = [stack.values[0]["inference"] for stack in filtered_stacks] + else: + fixtures = INFERENCE_FIXTURES metafunc.parametrize("inference_stack", fixtures, indirect=True) diff --git a/llama_stack/providers/tests/inference/fixtures.py b/llama_stack/providers/tests/inference/fixtures.py index b6653b65d..0767e940f 100644 --- a/llama_stack/providers/tests/inference/fixtures.py +++ b/llama_stack/providers/tests/inference/fixtures.py @@ -301,6 +301,7 @@ async def inference_stack(request, inference_model): inference_fixture.provider_data, models=[ ModelInput( + provider_id=inference_fixture.providers[0].provider_id, model_id=inference_model, model_type=model_type, metadata=metadata, diff --git a/llama_stack/providers/tests/memory/conftest.py b/llama_stack/providers/tests/memory/conftest.py index afe694b70..fef09319a 100644 --- a/llama_stack/providers/tests/memory/conftest.py +++ b/llama_stack/providers/tests/memory/conftest.py @@ -9,6 +9,10 @@ import pytest from ..conftest import get_provider_fixture_overrides from ..inference.fixtures import INFERENCE_FIXTURES +from ..test_config_helper import ( + get_provider_fixtures_from_config, + try_load_config_file_cached, +) from .fixtures import MEMORY_FIXTURES @@ -65,6 +69,15 @@ def pytest_configure(config): def pytest_generate_tests(metafunc): + test_config = try_load_config_file_cached(metafunc.config.getoption("config")) + provider_fixtures_config = ( + test_config.memory.fixtures.provider_fixtures + if test_config is not None + else None + ) + custom_fixtures = get_provider_fixtures_from_config( + provider_fixtures_config, DEFAULT_PROVIDER_COMBINATIONS + ) if "embedding_model" in metafunc.fixturenames: model = metafunc.config.getoption("--embedding-model") if model: @@ -80,7 +93,8 @@ def pytest_generate_tests(metafunc): "memory": MEMORY_FIXTURES, } combinations = ( - get_provider_fixture_overrides(metafunc.config, available_fixtures) + custom_fixtures + or get_provider_fixture_overrides(metafunc.config, available_fixtures) or DEFAULT_PROVIDER_COMBINATIONS ) metafunc.parametrize("memory_stack", combinations, indirect=True) diff --git a/llama_stack/providers/tests/test_config_helper.py b/llama_stack/providers/tests/test_config_helper.py new file mode 100644 index 000000000..fd51953ba --- /dev/null +++ b/llama_stack/providers/tests/test_config_helper.py @@ -0,0 +1,76 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, List, Optional + +import pytest +import yaml +from pydantic import BaseModel, Field + + +@dataclass +class APITestConfig(BaseModel): + + class Fixtures(BaseModel): + # provider fixtures can be either a mark or a dictionary of api -> providers + provider_fixtures: List[Dict[str, str]] + inference_models: List[str] = Field(default_factory=list) + safety_shield: Optional[str] + embedding_model: Optional[str] + + fixtures: Fixtures + tests: List[str] = Field(default_factory=list) + + # test name format should be :: + + +class TestConfig(BaseModel): + + inference: APITestConfig + agent: APITestConfig + memory: APITestConfig + + +CONFIG_CACHE = None + + +def try_load_config_file_cached(config_file): + if config_file is None: + return None + if CONFIG_CACHE is not None: + return CONFIG_CACHE + + config_file_path = Path(__file__).parent / config_file + if not config_file_path.exists(): + raise ValueError( + f"Test config {config_file} was specified but not found. Please make sure it exists in the llama_stack/providers/tests directory." + ) + with open(config_file_path, "r") as config_file: + config = yaml.safe_load(config_file) + return TestConfig(**config) + + +def get_provider_fixtures_from_config( + provider_fixtures_config, default_fixture_combination +): + custom_fixtures = [] + selected_default_param_id = set() + for fixture_config in provider_fixtures_config: + if "default_fixture_param_id" in fixture_config: + selected_default_param_id.add(fixture_config["default_fixture_param_id"]) + else: + custom_fixtures.append( + pytest.param(fixture_config, id=fixture_config.get("inference") or "") + ) + + if len(selected_default_param_id) > 0: + for default_fixture in default_fixture_combination: + if default_fixture.id in selected_default_param_id: + custom_fixtures.append(default_fixture) + + return custom_fixtures