Add Chroma and PGVector adapters (#56)

Co-authored-by: Ashwin Bharambe <ashwin@meta.com>
This commit is contained in:
Ashwin Bharambe 2024-09-06 18:53:17 -07:00 committed by GitHub
parent 5de6ed946e
commit 3f090d1975
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 628 additions and 119 deletions

View file

@ -0,0 +1,15 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_toolchain.core.datatypes import RemoteProviderConfig
async def get_adapter_impl(config: RemoteProviderConfig, _deps):
from .chroma import ChromaMemoryAdapter
impl = ChromaMemoryAdapter(config.url)
await impl.initialize()
return impl

View file

@ -0,0 +1,165 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import uuid
from typing import List
from urllib.parse import urlparse
import chromadb
from numpy.typing import NDArray
from llama_toolchain.memory.api import * # noqa: F403
from llama_toolchain.memory.common.vector_store import BankWithIndex, EmbeddingIndex
class ChromaIndex(EmbeddingIndex):
def __init__(self, client: chromadb.AsyncHttpClient, collection):
self.client = client
self.collection = collection
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
assert len(chunks) == len(
embeddings
), f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
for i, chunk in enumerate(chunks):
print(f"Adding chunk #{i} tokens={chunk.token_count}")
await self.collection.add(
documents=[chunk.json() for chunk in chunks],
embeddings=embeddings,
ids=[f"{c.document_id}:chunk-{i}" for i, c in enumerate(chunks)],
)
async def query(self, embedding: NDArray, k: int) -> QueryDocumentsResponse:
results = await self.collection.query(
query_embeddings=[embedding.tolist()],
n_results=k,
include=["documents", "distances"],
)
distances = results["distances"][0]
documents = results["documents"][0]
chunks = []
scores = []
for dist, doc in zip(distances, documents):
try:
doc = json.loads(doc)
chunk = Chunk(**doc)
except Exception:
import traceback
traceback.print_exc()
print(f"Failed to parse document: {doc}")
continue
chunks.append(chunk)
scores.append(1.0 / float(dist))
return QueryDocumentsResponse(chunks=chunks, scores=scores)
class ChromaMemoryAdapter(Memory):
def __init__(self, url: str) -> None:
print(f"Initializing ChromaMemoryAdapter with url: {url}")
url = url.rstrip("/")
parsed = urlparse(url)
if parsed.path and parsed.path != "/":
raise ValueError("URL should not contain a path")
self.host = parsed.hostname
self.port = parsed.port
self.client = None
self.cache = {}
async def initialize(self) -> None:
try:
print(f"Connecting to Chroma server at: {self.host}:{self.port}")
self.client = await chromadb.AsyncHttpClient(host=self.host, port=self.port)
except Exception as e:
import traceback
traceback.print_exc()
raise RuntimeError("Could not connect to Chroma server") from e
async def shutdown(self) -> None:
pass
async def create_memory_bank(
self,
name: str,
config: MemoryBankConfig,
url: Optional[URL] = None,
) -> MemoryBank:
bank_id = str(uuid.uuid4())
bank = MemoryBank(
bank_id=bank_id,
name=name,
config=config,
url=url,
)
collection = await self.client.create_collection(
name=bank_id,
metadata={"bank": bank.json()},
)
bank_index = BankWithIndex(
bank=bank, index=ChromaIndex(self.client, collection)
)
self.cache[bank_id] = bank_index
return bank
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
bank_index = await self._get_and_cache_bank_index(bank_id)
if bank_index is None:
return None
return bank_index.bank
async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]:
if bank_id in self.cache:
return self.cache[bank_id]
collections = await self.client.list_collections()
for collection in collections:
if collection.name == bank_id:
print(collection.metadata)
bank = MemoryBank(**json.loads(collection.metadata["bank"]))
index = BankWithIndex(
bank=bank,
index=ChromaIndex(self.client, collection),
)
self.cache[bank_id] = index
return index
return None
async def insert_documents(
self,
bank_id: str,
documents: List[MemoryBankDocument],
ttl_seconds: Optional[int] = None,
) -> None:
index = await self._get_and_cache_bank_index(bank_id)
if not index:
raise ValueError(f"Bank {bank_id} not found")
await index.insert_documents(documents)
async def query_documents(
self,
bank_id: str,
query: InterleavedTextMedia,
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse:
index = await self._get_and_cache_bank_index(bank_id)
if not index:
raise ValueError(f"Bank {bank_id} not found")
return await index.query_documents(query, params)

View file

@ -0,0 +1,15 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import PGVectorConfig
async def get_adapter_impl(config: PGVectorConfig, _deps):
from .pgvector import PGVectorMemoryAdapter
impl = PGVectorMemoryAdapter(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,17 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field
@json_schema_type
class PGVectorConfig(BaseModel):
host: str = Field(default="localhost")
port: int = Field(default=5432)
db: str
user: str
password: str

View file

@ -0,0 +1,234 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import uuid
from typing import List, Tuple
import psycopg2
from numpy.typing import NDArray
from psycopg2 import sql
from psycopg2.extras import execute_values, Json
from pydantic import BaseModel
from llama_toolchain.memory.api import * # noqa: F403
from llama_toolchain.memory.common.vector_store import (
ALL_MINILM_L6_V2_DIMENSION,
BankWithIndex,
EmbeddingIndex,
)
from .config import PGVectorConfig
def check_extension_version(cur):
cur.execute("SELECT extversion FROM pg_extension WHERE extname = 'vector'")
result = cur.fetchone()
return result[0] if result else None
def upsert_models(cur, keys_models: List[Tuple[str, BaseModel]]):
query = sql.SQL(
"""
INSERT INTO metadata_store (key, data)
VALUES %s
ON CONFLICT (key) DO UPDATE
SET data = EXCLUDED.data
"""
)
values = [(key, Json(model.dict())) for key, model in keys_models]
execute_values(cur, query, values, template="(%s, %s)")
def load_models(cur, keys: List[str], cls):
query = "SELECT key, data FROM metadata_store"
if keys:
placeholders = ",".join(["%s"] * len(keys))
query += f" WHERE key IN ({placeholders})"
cur.execute(query, keys)
else:
cur.execute(query)
rows = cur.fetchall()
return [cls(**row["data"]) for row in rows]
class PGVectorIndex(EmbeddingIndex):
def __init__(self, bank: MemoryBank, dimension: int, cursor):
self.cursor = cursor
self.table_name = f"vector_store_{bank.name}"
self.cursor.execute(
f"""
CREATE TABLE IF NOT EXISTS {self.table_name} (
id TEXT PRIMARY KEY,
document JSONB,
embedding vector({dimension})
)
"""
)
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
assert len(chunks) == len(
embeddings
), f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
values = []
for i, chunk in enumerate(chunks):
print(f"Adding chunk #{i} tokens={chunk.token_count}")
values.append(
(
f"{chunk.document_id}:chunk-{i}",
Json(chunk.dict()),
embeddings[i].tolist(),
)
)
query = sql.SQL(
f"""
INSERT INTO {self.table_name} (id, document, embedding)
VALUES %s
ON CONFLICT (id) DO UPDATE SET embedding = EXCLUDED.embedding, document = EXCLUDED.document
"""
)
execute_values(self.cursor, query, values, template="(%s, %s, %s::vector)")
async def query(self, embedding: NDArray, k: int) -> QueryDocumentsResponse:
self.cursor.execute(
f"""
SELECT document, embedding <-> %s::vector AS distance
FROM {self.table_name}
ORDER BY distance
LIMIT %s
""",
(embedding.tolist(), k),
)
results = self.cursor.fetchall()
chunks = []
scores = []
for doc, dist in results:
chunks.append(Chunk(**doc))
scores.append(1.0 / float(dist))
return QueryDocumentsResponse(chunks=chunks, scores=scores)
class PGVectorMemoryAdapter(Memory):
def __init__(self, config: PGVectorConfig) -> None:
print(f"Initializing PGVectorMemoryAdapter -> {config.host}:{config.port}")
self.config = config
self.cursor = None
self.conn = None
self.cache = {}
async def initialize(self) -> None:
try:
self.conn = psycopg2.connect(
host=self.config.host,
port=self.config.port,
database=self.config.db,
user=self.config.user,
password=self.config.password,
)
self.cursor = self.conn.cursor()
version = check_extension_version(self.cursor)
if version:
print(f"Vector extension version: {version}")
else:
raise RuntimeError("Vector extension is not installed.")
self.cursor.execute(
"""
CREATE TABLE IF NOT EXISTS metadata_store (
key TEXT PRIMARY KEY,
data JSONB
)
"""
)
except Exception as e:
import traceback
traceback.print_exc()
raise RuntimeError("Could not connect to PGVector database server") from e
async def shutdown(self) -> None:
pass
async def create_memory_bank(
self,
name: str,
config: MemoryBankConfig,
url: Optional[URL] = None,
) -> MemoryBank:
bank_id = str(uuid.uuid4())
bank = MemoryBank(
bank_id=bank_id,
name=name,
config=config,
url=url,
)
upsert_models(
self.cursor,
[
(bank.bank_id, bank),
],
)
index = BankWithIndex(
bank=bank,
index=PGVectorIndex(bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor),
)
self.cache[bank_id] = index
return bank
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
bank_index = await self._get_and_cache_bank_index(bank_id)
if bank_index is None:
return None
return bank_index.bank
async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]:
if bank_id in self.cache:
return self.cache[bank_id]
banks = load_models(self.cursor, [bank_id], MemoryBank)
if not banks:
return None
bank = banks[0]
index = BankWithIndex(
bank=bank,
index=PGVectorIndex(bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor),
)
self.cache[bank_id] = index
return index
async def insert_documents(
self,
bank_id: str,
documents: List[MemoryBankDocument],
ttl_seconds: Optional[int] = None,
) -> None:
index = await self._get_and_cache_bank_index(bank_id)
if not index:
raise ValueError(f"Bank {bank_id} not found")
await index.insert_documents(documents)
async def query_documents(
self,
bank_id: str,
query: InterleavedTextMedia,
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse:
index = await self._get_and_cache_bank_index(bank_id)
if not index:
raise ValueError(f"Bank {bank_id} not found")
return await index.query_documents(query, params)

View file

@ -0,0 +1,120 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
import httpx
import numpy as np
from numpy.typing import NDArray
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_toolchain.memory.api import * # noqa: F403
ALL_MINILM_L6_V2_DIMENSION = 384
EMBEDDING_MODEL = None
def get_embedding_model() -> "SentenceTransformer":
global EMBEDDING_MODEL
if EMBEDDING_MODEL is None:
print("Loading sentence transformer")
from sentence_transformers import SentenceTransformer
EMBEDDING_MODEL = SentenceTransformer("all-MiniLM-L6-v2")
return EMBEDDING_MODEL
async def content_from_doc(doc: MemoryBankDocument) -> str:
if isinstance(doc.content, URL):
async with httpx.AsyncClient() as client:
r = await client.get(doc.content.uri)
return r.text
return interleaved_text_media_as_str(doc.content)
def make_overlapped_chunks(
document_id: str, text: str, window_len: int, overlap_len: int
) -> List[Chunk]:
tokenizer = Tokenizer.get_instance()
tokens = tokenizer.encode(text, bos=False, eos=False)
chunks = []
for i in range(0, len(tokens), window_len - overlap_len):
toks = tokens[i : i + window_len]
chunk = tokenizer.decode(toks)
chunks.append(
Chunk(content=chunk, token_count=len(toks), document_id=document_id)
)
return chunks
class EmbeddingIndex(ABC):
@abstractmethod
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
raise NotImplementedError()
@abstractmethod
async def query(self, embedding: NDArray, k: int) -> QueryDocumentsResponse:
raise NotImplementedError()
@dataclass
class BankWithIndex:
bank: MemoryBank
index: EmbeddingIndex
async def insert_documents(
self,
documents: List[MemoryBankDocument],
) -> None:
model = get_embedding_model()
for doc in documents:
content = await content_from_doc(doc)
chunks = make_overlapped_chunks(
doc.document_id,
content,
self.bank.config.chunk_size_in_tokens,
self.bank.config.overlap_size_in_tokens
or (self.bank.config.chunk_size_in_tokens // 4),
)
embeddings = model.encode([x.content for x in chunks]).astype(np.float32)
await self.index.add_chunks(chunks, embeddings)
async def query_documents(
self,
query: InterleavedTextMedia,
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse:
if params is None:
params = {}
k = params.get("max_chunks", 3)
def _process(c) -> str:
if isinstance(c, str):
return c
else:
return "<media>"
if isinstance(query, list):
query_str = " ".join([_process(c) for c in query])
else:
query_str = _process(query)
model = get_embedding_model()
query_vector = model.encode([query_str])[0].astype(np.float32)
return await self.index.query(query_vector, k)

View file

@ -5,108 +5,45 @@
# the root directory of this source tree.
import uuid
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional
import faiss
import httpx
import numpy as np
from numpy.typing import NDArray
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_toolchain.memory.api import * # noqa: F403
from llama_toolchain.memory.common.vector_store import (
ALL_MINILM_L6_V2_DIMENSION,
BankWithIndex,
EmbeddingIndex,
)
from .config import FaissImplConfig
async def content_from_doc(doc: MemoryBankDocument) -> str:
if isinstance(doc.content, URL):
async with httpx.AsyncClient() as client:
r = await client.get(doc.content.uri)
return r.text
class FaissIndex(EmbeddingIndex):
id_by_index: Dict[int, str]
chunk_by_index: Dict[int, str]
return interleaved_text_media_as_str(doc.content)
def __init__(self, dimension: int):
self.index = faiss.IndexFlatL2(dimension)
self.id_by_index = {}
self.chunk_by_index = {}
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
indexlen = len(self.id_by_index)
for i, chunk in enumerate(chunks):
self.chunk_by_index[indexlen + i] = chunk
print(f"Adding chunk #{indexlen + i} tokens={chunk.token_count}")
self.id_by_index[indexlen + i] = chunk.document_id
def make_overlapped_chunks(
text: str, window_len: int, overlap_len: int
) -> List[Tuple[str, int]]:
tokenizer = Tokenizer.get_instance()
tokens = tokenizer.encode(text, bos=False, eos=False)
self.index.add(np.array(embeddings).astype(np.float32))
chunks = []
for i in range(0, len(tokens), window_len - overlap_len):
toks = tokens[i : i + window_len]
chunk = tokenizer.decode(toks)
chunks.append((chunk, len(toks)))
return chunks
@dataclass
class BankState:
bank: MemoryBank
index: Optional[faiss.IndexFlatL2] = None
doc_by_id: Dict[str, MemoryBankDocument] = field(default_factory=dict)
id_by_index: Dict[int, str] = field(default_factory=dict)
chunk_by_index: Dict[int, str] = field(default_factory=dict)
async def insert_documents(
self,
model: "SentenceTransformer",
documents: List[MemoryBankDocument],
) -> None:
tokenizer = Tokenizer.get_instance()
chunk_size = self.bank.config.chunk_size_in_tokens
for doc in documents:
indexlen = len(self.id_by_index)
self.doc_by_id[doc.document_id] = doc
content = await content_from_doc(doc)
chunks = make_overlapped_chunks(
content,
self.bank.config.chunk_size_in_tokens,
self.bank.config.overlap_size_in_tokens
or (self.bank.config.chunk_size_in_tokens // 4),
)
embeddings = model.encode([x[0] for x in chunks]).astype(np.float32)
await self._ensure_index(embeddings.shape[1])
self.index.add(embeddings)
for i, chunk in enumerate(chunks):
self.chunk_by_index[indexlen + i] = Chunk(
content=chunk[0],
token_count=chunk[1],
document_id=doc.document_id,
)
print(f"Adding chunk #{indexlen + i} tokens={chunk[1]}")
self.id_by_index[indexlen + i] = doc.document_id
async def query_documents(
self,
model: "SentenceTransformer",
query: InterleavedTextMedia,
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse:
if params is None:
params = {}
k = params.get("max_chunks", 3)
def _process(c) -> str:
if isinstance(c, str):
return c
else:
return "<media>"
if isinstance(query, list):
query_str = " ".join([_process(c) for c in query])
else:
query_str = _process(query)
query_vector = model.encode([query_str])[0]
async def query(self, embedding: NDArray, k: int) -> QueryDocumentsResponse:
distances, indices = self.index.search(
query_vector.reshape(1, -1).astype(np.float32), k
embedding.reshape(1, -1).astype(np.float32), k
)
chunks = []
@ -119,17 +56,11 @@ class BankState:
return QueryDocumentsResponse(chunks=chunks, scores=scores)
async def _ensure_index(self, dimension: int) -> faiss.IndexFlatL2:
if self.index is None:
self.index = faiss.IndexFlatL2(dimension)
return self.index
class FaissMemoryImpl(Memory):
def __init__(self, config: FaissImplConfig) -> None:
self.config = config
self.model = None
self.states = {}
self.cache = {}
async def initialize(self) -> None: ...
@ -153,14 +84,15 @@ class FaissMemoryImpl(Memory):
config=config,
url=url,
)
state = BankState(bank=bank)
self.states[bank_id] = state
index = BankWithIndex(bank=bank, index=FaissIndex(ALL_MINILM_L6_V2_DIMENSION))
self.cache[bank_id] = index
return bank
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
if bank_id not in self.states:
index = self.cache.get(bank_id)
if index is None:
return None
return self.states[bank_id].bank
return index.bank
async def insert_documents(
self,
@ -168,10 +100,11 @@ class FaissMemoryImpl(Memory):
documents: List[MemoryBankDocument],
ttl_seconds: Optional[int] = None,
) -> None:
assert bank_id in self.states, f"Bank {bank_id} not found"
state = self.states[bank_id]
index = self.cache.get(bank_id)
if index is None:
raise ValueError(f"Bank {bank_id} not found")
await state.insert_documents(self.get_model(), documents)
await index.insert_documents(documents)
async def query_documents(
self,
@ -179,16 +112,8 @@ class FaissMemoryImpl(Memory):
query: InterleavedTextMedia,
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse:
assert bank_id in self.states, f"Bank {bank_id} not found"
state = self.states[bank_id]
index = self.cache.get(bank_id)
if index is None:
raise ValueError(f"Bank {bank_id} not found")
return await state.query_documents(self.get_model(), query, params)
def get_model(self) -> "SentenceTransformer":
from sentence_transformers import SentenceTransformer
if self.model is None:
print("Loading sentence transformer")
self.model = SentenceTransformer("all-MiniLM-L6-v2")
return self.model
return await index.query_documents(query, params)

View file

@ -6,7 +6,12 @@
from typing import List
from llama_toolchain.core.datatypes import Api, InlineProviderSpec, ProviderSpec
from llama_toolchain.core.datatypes import * # noqa: F403
EMBEDDING_DEPS = [
"blobfile",
"sentence-transformers",
]
def available_memory_providers() -> List[ProviderSpec]:
@ -14,12 +19,25 @@ def available_memory_providers() -> List[ProviderSpec]:
InlineProviderSpec(
api=Api.memory,
provider_id="meta-reference-faiss",
pip_packages=[
"blobfile",
"faiss-cpu",
"sentence-transformers",
],
pip_packages=EMBEDDING_DEPS + ["faiss-cpu"],
module="llama_toolchain.memory.meta_reference.faiss",
config_class="llama_toolchain.memory.meta_reference.faiss.FaissImplConfig",
),
remote_provider_spec(
api=Api.memory,
adapter=AdapterSpec(
adapter_id="chromadb",
pip_packages=EMBEDDING_DEPS + ["chromadb-client"],
module="llama_toolchain.memory.adapters.chroma",
),
),
remote_provider_spec(
api=Api.memory,
adapter=AdapterSpec(
adapter_id="pgvector",
pip_packages=EMBEDDING_DEPS + ["psycopg2-binary"],
module="llama_toolchain.memory.adapters.pgvector",
config_class="llama_toolchain.memory.adapters.pgvector.PGVectorConfig",
),
),
]