Push registration methods onto the backing providers

This commit is contained in:
Ashwin Bharambe 2024-10-05 22:17:06 -07:00 committed by Ashwin Bharambe
parent 5a7b01d292
commit 4215cc9331
14 changed files with 269 additions and 220 deletions

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
import logging
import uuid
from typing import Any, Dict, List, Optional
@ -72,38 +71,29 @@ class FaissMemoryImpl(Memory, RoutableProvider):
async def shutdown(self) -> None: ...
async def validate_routing_keys(self, routing_keys: List[str]) -> None:
print(f"[faiss] Registering memory bank routing keys: {routing_keys}")
pass
async def create_memory_bank(
async def register_memory_bank(
self,
name: str,
config: MemoryBankConfig,
url: Optional[URL] = None,
) -> MemoryBank:
assert url is None, "URL is not supported for this implementation"
memory_bank: MemoryBankDef,
) -> None:
assert (
config.type == MemoryBankType.vector.value
), f"Only vector banks are supported {config.type}"
memory_bank.type == MemoryBankType.vector.value
), f"Only vector banks are supported {memory_bank.type}"
bank_id = str(uuid.uuid4())
bank = MemoryBank(
bank_id=bank_id,
name=name,
config=config,
url=url,
index = BankWithIndex(
bank=memory_bank, index=FaissIndex(ALL_MINILM_L6_V2_DIMENSION)
)
index = BankWithIndex(bank=bank, index=FaissIndex(ALL_MINILM_L6_V2_DIMENSION))
self.cache[bank_id] = index
self.cache[memory_bank.identifier] = index
return bank
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
index = self.cache.get(bank_id)
async def get_memory_bank(self, identifier: str) -> Optional[MemoryBankDef]:
index = self.cache.get(identifier)
if index is None:
return None
return index.bank
async def list_memory_banks(self) -> List[MemoryBankDef]:
return [x.bank for x in self.cache.values()]
async def insert_documents(
self,
bank_id: str,