add _get_and_cache_bank_index

This commit is contained in:
Dinesh Yeduguru 2024-11-20 10:21:55 -08:00
parent ac244d18c1
commit 8d5fdeedda
2 changed files with 30 additions and 29 deletions

View file

@ -147,9 +147,7 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
documents: List[MemoryBankDocument], documents: List[MemoryBankDocument],
ttl_seconds: Optional[int] = None, ttl_seconds: Optional[int] = None,
) -> None: ) -> None:
index = self.cache.get(bank_id, None) index = await self._get_and_cache_bank_index(bank_id)
if not index:
raise ValueError(f"Bank {bank_id} not found")
await index.insert_documents(documents) await index.insert_documents(documents)
@ -159,16 +157,20 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
query: InterleavedTextMedia, query: InterleavedTextMedia,
params: Optional[Dict[str, Any]] = None, params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse: ) -> QueryDocumentsResponse:
index = self.cache.get(bank_id, None) index = await self._get_and_cache_bank_index(bank_id)
if not index:
# if not in cache, try to get from chroma directly
bank = await self.memory_bank_store.get_memory_bank(bank_id)
if not bank:
raise ValueError(f"Bank {bank_id} not found in Llama Stack")
collection = await self.client.get_collection(bank_id)
if not collection:
raise ValueError(f"Bank {bank_id} not found in Chroma")
index = BankWithIndex(bank=bank, index=ChromaIndex(self.client, collection))
self.cache[bank_id] = index
return await index.query_documents(query, params) return await index.query_documents(query, params)
async def _get_and_cache_bank_index(self, bank_id: str) -> BankWithIndex:
if bank_id in self.cache:
return self.cache[bank_id]
bank = await self.memory_bank_store.get_memory_bank(bank_id)
if not bank:
raise ValueError(f"Bank {bank_id} not found in Llama Stack")
collection = await self.client.get_collection(bank_id)
if not collection:
raise ValueError(f"Bank {bank_id} not found in Chroma")
index = BankWithIndex(bank=bank, index=ChromaIndex(self.client, collection))
self.cache[bank_id] = index
return index

View file

@ -201,10 +201,7 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
documents: List[MemoryBankDocument], documents: List[MemoryBankDocument],
ttl_seconds: Optional[int] = None, ttl_seconds: Optional[int] = None,
) -> None: ) -> None:
index = self.cache.get(bank_id, None) index = await self._get_and_cache_bank_index(bank_id)
if not index:
raise ValueError(f"Bank {bank_id} not found")
await index.insert_documents(documents) await index.insert_documents(documents)
async def query_documents( async def query_documents(
@ -213,15 +210,17 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
query: InterleavedTextMedia, query: InterleavedTextMedia,
params: Optional[Dict[str, Any]] = None, params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse: ) -> QueryDocumentsResponse:
index = self.cache.get(bank_id, None) index = await self._get_and_cache_bank_index(bank_id)
if not index:
bank = await self.memory_bank_store.get_memory_bank(bank_id)
if not bank:
raise ValueError(f"Bank {bank_id} not found in Llama Stack")
index = BankWithIndex(
bank=bank,
index=PGVectorIndex(bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor),
)
self.cache[bank_id] = index
return await index.query_documents(query, params) return await index.query_documents(query, params)
async def _get_and_cache_bank_index(self, bank_id: str) -> BankWithIndex:
if bank_id in self.cache:
return self.cache[bank_id]
bank = await self.memory_bank_store.get_memory_bank(bank_id)
index = BankWithIndex(
bank=bank,
index=PGVectorIndex(bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor),
)
self.cache[bank_id] = index
return index