From a42fbea1b88f49e5d98548f08139b7e50546f420 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Sat, 2 Nov 2024 13:32:41 -0700 Subject: [PATCH] convert memory tests --- .../providers/tests/memory/conftest.py | 96 +++++++++++ .../providers/tests/memory/test_memory.py | 154 ++++++++---------- 2 files changed, 164 insertions(+), 86 deletions(-) create mode 100644 llama_stack/providers/tests/memory/conftest.py diff --git a/llama_stack/providers/tests/memory/conftest.py b/llama_stack/providers/tests/memory/conftest.py new file mode 100644 index 000000000..f1aea99c2 --- /dev/null +++ b/llama_stack/providers/tests/memory/conftest.py @@ -0,0 +1,96 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import os +from typing import Any, Dict, Tuple + +import pytest +import pytest_asyncio + +from llama_stack.distribution.datatypes import Api, Provider +from llama_stack.providers.adapters.memory.pgvector import PGVectorConfig +from llama_stack.providers.adapters.memory.weaviate import WeaviateConfig +from llama_stack.providers.impls.meta_reference.memory import FaissImplConfig + +from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2 +from ..env import get_env_or_fail + + +@pytest.fixture(scope="session") +def meta_reference() -> Provider: + return Provider( + provider_id="meta-reference", + provider_type="meta-reference", + config=FaissImplConfig().model_dump(), + ) + + +@pytest.fixture(scope="session") +def pgvector() -> Provider: + return Provider( + provider_id="pgvector", + provider_type="remote::pgvector", + config=PGVectorConfig( + host=os.getenv("PGVECTOR_HOST", "localhost"), + port=os.getenv("PGVECTOR_PORT", 5432), + db=get_env_or_fail("PGVECTOR_DB"), + user=get_env_or_fail("PGVECTOR_USER"), + password=get_env_or_fail("PGVECTOR_PASSWORD"), + ).model_dump(), + ) + + +@pytest.fixture(scope="session") +def weaviate() -> Tuple[Provider, Dict[str, Any]]: + provider = Provider( + provider_id="weaviate", + provider_type="remote::weaviate", + config=WeaviateConfig().model_dump(), + ) + return provider, dict( + weaviate_api_key=get_env_or_fail("WEAVIATE_API_KEY"), + weaviate_cluster_url=get_env_or_fail("WEAVIATE_CLUSTER_URL"), + ) + + +PROVIDER_PARAMS = [ + pytest.param("meta_reference", marks=pytest.mark.meta_reference), + pytest.param("pgvector", marks=pytest.mark.pgvector), + pytest.param("weaviate", marks=pytest.mark.weaviate), +] + + +@pytest_asyncio.fixture( + scope="session", + params=PROVIDER_PARAMS, +) +async def stack_impls(request): + provider_fixture = request.param + provider = request.getfixturevalue(provider_fixture) + if isinstance(provider, tuple): + provider, provider_data = provider + else: + provider_data = None + + impls = await resolve_impls_for_test_v2( + [Api.memory], + {"memory": [provider.model_dump()]}, + provider_data, + ) + + return impls[Api.memory], impls[Api.memory_banks] + + +def pytest_configure(config): + config.addinivalue_line("markers", "pgvector: marks tests as pgvector specific") + config.addinivalue_line( + "markers", + "meta_reference: marks tests as metaref specific", + ) + config.addinivalue_line( + "markers", + "weaviate: marks tests as weaviate specific", + ) diff --git a/llama_stack/providers/tests/memory/test_memory.py b/llama_stack/providers/tests/memory/test_memory.py index d83601de1..aa8594a36 100644 --- a/llama_stack/providers/tests/memory/test_memory.py +++ b/llama_stack/providers/tests/memory/test_memory.py @@ -5,39 +5,16 @@ # the root directory of this source tree. import pytest -import pytest_asyncio from llama_stack.apis.memory import * # noqa: F403 from llama_stack.distribution.datatypes import * # noqa: F403 -from llama_stack.providers.tests.resolver import resolve_impls_for_test +from .conftest import PROVIDER_PARAMS # How to run this test: # -# 1. Ensure you have a conda with the right dependencies installed. This is a bit tricky -# since it depends on the provider you are testing. On top of that you need -# `pytest` and `pytest-asyncio` installed. -# -# 2. Copy and modify the provider_config_example.yaml depending on the provider you are testing. -# -# 3. Run: -# -# ```bash -# PROVIDER_ID= \ -# PROVIDER_CONFIG=provider_config.yaml \ -# pytest -s llama_stack/providers/tests/memory/test_memory.py \ -# --tb=short --disable-warnings -# ``` - - -@pytest_asyncio.fixture(scope="session") -async def memory_settings(): - impls = await resolve_impls_for_test( - Api.memory, - ) - return { - "memory_impl": impls[Api.memory], - "memory_banks_impl": impls[Api.memory_banks], - } +# pytest llama_stack/providers/tests/memory/test_memory.py +# -m "meta_reference" +# -v -s --tb=short --disable-warnings @pytest.fixture @@ -77,76 +54,81 @@ async def register_memory_bank(banks_impl: MemoryBanks): await banks_impl.register_memory_bank(bank) -@pytest.mark.asyncio -async def test_banks_list(memory_settings): - # NOTE: this needs you to ensure that you are starting from a clean state - # but so far we don't have an unregister API unfortunately, so be careful - banks_impl = memory_settings["memory_banks_impl"] - response = await banks_impl.list_memory_banks() - assert isinstance(response, list) - assert len(response) == 0 +@pytest.mark.parametrize( + "stack_impls", + PROVIDER_PARAMS, + indirect=True, +) +class TestMemory: + @pytest.mark.asyncio + async def test_banks_list(self, stack_impls): + # NOTE: this needs you to ensure that you are starting from a clean state + # but so far we don't have an unregister API unfortunately, so be careful + _, banks_impl = stack_impls + response = await banks_impl.list_memory_banks() + assert isinstance(response, list) + assert len(response) == 0 + @pytest.mark.asyncio + async def test_banks_register(self, stack_impls): + # NOTE: this needs you to ensure that you are starting from a clean state + # but so far we don't have an unregister API unfortunately, so be careful + _, banks_impl = stack_impls + bank = VectorMemoryBankDef( + identifier="test_bank_no_provider", + embedding_model="all-MiniLM-L6-v2", + chunk_size_in_tokens=512, + overlap_size_in_tokens=64, + ) -@pytest.mark.asyncio -async def test_banks_register(memory_settings): - # NOTE: this needs you to ensure that you are starting from a clean state - # but so far we don't have an unregister API unfortunately, so be careful - banks_impl = memory_settings["memory_banks_impl"] - bank = VectorMemoryBankDef( - identifier="test_bank_no_provider", - embedding_model="all-MiniLM-L6-v2", - chunk_size_in_tokens=512, - overlap_size_in_tokens=64, - ) + await banks_impl.register_memory_bank(bank) + response = await banks_impl.list_memory_banks() + assert isinstance(response, list) + assert len(response) == 1 - await banks_impl.register_memory_bank(bank) - response = await banks_impl.list_memory_banks() - assert isinstance(response, list) - assert len(response) == 1 + # register same memory bank with same id again will fail + await banks_impl.register_memory_bank(bank) + response = await banks_impl.list_memory_banks() + assert isinstance(response, list) + assert len(response) == 1 - # register same memory bank with same id again will fail - await banks_impl.register_memory_bank(bank) - response = await banks_impl.list_memory_banks() - assert isinstance(response, list) - assert len(response) == 1 + @pytest.mark.asyncio + async def test_query_documents(self, stack_impls, sample_documents): + memory_impl, banks_impl = stack_impls + with pytest.raises(ValueError): + await memory_impl.insert_documents("test_bank", sample_documents) -@pytest.mark.asyncio -async def test_query_documents(memory_settings, sample_documents): - memory_impl = memory_settings["memory_impl"] - banks_impl = memory_settings["memory_banks_impl"] - - with pytest.raises(ValueError): + await register_memory_bank(banks_impl) await memory_impl.insert_documents("test_bank", sample_documents) - await register_memory_bank(banks_impl) - await memory_impl.insert_documents("test_bank", sample_documents) + query1 = "programming language" + response1 = await memory_impl.query_documents("test_bank", query1) + assert_valid_response(response1) + assert any("Python" in chunk.content for chunk in response1.chunks) - query1 = "programming language" - response1 = await memory_impl.query_documents("test_bank", query1) - assert_valid_response(response1) - assert any("Python" in chunk.content for chunk in response1.chunks) + # Test case 3: Query with semantic similarity + query3 = "AI and brain-inspired computing" + response3 = await memory_impl.query_documents("test_bank", query3) + assert_valid_response(response3) + assert any( + "neural networks" in chunk.content.lower() for chunk in response3.chunks + ) - # Test case 3: Query with semantic similarity - query3 = "AI and brain-inspired computing" - response3 = await memory_impl.query_documents("test_bank", query3) - assert_valid_response(response3) - assert any("neural networks" in chunk.content.lower() for chunk in response3.chunks) + # Test case 4: Query with limit on number of results + query4 = "computer" + params4 = {"max_chunks": 2} + response4 = await memory_impl.query_documents("test_bank", query4, params4) + assert_valid_response(response4) + assert len(response4.chunks) <= 2 - # Test case 4: Query with limit on number of results - query4 = "computer" - params4 = {"max_chunks": 2} - response4 = await memory_impl.query_documents("test_bank", query4, params4) - assert_valid_response(response4) - assert len(response4.chunks) <= 2 - - # Test case 5: Query with threshold on similarity score - query5 = "quantum computing" # Not directly related to any document - params5 = {"score_threshold": 0.2} - response5 = await memory_impl.query_documents("test_bank", query5, params5) - assert_valid_response(response5) - print("The scores are:", response5.scores) - assert all(score >= 0.2 for score in response5.scores) + # Test case 5: Query with threshold on similarity score + query5 = "quantum computing" # Not directly related to any document + params5 = {"score_threshold": 0.2} + response5 = await memory_impl.query_documents("test_bank", query5, params5) + assert_valid_response(response5) + print("The scores are:", response5.scores) + assert all(score >= 0.2 for score in response5.scores) def assert_valid_response(response: QueryDocumentsResponse):