migrate memory banks to Resource and new registration (#411)

* migrate memory banks to Resource and new registration

* address feedback

* address feedback

* fix tests

* pgvector fix

* pgvector fix v2

* remove auto discovery

* change register signature to make params required

* update client

* client fix

* use annotated union to parse

* remove base MemoryBank inheritence

---------

Co-authored-by: Dinesh Yeduguru <dineshyv@fb.com>
This commit is contained in:
Dinesh Yeduguru 2024-11-11 17:10:44 -08:00 committed by GitHub
parent 6b9850e11b
commit 38cce97597
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 240 additions and 129 deletions

View file

@ -83,7 +83,7 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
stored_banks = await self.kvstore.range(start_key, end_key)
for bank_data in stored_banks:
bank = VectorMemoryBankDef.model_validate_json(bank_data)
bank = VectorMemoryBank.model_validate_json(bank_data)
index = BankWithIndex(
bank=bank, index=FaissIndex(ALL_MINILM_L6_V2_DIMENSION)
)
@ -95,10 +95,10 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
async def register_memory_bank(
self,
memory_bank: MemoryBankDef,
memory_bank: MemoryBank,
) -> None:
assert (
memory_bank.type == MemoryBankType.vector.value
memory_bank.memory_bank_type == MemoryBankType.vector.value
), f"Only vector banks are supported {memory_bank.type}"
# Store in kvstore
@ -114,7 +114,7 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
)
self.cache[memory_bank.identifier] = index
async def list_memory_banks(self) -> List[MemoryBankDef]:
async def list_memory_banks(self) -> List[MemoryBank]:
return [i.bank for i in self.cache.values()]
async def insert_documents(