mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-16 22:09:27 +00:00
Summary: Part of * https://github.com/meta-llama/llama-stack/issues/436 This change adds a minimalistic support to creating memory provider test fake. For more details about the approach, follow the issue link from above. Test Plan: Run tests using the "test_fake" mark: ``` pytest llama_stack/providers/tests/memory/test_memory.py -m "test_fake" /opt/homebrew/Caskroom/miniconda/base/envs/llama-stack/lib/python3.11/site-packages/pytest_asyncio/plugin.py:208: PytestDeprecationWarning: The configuration option "asyncio_default_fixture_loop_scope" is unset. The event loop scope for asynchronous fixtures will default to the fixture caching scope. Future versions of pytest-asyncio will default the loop scope for asynchronous fixtures to function scope. Set the default fixture loop scope explicitly in order to avoid unexpected behavior in the future. Valid fixture loop scopes are: "function", "class", "module", "package", "session" warnings.warn(PytestDeprecationWarning(_DEFAULT_FIXTURE_LOOP_SCOPE_UNSET)) ====================================================================================================== test session starts ====================================================================================================== platform darwin -- Python 3.11.10, pytest-8.3.3, pluggy-1.5.0 rootdir: /llama-stack configfile: pyproject.toml plugins: asyncio-0.24.0, anyio-4.6.2.post1 asyncio: mode=Mode.STRICT, default_loop_scope=None collected 18 items / 15 deselected / 3 selected llama_stack/providers/tests/memory/test_memory.py ... [100%] ========================================================================================= 3 passed, 15 deselected, 6 warnings in 0.03s ========================================================================================== ```
92 lines
3.1 KiB
Python
92 lines
3.1 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
from llama_stack.apis.memory.memory import (
|
|
Memory,
|
|
MemoryBankDocument,
|
|
MemoryBankStore,
|
|
QueryDocumentsResponse,
|
|
)
|
|
from llama_stack.apis.memory_banks.memory_banks import KeyValueMemoryBank
|
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
|
from llama_stack.distribution.datatypes import * # noqa: F403
|
|
|
|
|
|
# MemoryBanks test fake implementation to
|
|
# support behaviors tested in test_memory.py
|
|
class MemoryBanksTestFakeImpl(MemoryBanks):
|
|
def __init__(self):
|
|
self.memory_banks: Dict[str, MemoryBank] = dict()
|
|
|
|
async def list_memory_banks(self) -> List[MemoryBank]:
|
|
return list(self.memory_banks.values())
|
|
|
|
async def get_memory_bank(self, memory_bank_id: str) -> Optional[MemoryBank]:
|
|
if memory_bank_id in self.memory_banks:
|
|
return self.memory_banks[memory_bank_id]
|
|
|
|
async def register_memory_bank(
|
|
self,
|
|
memory_bank_id: str,
|
|
params: BankParams,
|
|
provider_id: Optional[str] = None,
|
|
provider_memory_bank_id: Optional[str] = None,
|
|
) -> MemoryBank:
|
|
memory_bank = KeyValueMemoryBank(
|
|
identifier=memory_bank_id,
|
|
provider_id="test::test-fake",
|
|
)
|
|
self.memory_banks[memory_bank_id] = memory_bank
|
|
return self.memory_banks[memory_bank_id]
|
|
|
|
async def unregister_memory_bank(self, memory_bank_id: str) -> None:
|
|
if memory_bank_id in self.memory_banks:
|
|
del self.memory_banks[memory_bank_id]
|
|
|
|
|
|
# Memory test fake implementation to
|
|
# support behaviors tested in test_memory.py
|
|
class MemoryTestFakeImpl(Memory):
|
|
memory_bank_store: MemoryBankStore
|
|
|
|
def __init__(self):
|
|
self.memory_banks = None
|
|
self.stubs: Dict[str, Any] = {}
|
|
|
|
def set_memory_banks(self, memory_banks: MemoryBanks) -> None:
|
|
self.memory_banks = memory_banks
|
|
|
|
def set_stubs(self, method: str, stubs: Dict[str, Any]):
|
|
self.stubs[method] = stubs
|
|
|
|
async def insert_documents(
|
|
self,
|
|
bank_id: str,
|
|
documents: List[MemoryBankDocument],
|
|
ttl_seconds: Optional[int] = None,
|
|
) -> None:
|
|
if not await self.memory_banks.get_memory_bank(bank_id):
|
|
raise ValueError(f"Bank {bank_id} not found")
|
|
# No-op
|
|
# We will just ignore documents here since we will init this
|
|
# test fake with stubs to match expecting query-response pairs
|
|
|
|
async def query_documents(
|
|
self,
|
|
bank_id: str,
|
|
query: InterleavedTextMedia,
|
|
params: Optional[Dict[str, Any]] = None,
|
|
) -> QueryDocumentsResponse:
|
|
if not await self.memory_banks.get_memory_bank(bank_id):
|
|
raise ValueError(f"Bank {bank_id} not found")
|
|
if query not in self.stubs["query_documents"]:
|
|
raise ValueError(
|
|
f"Stub not created for query {query}, please check your test setup."
|
|
)
|
|
|
|
return self.stubs["query_documents"][query]
|