diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index 4c74b0d1f..07158aded 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -360,6 +360,19 @@ def check_protocol_compliance(obj: Any, protocol: Any) -> None: ) +# Here we simply want to apply mapping from the config. +# All test fakes before this point must be initialized +# and properly setup (stubs, mocks, etc). +async def resolve_test_fake_stack_impls( + config: TestFakeProviderConfig, +) -> Dict[Api, Any]: + impls = {} + for api, impl in config.impls.items(): + impls[api] = impl + + return impls + + async def resolve_remote_stack_impls( config: RemoteProviderConfig, apis: List[str], diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py index 080204e45..1c26fd4a2 100644 --- a/llama_stack/providers/datatypes.py +++ b/llama_stack/providers/datatypes.py @@ -5,7 +5,7 @@ # the root directory of this source tree. from enum import Enum -from typing import Any, List, Optional, Protocol +from typing import Any, Dict, List, Optional, Protocol from urllib.parse import urlparse from llama_models.schema_utils import json_schema_type @@ -157,6 +157,16 @@ Fully-qualified name of the module to import. The module is expected to have: ) +# Here we need this config to be as simple as possible +# to ensure we can set things up in a generic way. +# +# `impls``contains a simple Api to TestFake object mapping. +# +# Test fake objects must be initialized and properly setup beforehand. +class TestFakeProviderConfig(BaseModel): + impls: Dict[Api, Any] + + class RemoteProviderConfig(BaseModel): host: str = "localhost" port: Optional[int] = None diff --git a/llama_stack/providers/tests/conftest.py b/llama_stack/providers/tests/conftest.py index 8b73500d0..ff0e68df0 100644 --- a/llama_stack/providers/tests/conftest.py +++ b/llama_stack/providers/tests/conftest.py @@ -14,7 +14,7 @@ from pydantic import BaseModel from termcolor import colored from llama_stack.distribution.datatypes import Provider -from llama_stack.providers.datatypes import RemoteProviderConfig +from llama_stack.providers.datatypes import RemoteProviderConfig, TestFakeProviderConfig from .env import get_env_or_fail @@ -24,6 +24,20 @@ class ProviderFixture(BaseModel): provider_data: Optional[Dict[str, Any]] = None +# Generic test fake fixture. Use TestFakeProviderConfig to set test fakes +# that will be mapped to their corresponding APIs. +def test_fake_stack_fixture(config: TestFakeProviderConfig) -> ProviderFixture: + return ProviderFixture( + providers=[ + Provider( + provider_id="test::test-fake", + provider_type="test::test-fake", + config=config.model_dump(), + ) + ], + ) + + def remote_stack_fixture() -> ProviderFixture: if url := os.getenv("REMOTE_STACK_URL", None): config = RemoteProviderConfig.from_url(url) diff --git a/llama_stack/providers/tests/memory/fakes.py b/llama_stack/providers/tests/memory/fakes.py new file mode 100644 index 000000000..9fccb7d89 --- /dev/null +++ b/llama_stack/providers/tests/memory/fakes.py @@ -0,0 +1,92 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any, Dict, List, Optional + +from llama_stack.apis.memory.memory import ( + Memory, + MemoryBankDocument, + MemoryBankStore, + QueryDocumentsResponse, +) +from llama_stack.apis.memory_banks.memory_banks import KeyValueMemoryBank +from llama_models.llama3.api.datatypes import * # noqa: F403 +from llama_stack.distribution.datatypes import * # noqa: F403 + + +# MemoryBanks test fake implementation to +# support behaviors tested in test_memory.py +class MemoryBanksTestFakeImpl(MemoryBanks): + def __init__(self): + self.memory_banks: Dict[str, MemoryBank] = dict() + + async def list_memory_banks(self) -> List[MemoryBank]: + return list(self.memory_banks.values()) + + async def get_memory_bank(self, memory_bank_id: str) -> Optional[MemoryBank]: + if memory_bank_id in self.memory_banks: + return self.memory_banks[memory_bank_id] + + async def register_memory_bank( + self, + memory_bank_id: str, + params: BankParams, + provider_id: Optional[str] = None, + provider_memory_bank_id: Optional[str] = None, + ) -> MemoryBank: + memory_bank = KeyValueMemoryBank( + identifier=memory_bank_id, + provider_id="test::test-fake", + ) + self.memory_banks[memory_bank_id] = memory_bank + return self.memory_banks[memory_bank_id] + + async def unregister_memory_bank(self, memory_bank_id: str) -> None: + if memory_bank_id in self.memory_banks: + del self.memory_banks[memory_bank_id] + + +# Memory test fake implementation to +# support behaviors tested in test_memory.py +class MemoryTestFakeImpl(Memory): + memory_bank_store: MemoryBankStore + + def __init__(self): + self.memory_banks = None + self.stubs: Dict[str, Any] = {} + + def set_memory_banks(self, memory_banks: MemoryBanks) -> None: + self.memory_banks = memory_banks + + def set_stubs(self, method: str, stubs: Dict[str, Any]): + self.stubs[method] = stubs + + async def insert_documents( + self, + bank_id: str, + documents: List[MemoryBankDocument], + ttl_seconds: Optional[int] = None, + ) -> None: + if not await self.memory_banks.get_memory_bank(bank_id): + raise ValueError(f"Bank {bank_id} not found") + # No-op + # We will just ignore documents here since we will init this + # test fake with stubs to match expecting query-response pairs + + async def query_documents( + self, + bank_id: str, + query: InterleavedTextMedia, + params: Optional[Dict[str, Any]] = None, + ) -> QueryDocumentsResponse: + if not await self.memory_banks.get_memory_bank(bank_id): + raise ValueError(f"Bank {bank_id} not found") + if query not in self.stubs["query_documents"]: + raise ValueError( + f"Stub not created for query {query}, please check your test setup." + ) + + return self.stubs["query_documents"][query] diff --git a/llama_stack/providers/tests/memory/fixtures.py b/llama_stack/providers/tests/memory/fixtures.py index c9559b61c..3780b162e 100644 --- a/llama_stack/providers/tests/memory/fixtures.py +++ b/llama_stack/providers/tests/memory/fixtures.py @@ -10,14 +10,62 @@ import tempfile import pytest import pytest_asyncio +from llama_stack.apis.memory.memory import Chunk, QueryDocumentsResponse from llama_stack.distribution.datatypes import Api, Provider, RemoteProviderConfig +from llama_stack.providers.datatypes import TestFakeProviderConfig from llama_stack.providers.inline.memory.faiss import FaissImplConfig from llama_stack.providers.remote.memory.pgvector import PGVectorConfig from llama_stack.providers.remote.memory.weaviate import WeaviateConfig from llama_stack.providers.tests.resolver import construct_stack_for_test from llama_stack.providers.utils.kvstore import SqliteKVStoreConfig -from ..conftest import ProviderFixture, remote_stack_fixture +from ..conftest import ProviderFixture, remote_stack_fixture, test_fake_stack_fixture from ..env import get_env_or_fail +from .fakes import MemoryBanksTestFakeImpl, MemoryTestFakeImpl + + +@pytest.fixture(scope="session") +def query_documents_stubs(): + # These are stubs for the method calls against MemoryTestFakeImpl fake + # so the tests inside test_memory will as with a real provider + return { + "programming language": QueryDocumentsResponse( + chunks=[Chunk(content="Python", token_count=1, document_id="")], + scores=[0.1], + ), + "AI and brain-inspired computing": QueryDocumentsResponse( + chunks=[Chunk(content="neural networks", token_count=2, document_id="")], + scores=[0.1], + ), + "computer": QueryDocumentsResponse( + chunks=[ + Chunk(content="test-chunk-1", token_count=1, document_id=""), + Chunk(content="test-chunk-2", token_count=1, document_id=""), + ], + scores=[0.1, 0.5], + ), + "quantum computing": QueryDocumentsResponse( + chunks=[Chunk(content="Python", token_count=1, document_id="")], + scores=[0.5], + ), + } + + +@pytest.fixture(scope="session") +def memory_test_fake(query_documents_stubs) -> ProviderFixture: + # Prepare impl instances here, initiate fake objects and set up stubs + memory_banks_impl = MemoryBanksTestFakeImpl() + memory_impl = MemoryTestFakeImpl() + memory_impl.set_memory_banks(memory_banks_impl) + memory_impl.set_stubs("query_documents", query_documents_stubs) + + config = TestFakeProviderConfig( + impls={ + Api.memory: memory_impl, + Api.memory_banks: memory_banks_impl, + } + ) + + return test_fake_stack_fixture(config) @pytest.fixture(scope="session") @@ -93,7 +141,7 @@ def memory_chroma() -> ProviderFixture: ) -MEMORY_FIXTURES = ["faiss", "pgvector", "weaviate", "remote", "chroma"] +MEMORY_FIXTURES = ["test_fake", "faiss", "pgvector", "weaviate", "remote", "chroma"] @pytest_asyncio.fixture(scope="session") diff --git a/llama_stack/providers/tests/resolver.py b/llama_stack/providers/tests/resolver.py index df927926e..9c005d707 100644 --- a/llama_stack/providers/tests/resolver.py +++ b/llama_stack/providers/tests/resolver.py @@ -14,7 +14,10 @@ from llama_stack.distribution.build import print_pip_install_help from llama_stack.distribution.configure import parse_and_maybe_upgrade_config from llama_stack.distribution.distribution import get_provider_registry from llama_stack.distribution.request_headers import set_request_provider_data -from llama_stack.distribution.resolver import resolve_remote_stack_impls +from llama_stack.distribution.resolver import ( + resolve_remote_stack_impls, + resolve_test_fake_stack_impls, +) from llama_stack.distribution.stack import construct_stack from llama_stack.providers.utils.kvstore import SqliteKVStoreConfig @@ -51,16 +54,26 @@ async def construct_stack_for_test( ) run_config = parse_and_maybe_upgrade_config(run_config) try: - remote_config = remote_provider_config(run_config) - if not remote_config: - # TODO: add to provider registry by creating interesting mocks or fakes - impls = await construct_stack(run_config, get_provider_registry()) - else: - # we don't register resources for a remote stack as part of the fixture setup - # because the stack is already "up". if a test needs to register resources, it - # can do so manually always. + impls = None - impls = await resolve_remote_stack_impls(remote_config, run_config.apis) + # "Resolve" implementations when using test::test-fake providers. + # This is actually injecting test fakes as resolved API implementations. + test_fake_config = test_fake_provider_config(run_config) + if test_fake_config: + impls = await resolve_test_fake_stack_impls(test_fake_config) + + # Resolve implementations when using test::remove providers + if not impls: + remote_config = remote_provider_config(run_config) + if remote_config: + # we don't register resources for a remote stack as part of the fixture setup + # because the stack is already "up". if a test needs to register resources, it + # can do so manually always. + impls = await resolve_remote_stack_impls(remote_config, run_config.apis) + + # In case none of the above happened, resolve implementations as normal providers + if not impls: + impls = await construct_stack(run_config, get_provider_registry()) test_stack = TestStack(impls=impls, run_config=run_config) except ModuleNotFoundError as e: @@ -75,6 +88,28 @@ async def construct_stack_for_test( return test_stack +# In case when we want to run multiple test fakes, then we need to +# make sure the stack contains only test fake providers as after this +# we will be calling resolve_test_fake_stack_impls() which cannot resolve +# any other provider type. This could be refactored in case we need to change this. +def test_fake_provider_config( + run_config: StackRunConfig, +) -> Optional[RemoteProviderConfig]: + test_fake_config = None + has_non_test_fake = False + for api_providers in run_config.providers.values(): + for provider in api_providers: + if provider.provider_type == "test::test-fake": + test_fake_config = TestFakeProviderConfig(**provider.config) + else: + has_non_test_fake = True + + if test_fake_config: + assert not has_non_test_fake, "Test fake stack cannot have non-fake providers" + + return test_fake_config + + def remote_provider_config( run_config: StackRunConfig, ) -> Optional[RemoteProviderConfig]: