diff --git a/llama_stack/providers/remote/memory/chroma/chroma.py b/llama_stack/providers/remote/memory/chroma/chroma.py index ac00fc749..3ccd6a534 100644 --- a/llama_stack/providers/remote/memory/chroma/chroma.py +++ b/llama_stack/providers/remote/memory/chroma/chroma.py @@ -147,9 +147,7 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate): documents: List[MemoryBankDocument], ttl_seconds: Optional[int] = None, ) -> None: - index = self.cache.get(bank_id, None) - if not index: - raise ValueError(f"Bank {bank_id} not found") + index = await self._get_and_cache_bank_index(bank_id) await index.insert_documents(documents) @@ -159,8 +157,20 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate): query: InterleavedTextMedia, params: Optional[Dict[str, Any]] = None, ) -> QueryDocumentsResponse: - index = self.cache.get(bank_id, None) - if not index: - raise ValueError(f"Bank {bank_id} not found") + index = await self._get_and_cache_bank_index(bank_id) return await index.query_documents(query, params) + + async def _get_and_cache_bank_index(self, bank_id: str) -> BankWithIndex: + if bank_id in self.cache: + return self.cache[bank_id] + + bank = await self.memory_bank_store.get_memory_bank(bank_id) + if not bank: + raise ValueError(f"Bank {bank_id} not found in Llama Stack") + collection = await self.client.get_collection(bank_id) + if not collection: + raise ValueError(f"Bank {bank_id} not found in Chroma") + index = BankWithIndex(bank=bank, index=ChromaIndex(self.client, collection)) + self.cache[bank_id] = index + return index diff --git a/llama_stack/providers/remote/memory/pgvector/pgvector.py b/llama_stack/providers/remote/memory/pgvector/pgvector.py index 44c2a8fe1..bd27509d6 100644 --- a/llama_stack/providers/remote/memory/pgvector/pgvector.py +++ b/llama_stack/providers/remote/memory/pgvector/pgvector.py @@ -201,10 +201,7 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate): documents: List[MemoryBankDocument], ttl_seconds: Optional[int] = None, ) -> None: - index = self.cache.get(bank_id, None) - if not index: - raise ValueError(f"Bank {bank_id} not found") - + index = await self._get_and_cache_bank_index(bank_id) await index.insert_documents(documents) async def query_documents( @@ -213,8 +210,17 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate): query: InterleavedTextMedia, params: Optional[Dict[str, Any]] = None, ) -> QueryDocumentsResponse: - index = self.cache.get(bank_id, None) - if not index: - raise ValueError(f"Bank {bank_id} not found") - + index = await self._get_and_cache_bank_index(bank_id) return await index.query_documents(query, params) + + async def _get_and_cache_bank_index(self, bank_id: str) -> BankWithIndex: + if bank_id in self.cache: + return self.cache[bank_id] + + bank = await self.memory_bank_store.get_memory_bank(bank_id) + index = BankWithIndex( + bank=bank, + index=PGVectorIndex(bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor), + ) + self.cache[bank_id] = index + return index