From ac1791f8b16c3d5d9c75f2719ebb7be7df08384f Mon Sep 17 00:00:00 2001 From: Vladimir Ivic Date: Wed, 20 Nov 2024 08:56:13 -0800 Subject: [PATCH] Adding memory provider mocks --- llama_stack/distribution/resolver.py | 21 +++- llama_stack/providers/tests/conftest.py | 17 +++ .../providers/tests/memory/fixtures.py | 16 ++- llama_stack/providers/tests/memory/mocks.py | 102 ++++++++++++++++++ 4 files changed, 153 insertions(+), 3 deletions(-) create mode 100644 llama_stack/providers/tests/memory/mocks.py diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index 4c74b0d1f..cb9c1a8ef 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -8,6 +8,8 @@ import inspect from typing import Any, Dict, List, Set +from unittest.mock import MagicMock, Mock, NonCallableMagicMock, NonCallableMock + from termcolor import cprint from llama_stack.providers.datatypes import * # noqa: F403 @@ -33,6 +35,13 @@ from llama_stack.distribution.distribution import builtin_automatically_routed_a from llama_stack.distribution.store import DistributionRegistry from llama_stack.distribution.utils.dynamic import instantiate_class_type +ALLOWED_MOCK_IMPLEMENTATIONS = { + NonCallableMock, + Mock, + NonCallableMagicMock, + MagicMock, +} + class InvalidProviderError(Exception): pass @@ -346,8 +355,18 @@ def check_protocol_compliance(obj: Any, protocol: Any) -> None: else: # Check if the method is actually implemented in the class method_owner = next( - (cls for cls in mro if name in cls.__dict__), None + ( + cls + for cls in mro + if name in cls.__dict__ + # Finding method owner for a mock object will fail + # implementaiont compliance check. In case we are holding + # a mock object we are going to skip this validation + or cls in ALLOWED_MOCK_IMPLEMENTATIONS + ), + None, ) + if ( method_owner is None or method_owner.__name__ == protocol.__name__ diff --git a/llama_stack/providers/tests/conftest.py b/llama_stack/providers/tests/conftest.py index 8b73500d0..39faafd7e 100644 --- a/llama_stack/providers/tests/conftest.py +++ b/llama_stack/providers/tests/conftest.py @@ -73,6 +73,12 @@ def pytest_addoption(parser): parser.addoption( "--env", action="append", help="Set environment variables, e.g. --env KEY=value" ) + """Specify which providers will be set up as mocks""" + parser.addoption( + "--mock-overrides", + action="append", + help="Specify which providers will be set up as mocks, e.g. --mock-overrides memory=faiss,safety=meta-reference", + ) def make_provider_id(providers: Dict[str, str]) -> str: @@ -139,6 +145,17 @@ def parse_fixture_string( return fixtures +def should_use_mock_overrides(request, api_with_provider: str, mock_name: str) -> bool: + enabled_api_provider_overrides = request.config.getoption("--mock-overrides") + if enabled_api_provider_overrides is None: + return False + + if api_with_provider in enabled_api_provider_overrides: + print(f"Overriding {api_with_provider} with mocks from {mock_name}") + return True + return False + + def pytest_itemcollected(item): # Get all markers as a list filtered = ("asyncio", "parametrize") diff --git a/llama_stack/providers/tests/memory/fixtures.py b/llama_stack/providers/tests/memory/fixtures.py index c9559b61c..2f5d422bd 100644 --- a/llama_stack/providers/tests/memory/fixtures.py +++ b/llama_stack/providers/tests/memory/fixtures.py @@ -16,8 +16,9 @@ from llama_stack.providers.remote.memory.pgvector import PGVectorConfig from llama_stack.providers.remote.memory.weaviate import WeaviateConfig from llama_stack.providers.tests.resolver import construct_stack_for_test from llama_stack.providers.utils.kvstore import SqliteKVStoreConfig -from ..conftest import ProviderFixture, remote_stack_fixture +from ..conftest import ProviderFixture, remote_stack_fixture, should_use_mock_overrides from ..env import get_env_or_fail +from .mocks import * # noqa F401 @pytest.fixture(scope="session") @@ -101,10 +102,21 @@ async def memory_stack(request): fixture_name = request.param fixture = request.getfixturevalue(f"memory_{fixture_name}") + # Setup mocks if they are specified via the command line and they are defined + if should_use_mock_overrides( + request, f"memory={fixture_name}", f"memory_{fixture_name}_mocks" + ): + try: + request.getfixturevalue(f"memory_{fixture_name}_mocks") + except pytest.FixtureLookupError: + print( + f"Fixture memory_{fixture_name}_mocks not implemented, skipping mocks." + ) + test_stack = await construct_stack_for_test( [Api.memory], {"memory": fixture.providers}, fixture.provider_data, ) - return test_stack.impls[Api.memory], test_stack.impls[Api.memory_banks] + yield test_stack.impls[Api.memory], test_stack.impls[Api.memory_banks] diff --git a/llama_stack/providers/tests/memory/mocks.py b/llama_stack/providers/tests/memory/mocks.py new file mode 100644 index 000000000..717ea9c6a --- /dev/null +++ b/llama_stack/providers/tests/memory/mocks.py @@ -0,0 +1,102 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from unittest.mock import create_autospec, patch + +import pytest + +from llama_models.llama3.api.datatypes import * # noqa: F403 +from llama_stack.apis.memory.memory import * # noqa: F403 +from llama_stack.distribution.request_headers import NeedsRequestProviderData +from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate + + +class MemoryImplFake(Memory, MemoryBanksProtocolPrivate): ... + + +class MemoryAdapterFake( + Memory, NeedsRequestProviderData, MemoryBanksProtocolPrivate +): ... + + +class MethodStubs: + QUERY_DOCUMENTS_RETURN_VALUES = [ + QueryDocumentsResponse( + chunks=[Chunk(content="Python", token_count=1, document_id="")], + scores=[0.1], + ), + QueryDocumentsResponse( + chunks=[Chunk(content="neural networks", token_count=2, document_id="")], + scores=[0.1], + ), + QueryDocumentsResponse( + chunks=[ + Chunk(content="chunk-1", token_count=1, document_id=""), + Chunk(content="chunk-2", token_count=1, document_id=""), + ], + scores=[0.1, 0.5], + ), + QueryDocumentsResponse( + chunks=[Chunk(content="Python", token_count=1, document_id="")], + scores=[0.5], + ), + ] + + +@pytest.fixture(scope="session") +def memory_faiss_mocks(request): + with patch( + "llama_stack.providers.inline.memory.faiss.get_provider_impl", + autospec=True, + ) as get_adapter_impl_mock: # noqa N806 + impl_mock = create_autospec(MemoryImplFake) + impl_mock.query_documents.side_effect = ( + MethodStubs.QUERY_DOCUMENTS_RETURN_VALUES + ) + get_adapter_impl_mock.return_value = impl_mock + yield + + +@pytest.fixture(scope="session") +def memory_pgvector_mocks(request): + with patch( + "llama_stack.providers.remote.memory.pgvector.get_adapter_impl", + autospec=True, + ) as get_adapter_impl_mock: # noqa N806 + adapter_mock = create_autospec(MemoryAdapterFake) + adapter_mock.query_documents.side_effect = ( + MethodStubs.QUERY_DOCUMENTS_RETURN_VALUES + ) + get_adapter_impl_mock.return_value = adapter_mock + yield + + +@pytest.fixture(scope="session") +def memory_weaviate_mocks(request): + with patch( + "llama_stack.providers.remote.memory.weaviate.get_adapter_impl", + autospec=True, + ) as get_adapter_impl_mock: # noqa N806 + adapter_mock = create_autospec(MemoryAdapterFake) + adapter_mock.query_documents.side_effect = ( + MethodStubs.QUERY_DOCUMENTS_RETURN_VALUES + ) + get_adapter_impl_mock.return_value = adapter_mock + yield + + +@pytest.fixture(scope="session") +def memory_chroma_mocks(request): + with patch( + "llama_stack.providers.remote.memory.chroma.get_adapter_impl", + autospec=True, + ) as get_adapter_impl_mock: # noqa N806 + adapter_mock = create_autospec(MemoryAdapterFake) + adapter_mock.query_documents.side_effect = ( + MethodStubs.QUERY_DOCUMENTS_RETURN_VALUES + ) + get_adapter_impl_mock.return_value = adapter_mock + yield