From fe0dabe5960d40bfdd34408a29e2c8a9aa330a43 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Thu, 10 Oct 2024 10:12:14 -0700 Subject: [PATCH] Fix pgvector, store source of truth in Chroma --- .../adapters/memory/chroma/chroma.py | 43 ++++++----- .../adapters/memory/pgvector/pgvector.py | 77 ++++++++++++++----- .../adapters/memory/weaviate/weaviate.py | 4 + .../providers/tests/memory/test_memory.py | 2 + 4 files changed, 86 insertions(+), 40 deletions(-) diff --git a/llama_stack/providers/adapters/memory/chroma/chroma.py b/llama_stack/providers/adapters/memory/chroma/chroma.py index f8af9ac5c..954acc09b 100644 --- a/llama_stack/providers/adapters/memory/chroma/chroma.py +++ b/llama_stack/providers/adapters/memory/chroma/chroma.py @@ -11,8 +11,11 @@ from urllib.parse import urlparse import chromadb from numpy.typing import NDArray +from pydantic import parse_obj_as + from llama_stack.apis.memory import * # noqa: F403 +from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate from llama_stack.providers.utils.memory.vector_store import ( BankWithIndex, EmbeddingIndex, @@ -63,7 +66,7 @@ class ChromaIndex(EmbeddingIndex): return QueryDocumentsResponse(chunks=chunks, scores=scores) -class ChromaMemoryAdapter(Memory): +class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate): def __init__(self, url: str) -> None: print(f"Initializing ChromaMemoryAdapter with url: {url}") url = url.rstrip("/") @@ -101,31 +104,33 @@ class ChromaMemoryAdapter(Memory): collection = await self.client.get_or_create_collection( name=memory_bank.identifier, + metadata={"bank": memory_bank.json()}, ) bank_index = BankWithIndex( bank=memory_bank, index=ChromaIndex(self.client, collection) ) self.cache[memory_bank.identifier] = bank_index - async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]: - if bank_id in self.cache: - return self.cache[bank_id] - - bank = await self.memory_bank_store.get_memory_bank(bank_id) - if bank is None: - raise ValueError(f"Bank {bank_id} not found") - + async def list_memory_banks(self) -> List[MemoryBankDef]: collections = await self.client.list_collections() for collection in collections: - if collection.name == bank_id: - index = BankWithIndex( - bank=bank, - index=ChromaIndex(self.client, collection), - ) - self.cache[bank_id] = index - return index + try: + data = json.loads(collection.metadata["bank"]) + bank = parse_obj_as(MemoryBankDef, data) + except Exception: + import traceback - return None + traceback.print_exc() + print(f"Failed to parse bank: {collection.metadata}") + continue + + index = BankWithIndex( + bank=bank, + index=ChromaIndex(self.client, collection), + ) + self.cache[bank.identifier] = index + + return [i.bank for i in self.cache.values()] async def insert_documents( self, @@ -133,7 +138,7 @@ class ChromaMemoryAdapter(Memory): documents: List[MemoryBankDocument], ttl_seconds: Optional[int] = None, ) -> None: - index = await self._get_and_cache_bank_index(bank_id) + index = self.cache.get(bank_id, None) if not index: raise ValueError(f"Bank {bank_id} not found") @@ -145,7 +150,7 @@ class ChromaMemoryAdapter(Memory): query: InterleavedTextMedia, params: Optional[Dict[str, Any]] = None, ) -> QueryDocumentsResponse: - index = await self._get_and_cache_bank_index(bank_id) + index = self.cache.get(bank_id, None) if not index: raise ValueError(f"Bank {bank_id} not found") diff --git a/llama_stack/providers/adapters/memory/pgvector/pgvector.py b/llama_stack/providers/adapters/memory/pgvector/pgvector.py index c5dc1f4be..251402b46 100644 --- a/llama_stack/providers/adapters/memory/pgvector/pgvector.py +++ b/llama_stack/providers/adapters/memory/pgvector/pgvector.py @@ -4,15 +4,18 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import List +from typing import List, Tuple import psycopg2 from numpy.typing import NDArray from psycopg2 import sql from psycopg2.extras import execute_values, Json +from pydantic import BaseModel, parse_obj_as + from llama_stack.apis.memory import * # noqa: F403 +from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate from llama_stack.providers.utils.memory.vector_store import ( ALL_MINILM_L6_V2_DIMENSION, BankWithIndex, @@ -28,10 +31,31 @@ def check_extension_version(cur): return result[0] if result else None +def upsert_models(cur, keys_models: List[Tuple[str, BaseModel]]): + query = sql.SQL( + """ + INSERT INTO metadata_store (key, data) + VALUES %s + ON CONFLICT (key) DO UPDATE + SET data = EXCLUDED.data + """ + ) + + values = [(key, Json(model.dict())) for key, model in keys_models] + execute_values(cur, query, values, template="(%s, %s)") + + +def load_models(cur, cls): + query = "SELECT key, data FROM metadata_store" + cur.execute(query) + rows = cur.fetchall() + return [parse_obj_as(cls, row["data"]) for row in rows] + + class PGVectorIndex(EmbeddingIndex): - def __init__(self, bank: MemoryBank, dimension: int, cursor): + def __init__(self, bank: MemoryBankDef, dimension: int, cursor): self.cursor = cursor - self.table_name = f"vector_store_{bank.name}" + self.table_name = f"vector_store_{bank.identifier}" self.cursor.execute( f""" @@ -88,7 +112,7 @@ class PGVectorIndex(EmbeddingIndex): return QueryDocumentsResponse(chunks=chunks, scores=scores) -class PGVectorMemoryAdapter(Memory): +class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate): def __init__(self, config: PGVectorConfig) -> None: print(f"Initializing PGVectorMemoryAdapter -> {config.host}:{config.port}") self.config = config @@ -113,6 +137,14 @@ class PGVectorMemoryAdapter(Memory): else: raise RuntimeError("Vector extension is not installed.") + self.cursor.execute( + """ + CREATE TABLE IF NOT EXISTS metadata_store ( + key TEXT PRIMARY KEY, + data JSONB + ) + """ + ) except Exception as e: import traceback @@ -130,26 +162,29 @@ class PGVectorMemoryAdapter(Memory): memory_bank.type == MemoryBankType.vector.value ), f"Only vector banks are supported {memory_bank.type}" + upsert_models( + self.cursor, + [ + (memory_bank.identifier, memory_bank), + ], + ) + index = BankWithIndex( bank=memory_bank, index=PGVectorIndex(memory_bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor), ) - self.cache[bank_id] = index + self.cache[memory_bank.identifier] = index - async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]: - if bank_id in self.cache: - return self.cache[bank_id] - - bank = await self.memory_bank_store.get_memory_bank(bank_id) - if not bank: - raise ValueError(f"Bank {bank_id} not found") - - index = BankWithIndex( - bank=bank, - index=PGVectorIndex(bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor), - ) - self.cache[bank_id] = index - return index + async def list_memory_banks(self) -> List[MemoryBankDef]: + banks = load_models(self.cursor, MemoryBankDef) + for bank in banks: + if bank.identifier not in self.cache: + index = BankWithIndex( + bank=bank, + index=PGVectorIndex(bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor), + ) + self.cache[bank.identifier] = index + return banks async def insert_documents( self, @@ -157,7 +192,7 @@ class PGVectorMemoryAdapter(Memory): documents: List[MemoryBankDocument], ttl_seconds: Optional[int] = None, ) -> None: - index = await self._get_and_cache_bank_index(bank_id) + index = self.cache.get(bank_id, None) if not index: raise ValueError(f"Bank {bank_id} not found") @@ -169,7 +204,7 @@ class PGVectorMemoryAdapter(Memory): query: InterleavedTextMedia, params: Optional[Dict[str, Any]] = None, ) -> QueryDocumentsResponse: - index = await self._get_and_cache_bank_index(bank_id) + index = self.cache.get(bank_id, None) if not index: raise ValueError(f"Bank {bank_id} not found") diff --git a/llama_stack/providers/adapters/memory/weaviate/weaviate.py b/llama_stack/providers/adapters/memory/weaviate/weaviate.py index 1a04527f7..3580b95f8 100644 --- a/llama_stack/providers/adapters/memory/weaviate/weaviate.py +++ b/llama_stack/providers/adapters/memory/weaviate/weaviate.py @@ -140,6 +140,10 @@ class WeaviateMemoryAdapter( self.cache[memory_bank.identifier] = index async def list_memory_banks(self) -> List[MemoryBankDef]: + # TODO: right now the Llama Stack is the source of truth for these banks. That is + # not ideal. It should be Weaviate which is the source of truth. Unfortunately, + # list() happens at Stack startup when the Weaviate client (credentials) is not + # yet available. We need to figure out a way to make this work. return [i.bank for i in self.cache.values()] async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]: diff --git a/llama_stack/providers/tests/memory/test_memory.py b/llama_stack/providers/tests/memory/test_memory.py index 2566199ae..c5ebdf9c7 100644 --- a/llama_stack/providers/tests/memory/test_memory.py +++ b/llama_stack/providers/tests/memory/test_memory.py @@ -81,6 +81,8 @@ async def register_memory_bank(banks_impl: MemoryBanks): @pytest.mark.asyncio async def test_banks_list(memory_settings): + # NOTE: this needs you to ensure that you are starting from a clean state + # but so far we don't have an unregister API unfortunately, so be careful banks_impl = memory_settings["memory_banks_impl"] response = await banks_impl.list_memory_banks() assert isinstance(response, list)