fall to back to read from chroma/pgvector when not in cache (#489)

# What does this PR do?

The chroma provider maintains a cache but does not sync up with chroma
on a cold start. this change adds a fallback to read from chroma on a
cache miss.


## Test Plan
```bash
#start stack
llama stack run /Users/dineshyv/.llama/distributions/llamastack-together/together-run.yaml
# Add documents
PYTHONPATH=. python -m examples.agents.rag_with_memory_bank localhost 5000

No available shields. Disable safety.
Using model: Llama3.1-8B-Instruct
Created session_id=b951b14f-a9d2-43a3-8b80-d80114d58322 for Agent(0687a251-6906-4081-8d4c-f52e19db9dd7)
memory_retrieval> Retrieved context from banks: ['test_bank'].
====
Here are the retrieved documents for relevant context:
=== START-RETRIEVED-CONTEXT ===
 id:num-1; content:_
the template from Llama2 to better support multiturn conversations. The same text
in the Lla...
>
inference> Based on the retrieved documentation, the top 5 topics that were explained are:
...............

# Kill stack
# Bootup stack
llama stack run /Users/dineshyv/.llama/distributions/llamastack-together/together-run.yaml
# Run a RAG app with just the agent flow. it discovers the previously added documents
No available shields. Disable safety.
Using model: Llama3.1-8B-Instruct
Created session_id=7a30c1a7-c87e-4787-936c-d0306589fe5d for Agent(b30420f3-c928-498a-887b-d084f0f3806c)
memory_retrieval> Retrieved context from banks: ['test_bank'].
====
Here are the retrieved documents for relevant context:
=== START-RETRIEVED-CONTEXT ===
 id:num-1; content:_
the template from Llama2 to better support multiturn conversations. The same text
in the Lla...
>
inference> Based on the provided documentation, the top 5 topics that were explained are:
.....
```
This commit is contained in:
Dinesh Yeduguru 2024-11-20 10:30:23 -08:00 committed by GitHub
parent ae49a4cb97
commit 91e7efbc91
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 30 additions and 14 deletions

View file

@ -147,9 +147,7 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
documents: List[MemoryBankDocument],
ttl_seconds: Optional[int] = None,
) -> None:
index = self.cache.get(bank_id, None)
if not index:
raise ValueError(f"Bank {bank_id} not found")
index = await self._get_and_cache_bank_index(bank_id)
await index.insert_documents(documents)
@ -159,8 +157,20 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
query: InterleavedTextMedia,
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse:
index = self.cache.get(bank_id, None)
if not index:
raise ValueError(f"Bank {bank_id} not found")
index = await self._get_and_cache_bank_index(bank_id)
return await index.query_documents(query, params)
async def _get_and_cache_bank_index(self, bank_id: str) -> BankWithIndex:
if bank_id in self.cache:
return self.cache[bank_id]
bank = await self.memory_bank_store.get_memory_bank(bank_id)
if not bank:
raise ValueError(f"Bank {bank_id} not found in Llama Stack")
collection = await self.client.get_collection(bank_id)
if not collection:
raise ValueError(f"Bank {bank_id} not found in Chroma")
index = BankWithIndex(bank=bank, index=ChromaIndex(self.client, collection))
self.cache[bank_id] = index
return index

View file

@ -201,10 +201,7 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
documents: List[MemoryBankDocument],
ttl_seconds: Optional[int] = None,
) -> None:
index = self.cache.get(bank_id, None)
if not index:
raise ValueError(f"Bank {bank_id} not found")
index = await self._get_and_cache_bank_index(bank_id)
await index.insert_documents(documents)
async def query_documents(
@ -213,8 +210,17 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
query: InterleavedTextMedia,
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse:
index = self.cache.get(bank_id, None)
if not index:
raise ValueError(f"Bank {bank_id} not found")
index = await self._get_and_cache_bank_index(bank_id)
return await index.query_documents(query, params)
async def _get_and_cache_bank_index(self, bank_id: str) -> BankWithIndex:
if bank_id in self.cache:
return self.cache[bank_id]
bank = await self.memory_bank_store.get_memory_bank(bank_id)
index = BankWithIndex(
bank=bank,
index=PGVectorIndex(bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor),
)
self.cache[bank_id] = index
return index