From 2f51af1bb753054e7ce9570e351936c625f1b865 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Mon, 4 Nov 2024 17:17:01 -0800 Subject: [PATCH] init --- .../impls/meta_reference/memory/config.py | 5 +-- .../impls/meta_reference/memory/faiss.py | 36 ++++++++++++++++--- 2 files changed, 35 insertions(+), 6 deletions(-) diff --git a/llama_stack/providers/impls/meta_reference/memory/config.py b/llama_stack/providers/impls/meta_reference/memory/config.py index b1c94c889..ae005ac63 100644 --- a/llama_stack/providers/impls/meta_reference/memory/config.py +++ b/llama_stack/providers/impls/meta_reference/memory/config.py @@ -5,9 +5,10 @@ # the root directory of this source tree. from llama_models.schema_utils import json_schema_type - +from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig from pydantic import BaseModel @json_schema_type -class FaissImplConfig(BaseModel): ... +class FaissImplConfig(BaseModel): + kvstore: KVStoreConfig = SqliteKVStoreConfig() # Uses default SQLite config diff --git a/llama_stack/providers/impls/meta_reference/memory/faiss.py b/llama_stack/providers/impls/meta_reference/memory/faiss.py index 02829f7be..afeb7207b 100644 --- a/llama_stack/providers/impls/meta_reference/memory/faiss.py +++ b/llama_stack/providers/impls/meta_reference/memory/faiss.py @@ -25,9 +25,12 @@ from llama_stack.providers.utils.memory.vector_store import ( from llama_stack.providers.utils.telemetry import tracing from .config import FaissImplConfig +from llama_stack.providers.utils.kvstore import kvstore_impl logger = logging.getLogger(__name__) +MEMORY_BANKS_PREFIX = "memory_banks:" + class FaissIndex(EmbeddingIndex): id_by_index: Dict[int, str] @@ -69,10 +72,26 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate): def __init__(self, config: FaissImplConfig) -> None: self.config = config self.cache = {} + self.kvstore = None + + async def initialize(self) -> None: + self.kvstore = await kvstore_impl(self.config.kvstore) + # Load existing banks from kvstore + start_key = MEMORY_BANKS_PREFIX + end_key = f"{MEMORY_BANKS_PREFIX}\xff" + stored_banks = await self.kvstore.range(start_key, end_key) + + for bank_data in stored_banks: + bank = VectorMemoryBankDef.model_validate_json(bank_data) + index = BankWithIndex( + bank=bank, + index=FaissIndex(ALL_MINILM_L6_V2_DIMENSION) + ) + self.cache[bank.identifier] = index - async def initialize(self) -> None: ... - - async def shutdown(self) -> None: ... + async def shutdown(self) -> None: + # Cleanup if needed + pass async def register_memory_bank( self, @@ -82,8 +101,17 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate): memory_bank.type == MemoryBankType.vector.value ), f"Only vector banks are supported {memory_bank.type}" + # Store in kvstore + key = f"{MEMORY_BANKS_PREFIX}{memory_bank.identifier}" + await self.kvstore.set( + key=key, + value=memory_bank.json(), + ) + + # Store in cache index = BankWithIndex( - bank=memory_bank, index=FaissIndex(ALL_MINILM_L6_V2_DIMENSION) + bank=memory_bank, + index=FaissIndex(ALL_MINILM_L6_V2_DIMENSION) ) self.cache[memory_bank.identifier] = index