mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-01 16:24:44 +00:00
Adding memory provider mocks
This commit is contained in:
parent
7693786322
commit
ac1791f8b1
4 changed files with 153 additions and 3 deletions
|
@ -8,6 +8,8 @@ import inspect
|
||||||
|
|
||||||
from typing import Any, Dict, List, Set
|
from typing import Any, Dict, List, Set
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock, Mock, NonCallableMagicMock, NonCallableMock
|
||||||
|
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
from llama_stack.providers.datatypes import * # noqa: F403
|
from llama_stack.providers.datatypes import * # noqa: F403
|
||||||
|
@ -33,6 +35,13 @@ from llama_stack.distribution.distribution import builtin_automatically_routed_a
|
||||||
from llama_stack.distribution.store import DistributionRegistry
|
from llama_stack.distribution.store import DistributionRegistry
|
||||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||||
|
|
||||||
|
ALLOWED_MOCK_IMPLEMENTATIONS = {
|
||||||
|
NonCallableMock,
|
||||||
|
Mock,
|
||||||
|
NonCallableMagicMock,
|
||||||
|
MagicMock,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class InvalidProviderError(Exception):
|
class InvalidProviderError(Exception):
|
||||||
pass
|
pass
|
||||||
|
@ -346,8 +355,18 @@ def check_protocol_compliance(obj: Any, protocol: Any) -> None:
|
||||||
else:
|
else:
|
||||||
# Check if the method is actually implemented in the class
|
# Check if the method is actually implemented in the class
|
||||||
method_owner = next(
|
method_owner = next(
|
||||||
(cls for cls in mro if name in cls.__dict__), None
|
(
|
||||||
|
cls
|
||||||
|
for cls in mro
|
||||||
|
if name in cls.__dict__
|
||||||
|
# Finding method owner for a mock object will fail
|
||||||
|
# implementaiont compliance check. In case we are holding
|
||||||
|
# a mock object we are going to skip this validation
|
||||||
|
or cls in ALLOWED_MOCK_IMPLEMENTATIONS
|
||||||
|
),
|
||||||
|
None,
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
method_owner is None
|
method_owner is None
|
||||||
or method_owner.__name__ == protocol.__name__
|
or method_owner.__name__ == protocol.__name__
|
||||||
|
|
|
@ -73,6 +73,12 @@ def pytest_addoption(parser):
|
||||||
parser.addoption(
|
parser.addoption(
|
||||||
"--env", action="append", help="Set environment variables, e.g. --env KEY=value"
|
"--env", action="append", help="Set environment variables, e.g. --env KEY=value"
|
||||||
)
|
)
|
||||||
|
"""Specify which providers will be set up as mocks"""
|
||||||
|
parser.addoption(
|
||||||
|
"--mock-overrides",
|
||||||
|
action="append",
|
||||||
|
help="Specify which providers will be set up as mocks, e.g. --mock-overrides memory=faiss,safety=meta-reference",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def make_provider_id(providers: Dict[str, str]) -> str:
|
def make_provider_id(providers: Dict[str, str]) -> str:
|
||||||
|
@ -139,6 +145,17 @@ def parse_fixture_string(
|
||||||
return fixtures
|
return fixtures
|
||||||
|
|
||||||
|
|
||||||
|
def should_use_mock_overrides(request, api_with_provider: str, mock_name: str) -> bool:
|
||||||
|
enabled_api_provider_overrides = request.config.getoption("--mock-overrides")
|
||||||
|
if enabled_api_provider_overrides is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if api_with_provider in enabled_api_provider_overrides:
|
||||||
|
print(f"Overriding {api_with_provider} with mocks from {mock_name}")
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def pytest_itemcollected(item):
|
def pytest_itemcollected(item):
|
||||||
# Get all markers as a list
|
# Get all markers as a list
|
||||||
filtered = ("asyncio", "parametrize")
|
filtered = ("asyncio", "parametrize")
|
||||||
|
|
|
@ -16,8 +16,9 @@ from llama_stack.providers.remote.memory.pgvector import PGVectorConfig
|
||||||
from llama_stack.providers.remote.memory.weaviate import WeaviateConfig
|
from llama_stack.providers.remote.memory.weaviate import WeaviateConfig
|
||||||
from llama_stack.providers.tests.resolver import construct_stack_for_test
|
from llama_stack.providers.tests.resolver import construct_stack_for_test
|
||||||
from llama_stack.providers.utils.kvstore import SqliteKVStoreConfig
|
from llama_stack.providers.utils.kvstore import SqliteKVStoreConfig
|
||||||
from ..conftest import ProviderFixture, remote_stack_fixture
|
from ..conftest import ProviderFixture, remote_stack_fixture, should_use_mock_overrides
|
||||||
from ..env import get_env_or_fail
|
from ..env import get_env_or_fail
|
||||||
|
from .mocks import * # noqa F401
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
|
@ -101,10 +102,21 @@ async def memory_stack(request):
|
||||||
fixture_name = request.param
|
fixture_name = request.param
|
||||||
fixture = request.getfixturevalue(f"memory_{fixture_name}")
|
fixture = request.getfixturevalue(f"memory_{fixture_name}")
|
||||||
|
|
||||||
|
# Setup mocks if they are specified via the command line and they are defined
|
||||||
|
if should_use_mock_overrides(
|
||||||
|
request, f"memory={fixture_name}", f"memory_{fixture_name}_mocks"
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
request.getfixturevalue(f"memory_{fixture_name}_mocks")
|
||||||
|
except pytest.FixtureLookupError:
|
||||||
|
print(
|
||||||
|
f"Fixture memory_{fixture_name}_mocks not implemented, skipping mocks."
|
||||||
|
)
|
||||||
|
|
||||||
test_stack = await construct_stack_for_test(
|
test_stack = await construct_stack_for_test(
|
||||||
[Api.memory],
|
[Api.memory],
|
||||||
{"memory": fixture.providers},
|
{"memory": fixture.providers},
|
||||||
fixture.provider_data,
|
fixture.provider_data,
|
||||||
)
|
)
|
||||||
|
|
||||||
return test_stack.impls[Api.memory], test_stack.impls[Api.memory_banks]
|
yield test_stack.impls[Api.memory], test_stack.impls[Api.memory_banks]
|
||||||
|
|
102
llama_stack/providers/tests/memory/mocks.py
Normal file
102
llama_stack/providers/tests/memory/mocks.py
Normal file
|
@ -0,0 +1,102 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from unittest.mock import create_autospec, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
|
from llama_stack.apis.memory.memory import * # noqa: F403
|
||||||
|
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||||
|
from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryImplFake(Memory, MemoryBanksProtocolPrivate): ...
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryAdapterFake(
|
||||||
|
Memory, NeedsRequestProviderData, MemoryBanksProtocolPrivate
|
||||||
|
): ...
|
||||||
|
|
||||||
|
|
||||||
|
class MethodStubs:
|
||||||
|
QUERY_DOCUMENTS_RETURN_VALUES = [
|
||||||
|
QueryDocumentsResponse(
|
||||||
|
chunks=[Chunk(content="Python", token_count=1, document_id="")],
|
||||||
|
scores=[0.1],
|
||||||
|
),
|
||||||
|
QueryDocumentsResponse(
|
||||||
|
chunks=[Chunk(content="neural networks", token_count=2, document_id="")],
|
||||||
|
scores=[0.1],
|
||||||
|
),
|
||||||
|
QueryDocumentsResponse(
|
||||||
|
chunks=[
|
||||||
|
Chunk(content="chunk-1", token_count=1, document_id=""),
|
||||||
|
Chunk(content="chunk-2", token_count=1, document_id=""),
|
||||||
|
],
|
||||||
|
scores=[0.1, 0.5],
|
||||||
|
),
|
||||||
|
QueryDocumentsResponse(
|
||||||
|
chunks=[Chunk(content="Python", token_count=1, document_id="")],
|
||||||
|
scores=[0.5],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def memory_faiss_mocks(request):
|
||||||
|
with patch(
|
||||||
|
"llama_stack.providers.inline.memory.faiss.get_provider_impl",
|
||||||
|
autospec=True,
|
||||||
|
) as get_adapter_impl_mock: # noqa N806
|
||||||
|
impl_mock = create_autospec(MemoryImplFake)
|
||||||
|
impl_mock.query_documents.side_effect = (
|
||||||
|
MethodStubs.QUERY_DOCUMENTS_RETURN_VALUES
|
||||||
|
)
|
||||||
|
get_adapter_impl_mock.return_value = impl_mock
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def memory_pgvector_mocks(request):
|
||||||
|
with patch(
|
||||||
|
"llama_stack.providers.remote.memory.pgvector.get_adapter_impl",
|
||||||
|
autospec=True,
|
||||||
|
) as get_adapter_impl_mock: # noqa N806
|
||||||
|
adapter_mock = create_autospec(MemoryAdapterFake)
|
||||||
|
adapter_mock.query_documents.side_effect = (
|
||||||
|
MethodStubs.QUERY_DOCUMENTS_RETURN_VALUES
|
||||||
|
)
|
||||||
|
get_adapter_impl_mock.return_value = adapter_mock
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def memory_weaviate_mocks(request):
|
||||||
|
with patch(
|
||||||
|
"llama_stack.providers.remote.memory.weaviate.get_adapter_impl",
|
||||||
|
autospec=True,
|
||||||
|
) as get_adapter_impl_mock: # noqa N806
|
||||||
|
adapter_mock = create_autospec(MemoryAdapterFake)
|
||||||
|
adapter_mock.query_documents.side_effect = (
|
||||||
|
MethodStubs.QUERY_DOCUMENTS_RETURN_VALUES
|
||||||
|
)
|
||||||
|
get_adapter_impl_mock.return_value = adapter_mock
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def memory_chroma_mocks(request):
|
||||||
|
with patch(
|
||||||
|
"llama_stack.providers.remote.memory.chroma.get_adapter_impl",
|
||||||
|
autospec=True,
|
||||||
|
) as get_adapter_impl_mock: # noqa N806
|
||||||
|
adapter_mock = create_autospec(MemoryAdapterFake)
|
||||||
|
adapter_mock.query_documents.side_effect = (
|
||||||
|
MethodStubs.QUERY_DOCUMENTS_RETURN_VALUES
|
||||||
|
)
|
||||||
|
get_adapter_impl_mock.return_value = adapter_mock
|
||||||
|
yield
|
Loading…
Add table
Add a link
Reference in a new issue