From 62dd3b376ca24c6e0e592d6ba01a11232c64dd98 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Sun, 3 Nov 2024 16:11:43 -0800 Subject: [PATCH] Fix memory to use the newer fixture organization --- llama_stack/providers/tests/conftest.py | 1 + .../providers/tests/memory/conftest.py | 82 +---------------- .../providers/tests/memory/fixtures.py | 87 +++++++++++++++++++ .../tests/memory/provider_config_example.yaml | 29 ------- .../providers/tests/memory/test_memory.py | 16 ++-- .../providers/tests/safety/conftest.py | 4 + .../providers/tests/safety/fixtures.py | 1 + .../tests/safety/provider_config_example.yaml | 19 ---- 8 files changed, 102 insertions(+), 137 deletions(-) create mode 100644 llama_stack/providers/tests/memory/fixtures.py delete mode 100644 llama_stack/providers/tests/memory/provider_config_example.yaml delete mode 100644 llama_stack/providers/tests/safety/provider_config_example.yaml diff --git a/llama_stack/providers/tests/conftest.py b/llama_stack/providers/tests/conftest.py index abc784a01..8647baddd 100644 --- a/llama_stack/providers/tests/conftest.py +++ b/llama_stack/providers/tests/conftest.py @@ -129,4 +129,5 @@ def pytest_itemcollected(item): pytest_plugins = [ "llama_stack.providers.tests.inference.fixtures", "llama_stack.providers.tests.safety.fixtures", + "llama_stack.providers.tests.memory.fixtures", ] diff --git a/llama_stack/providers/tests/memory/conftest.py b/llama_stack/providers/tests/memory/conftest.py index 1a85fe17b..c5057ecb4 100644 --- a/llama_stack/providers/tests/memory/conftest.py +++ b/llama_stack/providers/tests/memory/conftest.py @@ -4,87 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import os - -import pytest -import pytest_asyncio - -from llama_stack.distribution.datatypes import Api, Provider -from llama_stack.providers.adapters.memory.pgvector import PGVectorConfig -from llama_stack.providers.adapters.memory.weaviate import WeaviateConfig -from llama_stack.providers.impls.meta_reference.memory import FaissImplConfig - -from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2 -from ..conftest import ProviderFixture -from ..env import get_env_or_fail - - -@pytest.fixture(scope="session") -def meta_reference() -> ProviderFixture: - return ProviderFixture( - provider=Provider( - provider_id="meta-reference", - provider_type="meta-reference", - config=FaissImplConfig().model_dump(), - ), - ) - - -@pytest.fixture(scope="session") -def pgvector() -> ProviderFixture: - return ProviderFixture( - provider=Provider( - provider_id="pgvector", - provider_type="remote::pgvector", - config=PGVectorConfig( - host=os.getenv("PGVECTOR_HOST", "localhost"), - port=os.getenv("PGVECTOR_PORT", 5432), - db=get_env_or_fail("PGVECTOR_DB"), - user=get_env_or_fail("PGVECTOR_USER"), - password=get_env_or_fail("PGVECTOR_PASSWORD"), - ).model_dump(), - ), - ) - - -@pytest.fixture(scope="session") -def weaviate() -> ProviderFixture: - return ProviderFixture( - provider=Provider( - provider_id="weaviate", - provider_type="remote::weaviate", - config=WeaviateConfig().model_dump(), - ), - provider_data=dict( - weaviate_api_key=get_env_or_fail("WEAVIATE_API_KEY"), - weaviate_cluster_url=get_env_or_fail("WEAVIATE_CLUSTER_URL"), - ), - ) - - -MEMORY_FIXTURES = ["meta_reference", "pgvector", "weaviate"] - -PROVIDER_PARAMS = [ - pytest.param(fixture_name, marks=getattr(pytest.mark, fixture_name)) - for fixture_name in MEMORY_FIXTURES -] - - -@pytest_asyncio.fixture( - scope="session", - params=PROVIDER_PARAMS, -) -async def stack_impls(request): - fixture_name = request.param - fixture = request.getfixturevalue(fixture_name) - - impls = await resolve_impls_for_test_v2( - [Api.memory], - {"memory": [fixture.provider.model_dump()]}, - fixture.provider_data, - ) - - return impls[Api.memory], impls[Api.memory_banks] +from .fixtures import MEMORY_FIXTURES def pytest_configure(config): diff --git a/llama_stack/providers/tests/memory/fixtures.py b/llama_stack/providers/tests/memory/fixtures.py new file mode 100644 index 000000000..79ab2ffff --- /dev/null +++ b/llama_stack/providers/tests/memory/fixtures.py @@ -0,0 +1,87 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import os + +import pytest +import pytest_asyncio + +from llama_stack.distribution.datatypes import Api, Provider +from llama_stack.providers.adapters.memory.pgvector import PGVectorConfig +from llama_stack.providers.adapters.memory.weaviate import WeaviateConfig +from llama_stack.providers.impls.meta_reference.memory import FaissImplConfig + +from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2 +from ..conftest import ProviderFixture +from ..env import get_env_or_fail + + +@pytest.fixture(scope="session") +def meta_reference() -> ProviderFixture: + return ProviderFixture( + provider=Provider( + provider_id="meta-reference", + provider_type="meta-reference", + config=FaissImplConfig().model_dump(), + ), + ) + + +@pytest.fixture(scope="session") +def pgvector() -> ProviderFixture: + return ProviderFixture( + provider=Provider( + provider_id="pgvector", + provider_type="remote::pgvector", + config=PGVectorConfig( + host=os.getenv("PGVECTOR_HOST", "localhost"), + port=os.getenv("PGVECTOR_PORT", 5432), + db=get_env_or_fail("PGVECTOR_DB"), + user=get_env_or_fail("PGVECTOR_USER"), + password=get_env_or_fail("PGVECTOR_PASSWORD"), + ).model_dump(), + ), + ) + + +@pytest.fixture(scope="session") +def weaviate() -> ProviderFixture: + return ProviderFixture( + provider=Provider( + provider_id="weaviate", + provider_type="remote::weaviate", + config=WeaviateConfig().model_dump(), + ), + provider_data=dict( + weaviate_api_key=get_env_or_fail("WEAVIATE_API_KEY"), + weaviate_cluster_url=get_env_or_fail("WEAVIATE_CLUSTER_URL"), + ), + ) + + +MEMORY_FIXTURES = ["meta_reference", "pgvector", "weaviate"] + +PROVIDER_PARAMS = [ + pytest.param(fixture_name, marks=getattr(pytest.mark, fixture_name)) + for fixture_name in MEMORY_FIXTURES +] + + +@pytest_asyncio.fixture( + scope="session", + params=PROVIDER_PARAMS, +) +async def memory_stack(request): + fixture_name = request.param + fixture = request.getfixturevalue(fixture_name) + + impls = await resolve_impls_for_test_v2( + [Api.memory], + {"memory": [fixture.provider.model_dump()]}, + fixture.provider_data, + ) + + return impls[Api.memory], impls[Api.memory_banks] diff --git a/llama_stack/providers/tests/memory/provider_config_example.yaml b/llama_stack/providers/tests/memory/provider_config_example.yaml deleted file mode 100644 index 13575a598..000000000 --- a/llama_stack/providers/tests/memory/provider_config_example.yaml +++ /dev/null @@ -1,29 +0,0 @@ -providers: - - provider_id: test-faiss - provider_type: meta-reference - config: {} - - provider_id: test-chromadb - provider_type: remote::chromadb - config: - host: localhost - port: 6001 - - provider_id: test-remote - provider_type: remote - config: - host: localhost - port: 7002 - - provider_id: test-weaviate - provider_type: remote::weaviate - config: {} - - provider_id: test-qdrant - provider_type: remote::qdrant - config: - host: localhost - port: 6333 -# if a provider needs private keys from the client, they use the -# "get_request_provider_data" function (see distribution/request_headers.py) -# this is a place to provide such data. -provider_data: - "test-weaviate": - weaviate_api_key: 0xdeadbeefputrealapikeyhere - weaviate_cluster_url: http://foobarbaz diff --git a/llama_stack/providers/tests/memory/test_memory.py b/llama_stack/providers/tests/memory/test_memory.py index aa8594a36..a948fa17e 100644 --- a/llama_stack/providers/tests/memory/test_memory.py +++ b/llama_stack/providers/tests/memory/test_memory.py @@ -8,7 +8,7 @@ import pytest from llama_stack.apis.memory import * # noqa: F403 from llama_stack.distribution.datatypes import * # noqa: F403 -from .conftest import PROVIDER_PARAMS +from .fixtures import PROVIDER_PARAMS # How to run this test: # @@ -55,25 +55,25 @@ async def register_memory_bank(banks_impl: MemoryBanks): @pytest.mark.parametrize( - "stack_impls", + "memory_stack", PROVIDER_PARAMS, indirect=True, ) class TestMemory: @pytest.mark.asyncio - async def test_banks_list(self, stack_impls): + async def test_banks_list(self, memory_stack): # NOTE: this needs you to ensure that you are starting from a clean state # but so far we don't have an unregister API unfortunately, so be careful - _, banks_impl = stack_impls + _, banks_impl = memory_stack response = await banks_impl.list_memory_banks() assert isinstance(response, list) assert len(response) == 0 @pytest.mark.asyncio - async def test_banks_register(self, stack_impls): + async def test_banks_register(self, memory_stack): # NOTE: this needs you to ensure that you are starting from a clean state # but so far we don't have an unregister API unfortunately, so be careful - _, banks_impl = stack_impls + _, banks_impl = memory_stack bank = VectorMemoryBankDef( identifier="test_bank_no_provider", embedding_model="all-MiniLM-L6-v2", @@ -93,8 +93,8 @@ class TestMemory: assert len(response) == 1 @pytest.mark.asyncio - async def test_query_documents(self, stack_impls, sample_documents): - memory_impl, banks_impl = stack_impls + async def test_query_documents(self, memory_stack, sample_documents): + memory_impl, banks_impl = memory_stack with pytest.raises(ValueError): await memory_impl.insert_documents("test_bank", sample_documents) diff --git a/llama_stack/providers/tests/safety/conftest.py b/llama_stack/providers/tests/safety/conftest.py index c3a120c0b..25f13b1a4 100644 --- a/llama_stack/providers/tests/safety/conftest.py +++ b/llama_stack/providers/tests/safety/conftest.py @@ -49,6 +49,10 @@ def pytest_configure(config): def pytest_generate_tests(metafunc): + # We use this method to make sure we have built-in simple combos for safety tests + # But a user can also pass in a custom combination via the CLI by doing + # `--providers inference=together,safety=meta_reference` + if "safety_stack" in metafunc.fixturenames: # print(f"metafunc.fixturenames: {metafunc.fixturenames}, {metafunc}") available_fixtures = { diff --git a/llama_stack/providers/tests/safety/fixtures.py b/llama_stack/providers/tests/safety/fixtures.py index cf5aa9589..cf23e032b 100644 --- a/llama_stack/providers/tests/safety/fixtures.py +++ b/llama_stack/providers/tests/safety/fixtures.py @@ -64,6 +64,7 @@ SAFETY_FIXTURES = ["meta_reference", "together"] @pytest_asyncio.fixture(scope="session") async def safety_stack(inference_model, safety_model, request): + # We need an inference + safety fixture to test safety fixture_dict = request.param inference_fixture = request.getfixturevalue( f"inference_{fixture_dict['inference']}" diff --git a/llama_stack/providers/tests/safety/provider_config_example.yaml b/llama_stack/providers/tests/safety/provider_config_example.yaml deleted file mode 100644 index 088dc2cf2..000000000 --- a/llama_stack/providers/tests/safety/provider_config_example.yaml +++ /dev/null @@ -1,19 +0,0 @@ -providers: - inference: - - provider_id: together - provider_type: remote::together - config: {} - - provider_id: tgi - provider_type: remote::tgi - config: - url: http://127.0.0.1:7002 - - provider_id: meta-reference - provider_type: meta-reference - config: - model: Llama-Guard-3-1B - safety: - - provider_id: meta-reference - provider_type: meta-reference - config: - llama_guard_shield: - model: Llama-Guard-3-1B