mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 10:54:19 +00:00
Add Chroma and PGVector adapters (#56)
Co-authored-by: Ashwin Bharambe <ashwin@meta.com>
This commit is contained in:
parent
5de6ed946e
commit
3f090d1975
8 changed files with 628 additions and 119 deletions
15
llama_toolchain/memory/adapters/chroma/__init__.py
Normal file
15
llama_toolchain/memory/adapters/chroma/__init__.py
Normal file
|
@ -0,0 +1,15 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from llama_toolchain.core.datatypes import RemoteProviderConfig
|
||||||
|
|
||||||
|
|
||||||
|
async def get_adapter_impl(config: RemoteProviderConfig, _deps):
|
||||||
|
from .chroma import ChromaMemoryAdapter
|
||||||
|
|
||||||
|
impl = ChromaMemoryAdapter(config.url)
|
||||||
|
await impl.initialize()
|
||||||
|
return impl
|
165
llama_toolchain/memory/adapters/chroma/chroma.py
Normal file
165
llama_toolchain/memory/adapters/chroma/chroma.py
Normal file
|
@ -0,0 +1,165 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
from typing import List
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
import chromadb
|
||||||
|
from numpy.typing import NDArray
|
||||||
|
|
||||||
|
from llama_toolchain.memory.api import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
|
from llama_toolchain.memory.common.vector_store import BankWithIndex, EmbeddingIndex
|
||||||
|
|
||||||
|
|
||||||
|
class ChromaIndex(EmbeddingIndex):
|
||||||
|
def __init__(self, client: chromadb.AsyncHttpClient, collection):
|
||||||
|
self.client = client
|
||||||
|
self.collection = collection
|
||||||
|
|
||||||
|
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
|
||||||
|
assert len(chunks) == len(
|
||||||
|
embeddings
|
||||||
|
), f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
|
||||||
|
|
||||||
|
for i, chunk in enumerate(chunks):
|
||||||
|
print(f"Adding chunk #{i} tokens={chunk.token_count}")
|
||||||
|
|
||||||
|
await self.collection.add(
|
||||||
|
documents=[chunk.json() for chunk in chunks],
|
||||||
|
embeddings=embeddings,
|
||||||
|
ids=[f"{c.document_id}:chunk-{i}" for i, c in enumerate(chunks)],
|
||||||
|
)
|
||||||
|
|
||||||
|
async def query(self, embedding: NDArray, k: int) -> QueryDocumentsResponse:
|
||||||
|
results = await self.collection.query(
|
||||||
|
query_embeddings=[embedding.tolist()],
|
||||||
|
n_results=k,
|
||||||
|
include=["documents", "distances"],
|
||||||
|
)
|
||||||
|
distances = results["distances"][0]
|
||||||
|
documents = results["documents"][0]
|
||||||
|
|
||||||
|
chunks = []
|
||||||
|
scores = []
|
||||||
|
for dist, doc in zip(distances, documents):
|
||||||
|
try:
|
||||||
|
doc = json.loads(doc)
|
||||||
|
chunk = Chunk(**doc)
|
||||||
|
except Exception:
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
traceback.print_exc()
|
||||||
|
print(f"Failed to parse document: {doc}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
chunks.append(chunk)
|
||||||
|
scores.append(1.0 / float(dist))
|
||||||
|
|
||||||
|
return QueryDocumentsResponse(chunks=chunks, scores=scores)
|
||||||
|
|
||||||
|
|
||||||
|
class ChromaMemoryAdapter(Memory):
|
||||||
|
def __init__(self, url: str) -> None:
|
||||||
|
print(f"Initializing ChromaMemoryAdapter with url: {url}")
|
||||||
|
url = url.rstrip("/")
|
||||||
|
parsed = urlparse(url)
|
||||||
|
|
||||||
|
if parsed.path and parsed.path != "/":
|
||||||
|
raise ValueError("URL should not contain a path")
|
||||||
|
|
||||||
|
self.host = parsed.hostname
|
||||||
|
self.port = parsed.port
|
||||||
|
|
||||||
|
self.client = None
|
||||||
|
self.cache = {}
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
try:
|
||||||
|
print(f"Connecting to Chroma server at: {self.host}:{self.port}")
|
||||||
|
self.client = await chromadb.AsyncHttpClient(host=self.host, port=self.port)
|
||||||
|
except Exception as e:
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
traceback.print_exc()
|
||||||
|
raise RuntimeError("Could not connect to Chroma server") from e
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def create_memory_bank(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
config: MemoryBankConfig,
|
||||||
|
url: Optional[URL] = None,
|
||||||
|
) -> MemoryBank:
|
||||||
|
bank_id = str(uuid.uuid4())
|
||||||
|
bank = MemoryBank(
|
||||||
|
bank_id=bank_id,
|
||||||
|
name=name,
|
||||||
|
config=config,
|
||||||
|
url=url,
|
||||||
|
)
|
||||||
|
collection = await self.client.create_collection(
|
||||||
|
name=bank_id,
|
||||||
|
metadata={"bank": bank.json()},
|
||||||
|
)
|
||||||
|
bank_index = BankWithIndex(
|
||||||
|
bank=bank, index=ChromaIndex(self.client, collection)
|
||||||
|
)
|
||||||
|
self.cache[bank_id] = bank_index
|
||||||
|
return bank
|
||||||
|
|
||||||
|
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
|
||||||
|
bank_index = await self._get_and_cache_bank_index(bank_id)
|
||||||
|
if bank_index is None:
|
||||||
|
return None
|
||||||
|
return bank_index.bank
|
||||||
|
|
||||||
|
async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]:
|
||||||
|
if bank_id in self.cache:
|
||||||
|
return self.cache[bank_id]
|
||||||
|
|
||||||
|
collections = await self.client.list_collections()
|
||||||
|
for collection in collections:
|
||||||
|
if collection.name == bank_id:
|
||||||
|
print(collection.metadata)
|
||||||
|
bank = MemoryBank(**json.loads(collection.metadata["bank"]))
|
||||||
|
index = BankWithIndex(
|
||||||
|
bank=bank,
|
||||||
|
index=ChromaIndex(self.client, collection),
|
||||||
|
)
|
||||||
|
self.cache[bank_id] = index
|
||||||
|
return index
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def insert_documents(
|
||||||
|
self,
|
||||||
|
bank_id: str,
|
||||||
|
documents: List[MemoryBankDocument],
|
||||||
|
ttl_seconds: Optional[int] = None,
|
||||||
|
) -> None:
|
||||||
|
index = await self._get_and_cache_bank_index(bank_id)
|
||||||
|
if not index:
|
||||||
|
raise ValueError(f"Bank {bank_id} not found")
|
||||||
|
|
||||||
|
await index.insert_documents(documents)
|
||||||
|
|
||||||
|
async def query_documents(
|
||||||
|
self,
|
||||||
|
bank_id: str,
|
||||||
|
query: InterleavedTextMedia,
|
||||||
|
params: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> QueryDocumentsResponse:
|
||||||
|
index = await self._get_and_cache_bank_index(bank_id)
|
||||||
|
if not index:
|
||||||
|
raise ValueError(f"Bank {bank_id} not found")
|
||||||
|
|
||||||
|
return await index.query_documents(query, params)
|
15
llama_toolchain/memory/adapters/pgvector/__init__.py
Normal file
15
llama_toolchain/memory/adapters/pgvector/__init__.py
Normal file
|
@ -0,0 +1,15 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from .config import PGVectorConfig
|
||||||
|
|
||||||
|
|
||||||
|
async def get_adapter_impl(config: PGVectorConfig, _deps):
|
||||||
|
from .pgvector import PGVectorMemoryAdapter
|
||||||
|
|
||||||
|
impl = PGVectorMemoryAdapter(config)
|
||||||
|
await impl.initialize()
|
||||||
|
return impl
|
17
llama_toolchain/memory/adapters/pgvector/config.py
Normal file
17
llama_toolchain/memory/adapters/pgvector/config.py
Normal file
|
@ -0,0 +1,17 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from llama_models.schema_utils import json_schema_type
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class PGVectorConfig(BaseModel):
|
||||||
|
host: str = Field(default="localhost")
|
||||||
|
port: int = Field(default=5432)
|
||||||
|
db: str
|
||||||
|
user: str
|
||||||
|
password: str
|
234
llama_toolchain/memory/adapters/pgvector/pgvector.py
Normal file
234
llama_toolchain/memory/adapters/pgvector/pgvector.py
Normal file
|
@ -0,0 +1,234 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
import psycopg2
|
||||||
|
from numpy.typing import NDArray
|
||||||
|
from psycopg2 import sql
|
||||||
|
from psycopg2.extras import execute_values, Json
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from llama_toolchain.memory.api import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
|
from llama_toolchain.memory.common.vector_store import (
|
||||||
|
ALL_MINILM_L6_V2_DIMENSION,
|
||||||
|
BankWithIndex,
|
||||||
|
EmbeddingIndex,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .config import PGVectorConfig
|
||||||
|
|
||||||
|
|
||||||
|
def check_extension_version(cur):
|
||||||
|
cur.execute("SELECT extversion FROM pg_extension WHERE extname = 'vector'")
|
||||||
|
result = cur.fetchone()
|
||||||
|
return result[0] if result else None
|
||||||
|
|
||||||
|
|
||||||
|
def upsert_models(cur, keys_models: List[Tuple[str, BaseModel]]):
|
||||||
|
query = sql.SQL(
|
||||||
|
"""
|
||||||
|
INSERT INTO metadata_store (key, data)
|
||||||
|
VALUES %s
|
||||||
|
ON CONFLICT (key) DO UPDATE
|
||||||
|
SET data = EXCLUDED.data
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
values = [(key, Json(model.dict())) for key, model in keys_models]
|
||||||
|
execute_values(cur, query, values, template="(%s, %s)")
|
||||||
|
|
||||||
|
|
||||||
|
def load_models(cur, keys: List[str], cls):
|
||||||
|
query = "SELECT key, data FROM metadata_store"
|
||||||
|
if keys:
|
||||||
|
placeholders = ",".join(["%s"] * len(keys))
|
||||||
|
query += f" WHERE key IN ({placeholders})"
|
||||||
|
cur.execute(query, keys)
|
||||||
|
else:
|
||||||
|
cur.execute(query)
|
||||||
|
|
||||||
|
rows = cur.fetchall()
|
||||||
|
return [cls(**row["data"]) for row in rows]
|
||||||
|
|
||||||
|
|
||||||
|
class PGVectorIndex(EmbeddingIndex):
|
||||||
|
def __init__(self, bank: MemoryBank, dimension: int, cursor):
|
||||||
|
self.cursor = cursor
|
||||||
|
self.table_name = f"vector_store_{bank.name}"
|
||||||
|
|
||||||
|
self.cursor.execute(
|
||||||
|
f"""
|
||||||
|
CREATE TABLE IF NOT EXISTS {self.table_name} (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
document JSONB,
|
||||||
|
embedding vector({dimension})
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
|
||||||
|
assert len(chunks) == len(
|
||||||
|
embeddings
|
||||||
|
), f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
|
||||||
|
|
||||||
|
values = []
|
||||||
|
for i, chunk in enumerate(chunks):
|
||||||
|
print(f"Adding chunk #{i} tokens={chunk.token_count}")
|
||||||
|
values.append(
|
||||||
|
(
|
||||||
|
f"{chunk.document_id}:chunk-{i}",
|
||||||
|
Json(chunk.dict()),
|
||||||
|
embeddings[i].tolist(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
query = sql.SQL(
|
||||||
|
f"""
|
||||||
|
INSERT INTO {self.table_name} (id, document, embedding)
|
||||||
|
VALUES %s
|
||||||
|
ON CONFLICT (id) DO UPDATE SET embedding = EXCLUDED.embedding, document = EXCLUDED.document
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
execute_values(self.cursor, query, values, template="(%s, %s, %s::vector)")
|
||||||
|
|
||||||
|
async def query(self, embedding: NDArray, k: int) -> QueryDocumentsResponse:
|
||||||
|
self.cursor.execute(
|
||||||
|
f"""
|
||||||
|
SELECT document, embedding <-> %s::vector AS distance
|
||||||
|
FROM {self.table_name}
|
||||||
|
ORDER BY distance
|
||||||
|
LIMIT %s
|
||||||
|
""",
|
||||||
|
(embedding.tolist(), k),
|
||||||
|
)
|
||||||
|
results = self.cursor.fetchall()
|
||||||
|
|
||||||
|
chunks = []
|
||||||
|
scores = []
|
||||||
|
for doc, dist in results:
|
||||||
|
chunks.append(Chunk(**doc))
|
||||||
|
scores.append(1.0 / float(dist))
|
||||||
|
|
||||||
|
return QueryDocumentsResponse(chunks=chunks, scores=scores)
|
||||||
|
|
||||||
|
|
||||||
|
class PGVectorMemoryAdapter(Memory):
|
||||||
|
def __init__(self, config: PGVectorConfig) -> None:
|
||||||
|
print(f"Initializing PGVectorMemoryAdapter -> {config.host}:{config.port}")
|
||||||
|
self.config = config
|
||||||
|
self.cursor = None
|
||||||
|
self.conn = None
|
||||||
|
self.cache = {}
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
try:
|
||||||
|
self.conn = psycopg2.connect(
|
||||||
|
host=self.config.host,
|
||||||
|
port=self.config.port,
|
||||||
|
database=self.config.db,
|
||||||
|
user=self.config.user,
|
||||||
|
password=self.config.password,
|
||||||
|
)
|
||||||
|
self.cursor = self.conn.cursor()
|
||||||
|
|
||||||
|
version = check_extension_version(self.cursor)
|
||||||
|
if version:
|
||||||
|
print(f"Vector extension version: {version}")
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Vector extension is not installed.")
|
||||||
|
|
||||||
|
self.cursor.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS metadata_store (
|
||||||
|
key TEXT PRIMARY KEY,
|
||||||
|
data JSONB
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
traceback.print_exc()
|
||||||
|
raise RuntimeError("Could not connect to PGVector database server") from e
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def create_memory_bank(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
config: MemoryBankConfig,
|
||||||
|
url: Optional[URL] = None,
|
||||||
|
) -> MemoryBank:
|
||||||
|
bank_id = str(uuid.uuid4())
|
||||||
|
bank = MemoryBank(
|
||||||
|
bank_id=bank_id,
|
||||||
|
name=name,
|
||||||
|
config=config,
|
||||||
|
url=url,
|
||||||
|
)
|
||||||
|
upsert_models(
|
||||||
|
self.cursor,
|
||||||
|
[
|
||||||
|
(bank.bank_id, bank),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
index = BankWithIndex(
|
||||||
|
bank=bank,
|
||||||
|
index=PGVectorIndex(bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor),
|
||||||
|
)
|
||||||
|
self.cache[bank_id] = index
|
||||||
|
return bank
|
||||||
|
|
||||||
|
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
|
||||||
|
bank_index = await self._get_and_cache_bank_index(bank_id)
|
||||||
|
if bank_index is None:
|
||||||
|
return None
|
||||||
|
return bank_index.bank
|
||||||
|
|
||||||
|
async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]:
|
||||||
|
if bank_id in self.cache:
|
||||||
|
return self.cache[bank_id]
|
||||||
|
|
||||||
|
banks = load_models(self.cursor, [bank_id], MemoryBank)
|
||||||
|
if not banks:
|
||||||
|
return None
|
||||||
|
|
||||||
|
bank = banks[0]
|
||||||
|
index = BankWithIndex(
|
||||||
|
bank=bank,
|
||||||
|
index=PGVectorIndex(bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor),
|
||||||
|
)
|
||||||
|
self.cache[bank_id] = index
|
||||||
|
return index
|
||||||
|
|
||||||
|
async def insert_documents(
|
||||||
|
self,
|
||||||
|
bank_id: str,
|
||||||
|
documents: List[MemoryBankDocument],
|
||||||
|
ttl_seconds: Optional[int] = None,
|
||||||
|
) -> None:
|
||||||
|
index = await self._get_and_cache_bank_index(bank_id)
|
||||||
|
if not index:
|
||||||
|
raise ValueError(f"Bank {bank_id} not found")
|
||||||
|
|
||||||
|
await index.insert_documents(documents)
|
||||||
|
|
||||||
|
async def query_documents(
|
||||||
|
self,
|
||||||
|
bank_id: str,
|
||||||
|
query: InterleavedTextMedia,
|
||||||
|
params: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> QueryDocumentsResponse:
|
||||||
|
index = await self._get_and_cache_bank_index(bank_id)
|
||||||
|
if not index:
|
||||||
|
raise ValueError(f"Bank {bank_id} not found")
|
||||||
|
|
||||||
|
return await index.query_documents(query, params)
|
120
llama_toolchain/memory/common/vector_store.py
Normal file
120
llama_toolchain/memory/common/vector_store.py
Normal file
|
@ -0,0 +1,120 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import numpy as np
|
||||||
|
from numpy.typing import NDArray
|
||||||
|
|
||||||
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
|
|
||||||
|
from llama_toolchain.memory.api import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
|
ALL_MINILM_L6_V2_DIMENSION = 384
|
||||||
|
|
||||||
|
EMBEDDING_MODEL = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_embedding_model() -> "SentenceTransformer":
|
||||||
|
global EMBEDDING_MODEL
|
||||||
|
|
||||||
|
if EMBEDDING_MODEL is None:
|
||||||
|
print("Loading sentence transformer")
|
||||||
|
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
|
||||||
|
EMBEDDING_MODEL = SentenceTransformer("all-MiniLM-L6-v2")
|
||||||
|
|
||||||
|
return EMBEDDING_MODEL
|
||||||
|
|
||||||
|
|
||||||
|
async def content_from_doc(doc: MemoryBankDocument) -> str:
|
||||||
|
if isinstance(doc.content, URL):
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
r = await client.get(doc.content.uri)
|
||||||
|
return r.text
|
||||||
|
|
||||||
|
return interleaved_text_media_as_str(doc.content)
|
||||||
|
|
||||||
|
|
||||||
|
def make_overlapped_chunks(
|
||||||
|
document_id: str, text: str, window_len: int, overlap_len: int
|
||||||
|
) -> List[Chunk]:
|
||||||
|
tokenizer = Tokenizer.get_instance()
|
||||||
|
tokens = tokenizer.encode(text, bos=False, eos=False)
|
||||||
|
|
||||||
|
chunks = []
|
||||||
|
for i in range(0, len(tokens), window_len - overlap_len):
|
||||||
|
toks = tokens[i : i + window_len]
|
||||||
|
chunk = tokenizer.decode(toks)
|
||||||
|
chunks.append(
|
||||||
|
Chunk(content=chunk, token_count=len(toks), document_id=document_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingIndex(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def query(self, embedding: NDArray, k: int) -> QueryDocumentsResponse:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BankWithIndex:
|
||||||
|
bank: MemoryBank
|
||||||
|
index: EmbeddingIndex
|
||||||
|
|
||||||
|
async def insert_documents(
|
||||||
|
self,
|
||||||
|
documents: List[MemoryBankDocument],
|
||||||
|
) -> None:
|
||||||
|
model = get_embedding_model()
|
||||||
|
for doc in documents:
|
||||||
|
content = await content_from_doc(doc)
|
||||||
|
chunks = make_overlapped_chunks(
|
||||||
|
doc.document_id,
|
||||||
|
content,
|
||||||
|
self.bank.config.chunk_size_in_tokens,
|
||||||
|
self.bank.config.overlap_size_in_tokens
|
||||||
|
or (self.bank.config.chunk_size_in_tokens // 4),
|
||||||
|
)
|
||||||
|
embeddings = model.encode([x.content for x in chunks]).astype(np.float32)
|
||||||
|
|
||||||
|
await self.index.add_chunks(chunks, embeddings)
|
||||||
|
|
||||||
|
async def query_documents(
|
||||||
|
self,
|
||||||
|
query: InterleavedTextMedia,
|
||||||
|
params: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> QueryDocumentsResponse:
|
||||||
|
if params is None:
|
||||||
|
params = {}
|
||||||
|
k = params.get("max_chunks", 3)
|
||||||
|
|
||||||
|
def _process(c) -> str:
|
||||||
|
if isinstance(c, str):
|
||||||
|
return c
|
||||||
|
else:
|
||||||
|
return "<media>"
|
||||||
|
|
||||||
|
if isinstance(query, list):
|
||||||
|
query_str = " ".join([_process(c) for c in query])
|
||||||
|
else:
|
||||||
|
query_str = _process(query)
|
||||||
|
|
||||||
|
model = get_embedding_model()
|
||||||
|
query_vector = model.encode([query_str])[0].astype(np.float32)
|
||||||
|
return await self.index.query(query_vector, k)
|
|
@ -5,108 +5,45 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import faiss
|
import faiss
|
||||||
import httpx
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from numpy.typing import NDArray
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
|
||||||
|
|
||||||
from llama_toolchain.memory.api import * # noqa: F403
|
from llama_toolchain.memory.api import * # noqa: F403
|
||||||
|
from llama_toolchain.memory.common.vector_store import (
|
||||||
|
ALL_MINILM_L6_V2_DIMENSION,
|
||||||
|
BankWithIndex,
|
||||||
|
EmbeddingIndex,
|
||||||
|
)
|
||||||
from .config import FaissImplConfig
|
from .config import FaissImplConfig
|
||||||
|
|
||||||
|
|
||||||
async def content_from_doc(doc: MemoryBankDocument) -> str:
|
class FaissIndex(EmbeddingIndex):
|
||||||
if isinstance(doc.content, URL):
|
id_by_index: Dict[int, str]
|
||||||
async with httpx.AsyncClient() as client:
|
chunk_by_index: Dict[int, str]
|
||||||
r = await client.get(doc.content.uri)
|
|
||||||
return r.text
|
|
||||||
|
|
||||||
return interleaved_text_media_as_str(doc.content)
|
def __init__(self, dimension: int):
|
||||||
|
self.index = faiss.IndexFlatL2(dimension)
|
||||||
|
self.id_by_index = {}
|
||||||
|
self.chunk_by_index = {}
|
||||||
|
|
||||||
|
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
|
||||||
def make_overlapped_chunks(
|
|
||||||
text: str, window_len: int, overlap_len: int
|
|
||||||
) -> List[Tuple[str, int]]:
|
|
||||||
tokenizer = Tokenizer.get_instance()
|
|
||||||
tokens = tokenizer.encode(text, bos=False, eos=False)
|
|
||||||
|
|
||||||
chunks = []
|
|
||||||
for i in range(0, len(tokens), window_len - overlap_len):
|
|
||||||
toks = tokens[i : i + window_len]
|
|
||||||
chunk = tokenizer.decode(toks)
|
|
||||||
chunks.append((chunk, len(toks)))
|
|
||||||
|
|
||||||
return chunks
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class BankState:
|
|
||||||
bank: MemoryBank
|
|
||||||
index: Optional[faiss.IndexFlatL2] = None
|
|
||||||
doc_by_id: Dict[str, MemoryBankDocument] = field(default_factory=dict)
|
|
||||||
id_by_index: Dict[int, str] = field(default_factory=dict)
|
|
||||||
chunk_by_index: Dict[int, str] = field(default_factory=dict)
|
|
||||||
|
|
||||||
async def insert_documents(
|
|
||||||
self,
|
|
||||||
model: "SentenceTransformer",
|
|
||||||
documents: List[MemoryBankDocument],
|
|
||||||
) -> None:
|
|
||||||
tokenizer = Tokenizer.get_instance()
|
|
||||||
chunk_size = self.bank.config.chunk_size_in_tokens
|
|
||||||
|
|
||||||
for doc in documents:
|
|
||||||
indexlen = len(self.id_by_index)
|
indexlen = len(self.id_by_index)
|
||||||
self.doc_by_id[doc.document_id] = doc
|
|
||||||
|
|
||||||
content = await content_from_doc(doc)
|
|
||||||
chunks = make_overlapped_chunks(
|
|
||||||
content,
|
|
||||||
self.bank.config.chunk_size_in_tokens,
|
|
||||||
self.bank.config.overlap_size_in_tokens
|
|
||||||
or (self.bank.config.chunk_size_in_tokens // 4),
|
|
||||||
)
|
|
||||||
embeddings = model.encode([x[0] for x in chunks]).astype(np.float32)
|
|
||||||
await self._ensure_index(embeddings.shape[1])
|
|
||||||
|
|
||||||
self.index.add(embeddings)
|
|
||||||
for i, chunk in enumerate(chunks):
|
for i, chunk in enumerate(chunks):
|
||||||
self.chunk_by_index[indexlen + i] = Chunk(
|
self.chunk_by_index[indexlen + i] = chunk
|
||||||
content=chunk[0],
|
print(f"Adding chunk #{indexlen + i} tokens={chunk.token_count}")
|
||||||
token_count=chunk[1],
|
self.id_by_index[indexlen + i] = chunk.document_id
|
||||||
document_id=doc.document_id,
|
|
||||||
)
|
|
||||||
print(f"Adding chunk #{indexlen + i} tokens={chunk[1]}")
|
|
||||||
self.id_by_index[indexlen + i] = doc.document_id
|
|
||||||
|
|
||||||
async def query_documents(
|
self.index.add(np.array(embeddings).astype(np.float32))
|
||||||
self,
|
|
||||||
model: "SentenceTransformer",
|
|
||||||
query: InterleavedTextMedia,
|
|
||||||
params: Optional[Dict[str, Any]] = None,
|
|
||||||
) -> QueryDocumentsResponse:
|
|
||||||
if params is None:
|
|
||||||
params = {}
|
|
||||||
k = params.get("max_chunks", 3)
|
|
||||||
|
|
||||||
def _process(c) -> str:
|
async def query(self, embedding: NDArray, k: int) -> QueryDocumentsResponse:
|
||||||
if isinstance(c, str):
|
|
||||||
return c
|
|
||||||
else:
|
|
||||||
return "<media>"
|
|
||||||
|
|
||||||
if isinstance(query, list):
|
|
||||||
query_str = " ".join([_process(c) for c in query])
|
|
||||||
else:
|
|
||||||
query_str = _process(query)
|
|
||||||
|
|
||||||
query_vector = model.encode([query_str])[0]
|
|
||||||
distances, indices = self.index.search(
|
distances, indices = self.index.search(
|
||||||
query_vector.reshape(1, -1).astype(np.float32), k
|
embedding.reshape(1, -1).astype(np.float32), k
|
||||||
)
|
)
|
||||||
|
|
||||||
chunks = []
|
chunks = []
|
||||||
|
@ -119,17 +56,11 @@ class BankState:
|
||||||
|
|
||||||
return QueryDocumentsResponse(chunks=chunks, scores=scores)
|
return QueryDocumentsResponse(chunks=chunks, scores=scores)
|
||||||
|
|
||||||
async def _ensure_index(self, dimension: int) -> faiss.IndexFlatL2:
|
|
||||||
if self.index is None:
|
|
||||||
self.index = faiss.IndexFlatL2(dimension)
|
|
||||||
return self.index
|
|
||||||
|
|
||||||
|
|
||||||
class FaissMemoryImpl(Memory):
|
class FaissMemoryImpl(Memory):
|
||||||
def __init__(self, config: FaissImplConfig) -> None:
|
def __init__(self, config: FaissImplConfig) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
self.model = None
|
self.cache = {}
|
||||||
self.states = {}
|
|
||||||
|
|
||||||
async def initialize(self) -> None: ...
|
async def initialize(self) -> None: ...
|
||||||
|
|
||||||
|
@ -153,14 +84,15 @@ class FaissMemoryImpl(Memory):
|
||||||
config=config,
|
config=config,
|
||||||
url=url,
|
url=url,
|
||||||
)
|
)
|
||||||
state = BankState(bank=bank)
|
index = BankWithIndex(bank=bank, index=FaissIndex(ALL_MINILM_L6_V2_DIMENSION))
|
||||||
self.states[bank_id] = state
|
self.cache[bank_id] = index
|
||||||
return bank
|
return bank
|
||||||
|
|
||||||
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
|
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
|
||||||
if bank_id not in self.states:
|
index = self.cache.get(bank_id)
|
||||||
|
if index is None:
|
||||||
return None
|
return None
|
||||||
return self.states[bank_id].bank
|
return index.bank
|
||||||
|
|
||||||
async def insert_documents(
|
async def insert_documents(
|
||||||
self,
|
self,
|
||||||
|
@ -168,10 +100,11 @@ class FaissMemoryImpl(Memory):
|
||||||
documents: List[MemoryBankDocument],
|
documents: List[MemoryBankDocument],
|
||||||
ttl_seconds: Optional[int] = None,
|
ttl_seconds: Optional[int] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
assert bank_id in self.states, f"Bank {bank_id} not found"
|
index = self.cache.get(bank_id)
|
||||||
state = self.states[bank_id]
|
if index is None:
|
||||||
|
raise ValueError(f"Bank {bank_id} not found")
|
||||||
|
|
||||||
await state.insert_documents(self.get_model(), documents)
|
await index.insert_documents(documents)
|
||||||
|
|
||||||
async def query_documents(
|
async def query_documents(
|
||||||
self,
|
self,
|
||||||
|
@ -179,16 +112,8 @@ class FaissMemoryImpl(Memory):
|
||||||
query: InterleavedTextMedia,
|
query: InterleavedTextMedia,
|
||||||
params: Optional[Dict[str, Any]] = None,
|
params: Optional[Dict[str, Any]] = None,
|
||||||
) -> QueryDocumentsResponse:
|
) -> QueryDocumentsResponse:
|
||||||
assert bank_id in self.states, f"Bank {bank_id} not found"
|
index = self.cache.get(bank_id)
|
||||||
state = self.states[bank_id]
|
if index is None:
|
||||||
|
raise ValueError(f"Bank {bank_id} not found")
|
||||||
|
|
||||||
return await state.query_documents(self.get_model(), query, params)
|
return await index.query_documents(query, params)
|
||||||
|
|
||||||
def get_model(self) -> "SentenceTransformer":
|
|
||||||
from sentence_transformers import SentenceTransformer
|
|
||||||
|
|
||||||
if self.model is None:
|
|
||||||
print("Loading sentence transformer")
|
|
||||||
self.model = SentenceTransformer("all-MiniLM-L6-v2")
|
|
||||||
|
|
||||||
return self.model
|
|
||||||
|
|
|
@ -6,7 +6,12 @@
|
||||||
|
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from llama_toolchain.core.datatypes import Api, InlineProviderSpec, ProviderSpec
|
from llama_toolchain.core.datatypes import * # noqa: F403
|
||||||
|
|
||||||
|
EMBEDDING_DEPS = [
|
||||||
|
"blobfile",
|
||||||
|
"sentence-transformers",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def available_memory_providers() -> List[ProviderSpec]:
|
def available_memory_providers() -> List[ProviderSpec]:
|
||||||
|
@ -14,12 +19,25 @@ def available_memory_providers() -> List[ProviderSpec]:
|
||||||
InlineProviderSpec(
|
InlineProviderSpec(
|
||||||
api=Api.memory,
|
api=Api.memory,
|
||||||
provider_id="meta-reference-faiss",
|
provider_id="meta-reference-faiss",
|
||||||
pip_packages=[
|
pip_packages=EMBEDDING_DEPS + ["faiss-cpu"],
|
||||||
"blobfile",
|
|
||||||
"faiss-cpu",
|
|
||||||
"sentence-transformers",
|
|
||||||
],
|
|
||||||
module="llama_toolchain.memory.meta_reference.faiss",
|
module="llama_toolchain.memory.meta_reference.faiss",
|
||||||
config_class="llama_toolchain.memory.meta_reference.faiss.FaissImplConfig",
|
config_class="llama_toolchain.memory.meta_reference.faiss.FaissImplConfig",
|
||||||
),
|
),
|
||||||
|
remote_provider_spec(
|
||||||
|
api=Api.memory,
|
||||||
|
adapter=AdapterSpec(
|
||||||
|
adapter_id="chromadb",
|
||||||
|
pip_packages=EMBEDDING_DEPS + ["chromadb-client"],
|
||||||
|
module="llama_toolchain.memory.adapters.chroma",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
remote_provider_spec(
|
||||||
|
api=Api.memory,
|
||||||
|
adapter=AdapterSpec(
|
||||||
|
adapter_id="pgvector",
|
||||||
|
pip_packages=EMBEDDING_DEPS + ["psycopg2-binary"],
|
||||||
|
module="llama_toolchain.memory.adapters.pgvector",
|
||||||
|
config_class="llama_toolchain.memory.adapters.pgvector.PGVectorConfig",
|
||||||
|
),
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue