llama-stack-mirror/llama_stack/providers/tests/memory/fixtures.py
Vladimir Ivic c62187c4a2 Adding memory provider test fakes
Summary:
Part of
* https://github.com/meta-llama/llama-stack/issues/436

This change adds a minimalistic support to creating memory provider test fake. For more details about the approach, follow the issue link from above.

Test Plan:
Run tests using the "test_fake" mark:
```
pytest llama_stack/providers/tests/memory/test_memory.py -m "test_fake"
/opt/homebrew/Caskroom/miniconda/base/envs/llama-stack/lib/python3.11/site-packages/pytest_asyncio/plugin.py:208: PytestDeprecationWarning: The configuration option "asyncio_default_fixture_loop_scope" is unset.
The event loop scope for asynchronous fixtures will default to the fixture caching scope. Future versions of pytest-asyncio will default the loop scope for asynchronous fixtures to function scope. Set the default fixture loop scope explicitly in order to avoid unexpected behavior in the future. Valid fixture loop scopes are: "function", "class", "module", "package", "session"

  warnings.warn(PytestDeprecationWarning(_DEFAULT_FIXTURE_LOOP_SCOPE_UNSET))
====================================================================================================== test session starts ======================================================================================================
platform darwin -- Python 3.11.10, pytest-8.3.3, pluggy-1.5.0
rootdir: /llama-stack
configfile: pyproject.toml
plugins: asyncio-0.24.0, anyio-4.6.2.post1
asyncio: mode=Mode.STRICT, default_loop_scope=None
collected 18 items / 15 deselected / 3 selected

llama_stack/providers/tests/memory/test_memory.py ...                                                                                                                                                                     [100%]

========================================================================================= 3 passed, 15 deselected, 6 warnings in 0.03s ==========================================================================================
```
2024-11-18 13:37:02 -08:00

158 lines
5.2 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
import tempfile
import pytest
import pytest_asyncio
from llama_stack.apis.memory.memory import Chunk, QueryDocumentsResponse
from llama_stack.distribution.datatypes import Api, Provider, RemoteProviderConfig
from llama_stack.providers.datatypes import TestFakeProviderConfig
from llama_stack.providers.inline.memory.faiss import FaissImplConfig
from llama_stack.providers.remote.memory.pgvector import PGVectorConfig
from llama_stack.providers.remote.memory.weaviate import WeaviateConfig
from llama_stack.providers.tests.resolver import construct_stack_for_test
from llama_stack.providers.utils.kvstore import SqliteKVStoreConfig
from ..conftest import ProviderFixture, remote_stack_fixture, test_fake_stack_fixture
from ..env import get_env_or_fail
from .fakes import MemoryBanksTestFakeImpl, MemoryTestFakeImpl
@pytest.fixture(scope="session")
def query_documents_stubs():
# These are stubs for the method calls against MemoryTestFakeImpl fake
# so the tests inside test_memory will as with a real provider
return {
"programming language": QueryDocumentsResponse(
chunks=[Chunk(content="Python", token_count=1, document_id="")],
scores=[0.1],
),
"AI and brain-inspired computing": QueryDocumentsResponse(
chunks=[Chunk(content="neural networks", token_count=2, document_id="")],
scores=[0.1],
),
"computer": QueryDocumentsResponse(
chunks=[
Chunk(content="test-chunk-1", token_count=1, document_id=""),
Chunk(content="test-chunk-2", token_count=1, document_id=""),
],
scores=[0.1, 0.5],
),
"quantum computing": QueryDocumentsResponse(
chunks=[Chunk(content="Python", token_count=1, document_id="")],
scores=[0.5],
),
}
@pytest.fixture(scope="session")
def memory_test_fake(query_documents_stubs) -> ProviderFixture:
# Prepare impl instances here, initiate fake objects and set up stubs
memory_banks_impl = MemoryBanksTestFakeImpl()
memory_impl = MemoryTestFakeImpl()
memory_impl.set_memory_banks(memory_banks_impl)
memory_impl.set_stubs("query_documents", query_documents_stubs)
config = TestFakeProviderConfig(
impls={
Api.memory: memory_impl,
Api.memory_banks: memory_banks_impl,
}
)
return test_fake_stack_fixture(config)
@pytest.fixture(scope="session")
def memory_remote() -> ProviderFixture:
return remote_stack_fixture()
@pytest.fixture(scope="session")
def memory_faiss() -> ProviderFixture:
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
return ProviderFixture(
providers=[
Provider(
provider_id="faiss",
provider_type="inline::faiss",
config=FaissImplConfig(
kvstore=SqliteKVStoreConfig(db_path=temp_file.name).model_dump(),
).model_dump(),
)
],
)
@pytest.fixture(scope="session")
def memory_pgvector() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="pgvector",
provider_type="remote::pgvector",
config=PGVectorConfig(
host=os.getenv("PGVECTOR_HOST", "localhost"),
port=os.getenv("PGVECTOR_PORT", 5432),
db=get_env_or_fail("PGVECTOR_DB"),
user=get_env_or_fail("PGVECTOR_USER"),
password=get_env_or_fail("PGVECTOR_PASSWORD"),
).model_dump(),
)
],
)
@pytest.fixture(scope="session")
def memory_weaviate() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="weaviate",
provider_type="remote::weaviate",
config=WeaviateConfig().model_dump(),
)
],
provider_data=dict(
weaviate_api_key=get_env_or_fail("WEAVIATE_API_KEY"),
weaviate_cluster_url=get_env_or_fail("WEAVIATE_CLUSTER_URL"),
),
)
@pytest.fixture(scope="session")
def memory_chroma() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="chroma",
provider_type="remote::chromadb",
config=RemoteProviderConfig(
host=get_env_or_fail("CHROMA_HOST"),
port=get_env_or_fail("CHROMA_PORT"),
).model_dump(),
)
]
)
MEMORY_FIXTURES = ["test_fake", "faiss", "pgvector", "weaviate", "remote", "chroma"]
@pytest_asyncio.fixture(scope="session")
async def memory_stack(request):
fixture_name = request.param
fixture = request.getfixturevalue(f"memory_{fixture_name}")
test_stack = await construct_stack_for_test(
[Api.memory],
{"memory": fixture.providers},
fixture.provider_data,
)
return test_stack.impls[Api.memory], test_stack.impls[Api.memory_banks]