Fix pgvector, store source of truth in Chroma

This commit is contained in:
Ashwin Bharambe 2024-10-10 10:12:14 -07:00
parent dd9d34cf7d
commit fe0dabe596
4 changed files with 86 additions and 40 deletions

View file

@ -11,8 +11,11 @@ from urllib.parse import urlparse
import chromadb
from numpy.typing import NDArray
from pydantic import parse_obj_as
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate
from llama_stack.providers.utils.memory.vector_store import (
BankWithIndex,
EmbeddingIndex,
@ -63,7 +66,7 @@ class ChromaIndex(EmbeddingIndex):
return QueryDocumentsResponse(chunks=chunks, scores=scores)
class ChromaMemoryAdapter(Memory):
class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
def __init__(self, url: str) -> None:
print(f"Initializing ChromaMemoryAdapter with url: {url}")
url = url.rstrip("/")
@ -101,31 +104,33 @@ class ChromaMemoryAdapter(Memory):
collection = await self.client.get_or_create_collection(
name=memory_bank.identifier,
metadata={"bank": memory_bank.json()},
)
bank_index = BankWithIndex(
bank=memory_bank, index=ChromaIndex(self.client, collection)
)
self.cache[memory_bank.identifier] = bank_index
async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]:
if bank_id in self.cache:
return self.cache[bank_id]
bank = await self.memory_bank_store.get_memory_bank(bank_id)
if bank is None:
raise ValueError(f"Bank {bank_id} not found")
async def list_memory_banks(self) -> List[MemoryBankDef]:
collections = await self.client.list_collections()
for collection in collections:
if collection.name == bank_id:
index = BankWithIndex(
bank=bank,
index=ChromaIndex(self.client, collection),
)
self.cache[bank_id] = index
return index
try:
data = json.loads(collection.metadata["bank"])
bank = parse_obj_as(MemoryBankDef, data)
except Exception:
import traceback
return None
traceback.print_exc()
print(f"Failed to parse bank: {collection.metadata}")
continue
index = BankWithIndex(
bank=bank,
index=ChromaIndex(self.client, collection),
)
self.cache[bank.identifier] = index
return [i.bank for i in self.cache.values()]
async def insert_documents(
self,
@ -133,7 +138,7 @@ class ChromaMemoryAdapter(Memory):
documents: List[MemoryBankDocument],
ttl_seconds: Optional[int] = None,
) -> None:
index = await self._get_and_cache_bank_index(bank_id)
index = self.cache.get(bank_id, None)
if not index:
raise ValueError(f"Bank {bank_id} not found")
@ -145,7 +150,7 @@ class ChromaMemoryAdapter(Memory):
query: InterleavedTextMedia,
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse:
index = await self._get_and_cache_bank_index(bank_id)
index = self.cache.get(bank_id, None)
if not index:
raise ValueError(f"Bank {bank_id} not found")