pgvector fix

This commit is contained in:
Dinesh Yeduguru 2024-11-11 10:28:23 -08:00
parent 5cdcdbe074
commit b2b49fbdb9

View file

@ -52,7 +52,7 @@ def load_models(cur, cls):
class PGVectorIndex(EmbeddingIndex): class PGVectorIndex(EmbeddingIndex):
def __init__(self, bank: MemoryBankDef, dimension: int, cursor): def __init__(self, bank: VectorMemoryBank, dimension: int, cursor):
self.cursor = cursor self.cursor = cursor
self.table_name = f"vector_store_{bank.identifier}" self.table_name = f"vector_store_{bank.identifier}"
@ -157,7 +157,7 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
async def register_memory_bank( async def register_memory_bank(
self, self,
memory_bank: MemoryBankDef, memory_bank: VectorMemoryBank,
) -> None: ) -> None:
assert ( assert (
memory_bank.type == MemoryBankType.vector.value memory_bank.type == MemoryBankType.vector.value
@ -176,8 +176,8 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
) )
self.cache[memory_bank.identifier] = index self.cache[memory_bank.identifier] = index
async def list_memory_banks(self) -> List[MemoryBankDef]: async def list_memory_banks(self) -> List[VectorMemoryBank]:
banks = load_models(self.cursor, MemoryBankDef) banks = load_models(self.cursor, VectorMemoryBank)
for bank in banks: for bank in banks:
if bank.identifier not in self.cache: if bank.identifier not in self.cache:
index = BankWithIndex( index = BankWithIndex(