Fix pgvector, store source of truth in Chroma

This commit is contained in:
Ashwin Bharambe 2024-10-10 10:12:14 -07:00
parent dd9d34cf7d
commit fe0dabe596
4 changed files with 86 additions and 40 deletions

View file

@ -11,8 +11,11 @@ from urllib.parse import urlparse
import chromadb
from numpy.typing import NDArray
from pydantic import parse_obj_as
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate
from llama_stack.providers.utils.memory.vector_store import (
BankWithIndex,
EmbeddingIndex,
@ -63,7 +66,7 @@ class ChromaIndex(EmbeddingIndex):
return QueryDocumentsResponse(chunks=chunks, scores=scores)
class ChromaMemoryAdapter(Memory):
class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
def __init__(self, url: str) -> None:
print(f"Initializing ChromaMemoryAdapter with url: {url}")
url = url.rstrip("/")
@ -101,31 +104,33 @@ class ChromaMemoryAdapter(Memory):
collection = await self.client.get_or_create_collection(
name=memory_bank.identifier,
metadata={"bank": memory_bank.json()},
)
bank_index = BankWithIndex(
bank=memory_bank, index=ChromaIndex(self.client, collection)
)
self.cache[memory_bank.identifier] = bank_index
async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]:
if bank_id in self.cache:
return self.cache[bank_id]
bank = await self.memory_bank_store.get_memory_bank(bank_id)
if bank is None:
raise ValueError(f"Bank {bank_id} not found")
async def list_memory_banks(self) -> List[MemoryBankDef]:
collections = await self.client.list_collections()
for collection in collections:
if collection.name == bank_id:
try:
data = json.loads(collection.metadata["bank"])
bank = parse_obj_as(MemoryBankDef, data)
except Exception:
import traceback
traceback.print_exc()
print(f"Failed to parse bank: {collection.metadata}")
continue
index = BankWithIndex(
bank=bank,
index=ChromaIndex(self.client, collection),
)
self.cache[bank_id] = index
return index
self.cache[bank.identifier] = index
return None
return [i.bank for i in self.cache.values()]
async def insert_documents(
self,
@ -133,7 +138,7 @@ class ChromaMemoryAdapter(Memory):
documents: List[MemoryBankDocument],
ttl_seconds: Optional[int] = None,
) -> None:
index = await self._get_and_cache_bank_index(bank_id)
index = self.cache.get(bank_id, None)
if not index:
raise ValueError(f"Bank {bank_id} not found")
@ -145,7 +150,7 @@ class ChromaMemoryAdapter(Memory):
query: InterleavedTextMedia,
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse:
index = await self._get_and_cache_bank_index(bank_id)
index = self.cache.get(bank_id, None)
if not index:
raise ValueError(f"Bank {bank_id} not found")

View file

@ -4,15 +4,18 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List
from typing import List, Tuple
import psycopg2
from numpy.typing import NDArray
from psycopg2 import sql
from psycopg2.extras import execute_values, Json
from pydantic import BaseModel, parse_obj_as
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate
from llama_stack.providers.utils.memory.vector_store import (
ALL_MINILM_L6_V2_DIMENSION,
BankWithIndex,
@ -28,10 +31,31 @@ def check_extension_version(cur):
return result[0] if result else None
def upsert_models(cur, keys_models: List[Tuple[str, BaseModel]]):
query = sql.SQL(
"""
INSERT INTO metadata_store (key, data)
VALUES %s
ON CONFLICT (key) DO UPDATE
SET data = EXCLUDED.data
"""
)
values = [(key, Json(model.dict())) for key, model in keys_models]
execute_values(cur, query, values, template="(%s, %s)")
def load_models(cur, cls):
query = "SELECT key, data FROM metadata_store"
cur.execute(query)
rows = cur.fetchall()
return [parse_obj_as(cls, row["data"]) for row in rows]
class PGVectorIndex(EmbeddingIndex):
def __init__(self, bank: MemoryBank, dimension: int, cursor):
def __init__(self, bank: MemoryBankDef, dimension: int, cursor):
self.cursor = cursor
self.table_name = f"vector_store_{bank.name}"
self.table_name = f"vector_store_{bank.identifier}"
self.cursor.execute(
f"""
@ -88,7 +112,7 @@ class PGVectorIndex(EmbeddingIndex):
return QueryDocumentsResponse(chunks=chunks, scores=scores)
class PGVectorMemoryAdapter(Memory):
class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
def __init__(self, config: PGVectorConfig) -> None:
print(f"Initializing PGVectorMemoryAdapter -> {config.host}:{config.port}")
self.config = config
@ -113,6 +137,14 @@ class PGVectorMemoryAdapter(Memory):
else:
raise RuntimeError("Vector extension is not installed.")
self.cursor.execute(
"""
CREATE TABLE IF NOT EXISTS metadata_store (
key TEXT PRIMARY KEY,
data JSONB
)
"""
)
except Exception as e:
import traceback
@ -130,26 +162,29 @@ class PGVectorMemoryAdapter(Memory):
memory_bank.type == MemoryBankType.vector.value
), f"Only vector banks are supported {memory_bank.type}"
upsert_models(
self.cursor,
[
(memory_bank.identifier, memory_bank),
],
)
index = BankWithIndex(
bank=memory_bank,
index=PGVectorIndex(memory_bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor),
)
self.cache[bank_id] = index
async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]:
if bank_id in self.cache:
return self.cache[bank_id]
bank = await self.memory_bank_store.get_memory_bank(bank_id)
if not bank:
raise ValueError(f"Bank {bank_id} not found")
self.cache[memory_bank.identifier] = index
async def list_memory_banks(self) -> List[MemoryBankDef]:
banks = load_models(self.cursor, MemoryBankDef)
for bank in banks:
if bank.identifier not in self.cache:
index = BankWithIndex(
bank=bank,
index=PGVectorIndex(bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor),
)
self.cache[bank_id] = index
return index
self.cache[bank.identifier] = index
return banks
async def insert_documents(
self,
@ -157,7 +192,7 @@ class PGVectorMemoryAdapter(Memory):
documents: List[MemoryBankDocument],
ttl_seconds: Optional[int] = None,
) -> None:
index = await self._get_and_cache_bank_index(bank_id)
index = self.cache.get(bank_id, None)
if not index:
raise ValueError(f"Bank {bank_id} not found")
@ -169,7 +204,7 @@ class PGVectorMemoryAdapter(Memory):
query: InterleavedTextMedia,
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse:
index = await self._get_and_cache_bank_index(bank_id)
index = self.cache.get(bank_id, None)
if not index:
raise ValueError(f"Bank {bank_id} not found")

View file

@ -140,6 +140,10 @@ class WeaviateMemoryAdapter(
self.cache[memory_bank.identifier] = index
async def list_memory_banks(self) -> List[MemoryBankDef]:
# TODO: right now the Llama Stack is the source of truth for these banks. That is
# not ideal. It should be Weaviate which is the source of truth. Unfortunately,
# list() happens at Stack startup when the Weaviate client (credentials) is not
# yet available. We need to figure out a way to make this work.
return [i.bank for i in self.cache.values()]
async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]:

View file

@ -81,6 +81,8 @@ async def register_memory_bank(banks_impl: MemoryBanks):
@pytest.mark.asyncio
async def test_banks_list(memory_settings):
# NOTE: this needs you to ensure that you are starting from a clean state
# but so far we don't have an unregister API unfortunately, so be careful
banks_impl = memory_settings["memory_banks_impl"]
response = await banks_impl.list_memory_banks()
assert isinstance(response, list)