From e2d1b712e28c5c58d37a359b37549e91243bcfd2 Mon Sep 17 00:00:00 2001 From: Vladimir Ivic Date: Mon, 25 Nov 2024 10:24:52 -0800 Subject: [PATCH] Testing - Memory provider fakes Summary: Implementing Memory provider fakes as discussed in this draft https://github.com/meta-llama/llama-stack/pull/490#issuecomment-2492877393. High level changes: * Fake provider is specified via the "fake" mark * Test config will setup a fake fixture for the run of the test * Test resolver checks fixtures and upon finding a fake provider it injects InlineProviderSpec for fake provider * Fake provider gets resolved through path/naming convention * Fake provider implementaion is contained to the tests/ directory and implements stubs and method fakes with minimal functionality to simulate real provider Instructins to creating a fake * Create the "fakes" module inside the provider test directory * Inside the module implement `get_provider_impl` that will return fake implementation object * Name the fake implementation class to match the fake provider id (e.g. memory_fake -> MemoryFakeImpl) * Same rule for the config (e.g. memory_fake -> MemoryFakeConfig) * Add fake fixture (in the fixtures.py) and setup methods stubs there Test Plan: Run memory tests ``` pytest llama_stack/providers/tests/memory/test_memory.py -m "fake" -v -s --tb=short ====================================================================================================== test session starts ====================================================================================================== platform darwin -- Python 3.11.10, pytest-8.3.3, pluggy-1.5.0 -- /opt/homebrew/Caskroom/miniconda/base/envs/llama-stack/bin/python cachedir: .pytest_cache rootdir: /Users/vivic/Code/llama-stack configfile: pyproject.toml plugins: asyncio-0.24.0, anyio-4.6.2.post1 asyncio: mode=Mode.STRICT, default_loop_scope=None collected 18 items / 15 deselected / 3 selected llama_stack/providers/tests/memory/test_memory.py::TestMemory::test_banks_list[fake] PASSED llama_stack/providers/tests/memory/test_memory.py::TestMemory::test_banks_register[fake] PASSED llama_stack/providers/tests/memory/test_memory.py::TestMemory::test_query_documents[fake] The scores are: [0.5] PASSED ========================================================================================= 3 passed, 15 deselected, 10 warnings in 0.46s ========================================================================================= ``` --- llama_stack/providers/tests/memory/fakes.py | 84 +++++++++++++++++++ .../providers/tests/memory/fixtures.py | 45 +++++++++- llama_stack/providers/tests/resolver.py | 63 +++++++++++++- 3 files changed, 189 insertions(+), 3 deletions(-) create mode 100644 llama_stack/providers/tests/memory/fakes.py diff --git a/llama_stack/providers/tests/memory/fakes.py b/llama_stack/providers/tests/memory/fakes.py new file mode 100644 index 000000000..71983c165 --- /dev/null +++ b/llama_stack/providers/tests/memory/fakes.py @@ -0,0 +1,84 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any, Dict, List, Optional + +from llama_models.schema_utils import json_schema_type +from pydantic import BaseModel + +from llama_stack.apis.memory import * # noqa: F403 +from llama_stack.apis.memory_banks import MemoryBank +from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate + + +@json_schema_type +class InlineMemoryFakeConfig(BaseModel): + pass + + +class InlineMemoryFakeImpl(Memory, MemoryBanksProtocolPrivate): + method_stubs: Dict[str, Any] = {} + memory_banks: Dict[str, MemoryBank] = {} + + @staticmethod + def stub_method(method_name: str, return_value_matchers: Dict[str, Any]) -> None: + if method_name in InlineMemoryFakeImpl.method_stubs: + InlineMemoryFakeImpl.method_stubs[method_name].update(return_value_matchers) + return + InlineMemoryFakeImpl.method_stubs[method_name] = return_value_matchers + + def __init__(self, config: InlineMemoryFakeConfig) -> None: + pass + + async def initialize(self) -> None: + pass + + async def shutdown(self) -> None: + pass + + async def register_memory_bank( + self, + memory_bank: MemoryBank, + ) -> None: + InlineMemoryFakeImpl.memory_banks[memory_bank.memory_bank_id] = memory_bank + + async def list_memory_banks(self) -> List[MemoryBank]: + return list(InlineMemoryFakeImpl.memory_banks.values()) + + async def unregister_memory_bank(self, memory_bank_id: str) -> None: + if memory_bank_id not in InlineMemoryFakeImpl.memory_banks: + raise ValueError(f"Bank {memory_bank_id} not found.") + del InlineMemoryFakeImpl.memory_banks[memory_bank_id] + + async def insert_documents( + self, + bank_id: str, + documents: List[MemoryBankDocument], + ttl_seconds: Optional[int] = None, + ) -> None: + pass + + async def query_documents( + self, + bank_id: str, + query: InterleavedTextMedia, + params: Optional[Dict[str, Any]] = None, + ) -> QueryDocumentsResponse: + if query in InlineMemoryFakeImpl.method_stubs["query_documents"]: + return InlineMemoryFakeImpl.method_stubs["query_documents"][query] + raise ValueError( + f"Stub for query '{query}' not found, please set up expected result" + ) + + +async def get_provider_impl(config: InlineMemoryFakeConfig, _deps: Any): + assert isinstance( + config, InlineMemoryFakeConfig + ), f"Unexpected config type: {type(config)}" + + impl = InlineMemoryFakeImpl(config) + await impl.initialize() + return impl diff --git a/llama_stack/providers/tests/memory/fixtures.py b/llama_stack/providers/tests/memory/fixtures.py index c9559b61c..b996cba52 100644 --- a/llama_stack/providers/tests/memory/fixtures.py +++ b/llama_stack/providers/tests/memory/fixtures.py @@ -10,6 +10,7 @@ import tempfile import pytest import pytest_asyncio +from llama_stack.apis.memory import * # noqa: F403 from llama_stack.distribution.datatypes import Api, Provider, RemoteProviderConfig from llama_stack.providers.inline.memory.faiss import FaissImplConfig from llama_stack.providers.remote.memory.pgvector import PGVectorConfig @@ -18,6 +19,48 @@ from llama_stack.providers.tests.resolver import construct_stack_for_test from llama_stack.providers.utils.kvstore import SqliteKVStoreConfig from ..conftest import ProviderFixture, remote_stack_fixture from ..env import get_env_or_fail +from .fakes import InlineMemoryFakeImpl + + +@pytest.fixture(scope="session") +def memory_fake() -> ProviderFixture: + InlineMemoryFakeImpl.stub_method( + method_name="query_documents", + return_value_matchers={ + "programming language": QueryDocumentsResponse( + chunks=[Chunk(content="Python", token_count=1, document_id="")], + scores=[0.1], + ), + "AI and brain-inspired computing": QueryDocumentsResponse( + chunks=[ + Chunk(content="neural networks", token_count=2, document_id="") + ], + scores=[0.1], + ), + "computer": QueryDocumentsResponse( + chunks=[ + Chunk(content="test-chunk-1", token_count=1, document_id=""), + Chunk(content="test-chunk-2", token_count=1, document_id=""), + ], + scores=[0.1, 0.5], + ), + "quantum computing": QueryDocumentsResponse( + chunks=[Chunk(content="Python", token_count=1, document_id="")], + scores=[0.5], + ), + }, + ) + + fixture = ProviderFixture( + providers=[ + Provider( + provider_id="inline_memory_fake", + provider_type="test::fake", + config={}, + ) + ], + ) + return fixture @pytest.fixture(scope="session") @@ -93,7 +136,7 @@ def memory_chroma() -> ProviderFixture: ) -MEMORY_FIXTURES = ["faiss", "pgvector", "weaviate", "remote", "chroma"] +MEMORY_FIXTURES = ["fake", "faiss", "pgvector", "weaviate", "remote", "chroma"] @pytest_asyncio.fixture(scope="session") diff --git a/llama_stack/providers/tests/resolver.py b/llama_stack/providers/tests/resolver.py index 8bbb902cd..a6cf1d8bc 100644 --- a/llama_stack/providers/tests/resolver.py +++ b/llama_stack/providers/tests/resolver.py @@ -4,8 +4,10 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import importlib.util import json import tempfile + from typing import Any, Dict, List, Optional from llama_stack.distribution.datatypes import * # noqa: F403 @@ -51,8 +53,11 @@ async def construct_stack_for_test( try: remote_config = remote_provider_config(run_config) if not remote_config: - # TODO: add to provider registry by creating interesting mocks or fakes - impls = await construct_stack(run_config, get_provider_registry()) + # Here we create instance of registry with optional fake provider + provider_registry = setup_provider_registry_for_test( + run_config, get_provider_registry() + ) + impls = await construct_stack(run_config, provider_registry) else: # we don't register resources for a remote stack as part of the fixture setup # because the stack is already "up". if a test needs to register resources, it @@ -73,6 +78,60 @@ async def construct_stack_for_test( return test_stack +def setup_provider_registry_for_test( + run_config: StackRunConfig, provider_registry: Dict[Api, Dict[str, ProviderSpec]] +) -> Dict[Api, Dict[str, ProviderSpec]]: + provider_registry = get_provider_registry() + for api_name, providers in run_config.providers.items(): + for provider in providers: + if provider.provider_type == "test::fake": + # Check if the fake provider module exists for the API that is trying + # to use test::fake provider_type + provider_fake_module_name = ( + f"llama_stack.providers.tests.{api_name}.fakes" + ) + provider_fake_module_spec = importlib.util.find_spec( + provider_fake_module_name + ) + if provider_fake_module_spec is None: + raise ValueError( + f"Fake provider module {provider_fake_module_name} does not exist. " + f"The module must be defined inside the providers/tests/{api_name}/fakes.py file." + ) + + # Import the module so we can validate that the config class exists + provider_fake_module = importlib.import_module( + provider_fake_module_name + ) + + # Check if the fake provider config class exists + # The class name is derived from the provider type e.g. + # provider_id: "example_provider" -> class_name: "ExampleProviderConfig" + provider_fake_config_class_name = ( + f"{provider.provider_id}_config".title().replace("_", "") + ) + if not hasattr(provider_fake_module, provider_fake_config_class_name): + raise ValueError( + f"Fake provider config class {provider_fake_config_class_name} " + f"does not exist in module {provider_fake_module_name}. " + f"The config class must be defined inside the providers/tests/{api_name}/fakes.py file." + ) + + provider_fake_config_class_path = f"llama_stack.providers.tests.{api_name}.fakes.{provider_fake_config_class_name}" + + api = getattr(Api, api_name) + fake_api_spec = InlineProviderSpec( + api=api, + provider_type="test::fake", + pip_packages=[], + module=provider_fake_module_name, + config_class=provider_fake_config_class_path, + ) + provider_registry[api].update({"test::fake": fake_api_spec}) + + return provider_registry + + def remote_provider_config( run_config: StackRunConfig, ) -> Optional[RemoteProviderConfig]: