Adding memory provider mocks

This commit is contained in:
Vladimir Ivic 2024-11-20 08:56:13 -08:00
parent 7693786322
commit ac1791f8b1
4 changed files with 153 additions and 3 deletions

View file

@ -8,6 +8,8 @@ import inspect
from typing import Any, Dict, List, Set from typing import Any, Dict, List, Set
from unittest.mock import MagicMock, Mock, NonCallableMagicMock, NonCallableMock
from termcolor import cprint from termcolor import cprint
from llama_stack.providers.datatypes import * # noqa: F403 from llama_stack.providers.datatypes import * # noqa: F403
@ -33,6 +35,13 @@ from llama_stack.distribution.distribution import builtin_automatically_routed_a
from llama_stack.distribution.store import DistributionRegistry from llama_stack.distribution.store import DistributionRegistry
from llama_stack.distribution.utils.dynamic import instantiate_class_type from llama_stack.distribution.utils.dynamic import instantiate_class_type
ALLOWED_MOCK_IMPLEMENTATIONS = {
NonCallableMock,
Mock,
NonCallableMagicMock,
MagicMock,
}
class InvalidProviderError(Exception): class InvalidProviderError(Exception):
pass pass
@ -346,8 +355,18 @@ def check_protocol_compliance(obj: Any, protocol: Any) -> None:
else: else:
# Check if the method is actually implemented in the class # Check if the method is actually implemented in the class
method_owner = next( method_owner = next(
(cls for cls in mro if name in cls.__dict__), None (
cls
for cls in mro
if name in cls.__dict__
# Finding method owner for a mock object will fail
# implementaiont compliance check. In case we are holding
# a mock object we are going to skip this validation
or cls in ALLOWED_MOCK_IMPLEMENTATIONS
),
None,
) )
if ( if (
method_owner is None method_owner is None
or method_owner.__name__ == protocol.__name__ or method_owner.__name__ == protocol.__name__

View file

@ -73,6 +73,12 @@ def pytest_addoption(parser):
parser.addoption( parser.addoption(
"--env", action="append", help="Set environment variables, e.g. --env KEY=value" "--env", action="append", help="Set environment variables, e.g. --env KEY=value"
) )
"""Specify which providers will be set up as mocks"""
parser.addoption(
"--mock-overrides",
action="append",
help="Specify which providers will be set up as mocks, e.g. --mock-overrides memory=faiss,safety=meta-reference",
)
def make_provider_id(providers: Dict[str, str]) -> str: def make_provider_id(providers: Dict[str, str]) -> str:
@ -139,6 +145,17 @@ def parse_fixture_string(
return fixtures return fixtures
def should_use_mock_overrides(request, api_with_provider: str, mock_name: str) -> bool:
enabled_api_provider_overrides = request.config.getoption("--mock-overrides")
if enabled_api_provider_overrides is None:
return False
if api_with_provider in enabled_api_provider_overrides:
print(f"Overriding {api_with_provider} with mocks from {mock_name}")
return True
return False
def pytest_itemcollected(item): def pytest_itemcollected(item):
# Get all markers as a list # Get all markers as a list
filtered = ("asyncio", "parametrize") filtered = ("asyncio", "parametrize")

View file

@ -16,8 +16,9 @@ from llama_stack.providers.remote.memory.pgvector import PGVectorConfig
from llama_stack.providers.remote.memory.weaviate import WeaviateConfig from llama_stack.providers.remote.memory.weaviate import WeaviateConfig
from llama_stack.providers.tests.resolver import construct_stack_for_test from llama_stack.providers.tests.resolver import construct_stack_for_test
from llama_stack.providers.utils.kvstore import SqliteKVStoreConfig from llama_stack.providers.utils.kvstore import SqliteKVStoreConfig
from ..conftest import ProviderFixture, remote_stack_fixture from ..conftest import ProviderFixture, remote_stack_fixture, should_use_mock_overrides
from ..env import get_env_or_fail from ..env import get_env_or_fail
from .mocks import * # noqa F401
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
@ -101,10 +102,21 @@ async def memory_stack(request):
fixture_name = request.param fixture_name = request.param
fixture = request.getfixturevalue(f"memory_{fixture_name}") fixture = request.getfixturevalue(f"memory_{fixture_name}")
# Setup mocks if they are specified via the command line and they are defined
if should_use_mock_overrides(
request, f"memory={fixture_name}", f"memory_{fixture_name}_mocks"
):
try:
request.getfixturevalue(f"memory_{fixture_name}_mocks")
except pytest.FixtureLookupError:
print(
f"Fixture memory_{fixture_name}_mocks not implemented, skipping mocks."
)
test_stack = await construct_stack_for_test( test_stack = await construct_stack_for_test(
[Api.memory], [Api.memory],
{"memory": fixture.providers}, {"memory": fixture.providers},
fixture.provider_data, fixture.provider_data,
) )
return test_stack.impls[Api.memory], test_stack.impls[Api.memory_banks] yield test_stack.impls[Api.memory], test_stack.impls[Api.memory_banks]

View file

@ -0,0 +1,102 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from unittest.mock import create_autospec, patch
import pytest
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.memory.memory import * # noqa: F403
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate
class MemoryImplFake(Memory, MemoryBanksProtocolPrivate): ...
class MemoryAdapterFake(
Memory, NeedsRequestProviderData, MemoryBanksProtocolPrivate
): ...
class MethodStubs:
QUERY_DOCUMENTS_RETURN_VALUES = [
QueryDocumentsResponse(
chunks=[Chunk(content="Python", token_count=1, document_id="")],
scores=[0.1],
),
QueryDocumentsResponse(
chunks=[Chunk(content="neural networks", token_count=2, document_id="")],
scores=[0.1],
),
QueryDocumentsResponse(
chunks=[
Chunk(content="chunk-1", token_count=1, document_id=""),
Chunk(content="chunk-2", token_count=1, document_id=""),
],
scores=[0.1, 0.5],
),
QueryDocumentsResponse(
chunks=[Chunk(content="Python", token_count=1, document_id="")],
scores=[0.5],
),
]
@pytest.fixture(scope="session")
def memory_faiss_mocks(request):
with patch(
"llama_stack.providers.inline.memory.faiss.get_provider_impl",
autospec=True,
) as get_adapter_impl_mock: # noqa N806
impl_mock = create_autospec(MemoryImplFake)
impl_mock.query_documents.side_effect = (
MethodStubs.QUERY_DOCUMENTS_RETURN_VALUES
)
get_adapter_impl_mock.return_value = impl_mock
yield
@pytest.fixture(scope="session")
def memory_pgvector_mocks(request):
with patch(
"llama_stack.providers.remote.memory.pgvector.get_adapter_impl",
autospec=True,
) as get_adapter_impl_mock: # noqa N806
adapter_mock = create_autospec(MemoryAdapterFake)
adapter_mock.query_documents.side_effect = (
MethodStubs.QUERY_DOCUMENTS_RETURN_VALUES
)
get_adapter_impl_mock.return_value = adapter_mock
yield
@pytest.fixture(scope="session")
def memory_weaviate_mocks(request):
with patch(
"llama_stack.providers.remote.memory.weaviate.get_adapter_impl",
autospec=True,
) as get_adapter_impl_mock: # noqa N806
adapter_mock = create_autospec(MemoryAdapterFake)
adapter_mock.query_documents.side_effect = (
MethodStubs.QUERY_DOCUMENTS_RETURN_VALUES
)
get_adapter_impl_mock.return_value = adapter_mock
yield
@pytest.fixture(scope="session")
def memory_chroma_mocks(request):
with patch(
"llama_stack.providers.remote.memory.chroma.get_adapter_impl",
autospec=True,
) as get_adapter_impl_mock: # noqa N806
adapter_mock = create_autospec(MemoryAdapterFake)
adapter_mock.query_documents.side_effect = (
MethodStubs.QUERY_DOCUMENTS_RETURN_VALUES
)
get_adapter_impl_mock.return_value = adapter_mock
yield