pgvector fix

This commit is contained in:
Dinesh Yeduguru 2024-11-11 10:28:23 -08:00
parent 5cdcdbe074
commit b2b49fbdb9

View file

@ -52,7 +52,7 @@ def load_models(cur, cls):
class PGVectorIndex(EmbeddingIndex):
def __init__(self, bank: MemoryBankDef, dimension: int, cursor):
def __init__(self, bank: VectorMemoryBank, dimension: int, cursor):
self.cursor = cursor
self.table_name = f"vector_store_{bank.identifier}"
@ -157,7 +157,7 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
async def register_memory_bank(
self,
memory_bank: MemoryBankDef,
memory_bank: VectorMemoryBank,
) -> None:
assert (
memory_bank.type == MemoryBankType.vector.value
@ -176,8 +176,8 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
)
self.cache[memory_bank.identifier] = index
async def list_memory_banks(self) -> List[MemoryBankDef]:
banks = load_models(self.cursor, MemoryBankDef)
async def list_memory_banks(self) -> List[VectorMemoryBank]:
banks = load_models(self.cursor, VectorMemoryBank)
for bank in banks:
if bank.identifier not in self.cache:
index = BankWithIndex(