remove mixin and test fixes

This commit is contained in:
Dinesh Yeduguru 2024-12-09 15:00:12 -08:00
parent 5bbeb985ca
commit 0e451525e5
9 changed files with 140 additions and 69 deletions

View file

@ -21,7 +21,6 @@ from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
from llama_stack.providers.utils.memory.vector_store import (
BankWithIndex,
EmbeddingIndex,
InferenceEmbeddingMixin,
)
from .config import PGVectorConfig
@ -120,9 +119,7 @@ class PGVectorIndex(EmbeddingIndex):
self.cursor.execute(f"DROP TABLE IF EXISTS {self.table_name}")
class PGVectorMemoryAdapter(
InferenceEmbeddingMixin, Memory, MemoryBanksProtocolPrivate
):
class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
def __init__(self, config: PGVectorConfig, inference_api: Api.inference) -> None:
self.config = config
self.inference_api = inference_api
@ -171,8 +168,8 @@ class PGVectorMemoryAdapter(
upsert_models(self.cursor, [(memory_bank.identifier, memory_bank)])
index = PGVectorIndex(memory_bank, memory_bank.embedding_dimension, self.cursor)
self.cache[memory_bank.identifier] = self._create_bank_with_index(
memory_bank, index
self.cache[memory_bank.identifier] = BankWithIndex(
memory_bank, index, self.inference_api
)
async def unregister_memory_bank(self, memory_bank_id: str) -> None:
@ -183,9 +180,10 @@ class PGVectorMemoryAdapter(
banks = load_models(self.cursor, VectorMemoryBank)
for bank in banks:
if bank.identifier not in self.cache:
index = self._create_bank_with_index(
index = BankWithIndex(
bank,
PGVectorIndex(bank, bank.embedding_dimension, self.cursor),
self.inference_api,
)
self.cache[bank.identifier] = index
return banks
@ -216,5 +214,5 @@ class PGVectorMemoryAdapter(
bank = await self.memory_bank_store.get_memory_bank(bank_id)
index = PGVectorIndex(bank, bank.embedding_dimension, self.cursor)
self.cache[bank_id] = self._create_bank_with_index(bank, index)
self.cache[bank_id] = BankWithIndex(bank, index, self.inference_api)
return self.cache[bank_id]