mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-01 16:24:44 +00:00
add _get_and_cache_bank_index
This commit is contained in:
parent
ac244d18c1
commit
8d5fdeedda
2 changed files with 30 additions and 29 deletions
|
@ -147,9 +147,7 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
||||||
documents: List[MemoryBankDocument],
|
documents: List[MemoryBankDocument],
|
||||||
ttl_seconds: Optional[int] = None,
|
ttl_seconds: Optional[int] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
index = self.cache.get(bank_id, None)
|
index = await self._get_and_cache_bank_index(bank_id)
|
||||||
if not index:
|
|
||||||
raise ValueError(f"Bank {bank_id} not found")
|
|
||||||
|
|
||||||
await index.insert_documents(documents)
|
await index.insert_documents(documents)
|
||||||
|
|
||||||
|
@ -159,16 +157,20 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
||||||
query: InterleavedTextMedia,
|
query: InterleavedTextMedia,
|
||||||
params: Optional[Dict[str, Any]] = None,
|
params: Optional[Dict[str, Any]] = None,
|
||||||
) -> QueryDocumentsResponse:
|
) -> QueryDocumentsResponse:
|
||||||
index = self.cache.get(bank_id, None)
|
index = await self._get_and_cache_bank_index(bank_id)
|
||||||
if not index:
|
|
||||||
# if not in cache, try to get from chroma directly
|
|
||||||
bank = await self.memory_bank_store.get_memory_bank(bank_id)
|
|
||||||
if not bank:
|
|
||||||
raise ValueError(f"Bank {bank_id} not found in Llama Stack")
|
|
||||||
collection = await self.client.get_collection(bank_id)
|
|
||||||
if not collection:
|
|
||||||
raise ValueError(f"Bank {bank_id} not found in Chroma")
|
|
||||||
index = BankWithIndex(bank=bank, index=ChromaIndex(self.client, collection))
|
|
||||||
self.cache[bank_id] = index
|
|
||||||
|
|
||||||
return await index.query_documents(query, params)
|
return await index.query_documents(query, params)
|
||||||
|
|
||||||
|
async def _get_and_cache_bank_index(self, bank_id: str) -> BankWithIndex:
|
||||||
|
if bank_id in self.cache:
|
||||||
|
return self.cache[bank_id]
|
||||||
|
|
||||||
|
bank = await self.memory_bank_store.get_memory_bank(bank_id)
|
||||||
|
if not bank:
|
||||||
|
raise ValueError(f"Bank {bank_id} not found in Llama Stack")
|
||||||
|
collection = await self.client.get_collection(bank_id)
|
||||||
|
if not collection:
|
||||||
|
raise ValueError(f"Bank {bank_id} not found in Chroma")
|
||||||
|
index = BankWithIndex(bank=bank, index=ChromaIndex(self.client, collection))
|
||||||
|
self.cache[bank_id] = index
|
||||||
|
return index
|
||||||
|
|
|
@ -201,10 +201,7 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
||||||
documents: List[MemoryBankDocument],
|
documents: List[MemoryBankDocument],
|
||||||
ttl_seconds: Optional[int] = None,
|
ttl_seconds: Optional[int] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
index = self.cache.get(bank_id, None)
|
index = await self._get_and_cache_bank_index(bank_id)
|
||||||
if not index:
|
|
||||||
raise ValueError(f"Bank {bank_id} not found")
|
|
||||||
|
|
||||||
await index.insert_documents(documents)
|
await index.insert_documents(documents)
|
||||||
|
|
||||||
async def query_documents(
|
async def query_documents(
|
||||||
|
@ -213,15 +210,17 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
||||||
query: InterleavedTextMedia,
|
query: InterleavedTextMedia,
|
||||||
params: Optional[Dict[str, Any]] = None,
|
params: Optional[Dict[str, Any]] = None,
|
||||||
) -> QueryDocumentsResponse:
|
) -> QueryDocumentsResponse:
|
||||||
index = self.cache.get(bank_id, None)
|
index = await self._get_and_cache_bank_index(bank_id)
|
||||||
if not index:
|
|
||||||
bank = await self.memory_bank_store.get_memory_bank(bank_id)
|
|
||||||
if not bank:
|
|
||||||
raise ValueError(f"Bank {bank_id} not found in Llama Stack")
|
|
||||||
index = BankWithIndex(
|
|
||||||
bank=bank,
|
|
||||||
index=PGVectorIndex(bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor),
|
|
||||||
)
|
|
||||||
self.cache[bank_id] = index
|
|
||||||
|
|
||||||
return await index.query_documents(query, params)
|
return await index.query_documents(query, params)
|
||||||
|
|
||||||
|
async def _get_and_cache_bank_index(self, bank_id: str) -> BankWithIndex:
|
||||||
|
if bank_id in self.cache:
|
||||||
|
return self.cache[bank_id]
|
||||||
|
|
||||||
|
bank = await self.memory_bank_store.get_memory_bank(bank_id)
|
||||||
|
index = BankWithIndex(
|
||||||
|
bank=bank,
|
||||||
|
index=PGVectorIndex(bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor),
|
||||||
|
)
|
||||||
|
self.cache[bank_id] = index
|
||||||
|
return index
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue