Testing - Memory provider fakes

Summary:
Implementing Memory provider fakes as discussed in this draft https://github.com/meta-llama/llama-stack/pull/490#issuecomment-2492877393.

High level changes:
* Fake provider is specified via the "fake" mark
* Test config will setup a fake fixture for the run of the test
* Test resolver checks fixtures and upon finding a fake provider it injects InlineProviderSpec for fake provider
* Fake provider gets resolved through path/naming convention
* Fake provider implementaion is contained to the tests/ directory and implements stubs and method fakes with minimal functionality to simulate real provider

Instructins to creating a fake
* Create the "fakes" module inside the provider test directory
* Inside the module implement `get_provider_impl` that will return fake implementation object
* Name the fake implementation class to match the fake provider id (e.g. memory_fake -> MemoryFakeImpl)
  * Same rule for the config (e.g. memory_fake -> MemoryFakeConfig)
* Add fake fixture (in the fixtures.py) and setup methods stubs there

Test Plan:
Run memory tests
```
pytest llama_stack/providers/tests/memory/test_memory.py -m "fake" -v -s --tb=short

====================================================================================================== test session starts ======================================================================================================
platform darwin -- Python 3.11.10, pytest-8.3.3, pluggy-1.5.0 -- /opt/homebrew/Caskroom/miniconda/base/envs/llama-stack/bin/python
cachedir: .pytest_cache
rootdir: /Users/vivic/Code/llama-stack
configfile: pyproject.toml
plugins: asyncio-0.24.0, anyio-4.6.2.post1
asyncio: mode=Mode.STRICT, default_loop_scope=None
collected 18 items / 15 deselected / 3 selected

llama_stack/providers/tests/memory/test_memory.py::TestMemory::test_banks_list[fake] PASSED
llama_stack/providers/tests/memory/test_memory.py::TestMemory::test_banks_register[fake] PASSED
llama_stack/providers/tests/memory/test_memory.py::TestMemory::test_query_documents[fake] The scores are: [0.5]
PASSED

========================================================================================= 3 passed, 15 deselected, 10 warnings in 0.46s =========================================================================================

```
This commit is contained in:
Vladimir Ivic 2024-11-25 10:24:52 -08:00
parent 4e6c984c26
commit e2d1b712e2
3 changed files with 189 additions and 3 deletions

View file

@ -0,0 +1,84 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List, Optional
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.apis.memory_banks import MemoryBank
from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate
@json_schema_type
class InlineMemoryFakeConfig(BaseModel):
pass
class InlineMemoryFakeImpl(Memory, MemoryBanksProtocolPrivate):
method_stubs: Dict[str, Any] = {}
memory_banks: Dict[str, MemoryBank] = {}
@staticmethod
def stub_method(method_name: str, return_value_matchers: Dict[str, Any]) -> None:
if method_name in InlineMemoryFakeImpl.method_stubs:
InlineMemoryFakeImpl.method_stubs[method_name].update(return_value_matchers)
return
InlineMemoryFakeImpl.method_stubs[method_name] = return_value_matchers
def __init__(self, config: InlineMemoryFakeConfig) -> None:
pass
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def register_memory_bank(
self,
memory_bank: MemoryBank,
) -> None:
InlineMemoryFakeImpl.memory_banks[memory_bank.memory_bank_id] = memory_bank
async def list_memory_banks(self) -> List[MemoryBank]:
return list(InlineMemoryFakeImpl.memory_banks.values())
async def unregister_memory_bank(self, memory_bank_id: str) -> None:
if memory_bank_id not in InlineMemoryFakeImpl.memory_banks:
raise ValueError(f"Bank {memory_bank_id} not found.")
del InlineMemoryFakeImpl.memory_banks[memory_bank_id]
async def insert_documents(
self,
bank_id: str,
documents: List[MemoryBankDocument],
ttl_seconds: Optional[int] = None,
) -> None:
pass
async def query_documents(
self,
bank_id: str,
query: InterleavedTextMedia,
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse:
if query in InlineMemoryFakeImpl.method_stubs["query_documents"]:
return InlineMemoryFakeImpl.method_stubs["query_documents"][query]
raise ValueError(
f"Stub for query '{query}' not found, please set up expected result"
)
async def get_provider_impl(config: InlineMemoryFakeConfig, _deps: Any):
assert isinstance(
config, InlineMemoryFakeConfig
), f"Unexpected config type: {type(config)}"
impl = InlineMemoryFakeImpl(config)
await impl.initialize()
return impl

View file

@ -10,6 +10,7 @@ import tempfile
import pytest import pytest
import pytest_asyncio import pytest_asyncio
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.distribution.datatypes import Api, Provider, RemoteProviderConfig from llama_stack.distribution.datatypes import Api, Provider, RemoteProviderConfig
from llama_stack.providers.inline.memory.faiss import FaissImplConfig from llama_stack.providers.inline.memory.faiss import FaissImplConfig
from llama_stack.providers.remote.memory.pgvector import PGVectorConfig from llama_stack.providers.remote.memory.pgvector import PGVectorConfig
@ -18,6 +19,48 @@ from llama_stack.providers.tests.resolver import construct_stack_for_test
from llama_stack.providers.utils.kvstore import SqliteKVStoreConfig from llama_stack.providers.utils.kvstore import SqliteKVStoreConfig
from ..conftest import ProviderFixture, remote_stack_fixture from ..conftest import ProviderFixture, remote_stack_fixture
from ..env import get_env_or_fail from ..env import get_env_or_fail
from .fakes import InlineMemoryFakeImpl
@pytest.fixture(scope="session")
def memory_fake() -> ProviderFixture:
InlineMemoryFakeImpl.stub_method(
method_name="query_documents",
return_value_matchers={
"programming language": QueryDocumentsResponse(
chunks=[Chunk(content="Python", token_count=1, document_id="")],
scores=[0.1],
),
"AI and brain-inspired computing": QueryDocumentsResponse(
chunks=[
Chunk(content="neural networks", token_count=2, document_id="")
],
scores=[0.1],
),
"computer": QueryDocumentsResponse(
chunks=[
Chunk(content="test-chunk-1", token_count=1, document_id=""),
Chunk(content="test-chunk-2", token_count=1, document_id=""),
],
scores=[0.1, 0.5],
),
"quantum computing": QueryDocumentsResponse(
chunks=[Chunk(content="Python", token_count=1, document_id="")],
scores=[0.5],
),
},
)
fixture = ProviderFixture(
providers=[
Provider(
provider_id="inline_memory_fake",
provider_type="test::fake",
config={},
)
],
)
return fixture
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
@ -93,7 +136,7 @@ def memory_chroma() -> ProviderFixture:
) )
MEMORY_FIXTURES = ["faiss", "pgvector", "weaviate", "remote", "chroma"] MEMORY_FIXTURES = ["fake", "faiss", "pgvector", "weaviate", "remote", "chroma"]
@pytest_asyncio.fixture(scope="session") @pytest_asyncio.fixture(scope="session")

View file

@ -4,8 +4,10 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import importlib.util
import json import json
import tempfile import tempfile
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.distribution.datatypes import * # noqa: F403
@ -51,8 +53,11 @@ async def construct_stack_for_test(
try: try:
remote_config = remote_provider_config(run_config) remote_config = remote_provider_config(run_config)
if not remote_config: if not remote_config:
# TODO: add to provider registry by creating interesting mocks or fakes # Here we create instance of registry with optional fake provider
impls = await construct_stack(run_config, get_provider_registry()) provider_registry = setup_provider_registry_for_test(
run_config, get_provider_registry()
)
impls = await construct_stack(run_config, provider_registry)
else: else:
# we don't register resources for a remote stack as part of the fixture setup # we don't register resources for a remote stack as part of the fixture setup
# because the stack is already "up". if a test needs to register resources, it # because the stack is already "up". if a test needs to register resources, it
@ -73,6 +78,60 @@ async def construct_stack_for_test(
return test_stack return test_stack
def setup_provider_registry_for_test(
run_config: StackRunConfig, provider_registry: Dict[Api, Dict[str, ProviderSpec]]
) -> Dict[Api, Dict[str, ProviderSpec]]:
provider_registry = get_provider_registry()
for api_name, providers in run_config.providers.items():
for provider in providers:
if provider.provider_type == "test::fake":
# Check if the fake provider module exists for the API that is trying
# to use test::fake provider_type
provider_fake_module_name = (
f"llama_stack.providers.tests.{api_name}.fakes"
)
provider_fake_module_spec = importlib.util.find_spec(
provider_fake_module_name
)
if provider_fake_module_spec is None:
raise ValueError(
f"Fake provider module {provider_fake_module_name} does not exist. "
f"The module must be defined inside the providers/tests/{api_name}/fakes.py file."
)
# Import the module so we can validate that the config class exists
provider_fake_module = importlib.import_module(
provider_fake_module_name
)
# Check if the fake provider config class exists
# The class name is derived from the provider type e.g.
# provider_id: "example_provider" -> class_name: "ExampleProviderConfig"
provider_fake_config_class_name = (
f"{provider.provider_id}_config".title().replace("_", "")
)
if not hasattr(provider_fake_module, provider_fake_config_class_name):
raise ValueError(
f"Fake provider config class {provider_fake_config_class_name} "
f"does not exist in module {provider_fake_module_name}. "
f"The config class must be defined inside the providers/tests/{api_name}/fakes.py file."
)
provider_fake_config_class_path = f"llama_stack.providers.tests.{api_name}.fakes.{provider_fake_config_class_name}"
api = getattr(Api, api_name)
fake_api_spec = InlineProviderSpec(
api=api,
provider_type="test::fake",
pip_packages=[],
module=provider_fake_module_name,
config_class=provider_fake_config_class_path,
)
provider_registry[api].update({"test::fake": fake_api_spec})
return provider_registry
def remote_provider_config( def remote_provider_config(
run_config: StackRunConfig, run_config: StackRunConfig,
) -> Optional[RemoteProviderConfig]: ) -> Optional[RemoteProviderConfig]: