From 91e7efbc91c729d74c5cf9b3947d3e8acc1fbb71 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Wed, 20 Nov 2024 10:30:23 -0800 Subject: [PATCH] fall to back to read from chroma/pgvector when not in cache (#489) # What does this PR do? The chroma provider maintains a cache but does not sync up with chroma on a cold start. this change adds a fallback to read from chroma on a cache miss. ## Test Plan ```bash #start stack llama stack run /Users/dineshyv/.llama/distributions/llamastack-together/together-run.yaml # Add documents PYTHONPATH=. python -m examples.agents.rag_with_memory_bank localhost 5000 No available shields. Disable safety. Using model: Llama3.1-8B-Instruct Created session_id=b951b14f-a9d2-43a3-8b80-d80114d58322 for Agent(0687a251-6906-4081-8d4c-f52e19db9dd7) memory_retrieval> Retrieved context from banks: ['test_bank']. ==== Here are the retrieved documents for relevant context: === START-RETRIEVED-CONTEXT === id:num-1; content:_ the template from Llama2 to better support multiturn conversations. The same text in the Lla... > inference> Based on the retrieved documentation, the top 5 topics that were explained are: ............... # Kill stack # Bootup stack llama stack run /Users/dineshyv/.llama/distributions/llamastack-together/together-run.yaml # Run a RAG app with just the agent flow. it discovers the previously added documents No available shields. Disable safety. Using model: Llama3.1-8B-Instruct Created session_id=7a30c1a7-c87e-4787-936c-d0306589fe5d for Agent(b30420f3-c928-498a-887b-d084f0f3806c) memory_retrieval> Retrieved context from banks: ['test_bank']. ==== Here are the retrieved documents for relevant context: === START-RETRIEVED-CONTEXT === id:num-1; content:_ the template from Llama2 to better support multiturn conversations. The same text in the Lla... > inference> Based on the provided documentation, the top 5 topics that were explained are: ..... ``` --- .../providers/remote/memory/chroma/chroma.py | 22 ++++++++++++++----- .../remote/memory/pgvector/pgvector.py | 22 ++++++++++++------- 2 files changed, 30 insertions(+), 14 deletions(-) diff --git a/llama_stack/providers/remote/memory/chroma/chroma.py b/llama_stack/providers/remote/memory/chroma/chroma.py index ac00fc749..3ccd6a534 100644 --- a/llama_stack/providers/remote/memory/chroma/chroma.py +++ b/llama_stack/providers/remote/memory/chroma/chroma.py @@ -147,9 +147,7 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate): documents: List[MemoryBankDocument], ttl_seconds: Optional[int] = None, ) -> None: - index = self.cache.get(bank_id, None) - if not index: - raise ValueError(f"Bank {bank_id} not found") + index = await self._get_and_cache_bank_index(bank_id) await index.insert_documents(documents) @@ -159,8 +157,20 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate): query: InterleavedTextMedia, params: Optional[Dict[str, Any]] = None, ) -> QueryDocumentsResponse: - index = self.cache.get(bank_id, None) - if not index: - raise ValueError(f"Bank {bank_id} not found") + index = await self._get_and_cache_bank_index(bank_id) return await index.query_documents(query, params) + + async def _get_and_cache_bank_index(self, bank_id: str) -> BankWithIndex: + if bank_id in self.cache: + return self.cache[bank_id] + + bank = await self.memory_bank_store.get_memory_bank(bank_id) + if not bank: + raise ValueError(f"Bank {bank_id} not found in Llama Stack") + collection = await self.client.get_collection(bank_id) + if not collection: + raise ValueError(f"Bank {bank_id} not found in Chroma") + index = BankWithIndex(bank=bank, index=ChromaIndex(self.client, collection)) + self.cache[bank_id] = index + return index diff --git a/llama_stack/providers/remote/memory/pgvector/pgvector.py b/llama_stack/providers/remote/memory/pgvector/pgvector.py index 44c2a8fe1..bd27509d6 100644 --- a/llama_stack/providers/remote/memory/pgvector/pgvector.py +++ b/llama_stack/providers/remote/memory/pgvector/pgvector.py @@ -201,10 +201,7 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate): documents: List[MemoryBankDocument], ttl_seconds: Optional[int] = None, ) -> None: - index = self.cache.get(bank_id, None) - if not index: - raise ValueError(f"Bank {bank_id} not found") - + index = await self._get_and_cache_bank_index(bank_id) await index.insert_documents(documents) async def query_documents( @@ -213,8 +210,17 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate): query: InterleavedTextMedia, params: Optional[Dict[str, Any]] = None, ) -> QueryDocumentsResponse: - index = self.cache.get(bank_id, None) - if not index: - raise ValueError(f"Bank {bank_id} not found") - + index = await self._get_and_cache_bank_index(bank_id) return await index.query_documents(query, params) + + async def _get_and_cache_bank_index(self, bank_id: str) -> BankWithIndex: + if bank_id in self.cache: + return self.cache[bank_id] + + bank = await self.memory_bank_store.get_memory_bank(bank_id) + index = BankWithIndex( + bank=bank, + index=PGVectorIndex(bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor), + ) + self.cache[bank_id] = index + return index