diff --git a/llama_stack/providers/tests/memory/fakes.py b/llama_stack/providers/tests/memory/fakes.py new file mode 100644 index 000000000..71983c165 --- /dev/null +++ b/llama_stack/providers/tests/memory/fakes.py @@ -0,0 +1,84 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any, Dict, List, Optional + +from llama_models.schema_utils import json_schema_type +from pydantic import BaseModel + +from llama_stack.apis.memory import * # noqa: F403 +from llama_stack.apis.memory_banks import MemoryBank +from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate + + +@json_schema_type +class InlineMemoryFakeConfig(BaseModel): + pass + + +class InlineMemoryFakeImpl(Memory, MemoryBanksProtocolPrivate): + method_stubs: Dict[str, Any] = {} + memory_banks: Dict[str, MemoryBank] = {} + + @staticmethod + def stub_method(method_name: str, return_value_matchers: Dict[str, Any]) -> None: + if method_name in InlineMemoryFakeImpl.method_stubs: + InlineMemoryFakeImpl.method_stubs[method_name].update(return_value_matchers) + return + InlineMemoryFakeImpl.method_stubs[method_name] = return_value_matchers + + def __init__(self, config: InlineMemoryFakeConfig) -> None: + pass + + async def initialize(self) -> None: + pass + + async def shutdown(self) -> None: + pass + + async def register_memory_bank( + self, + memory_bank: MemoryBank, + ) -> None: + InlineMemoryFakeImpl.memory_banks[memory_bank.memory_bank_id] = memory_bank + + async def list_memory_banks(self) -> List[MemoryBank]: + return list(InlineMemoryFakeImpl.memory_banks.values()) + + async def unregister_memory_bank(self, memory_bank_id: str) -> None: + if memory_bank_id not in InlineMemoryFakeImpl.memory_banks: + raise ValueError(f"Bank {memory_bank_id} not found.") + del InlineMemoryFakeImpl.memory_banks[memory_bank_id] + + async def insert_documents( + self, + bank_id: str, + documents: List[MemoryBankDocument], + ttl_seconds: Optional[int] = None, + ) -> None: + pass + + async def query_documents( + self, + bank_id: str, + query: InterleavedTextMedia, + params: Optional[Dict[str, Any]] = None, + ) -> QueryDocumentsResponse: + if query in InlineMemoryFakeImpl.method_stubs["query_documents"]: + return InlineMemoryFakeImpl.method_stubs["query_documents"][query] + raise ValueError( + f"Stub for query '{query}' not found, please set up expected result" + ) + + +async def get_provider_impl(config: InlineMemoryFakeConfig, _deps: Any): + assert isinstance( + config, InlineMemoryFakeConfig + ), f"Unexpected config type: {type(config)}" + + impl = InlineMemoryFakeImpl(config) + await impl.initialize() + return impl diff --git a/llama_stack/providers/tests/memory/fixtures.py b/llama_stack/providers/tests/memory/fixtures.py index c9559b61c..b996cba52 100644 --- a/llama_stack/providers/tests/memory/fixtures.py +++ b/llama_stack/providers/tests/memory/fixtures.py @@ -10,6 +10,7 @@ import tempfile import pytest import pytest_asyncio +from llama_stack.apis.memory import * # noqa: F403 from llama_stack.distribution.datatypes import Api, Provider, RemoteProviderConfig from llama_stack.providers.inline.memory.faiss import FaissImplConfig from llama_stack.providers.remote.memory.pgvector import PGVectorConfig @@ -18,6 +19,48 @@ from llama_stack.providers.tests.resolver import construct_stack_for_test from llama_stack.providers.utils.kvstore import SqliteKVStoreConfig from ..conftest import ProviderFixture, remote_stack_fixture from ..env import get_env_or_fail +from .fakes import InlineMemoryFakeImpl + + +@pytest.fixture(scope="session") +def memory_fake() -> ProviderFixture: + InlineMemoryFakeImpl.stub_method( + method_name="query_documents", + return_value_matchers={ + "programming language": QueryDocumentsResponse( + chunks=[Chunk(content="Python", token_count=1, document_id="")], + scores=[0.1], + ), + "AI and brain-inspired computing": QueryDocumentsResponse( + chunks=[ + Chunk(content="neural networks", token_count=2, document_id="") + ], + scores=[0.1], + ), + "computer": QueryDocumentsResponse( + chunks=[ + Chunk(content="test-chunk-1", token_count=1, document_id=""), + Chunk(content="test-chunk-2", token_count=1, document_id=""), + ], + scores=[0.1, 0.5], + ), + "quantum computing": QueryDocumentsResponse( + chunks=[Chunk(content="Python", token_count=1, document_id="")], + scores=[0.5], + ), + }, + ) + + fixture = ProviderFixture( + providers=[ + Provider( + provider_id="inline_memory_fake", + provider_type="test::fake", + config={}, + ) + ], + ) + return fixture @pytest.fixture(scope="session") @@ -93,7 +136,7 @@ def memory_chroma() -> ProviderFixture: ) -MEMORY_FIXTURES = ["faiss", "pgvector", "weaviate", "remote", "chroma"] +MEMORY_FIXTURES = ["fake", "faiss", "pgvector", "weaviate", "remote", "chroma"] @pytest_asyncio.fixture(scope="session") diff --git a/llama_stack/providers/tests/resolver.py b/llama_stack/providers/tests/resolver.py index 8bbb902cd..a6cf1d8bc 100644 --- a/llama_stack/providers/tests/resolver.py +++ b/llama_stack/providers/tests/resolver.py @@ -4,8 +4,10 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import importlib.util import json import tempfile + from typing import Any, Dict, List, Optional from llama_stack.distribution.datatypes import * # noqa: F403 @@ -51,8 +53,11 @@ async def construct_stack_for_test( try: remote_config = remote_provider_config(run_config) if not remote_config: - # TODO: add to provider registry by creating interesting mocks or fakes - impls = await construct_stack(run_config, get_provider_registry()) + # Here we create instance of registry with optional fake provider + provider_registry = setup_provider_registry_for_test( + run_config, get_provider_registry() + ) + impls = await construct_stack(run_config, provider_registry) else: # we don't register resources for a remote stack as part of the fixture setup # because the stack is already "up". if a test needs to register resources, it @@ -73,6 +78,60 @@ async def construct_stack_for_test( return test_stack +def setup_provider_registry_for_test( + run_config: StackRunConfig, provider_registry: Dict[Api, Dict[str, ProviderSpec]] +) -> Dict[Api, Dict[str, ProviderSpec]]: + provider_registry = get_provider_registry() + for api_name, providers in run_config.providers.items(): + for provider in providers: + if provider.provider_type == "test::fake": + # Check if the fake provider module exists for the API that is trying + # to use test::fake provider_type + provider_fake_module_name = ( + f"llama_stack.providers.tests.{api_name}.fakes" + ) + provider_fake_module_spec = importlib.util.find_spec( + provider_fake_module_name + ) + if provider_fake_module_spec is None: + raise ValueError( + f"Fake provider module {provider_fake_module_name} does not exist. " + f"The module must be defined inside the providers/tests/{api_name}/fakes.py file." + ) + + # Import the module so we can validate that the config class exists + provider_fake_module = importlib.import_module( + provider_fake_module_name + ) + + # Check if the fake provider config class exists + # The class name is derived from the provider type e.g. + # provider_id: "example_provider" -> class_name: "ExampleProviderConfig" + provider_fake_config_class_name = ( + f"{provider.provider_id}_config".title().replace("_", "") + ) + if not hasattr(provider_fake_module, provider_fake_config_class_name): + raise ValueError( + f"Fake provider config class {provider_fake_config_class_name} " + f"does not exist in module {provider_fake_module_name}. " + f"The config class must be defined inside the providers/tests/{api_name}/fakes.py file." + ) + + provider_fake_config_class_path = f"llama_stack.providers.tests.{api_name}.fakes.{provider_fake_config_class_name}" + + api = getattr(Api, api_name) + fake_api_spec = InlineProviderSpec( + api=api, + provider_type="test::fake", + pip_packages=[], + module=provider_fake_module_name, + config_class=provider_fake_config_class_path, + ) + provider_registry[api].update({"test::fake": fake_api_spec}) + + return provider_registry + + def remote_provider_config( run_config: StackRunConfig, ) -> Optional[RemoteProviderConfig]: