Fix pgvector, store source of truth in Chroma

This commit is contained in:
Ashwin Bharambe 2024-10-10 10:12:14 -07:00
parent dd9d34cf7d
commit fe0dabe596
4 changed files with 86 additions and 40 deletions

View file

@ -11,8 +11,11 @@ from urllib.parse import urlparse
import chromadb import chromadb
from numpy.typing import NDArray from numpy.typing import NDArray
from pydantic import parse_obj_as
from llama_stack.apis.memory import * # noqa: F403 from llama_stack.apis.memory import * # noqa: F403
from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate
from llama_stack.providers.utils.memory.vector_store import ( from llama_stack.providers.utils.memory.vector_store import (
BankWithIndex, BankWithIndex,
EmbeddingIndex, EmbeddingIndex,
@ -63,7 +66,7 @@ class ChromaIndex(EmbeddingIndex):
return QueryDocumentsResponse(chunks=chunks, scores=scores) return QueryDocumentsResponse(chunks=chunks, scores=scores)
class ChromaMemoryAdapter(Memory): class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
def __init__(self, url: str) -> None: def __init__(self, url: str) -> None:
print(f"Initializing ChromaMemoryAdapter with url: {url}") print(f"Initializing ChromaMemoryAdapter with url: {url}")
url = url.rstrip("/") url = url.rstrip("/")
@ -101,31 +104,33 @@ class ChromaMemoryAdapter(Memory):
collection = await self.client.get_or_create_collection( collection = await self.client.get_or_create_collection(
name=memory_bank.identifier, name=memory_bank.identifier,
metadata={"bank": memory_bank.json()},
) )
bank_index = BankWithIndex( bank_index = BankWithIndex(
bank=memory_bank, index=ChromaIndex(self.client, collection) bank=memory_bank, index=ChromaIndex(self.client, collection)
) )
self.cache[memory_bank.identifier] = bank_index self.cache[memory_bank.identifier] = bank_index
async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]: async def list_memory_banks(self) -> List[MemoryBankDef]:
if bank_id in self.cache:
return self.cache[bank_id]
bank = await self.memory_bank_store.get_memory_bank(bank_id)
if bank is None:
raise ValueError(f"Bank {bank_id} not found")
collections = await self.client.list_collections() collections = await self.client.list_collections()
for collection in collections: for collection in collections:
if collection.name == bank_id: try:
index = BankWithIndex( data = json.loads(collection.metadata["bank"])
bank=bank, bank = parse_obj_as(MemoryBankDef, data)
index=ChromaIndex(self.client, collection), except Exception:
) import traceback
self.cache[bank_id] = index
return index
return None traceback.print_exc()
print(f"Failed to parse bank: {collection.metadata}")
continue
index = BankWithIndex(
bank=bank,
index=ChromaIndex(self.client, collection),
)
self.cache[bank.identifier] = index
return [i.bank for i in self.cache.values()]
async def insert_documents( async def insert_documents(
self, self,
@ -133,7 +138,7 @@ class ChromaMemoryAdapter(Memory):
documents: List[MemoryBankDocument], documents: List[MemoryBankDocument],
ttl_seconds: Optional[int] = None, ttl_seconds: Optional[int] = None,
) -> None: ) -> None:
index = await self._get_and_cache_bank_index(bank_id) index = self.cache.get(bank_id, None)
if not index: if not index:
raise ValueError(f"Bank {bank_id} not found") raise ValueError(f"Bank {bank_id} not found")
@ -145,7 +150,7 @@ class ChromaMemoryAdapter(Memory):
query: InterleavedTextMedia, query: InterleavedTextMedia,
params: Optional[Dict[str, Any]] = None, params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse: ) -> QueryDocumentsResponse:
index = await self._get_and_cache_bank_index(bank_id) index = self.cache.get(bank_id, None)
if not index: if not index:
raise ValueError(f"Bank {bank_id} not found") raise ValueError(f"Bank {bank_id} not found")

View file

@ -4,15 +4,18 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import List from typing import List, Tuple
import psycopg2 import psycopg2
from numpy.typing import NDArray from numpy.typing import NDArray
from psycopg2 import sql from psycopg2 import sql
from psycopg2.extras import execute_values, Json from psycopg2.extras import execute_values, Json
from pydantic import BaseModel, parse_obj_as
from llama_stack.apis.memory import * # noqa: F403 from llama_stack.apis.memory import * # noqa: F403
from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate
from llama_stack.providers.utils.memory.vector_store import ( from llama_stack.providers.utils.memory.vector_store import (
ALL_MINILM_L6_V2_DIMENSION, ALL_MINILM_L6_V2_DIMENSION,
BankWithIndex, BankWithIndex,
@ -28,10 +31,31 @@ def check_extension_version(cur):
return result[0] if result else None return result[0] if result else None
def upsert_models(cur, keys_models: List[Tuple[str, BaseModel]]):
query = sql.SQL(
"""
INSERT INTO metadata_store (key, data)
VALUES %s
ON CONFLICT (key) DO UPDATE
SET data = EXCLUDED.data
"""
)
values = [(key, Json(model.dict())) for key, model in keys_models]
execute_values(cur, query, values, template="(%s, %s)")
def load_models(cur, cls):
query = "SELECT key, data FROM metadata_store"
cur.execute(query)
rows = cur.fetchall()
return [parse_obj_as(cls, row["data"]) for row in rows]
class PGVectorIndex(EmbeddingIndex): class PGVectorIndex(EmbeddingIndex):
def __init__(self, bank: MemoryBank, dimension: int, cursor): def __init__(self, bank: MemoryBankDef, dimension: int, cursor):
self.cursor = cursor self.cursor = cursor
self.table_name = f"vector_store_{bank.name}" self.table_name = f"vector_store_{bank.identifier}"
self.cursor.execute( self.cursor.execute(
f""" f"""
@ -88,7 +112,7 @@ class PGVectorIndex(EmbeddingIndex):
return QueryDocumentsResponse(chunks=chunks, scores=scores) return QueryDocumentsResponse(chunks=chunks, scores=scores)
class PGVectorMemoryAdapter(Memory): class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
def __init__(self, config: PGVectorConfig) -> None: def __init__(self, config: PGVectorConfig) -> None:
print(f"Initializing PGVectorMemoryAdapter -> {config.host}:{config.port}") print(f"Initializing PGVectorMemoryAdapter -> {config.host}:{config.port}")
self.config = config self.config = config
@ -113,6 +137,14 @@ class PGVectorMemoryAdapter(Memory):
else: else:
raise RuntimeError("Vector extension is not installed.") raise RuntimeError("Vector extension is not installed.")
self.cursor.execute(
"""
CREATE TABLE IF NOT EXISTS metadata_store (
key TEXT PRIMARY KEY,
data JSONB
)
"""
)
except Exception as e: except Exception as e:
import traceback import traceback
@ -130,26 +162,29 @@ class PGVectorMemoryAdapter(Memory):
memory_bank.type == MemoryBankType.vector.value memory_bank.type == MemoryBankType.vector.value
), f"Only vector banks are supported {memory_bank.type}" ), f"Only vector banks are supported {memory_bank.type}"
upsert_models(
self.cursor,
[
(memory_bank.identifier, memory_bank),
],
)
index = BankWithIndex( index = BankWithIndex(
bank=memory_bank, bank=memory_bank,
index=PGVectorIndex(memory_bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor), index=PGVectorIndex(memory_bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor),
) )
self.cache[bank_id] = index self.cache[memory_bank.identifier] = index
async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]: async def list_memory_banks(self) -> List[MemoryBankDef]:
if bank_id in self.cache: banks = load_models(self.cursor, MemoryBankDef)
return self.cache[bank_id] for bank in banks:
if bank.identifier not in self.cache:
bank = await self.memory_bank_store.get_memory_bank(bank_id) index = BankWithIndex(
if not bank: bank=bank,
raise ValueError(f"Bank {bank_id} not found") index=PGVectorIndex(bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor),
)
index = BankWithIndex( self.cache[bank.identifier] = index
bank=bank, return banks
index=PGVectorIndex(bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor),
)
self.cache[bank_id] = index
return index
async def insert_documents( async def insert_documents(
self, self,
@ -157,7 +192,7 @@ class PGVectorMemoryAdapter(Memory):
documents: List[MemoryBankDocument], documents: List[MemoryBankDocument],
ttl_seconds: Optional[int] = None, ttl_seconds: Optional[int] = None,
) -> None: ) -> None:
index = await self._get_and_cache_bank_index(bank_id) index = self.cache.get(bank_id, None)
if not index: if not index:
raise ValueError(f"Bank {bank_id} not found") raise ValueError(f"Bank {bank_id} not found")
@ -169,7 +204,7 @@ class PGVectorMemoryAdapter(Memory):
query: InterleavedTextMedia, query: InterleavedTextMedia,
params: Optional[Dict[str, Any]] = None, params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse: ) -> QueryDocumentsResponse:
index = await self._get_and_cache_bank_index(bank_id) index = self.cache.get(bank_id, None)
if not index: if not index:
raise ValueError(f"Bank {bank_id} not found") raise ValueError(f"Bank {bank_id} not found")

View file

@ -140,6 +140,10 @@ class WeaviateMemoryAdapter(
self.cache[memory_bank.identifier] = index self.cache[memory_bank.identifier] = index
async def list_memory_banks(self) -> List[MemoryBankDef]: async def list_memory_banks(self) -> List[MemoryBankDef]:
# TODO: right now the Llama Stack is the source of truth for these banks. That is
# not ideal. It should be Weaviate which is the source of truth. Unfortunately,
# list() happens at Stack startup when the Weaviate client (credentials) is not
# yet available. We need to figure out a way to make this work.
return [i.bank for i in self.cache.values()] return [i.bank for i in self.cache.values()]
async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]: async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]:

View file

@ -81,6 +81,8 @@ async def register_memory_bank(banks_impl: MemoryBanks):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_banks_list(memory_settings): async def test_banks_list(memory_settings):
# NOTE: this needs you to ensure that you are starting from a clean state
# but so far we don't have an unregister API unfortunately, so be careful
banks_impl = memory_settings["memory_banks_impl"] banks_impl = memory_settings["memory_banks_impl"]
response = await banks_impl.list_memory_banks() response = await banks_impl.list_memory_banks()
assert isinstance(response, list) assert isinstance(response, list)