diff --git a/llama_stack/providers/remote/memory/pgvector/pgvector.py b/llama_stack/providers/remote/memory/pgvector/pgvector.py index 0d188d944..720c7083d 100644 --- a/llama_stack/providers/remote/memory/pgvector/pgvector.py +++ b/llama_stack/providers/remote/memory/pgvector/pgvector.py @@ -52,7 +52,7 @@ def load_models(cur, cls): class PGVectorIndex(EmbeddingIndex): - def __init__(self, bank: MemoryBankDef, dimension: int, cursor): + def __init__(self, bank: VectorMemoryBank, dimension: int, cursor): self.cursor = cursor self.table_name = f"vector_store_{bank.identifier}" @@ -157,7 +157,7 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate): async def register_memory_bank( self, - memory_bank: MemoryBankDef, + memory_bank: VectorMemoryBank, ) -> None: assert ( memory_bank.type == MemoryBankType.vector.value @@ -176,8 +176,8 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate): ) self.cache[memory_bank.identifier] = index - async def list_memory_banks(self) -> List[MemoryBankDef]: - banks = load_models(self.cursor, MemoryBankDef) + async def list_memory_banks(self) -> List[VectorMemoryBank]: + banks = load_models(self.cursor, VectorMemoryBank) for bank in banks: if bank.identifier not in self.cache: index = BankWithIndex(