From c62187c4a2cf6e9275f4fb5d79428ea868c1ae3f Mon Sep 17 00:00:00 2001 From: Vladimir Ivic Date: Mon, 18 Nov 2024 13:37:01 -0800 Subject: [PATCH] Adding memory provider test fakes Summary: Part of * https://github.com/meta-llama/llama-stack/issues/436 This change adds a minimalistic support to creating memory provider test fake. For more details about the approach, follow the issue link from above. Test Plan: Run tests using the "test_fake" mark: ``` pytest llama_stack/providers/tests/memory/test_memory.py -m "test_fake" /opt/homebrew/Caskroom/miniconda/base/envs/llama-stack/lib/python3.11/site-packages/pytest_asyncio/plugin.py:208: PytestDeprecationWarning: The configuration option "asyncio_default_fixture_loop_scope" is unset. The event loop scope for asynchronous fixtures will default to the fixture caching scope. Future versions of pytest-asyncio will default the loop scope for asynchronous fixtures to function scope. Set the default fixture loop scope explicitly in order to avoid unexpected behavior in the future. Valid fixture loop scopes are: "function", "class", "module", "package", "session" warnings.warn(PytestDeprecationWarning(_DEFAULT_FIXTURE_LOOP_SCOPE_UNSET)) ====================================================================================================== test session starts ====================================================================================================== platform darwin -- Python 3.11.10, pytest-8.3.3, pluggy-1.5.0 rootdir: /llama-stack configfile: pyproject.toml plugins: asyncio-0.24.0, anyio-4.6.2.post1 asyncio: mode=Mode.STRICT, default_loop_scope=None collected 18 items / 15 deselected / 3 selected llama_stack/providers/tests/memory/test_memory.py ... [100%] ========================================================================================= 3 passed, 15 deselected, 6 warnings in 0.03s ========================================================================================== ``` --- llama_stack/distribution/resolver.py | 13 +++ llama_stack/providers/datatypes.py | 12 ++- llama_stack/providers/tests/conftest.py | 16 +++- llama_stack/providers/tests/memory/fakes.py | 92 +++++++++++++++++++ .../providers/tests/memory/fixtures.py | 52 ++++++++++- llama_stack/providers/tests/resolver.py | 55 +++++++++-- 6 files changed, 226 insertions(+), 14 deletions(-) create mode 100644 llama_stack/providers/tests/memory/fakes.py diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index 4c74b0d1f..07158aded 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -360,6 +360,19 @@ def check_protocol_compliance(obj: Any, protocol: Any) -> None: ) +# Here we simply want to apply mapping from the config. +# All test fakes before this point must be initialized +# and properly setup (stubs, mocks, etc). +async def resolve_test_fake_stack_impls( + config: TestFakeProviderConfig, +) -> Dict[Api, Any]: + impls = {} + for api, impl in config.impls.items(): + impls[api] = impl + + return impls + + async def resolve_remote_stack_impls( config: RemoteProviderConfig, apis: List[str], diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py index 080204e45..1c26fd4a2 100644 --- a/llama_stack/providers/datatypes.py +++ b/llama_stack/providers/datatypes.py @@ -5,7 +5,7 @@ # the root directory of this source tree. from enum import Enum -from typing import Any, List, Optional, Protocol +from typing import Any, Dict, List, Optional, Protocol from urllib.parse import urlparse from llama_models.schema_utils import json_schema_type @@ -157,6 +157,16 @@ Fully-qualified name of the module to import. The module is expected to have: ) +# Here we need this config to be as simple as possible +# to ensure we can set things up in a generic way. +# +# `impls``contains a simple Api to TestFake object mapping. +# +# Test fake objects must be initialized and properly setup beforehand. +class TestFakeProviderConfig(BaseModel): + impls: Dict[Api, Any] + + class RemoteProviderConfig(BaseModel): host: str = "localhost" port: Optional[int] = None diff --git a/llama_stack/providers/tests/conftest.py b/llama_stack/providers/tests/conftest.py index 8b73500d0..ff0e68df0 100644 --- a/llama_stack/providers/tests/conftest.py +++ b/llama_stack/providers/tests/conftest.py @@ -14,7 +14,7 @@ from pydantic import BaseModel from termcolor import colored from llama_stack.distribution.datatypes import Provider -from llama_stack.providers.datatypes import RemoteProviderConfig +from llama_stack.providers.datatypes import RemoteProviderConfig, TestFakeProviderConfig from .env import get_env_or_fail @@ -24,6 +24,20 @@ class ProviderFixture(BaseModel): provider_data: Optional[Dict[str, Any]] = None +# Generic test fake fixture. Use TestFakeProviderConfig to set test fakes +# that will be mapped to their corresponding APIs. +def test_fake_stack_fixture(config: TestFakeProviderConfig) -> ProviderFixture: + return ProviderFixture( + providers=[ + Provider( + provider_id="test::test-fake", + provider_type="test::test-fake", + config=config.model_dump(), + ) + ], + ) + + def remote_stack_fixture() -> ProviderFixture: if url := os.getenv("REMOTE_STACK_URL", None): config = RemoteProviderConfig.from_url(url) diff --git a/llama_stack/providers/tests/memory/fakes.py b/llama_stack/providers/tests/memory/fakes.py new file mode 100644 index 000000000..9fccb7d89 --- /dev/null +++ b/llama_stack/providers/tests/memory/fakes.py @@ -0,0 +1,92 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any, Dict, List, Optional + +from llama_stack.apis.memory.memory import ( + Memory, + MemoryBankDocument, + MemoryBankStore, + QueryDocumentsResponse, +) +from llama_stack.apis.memory_banks.memory_banks import KeyValueMemoryBank +from llama_models.llama3.api.datatypes import * # noqa: F403 +from llama_stack.distribution.datatypes import * # noqa: F403 + + +# MemoryBanks test fake implementation to +# support behaviors tested in test_memory.py +class MemoryBanksTestFakeImpl(MemoryBanks): + def __init__(self): + self.memory_banks: Dict[str, MemoryBank] = dict() + + async def list_memory_banks(self) -> List[MemoryBank]: + return list(self.memory_banks.values()) + + async def get_memory_bank(self, memory_bank_id: str) -> Optional[MemoryBank]: + if memory_bank_id in self.memory_banks: + return self.memory_banks[memory_bank_id] + + async def register_memory_bank( + self, + memory_bank_id: str, + params: BankParams, + provider_id: Optional[str] = None, + provider_memory_bank_id: Optional[str] = None, + ) -> MemoryBank: + memory_bank = KeyValueMemoryBank( + identifier=memory_bank_id, + provider_id="test::test-fake", + ) + self.memory_banks[memory_bank_id] = memory_bank + return self.memory_banks[memory_bank_id] + + async def unregister_memory_bank(self, memory_bank_id: str) -> None: + if memory_bank_id in self.memory_banks: + del self.memory_banks[memory_bank_id] + + +# Memory test fake implementation to +# support behaviors tested in test_memory.py +class MemoryTestFakeImpl(Memory): + memory_bank_store: MemoryBankStore + + def __init__(self): + self.memory_banks = None + self.stubs: Dict[str, Any] = {} + + def set_memory_banks(self, memory_banks: MemoryBanks) -> None: + self.memory_banks = memory_banks + + def set_stubs(self, method: str, stubs: Dict[str, Any]): + self.stubs[method] = stubs + + async def insert_documents( + self, + bank_id: str, + documents: List[MemoryBankDocument], + ttl_seconds: Optional[int] = None, + ) -> None: + if not await self.memory_banks.get_memory_bank(bank_id): + raise ValueError(f"Bank {bank_id} not found") + # No-op + # We will just ignore documents here since we will init this + # test fake with stubs to match expecting query-response pairs + + async def query_documents( + self, + bank_id: str, + query: InterleavedTextMedia, + params: Optional[Dict[str, Any]] = None, + ) -> QueryDocumentsResponse: + if not await self.memory_banks.get_memory_bank(bank_id): + raise ValueError(f"Bank {bank_id} not found") + if query not in self.stubs["query_documents"]: + raise ValueError( + f"Stub not created for query {query}, please check your test setup." + ) + + return self.stubs["query_documents"][query] diff --git a/llama_stack/providers/tests/memory/fixtures.py b/llama_stack/providers/tests/memory/fixtures.py index c9559b61c..3780b162e 100644 --- a/llama_stack/providers/tests/memory/fixtures.py +++ b/llama_stack/providers/tests/memory/fixtures.py @@ -10,14 +10,62 @@ import tempfile import pytest import pytest_asyncio +from llama_stack.apis.memory.memory import Chunk, QueryDocumentsResponse from llama_stack.distribution.datatypes import Api, Provider, RemoteProviderConfig +from llama_stack.providers.datatypes import TestFakeProviderConfig from llama_stack.providers.inline.memory.faiss import FaissImplConfig from llama_stack.providers.remote.memory.pgvector import PGVectorConfig from llama_stack.providers.remote.memory.weaviate import WeaviateConfig from llama_stack.providers.tests.resolver import construct_stack_for_test from llama_stack.providers.utils.kvstore import SqliteKVStoreConfig -from ..conftest import ProviderFixture, remote_stack_fixture +from ..conftest import ProviderFixture, remote_stack_fixture, test_fake_stack_fixture from ..env import get_env_or_fail +from .fakes import MemoryBanksTestFakeImpl, MemoryTestFakeImpl + + +@pytest.fixture(scope="session") +def query_documents_stubs(): + # These are stubs for the method calls against MemoryTestFakeImpl fake + # so the tests inside test_memory will as with a real provider + return { + "programming language": QueryDocumentsResponse( + chunks=[Chunk(content="Python", token_count=1, document_id="")], + scores=[0.1], + ), + "AI and brain-inspired computing": QueryDocumentsResponse( + chunks=[Chunk(content="neural networks", token_count=2, document_id="")], + scores=[0.1], + ), + "computer": QueryDocumentsResponse( + chunks=[ + Chunk(content="test-chunk-1", token_count=1, document_id=""), + Chunk(content="test-chunk-2", token_count=1, document_id=""), + ], + scores=[0.1, 0.5], + ), + "quantum computing": QueryDocumentsResponse( + chunks=[Chunk(content="Python", token_count=1, document_id="")], + scores=[0.5], + ), + } + + +@pytest.fixture(scope="session") +def memory_test_fake(query_documents_stubs) -> ProviderFixture: + # Prepare impl instances here, initiate fake objects and set up stubs + memory_banks_impl = MemoryBanksTestFakeImpl() + memory_impl = MemoryTestFakeImpl() + memory_impl.set_memory_banks(memory_banks_impl) + memory_impl.set_stubs("query_documents", query_documents_stubs) + + config = TestFakeProviderConfig( + impls={ + Api.memory: memory_impl, + Api.memory_banks: memory_banks_impl, + } + ) + + return test_fake_stack_fixture(config) @pytest.fixture(scope="session") @@ -93,7 +141,7 @@ def memory_chroma() -> ProviderFixture: ) -MEMORY_FIXTURES = ["faiss", "pgvector", "weaviate", "remote", "chroma"] +MEMORY_FIXTURES = ["test_fake", "faiss", "pgvector", "weaviate", "remote", "chroma"] @pytest_asyncio.fixture(scope="session") diff --git a/llama_stack/providers/tests/resolver.py b/llama_stack/providers/tests/resolver.py index df927926e..9c005d707 100644 --- a/llama_stack/providers/tests/resolver.py +++ b/llama_stack/providers/tests/resolver.py @@ -14,7 +14,10 @@ from llama_stack.distribution.build import print_pip_install_help from llama_stack.distribution.configure import parse_and_maybe_upgrade_config from llama_stack.distribution.distribution import get_provider_registry from llama_stack.distribution.request_headers import set_request_provider_data -from llama_stack.distribution.resolver import resolve_remote_stack_impls +from llama_stack.distribution.resolver import ( + resolve_remote_stack_impls, + resolve_test_fake_stack_impls, +) from llama_stack.distribution.stack import construct_stack from llama_stack.providers.utils.kvstore import SqliteKVStoreConfig @@ -51,16 +54,26 @@ async def construct_stack_for_test( ) run_config = parse_and_maybe_upgrade_config(run_config) try: - remote_config = remote_provider_config(run_config) - if not remote_config: - # TODO: add to provider registry by creating interesting mocks or fakes - impls = await construct_stack(run_config, get_provider_registry()) - else: - # we don't register resources for a remote stack as part of the fixture setup - # because the stack is already "up". if a test needs to register resources, it - # can do so manually always. + impls = None - impls = await resolve_remote_stack_impls(remote_config, run_config.apis) + # "Resolve" implementations when using test::test-fake providers. + # This is actually injecting test fakes as resolved API implementations. + test_fake_config = test_fake_provider_config(run_config) + if test_fake_config: + impls = await resolve_test_fake_stack_impls(test_fake_config) + + # Resolve implementations when using test::remove providers + if not impls: + remote_config = remote_provider_config(run_config) + if remote_config: + # we don't register resources for a remote stack as part of the fixture setup + # because the stack is already "up". if a test needs to register resources, it + # can do so manually always. + impls = await resolve_remote_stack_impls(remote_config, run_config.apis) + + # In case none of the above happened, resolve implementations as normal providers + if not impls: + impls = await construct_stack(run_config, get_provider_registry()) test_stack = TestStack(impls=impls, run_config=run_config) except ModuleNotFoundError as e: @@ -75,6 +88,28 @@ async def construct_stack_for_test( return test_stack +# In case when we want to run multiple test fakes, then we need to +# make sure the stack contains only test fake providers as after this +# we will be calling resolve_test_fake_stack_impls() which cannot resolve +# any other provider type. This could be refactored in case we need to change this. +def test_fake_provider_config( + run_config: StackRunConfig, +) -> Optional[RemoteProviderConfig]: + test_fake_config = None + has_non_test_fake = False + for api_providers in run_config.providers.values(): + for provider in api_providers: + if provider.provider_type == "test::test-fake": + test_fake_config = TestFakeProviderConfig(**provider.config) + else: + has_non_test_fake = True + + if test_fake_config: + assert not has_non_test_fake, "Test fake stack cannot have non-fake providers" + + return test_fake_config + + def remote_provider_config( run_config: StackRunConfig, ) -> Optional[RemoteProviderConfig]: