diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 4e462c54b..68f5ef8e8 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -126,8 +126,13 @@ class CommonRoutingTableImpl(RoutingTable): if routing_key not in self.registry: apiname, objname = apiname_object() + provider_ids = list(self.impls_by_provider_id.keys()) + if len(provider_ids) > 1: + provider_ids_str = f"any of the providers: {', '.join(provider_ids)}" + else: + provider_ids_str = f"provider: `{provider_ids[0]}`" raise ValueError( - f"`{routing_key}` not registered. Make sure there is an {apiname} provider serving this {objname}." + f"`{routing_key}` not served by {provider_ids_str}. Make sure there is an {apiname} provider serving this {objname}." ) objs = self.registry[routing_key] diff --git a/llama_stack/providers/tests/agents/conftest.py b/llama_stack/providers/tests/agents/conftest.py index 1eb23ef6d..4eaf886ab 100644 --- a/llama_stack/providers/tests/agents/conftest.py +++ b/llama_stack/providers/tests/agents/conftest.py @@ -57,7 +57,34 @@ def pytest_configure(config): ) +def pytest_addoption(parser): + parser.addoption( + "--inference-model", + action="store", + default="Llama3.1-8B-Instruct", + help="Specify the inference model to use for testing", + ) + parser.addoption( + "--safety-model", + action="store", + default="Llama-Guard-3-8B", + help="Specify the safety model to use for testing", + ) + + def pytest_generate_tests(metafunc): + if "inference_model" in metafunc.fixturenames: + metafunc.parametrize( + "inference_model", + [pytest.param(metafunc.config.getoption("--inference-model"), id="")], + indirect=True, + ) + if "safety_model" in metafunc.fixturenames: + metafunc.parametrize( + "safety_model", + [pytest.param(metafunc.config.getoption("--safety-model"), id="")], + indirect=True, + ) if "agents_stack" in metafunc.fixturenames: available_fixtures = { "inference": INFERENCE_FIXTURES, diff --git a/llama_stack/providers/tests/agents/fixtures.py b/llama_stack/providers/tests/agents/fixtures.py index 5597f47e9..03bbc475a 100644 --- a/llama_stack/providers/tests/agents/fixtures.py +++ b/llama_stack/providers/tests/agents/fixtures.py @@ -22,7 +22,7 @@ from ..conftest import ProviderFixture @pytest.fixture(scope="session") -def agents_meta_reference(inference_model, safety_model) -> ProviderFixture: +def agents_meta_reference() -> ProviderFixture: sqlite_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db") return ProviderFixture( provider=Provider( @@ -42,12 +42,12 @@ AGENTS_FIXTURES = ["meta_reference"] @pytest_asyncio.fixture(scope="session") -async def agents_stack(inference_model, safety_model, request): +async def agents_stack(request): fixture_dict = request.param providers = {} provider_data = {} - for key in ["agents", "inference", "safety", "memory"]: + for key in ["inference", "safety", "memory", "agents"]: fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}") providers[key] = [fixture.provider] if fixture.provider_data: diff --git a/llama_stack/providers/tests/agents/test_agents.py b/llama_stack/providers/tests/agents/test_agents.py index 2d696e4b8..989328409 100644 --- a/llama_stack/providers/tests/agents/test_agents.py +++ b/llama_stack/providers/tests/agents/test_agents.py @@ -75,12 +75,6 @@ async def create_agent_session(agents_impl, agent_config): return agent_id, session_id -@pytest.mark.parametrize( - "inference_model", [pytest.param("Llama3.1-8B-Instruct", id="")], indirect=True -) -@pytest.mark.parametrize( - "safety_model", [pytest.param("Llama-Guard-3-8B", id="")], indirect=True -) class TestAgents: @pytest.mark.asyncio async def test_agent_turns_with_safety(self, agents_stack, common_params): diff --git a/llama_stack/providers/tests/inference/conftest.py b/llama_stack/providers/tests/inference/conftest.py index 34b2ce267..71253871d 100644 --- a/llama_stack/providers/tests/inference/conftest.py +++ b/llama_stack/providers/tests/inference/conftest.py @@ -4,9 +4,20 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import pytest + from .fixtures import INFERENCE_FIXTURES +def pytest_addoption(parser): + parser.addoption( + "--inference-model", + action="store", + default=None, + help="Specify the inference model to use for testing", + ) + + def pytest_configure(config): config.addinivalue_line( "markers", "llama_8b: mark test to run only with the given model" @@ -19,3 +30,33 @@ def pytest_configure(config): "markers", f"{fixture_name}: marks tests as {fixture_name} specific", ) + + +MODEL_PARAMS = [ + pytest.param("Llama3.1-8B-Instruct", marks=pytest.mark.llama_8b, id="llama_8b"), + pytest.param("Llama3.2-3B-Instruct", marks=pytest.mark.llama_3b, id="llama_3b"), +] + + +def pytest_generate_tests(metafunc): + if "inference_model" in metafunc.fixturenames: + model = metafunc.config.getoption("--inference-model") + if model: + params = [pytest.param(model, id="")] + else: + params = MODEL_PARAMS + + metafunc.parametrize( + "inference_model", + params, + indirect=True, + ) + if "inference_stack" in metafunc.fixturenames: + metafunc.parametrize( + "inference_stack", + [ + pytest.param(fixture_name, marks=getattr(pytest.mark, fixture_name)) + for fixture_name in INFERENCE_FIXTURES + ], + indirect=True, + ) diff --git a/llama_stack/providers/tests/inference/fixtures.py b/llama_stack/providers/tests/inference/fixtures.py index b5a8d1ad0..e53034370 100644 --- a/llama_stack/providers/tests/inference/fixtures.py +++ b/llama_stack/providers/tests/inference/fixtures.py @@ -22,19 +22,14 @@ from ..conftest import ProviderFixture from ..env import get_env_or_fail -MODEL_PARAMS = [ - pytest.param("Llama3.1-8B-Instruct", marks=pytest.mark.llama_8b, id="llama_8b"), - pytest.param("Llama3.2-3B-Instruct", marks=pytest.mark.llama_3b, id="llama_3b"), -] - - -@pytest.fixture(scope="session", params=MODEL_PARAMS) +@pytest.fixture(scope="session") def inference_model(request): - return request.param + return request.config.getoption("--inference-model", None) or request.param @pytest.fixture(scope="session") def inference_meta_reference(inference_model) -> ProviderFixture: + # TODO: make this work with multiple models return ProviderFixture( provider=Provider( provider_id="meta-reference", @@ -66,7 +61,7 @@ def inference_ollama(inference_model) -> ProviderFixture: @pytest.fixture(scope="session") -def inference_fireworks(inference_model) -> ProviderFixture: +def inference_fireworks() -> ProviderFixture: return ProviderFixture( provider=Provider( provider_id="fireworks", @@ -79,7 +74,7 @@ def inference_fireworks(inference_model) -> ProviderFixture: @pytest.fixture(scope="session") -def inference_together(inference_model) -> ProviderFixture: +def inference_together() -> ProviderFixture: return ProviderFixture( provider=Provider( provider_id="together", @@ -95,7 +90,7 @@ def inference_together(inference_model) -> ProviderFixture: INFERENCE_FIXTURES = ["meta_reference", "ollama", "fireworks", "together"] -@pytest_asyncio.fixture(scope="session", params=INFERENCE_FIXTURES) +@pytest_asyncio.fixture(scope="session") async def inference_stack(request): fixture_name = request.param inference_fixture = request.getfixturevalue(f"inference_{fixture_name}") diff --git a/llama_stack/providers/tests/inference/test_inference.py b/llama_stack/providers/tests/inference/test_inference.py index 4051cea69..29fdc43a4 100644 --- a/llama_stack/providers/tests/inference/test_inference.py +++ b/llama_stack/providers/tests/inference/test_inference.py @@ -14,7 +14,6 @@ from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403 from llama_stack.distribution.datatypes import * # noqa: F403 -from .fixtures import INFERENCE_FIXTURES, MODEL_PARAMS # How to run this test: # @@ -70,15 +69,6 @@ def sample_tool_definition(): ) -@pytest.mark.parametrize("inference_model", MODEL_PARAMS, indirect=True) -@pytest.mark.parametrize( - "inference_stack", - [ - pytest.param(fixture_name, marks=getattr(pytest.mark, fixture_name)) - for fixture_name in INFERENCE_FIXTURES - ], - indirect=True, -) class TestInference: @pytest.mark.asyncio async def test_model_list(self, inference_model, inference_stack): diff --git a/llama_stack/providers/tests/memory/conftest.py b/llama_stack/providers/tests/memory/conftest.py index c5057ecb4..99ecbe794 100644 --- a/llama_stack/providers/tests/memory/conftest.py +++ b/llama_stack/providers/tests/memory/conftest.py @@ -4,6 +4,8 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import pytest + from .fixtures import MEMORY_FIXTURES @@ -13,3 +15,15 @@ def pytest_configure(config): "markers", f"{fixture_name}: marks tests as {fixture_name} specific", ) + + +def pytest_generate_tests(metafunc): + if "memory_stack" in metafunc.fixturenames: + metafunc.parametrize( + "memory_stack", + [ + pytest.param(fixture_name, marks=getattr(pytest.mark, fixture_name)) + for fixture_name in MEMORY_FIXTURES + ], + indirect=True, + ) diff --git a/llama_stack/providers/tests/memory/fixtures.py b/llama_stack/providers/tests/memory/fixtures.py index 1f1050df4..11ed121c7 100644 --- a/llama_stack/providers/tests/memory/fixtures.py +++ b/llama_stack/providers/tests/memory/fixtures.py @@ -64,16 +64,8 @@ def memory_weaviate() -> ProviderFixture: MEMORY_FIXTURES = ["meta_reference", "pgvector", "weaviate"] -PROVIDER_PARAMS = [ - pytest.param(fixture_name, marks=getattr(pytest.mark, fixture_name)) - for fixture_name in MEMORY_FIXTURES -] - -@pytest_asyncio.fixture( - scope="session", - params=PROVIDER_PARAMS, -) +@pytest_asyncio.fixture(scope="session") async def memory_stack(request): fixture_name = request.param fixture = request.getfixturevalue(f"memory_{fixture_name}") diff --git a/llama_stack/providers/tests/memory/test_memory.py b/llama_stack/providers/tests/memory/test_memory.py index a948fa17e..ee3110dea 100644 --- a/llama_stack/providers/tests/memory/test_memory.py +++ b/llama_stack/providers/tests/memory/test_memory.py @@ -8,7 +8,6 @@ import pytest from llama_stack.apis.memory import * # noqa: F403 from llama_stack.distribution.datatypes import * # noqa: F403 -from .fixtures import PROVIDER_PARAMS # How to run this test: # @@ -54,11 +53,6 @@ async def register_memory_bank(banks_impl: MemoryBanks): await banks_impl.register_memory_bank(bank) -@pytest.mark.parametrize( - "memory_stack", - PROVIDER_PARAMS, - indirect=True, -) class TestMemory: @pytest.mark.asyncio async def test_banks_list(self, memory_stack): diff --git a/llama_stack/providers/tests/safety/conftest.py b/llama_stack/providers/tests/safety/conftest.py index 25f13b1a4..d64a2a1ee 100644 --- a/llama_stack/providers/tests/safety/conftest.py +++ b/llama_stack/providers/tests/safety/conftest.py @@ -48,11 +48,38 @@ def pytest_configure(config): ) +def pytest_addoption(parser): + parser.addoption( + "--safety-model", + action="store", + default=None, + help="Specify the safety model to use for testing", + ) + + +SAFETY_MODEL_PARAMS = [ + pytest.param("Llama-Guard-3-1B", marks=pytest.mark.guard_1b, id="guard_1b"), +] + + def pytest_generate_tests(metafunc): # We use this method to make sure we have built-in simple combos for safety tests # But a user can also pass in a custom combination via the CLI by doing # `--providers inference=together,safety=meta_reference` + if "safety_model" in metafunc.fixturenames: + model = metafunc.config.getoption("--safety-model") + if model: + params = [pytest.param(model, id="")] + else: + params = SAFETY_MODEL_PARAMS + for fixture in ["inference_model", "safety_model"]: + metafunc.parametrize( + fixture, + params, + indirect=True, + ) + if "safety_stack" in metafunc.fixturenames: # print(f"metafunc.fixturenames: {metafunc.fixturenames}, {metafunc}") available_fixtures = { diff --git a/llama_stack/providers/tests/safety/fixtures.py b/llama_stack/providers/tests/safety/fixtures.py index cf23e032b..17c5262de 100644 --- a/llama_stack/providers/tests/safety/fixtures.py +++ b/llama_stack/providers/tests/safety/fixtures.py @@ -20,14 +20,9 @@ from ..conftest import ProviderFixture from ..env import get_env_or_fail -SAFETY_MODEL_PARAMS = [ - pytest.param("Llama-Guard-3-1B", marks=pytest.mark.guard_1b, id="guard_1b"), -] - - -@pytest.fixture(scope="session", params=SAFETY_MODEL_PARAMS) +@pytest.fixture(scope="session") def safety_model(request): - return request.param + return request.config.getoption("--safety-model", None) or request.param @pytest.fixture(scope="session") diff --git a/llama_stack/providers/tests/safety/test_safety.py b/llama_stack/providers/tests/safety/test_safety.py index e355d5908..ddf472737 100644 --- a/llama_stack/providers/tests/safety/test_safety.py +++ b/llama_stack/providers/tests/safety/test_safety.py @@ -17,14 +17,6 @@ from llama_stack.distribution.datatypes import * # noqa: F403 # -m "ollama" -@pytest.mark.parametrize( - "inference_model", [pytest.param("Llama-Guard-3-1B", id="")], indirect=True -) -@pytest.mark.parametrize( - "safety_model", - [pytest.param("Llama-Guard-3-1B", id="guard_3_1b")], - indirect=True, -) class TestSafety: @pytest.mark.asyncio async def test_shield_list(self, safety_stack):