From 6c7ea6e904cc2e3a655c62ace081e497b3f65d88 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Mon, 4 Nov 2024 16:33:42 -0800 Subject: [PATCH] get multiple providers working for meta-reference (for inference + safety) --- .../distribution/routers/routing_tables.py | 2 +- .../providers/tests/agents/conftest.py | 18 ++-- .../providers/tests/agents/fixtures.py | 24 +++--- .../providers/tests/agents/test_agents.py | 7 ++ llama_stack/providers/tests/conftest.py | 2 +- .../providers/tests/inference/fixtures.py | 85 +++++++++++-------- .../providers/tests/memory/fixtures.py | 50 ++++++----- llama_stack/providers/tests/resolver.py | 2 +- .../providers/tests/safety/conftest.py | 1 - .../providers/tests/safety/fixtures.py | 40 +++++---- 10 files changed, 136 insertions(+), 95 deletions(-) diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 68f5ef8e8..ba3814123 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -132,7 +132,7 @@ class CommonRoutingTableImpl(RoutingTable): else: provider_ids_str = f"provider: `{provider_ids[0]}`" raise ValueError( - f"`{routing_key}` not served by {provider_ids_str}. Make sure there is an {apiname} provider serving this {objname}." + f"{objname.capitalize()} `{routing_key}` not served by {provider_ids_str}. Make sure there is an {apiname} provider serving this {objname}." ) objs = self.registry[routing_key] diff --git a/llama_stack/providers/tests/agents/conftest.py b/llama_stack/providers/tests/agents/conftest.py index 4eaf886ab..332efeed8 100644 --- a/llama_stack/providers/tests/agents/conftest.py +++ b/llama_stack/providers/tests/agents/conftest.py @@ -73,16 +73,20 @@ def pytest_addoption(parser): def pytest_generate_tests(metafunc): - if "inference_model" in metafunc.fixturenames: - metafunc.parametrize( - "inference_model", - [pytest.param(metafunc.config.getoption("--inference-model"), id="")], - indirect=True, - ) + safety_model = metafunc.config.getoption("--safety-model") if "safety_model" in metafunc.fixturenames: metafunc.parametrize( "safety_model", - [pytest.param(metafunc.config.getoption("--safety-model"), id="")], + [pytest.param(safety_model, id="")], + indirect=True, + ) + if "inference_model" in metafunc.fixturenames: + inference_model = metafunc.config.getoption("--inference-model") + models = list(set({inference_model, safety_model})) + + metafunc.parametrize( + "inference_model", + [pytest.param(models, id="")], indirect=True, ) if "agents_stack" in metafunc.fixturenames: diff --git a/llama_stack/providers/tests/agents/fixtures.py b/llama_stack/providers/tests/agents/fixtures.py index 03bbc475a..c667712a7 100644 --- a/llama_stack/providers/tests/agents/fixtures.py +++ b/llama_stack/providers/tests/agents/fixtures.py @@ -25,16 +25,18 @@ from ..conftest import ProviderFixture def agents_meta_reference() -> ProviderFixture: sqlite_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db") return ProviderFixture( - provider=Provider( - provider_id="meta-reference", - provider_type="meta-reference", - config=MetaReferenceAgentsImplConfig( - # TODO: make this an in-memory store - persistence_store=SqliteKVStoreConfig( - db_path=sqlite_file.name, - ), - ).model_dump(), - ), + providers=[ + Provider( + provider_id="meta-reference", + provider_type="meta-reference", + config=MetaReferenceAgentsImplConfig( + # TODO: make this an in-memory store + persistence_store=SqliteKVStoreConfig( + db_path=sqlite_file.name, + ), + ).model_dump(), + ) + ], ) @@ -49,7 +51,7 @@ async def agents_stack(request): provider_data = {} for key in ["inference", "safety", "memory", "agents"]: fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}") - providers[key] = [fixture.provider] + providers[key] = fixture.providers if fixture.provider_data: provider_data.update(fixture.provider_data) diff --git a/llama_stack/providers/tests/agents/test_agents.py b/llama_stack/providers/tests/agents/test_agents.py index 989328409..54c10a42d 100644 --- a/llama_stack/providers/tests/agents/test_agents.py +++ b/llama_stack/providers/tests/agents/test_agents.py @@ -19,6 +19,13 @@ from llama_stack.providers.datatypes import * # noqa: F403 @pytest.fixture def common_params(inference_model): + # This is not entirely satisfactory. The fixture `inference_model` can correspond to + # multiple models when you need to run a safety model in addition to normal agent + # inference model. We filter off the safety model by looking for "Llama-Guard" + if isinstance(inference_model, list): + inference_model = next(m for m in inference_model if "Llama-Guard" not in m) + assert inference_model is not None + return dict( model=inference_model, instructions="You are a helpful assistant.", diff --git a/llama_stack/providers/tests/conftest.py b/llama_stack/providers/tests/conftest.py index 6861a29fd..9fdf94582 100644 --- a/llama_stack/providers/tests/conftest.py +++ b/llama_stack/providers/tests/conftest.py @@ -17,7 +17,7 @@ from llama_stack.distribution.datatypes import Provider class ProviderFixture(BaseModel): - provider: Provider + providers: List[Provider] provider_data: Optional[Dict[str, Any]] = None diff --git a/llama_stack/providers/tests/inference/fixtures.py b/llama_stack/providers/tests/inference/fixtures.py index e53034370..860eea4b2 100644 --- a/llama_stack/providers/tests/inference/fixtures.py +++ b/llama_stack/providers/tests/inference/fixtures.py @@ -24,63 +24,80 @@ from ..env import get_env_or_fail @pytest.fixture(scope="session") def inference_model(request): - return request.config.getoption("--inference-model", None) or request.param + if hasattr(request, "param"): + return request.param + return request.config.getoption("--inference-model", None) @pytest.fixture(scope="session") def inference_meta_reference(inference_model) -> ProviderFixture: - # TODO: make this work with multiple models + inference_model = ( + [inference_model] if isinstance(inference_model, str) else inference_model + ) + return ProviderFixture( - provider=Provider( - provider_id="meta-reference", - provider_type="meta-reference", - config=MetaReferenceInferenceConfig( - model=inference_model, - max_seq_len=512, - create_distributed_process_group=False, - checkpoint_dir=os.getenv("MODEL_CHECKPOINT_DIR", None), - ).model_dump(), - ), + providers=[ + Provider( + provider_id=f"meta-reference-{i}", + provider_type="meta-reference", + config=MetaReferenceInferenceConfig( + model=m, + max_seq_len=4096, + create_distributed_process_group=False, + checkpoint_dir=os.getenv("MODEL_CHECKPOINT_DIR", None), + ).model_dump(), + ) + for i, m in enumerate(inference_model) + ] ) @pytest.fixture(scope="session") def inference_ollama(inference_model) -> ProviderFixture: - if inference_model == "Llama3.1-8B-Instruct": - pytest.skip("Ollama only support Llama3.2-3B-Instruct for testing") + inference_model = ( + [inference_model] if isinstance(inference_model, str) else inference_model + ) + if "Llama3.1-8B-Instruct" in inference_model: + pytest.skip("Ollama only supports Llama3.2-3B-Instruct for testing") return ProviderFixture( - provider=Provider( - provider_id="ollama", - provider_type="remote::ollama", - config=OllamaImplConfig( - host="localhost", port=os.getenv("OLLAMA_PORT", 11434) - ).model_dump(), - ), + providers=[ + Provider( + provider_id="ollama", + provider_type="remote::ollama", + config=OllamaImplConfig( + host="localhost", port=os.getenv("OLLAMA_PORT", 11434) + ).model_dump(), + ) + ], ) @pytest.fixture(scope="session") def inference_fireworks() -> ProviderFixture: return ProviderFixture( - provider=Provider( - provider_id="fireworks", - provider_type="remote::fireworks", - config=FireworksImplConfig( - api_key=get_env_or_fail("FIREWORKS_API_KEY"), - ).model_dump(), - ), + providers=[ + Provider( + provider_id="fireworks", + provider_type="remote::fireworks", + config=FireworksImplConfig( + api_key=get_env_or_fail("FIREWORKS_API_KEY"), + ).model_dump(), + ) + ], ) @pytest.fixture(scope="session") def inference_together() -> ProviderFixture: return ProviderFixture( - provider=Provider( - provider_id="together", - provider_type="remote::together", - config=TogetherImplConfig().model_dump(), - ), + providers=[ + Provider( + provider_id="together", + provider_type="remote::together", + config=TogetherImplConfig().model_dump(), + ) + ], provider_data=dict( together_api_key=get_env_or_fail("TOGETHER_API_KEY"), ), @@ -96,7 +113,7 @@ async def inference_stack(request): inference_fixture = request.getfixturevalue(f"inference_{fixture_name}") impls = await resolve_impls_for_test_v2( [Api.inference], - {"inference": [inference_fixture.provider.model_dump()]}, + {"inference": inference_fixture.providers}, inference_fixture.provider_data, ) diff --git a/llama_stack/providers/tests/memory/fixtures.py b/llama_stack/providers/tests/memory/fixtures.py index 11ed121c7..4a6642e85 100644 --- a/llama_stack/providers/tests/memory/fixtures.py +++ b/llama_stack/providers/tests/memory/fixtures.py @@ -22,39 +22,45 @@ from ..env import get_env_or_fail @pytest.fixture(scope="session") def memory_meta_reference() -> ProviderFixture: return ProviderFixture( - provider=Provider( - provider_id="meta-reference", - provider_type="meta-reference", - config=FaissImplConfig().model_dump(), - ), + providers=[ + Provider( + provider_id="meta-reference", + provider_type="meta-reference", + config=FaissImplConfig().model_dump(), + ) + ], ) @pytest.fixture(scope="session") def memory_pgvector() -> ProviderFixture: return ProviderFixture( - provider=Provider( - provider_id="pgvector", - provider_type="remote::pgvector", - config=PGVectorConfig( - host=os.getenv("PGVECTOR_HOST", "localhost"), - port=os.getenv("PGVECTOR_PORT", 5432), - db=get_env_or_fail("PGVECTOR_DB"), - user=get_env_or_fail("PGVECTOR_USER"), - password=get_env_or_fail("PGVECTOR_PASSWORD"), - ).model_dump(), - ), + providers=[ + Provider( + provider_id="pgvector", + provider_type="remote::pgvector", + config=PGVectorConfig( + host=os.getenv("PGVECTOR_HOST", "localhost"), + port=os.getenv("PGVECTOR_PORT", 5432), + db=get_env_or_fail("PGVECTOR_DB"), + user=get_env_or_fail("PGVECTOR_USER"), + password=get_env_or_fail("PGVECTOR_PASSWORD"), + ).model_dump(), + ) + ], ) @pytest.fixture(scope="session") def memory_weaviate() -> ProviderFixture: return ProviderFixture( - provider=Provider( - provider_id="weaviate", - provider_type="remote::weaviate", - config=WeaviateConfig().model_dump(), - ), + providers=[ + Provider( + provider_id="weaviate", + provider_type="remote::weaviate", + config=WeaviateConfig().model_dump(), + ) + ], provider_data=dict( weaviate_api_key=get_env_or_fail("WEAVIATE_API_KEY"), weaviate_cluster_url=get_env_or_fail("WEAVIATE_CLUSTER_URL"), @@ -72,7 +78,7 @@ async def memory_stack(request): impls = await resolve_impls_for_test_v2( [Api.memory], - {"memory": [fixture.provider.model_dump()]}, + {"memory": fixture.providers}, fixture.provider_data, ) diff --git a/llama_stack/providers/tests/resolver.py b/llama_stack/providers/tests/resolver.py index a03b25aba..2d6805b35 100644 --- a/llama_stack/providers/tests/resolver.py +++ b/llama_stack/providers/tests/resolver.py @@ -20,7 +20,7 @@ from llama_stack.distribution.resolver import resolve_impls async def resolve_impls_for_test_v2( apis: List[Api], - providers: Dict[str, Provider], + providers: Dict[str, List[Provider]], provider_data: Optional[Dict[str, Any]] = None, ): run_config = dict( diff --git a/llama_stack/providers/tests/safety/conftest.py b/llama_stack/providers/tests/safety/conftest.py index d64a2a1ee..c5424f8db 100644 --- a/llama_stack/providers/tests/safety/conftest.py +++ b/llama_stack/providers/tests/safety/conftest.py @@ -81,7 +81,6 @@ def pytest_generate_tests(metafunc): ) if "safety_stack" in metafunc.fixturenames: - # print(f"metafunc.fixturenames: {metafunc.fixturenames}, {metafunc}") available_fixtures = { "inference": INFERENCE_FIXTURES, "safety": SAFETY_FIXTURES, diff --git a/llama_stack/providers/tests/safety/fixtures.py b/llama_stack/providers/tests/safety/fixtures.py index 17c5262de..463c53d2c 100644 --- a/llama_stack/providers/tests/safety/fixtures.py +++ b/llama_stack/providers/tests/safety/fixtures.py @@ -22,32 +22,38 @@ from ..env import get_env_or_fail @pytest.fixture(scope="session") def safety_model(request): - return request.config.getoption("--safety-model", None) or request.param + if hasattr(request, "param"): + return request.param + return request.config.getoption("--safety-model", None) @pytest.fixture(scope="session") def safety_meta_reference(safety_model) -> ProviderFixture: return ProviderFixture( - provider=Provider( - provider_id="meta-reference", - provider_type="meta-reference", - config=SafetyConfig( - llama_guard_shield=LlamaGuardShieldConfig( - model=safety_model, - ), - ).model_dump(), - ), + providers=[ + Provider( + provider_id="meta-reference", + provider_type="meta-reference", + config=SafetyConfig( + llama_guard_shield=LlamaGuardShieldConfig( + model=safety_model, + ), + ).model_dump(), + ) + ], ) @pytest.fixture(scope="session") def safety_together() -> ProviderFixture: return ProviderFixture( - provider=Provider( - provider_id="together", - provider_type="remote::together", - config=TogetherSafetyConfig().model_dump(), - ), + providers=[ + Provider( + provider_id="together", + provider_type="remote::together", + config=TogetherSafetyConfig().model_dump(), + ) + ], provider_data=dict( together_api_key=get_env_or_fail("TOGETHER_API_KEY"), ), @@ -67,8 +73,8 @@ async def safety_stack(inference_model, safety_model, request): safety_fixture = request.getfixturevalue(f"safety_{fixture_dict['safety']}") providers = { - "inference": [inference_fixture.provider], - "safety": [safety_fixture.provider], + "inference": inference_fixture.providers, + "safety": safety_fixture.providers, } provider_data = {} if inference_fixture.provider_data: