llama_toolchain -> llama_stack

This commit is contained in:
Ashwin Bharambe 2024-09-16 17:21:08 -07:00
parent f372355409
commit 2cf731faea
175 changed files with 300 additions and 279 deletions

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,15 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.core.datatypes import RemoteProviderConfig
async def get_adapter_impl(config: RemoteProviderConfig, _deps):
from .chroma import ChromaMemoryAdapter
impl = ChromaMemoryAdapter(config.url)
await impl.initialize()
return impl

View file

@ -0,0 +1,165 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import uuid
from typing import List
from urllib.parse import urlparse
import chromadb
from numpy.typing import NDArray
from llama_stack.memory.api import * # noqa: F403
from llama_stack.memory.common.vector_store import BankWithIndex, EmbeddingIndex
class ChromaIndex(EmbeddingIndex):
def __init__(self, client: chromadb.AsyncHttpClient, collection):
self.client = client
self.collection = collection
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
assert len(chunks) == len(
embeddings
), f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
for i, chunk in enumerate(chunks):
print(f"Adding chunk #{i} tokens={chunk.token_count}")
await self.collection.add(
documents=[chunk.json() for chunk in chunks],
embeddings=embeddings,
ids=[f"{c.document_id}:chunk-{i}" for i, c in enumerate(chunks)],
)
async def query(self, embedding: NDArray, k: int) -> QueryDocumentsResponse:
results = await self.collection.query(
query_embeddings=[embedding.tolist()],
n_results=k,
include=["documents", "distances"],
)
distances = results["distances"][0]
documents = results["documents"][0]
chunks = []
scores = []
for dist, doc in zip(distances, documents):
try:
doc = json.loads(doc)
chunk = Chunk(**doc)
except Exception:
import traceback
traceback.print_exc()
print(f"Failed to parse document: {doc}")
continue
chunks.append(chunk)
scores.append(1.0 / float(dist))
return QueryDocumentsResponse(chunks=chunks, scores=scores)
class ChromaMemoryAdapter(Memory):
def __init__(self, url: str) -> None:
print(f"Initializing ChromaMemoryAdapter with url: {url}")
url = url.rstrip("/")
parsed = urlparse(url)
if parsed.path and parsed.path != "/":
raise ValueError("URL should not contain a path")
self.host = parsed.hostname
self.port = parsed.port
self.client = None
self.cache = {}
async def initialize(self) -> None:
try:
print(f"Connecting to Chroma server at: {self.host}:{self.port}")
self.client = await chromadb.AsyncHttpClient(host=self.host, port=self.port)
except Exception as e:
import traceback
traceback.print_exc()
raise RuntimeError("Could not connect to Chroma server") from e
async def shutdown(self) -> None:
pass
async def create_memory_bank(
self,
name: str,
config: MemoryBankConfig,
url: Optional[URL] = None,
) -> MemoryBank:
bank_id = str(uuid.uuid4())
bank = MemoryBank(
bank_id=bank_id,
name=name,
config=config,
url=url,
)
collection = await self.client.create_collection(
name=bank_id,
metadata={"bank": bank.json()},
)
bank_index = BankWithIndex(
bank=bank, index=ChromaIndex(self.client, collection)
)
self.cache[bank_id] = bank_index
return bank
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
bank_index = await self._get_and_cache_bank_index(bank_id)
if bank_index is None:
return None
return bank_index.bank
async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]:
if bank_id in self.cache:
return self.cache[bank_id]
collections = await self.client.list_collections()
for collection in collections:
if collection.name == bank_id:
print(collection.metadata)
bank = MemoryBank(**json.loads(collection.metadata["bank"]))
index = BankWithIndex(
bank=bank,
index=ChromaIndex(self.client, collection),
)
self.cache[bank_id] = index
return index
return None
async def insert_documents(
self,
bank_id: str,
documents: List[MemoryBankDocument],
ttl_seconds: Optional[int] = None,
) -> None:
index = await self._get_and_cache_bank_index(bank_id)
if not index:
raise ValueError(f"Bank {bank_id} not found")
await index.insert_documents(documents)
async def query_documents(
self,
bank_id: str,
query: InterleavedTextMedia,
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse:
index = await self._get_and_cache_bank_index(bank_id)
if not index:
raise ValueError(f"Bank {bank_id} not found")
return await index.query_documents(query, params)

View file

@ -0,0 +1,15 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import PGVectorConfig
async def get_adapter_impl(config: PGVectorConfig, _deps):
from .pgvector import PGVectorMemoryAdapter
impl = PGVectorMemoryAdapter(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,17 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field
@json_schema_type
class PGVectorConfig(BaseModel):
host: str = Field(default="localhost")
port: int = Field(default=5432)
db: str
user: str
password: str

View file

@ -0,0 +1,234 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import uuid
from typing import List, Tuple
import psycopg2
from numpy.typing import NDArray
from psycopg2 import sql
from psycopg2.extras import execute_values, Json
from pydantic import BaseModel
from llama_stack.memory.api import * # noqa: F403
from llama_stack.memory.common.vector_store import (
ALL_MINILM_L6_V2_DIMENSION,
BankWithIndex,
EmbeddingIndex,
)
from .config import PGVectorConfig
def check_extension_version(cur):
cur.execute("SELECT extversion FROM pg_extension WHERE extname = 'vector'")
result = cur.fetchone()
return result[0] if result else None
def upsert_models(cur, keys_models: List[Tuple[str, BaseModel]]):
query = sql.SQL(
"""
INSERT INTO metadata_store (key, data)
VALUES %s
ON CONFLICT (key) DO UPDATE
SET data = EXCLUDED.data
"""
)
values = [(key, Json(model.dict())) for key, model in keys_models]
execute_values(cur, query, values, template="(%s, %s)")
def load_models(cur, keys: List[str], cls):
query = "SELECT key, data FROM metadata_store"
if keys:
placeholders = ",".join(["%s"] * len(keys))
query += f" WHERE key IN ({placeholders})"
cur.execute(query, keys)
else:
cur.execute(query)
rows = cur.fetchall()
return [cls(**row["data"]) for row in rows]
class PGVectorIndex(EmbeddingIndex):
def __init__(self, bank: MemoryBank, dimension: int, cursor):
self.cursor = cursor
self.table_name = f"vector_store_{bank.name}"
self.cursor.execute(
f"""
CREATE TABLE IF NOT EXISTS {self.table_name} (
id TEXT PRIMARY KEY,
document JSONB,
embedding vector({dimension})
)
"""
)
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
assert len(chunks) == len(
embeddings
), f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
values = []
for i, chunk in enumerate(chunks):
print(f"Adding chunk #{i} tokens={chunk.token_count}")
values.append(
(
f"{chunk.document_id}:chunk-{i}",
Json(chunk.dict()),
embeddings[i].tolist(),
)
)
query = sql.SQL(
f"""
INSERT INTO {self.table_name} (id, document, embedding)
VALUES %s
ON CONFLICT (id) DO UPDATE SET embedding = EXCLUDED.embedding, document = EXCLUDED.document
"""
)
execute_values(self.cursor, query, values, template="(%s, %s, %s::vector)")
async def query(self, embedding: NDArray, k: int) -> QueryDocumentsResponse:
self.cursor.execute(
f"""
SELECT document, embedding <-> %s::vector AS distance
FROM {self.table_name}
ORDER BY distance
LIMIT %s
""",
(embedding.tolist(), k),
)
results = self.cursor.fetchall()
chunks = []
scores = []
for doc, dist in results:
chunks.append(Chunk(**doc))
scores.append(1.0 / float(dist))
return QueryDocumentsResponse(chunks=chunks, scores=scores)
class PGVectorMemoryAdapter(Memory):
def __init__(self, config: PGVectorConfig) -> None:
print(f"Initializing PGVectorMemoryAdapter -> {config.host}:{config.port}")
self.config = config
self.cursor = None
self.conn = None
self.cache = {}
async def initialize(self) -> None:
try:
self.conn = psycopg2.connect(
host=self.config.host,
port=self.config.port,
database=self.config.db,
user=self.config.user,
password=self.config.password,
)
self.cursor = self.conn.cursor()
version = check_extension_version(self.cursor)
if version:
print(f"Vector extension version: {version}")
else:
raise RuntimeError("Vector extension is not installed.")
self.cursor.execute(
"""
CREATE TABLE IF NOT EXISTS metadata_store (
key TEXT PRIMARY KEY,
data JSONB
)
"""
)
except Exception as e:
import traceback
traceback.print_exc()
raise RuntimeError("Could not connect to PGVector database server") from e
async def shutdown(self) -> None:
pass
async def create_memory_bank(
self,
name: str,
config: MemoryBankConfig,
url: Optional[URL] = None,
) -> MemoryBank:
bank_id = str(uuid.uuid4())
bank = MemoryBank(
bank_id=bank_id,
name=name,
config=config,
url=url,
)
upsert_models(
self.cursor,
[
(bank.bank_id, bank),
],
)
index = BankWithIndex(
bank=bank,
index=PGVectorIndex(bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor),
)
self.cache[bank_id] = index
return bank
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
bank_index = await self._get_and_cache_bank_index(bank_id)
if bank_index is None:
return None
return bank_index.bank
async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]:
if bank_id in self.cache:
return self.cache[bank_id]
banks = load_models(self.cursor, [bank_id], MemoryBank)
if not banks:
return None
bank = banks[0]
index = BankWithIndex(
bank=bank,
index=PGVectorIndex(bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor),
)
self.cache[bank_id] = index
return index
async def insert_documents(
self,
bank_id: str,
documents: List[MemoryBankDocument],
ttl_seconds: Optional[int] = None,
) -> None:
index = await self._get_and_cache_bank_index(bank_id)
if not index:
raise ValueError(f"Bank {bank_id} not found")
await index.insert_documents(documents)
async def query_documents(
self,
bank_id: str,
query: InterleavedTextMedia,
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse:
index = await self._get_and_cache_bank_index(bank_id)
if not index:
raise ValueError(f"Bank {bank_id} not found")
return await index.query_documents(query, params)

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .api import * # noqa: F401 F403

View file

@ -0,0 +1,156 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List, Optional, Protocol
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
from typing_extensions import Annotated
from llama_models.llama3.api.datatypes import * # noqa: F403
@json_schema_type
class MemoryBankDocument(BaseModel):
document_id: str
content: InterleavedTextMedia | URL
mime_type: str | None = None
metadata: Dict[str, Any] = Field(default_factory=dict)
@json_schema_type
class MemoryBankType(Enum):
vector = "vector"
keyvalue = "keyvalue"
keyword = "keyword"
graph = "graph"
class VectorMemoryBankConfig(BaseModel):
type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value
embedding_model: str
chunk_size_in_tokens: int
overlap_size_in_tokens: Optional[int] = None
class KeyValueMemoryBankConfig(BaseModel):
type: Literal[MemoryBankType.keyvalue.value] = MemoryBankType.keyvalue.value
class KeywordMemoryBankConfig(BaseModel):
type: Literal[MemoryBankType.keyword.value] = MemoryBankType.keyword.value
class GraphMemoryBankConfig(BaseModel):
type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value
MemoryBankConfig = Annotated[
Union[
VectorMemoryBankConfig,
KeyValueMemoryBankConfig,
KeywordMemoryBankConfig,
GraphMemoryBankConfig,
],
Field(discriminator="type"),
]
class Chunk(BaseModel):
content: InterleavedTextMedia
token_count: int
document_id: str
@json_schema_type
class QueryDocumentsResponse(BaseModel):
chunks: List[Chunk]
scores: List[float]
@json_schema_type
class QueryAPI(Protocol):
@webmethod(route="/query_documents")
def query_documents(
self,
query: InterleavedTextMedia,
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse: ...
@json_schema_type
class MemoryBank(BaseModel):
bank_id: str
name: str
config: MemoryBankConfig
# if there's a pre-existing (reachable-from-distribution) store which supports QueryAPI
url: Optional[URL] = None
class Memory(Protocol):
@webmethod(route="/memory_banks/create")
async def create_memory_bank(
self,
name: str,
config: MemoryBankConfig,
url: Optional[URL] = None,
) -> MemoryBank: ...
@webmethod(route="/memory_banks/list", method="GET")
async def list_memory_banks(self) -> List[MemoryBank]: ...
@webmethod(route="/memory_banks/get", method="GET")
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]: ...
@webmethod(route="/memory_banks/drop", method="DELETE")
async def drop_memory_bank(
self,
bank_id: str,
) -> str: ...
# this will just block now until documents are inserted, but it should
# probably return a Job instance which can be polled for completion
@webmethod(route="/memory_bank/insert")
async def insert_documents(
self,
bank_id: str,
documents: List[MemoryBankDocument],
ttl_seconds: Optional[int] = None,
) -> None: ...
@webmethod(route="/memory_bank/update")
async def update_documents(
self,
bank_id: str,
documents: List[MemoryBankDocument],
) -> None: ...
@webmethod(route="/memory_bank/query")
async def query_documents(
self,
bank_id: str,
query: InterleavedTextMedia,
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse: ...
@webmethod(route="/memory_bank/documents/get", method="GET")
async def get_documents(
self,
bank_id: str,
document_ids: List[str],
) -> List[MemoryBankDocument]: ...
@webmethod(route="/memory_bank/documents/delete", method="DELETE")
async def delete_documents(
self,
bank_id: str,
document_ids: List[str],
) -> None: ...

View file

@ -0,0 +1,196 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import json
import os
from pathlib import Path
from typing import Any, Dict, List, Optional
import fire
import httpx
from termcolor import cprint
from llama_stack.core.datatypes import RemoteProviderConfig
from .api import * # noqa: F403
from .common.file_utils import data_url_from_file
async def get_client_impl(config: RemoteProviderConfig, _deps: Any) -> Memory:
return MemoryClient(config.url)
class MemoryClient(Memory):
def __init__(self, base_url: str):
self.base_url = base_url
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
async with httpx.AsyncClient() as client:
r = await client.get(
f"{self.base_url}/memory_banks/get",
params={
"bank_id": bank_id,
},
headers={"Content-Type": "application/json"},
timeout=20,
)
r.raise_for_status()
d = r.json()
if not d:
return None
return MemoryBank(**d)
async def create_memory_bank(
self,
name: str,
config: MemoryBankConfig,
url: Optional[URL] = None,
) -> MemoryBank:
async with httpx.AsyncClient() as client:
r = await client.post(
f"{self.base_url}/memory_banks/create",
json={
"name": name,
"config": config.dict(),
"url": url,
},
headers={"Content-Type": "application/json"},
timeout=20,
)
r.raise_for_status()
d = r.json()
if not d:
return None
return MemoryBank(**d)
async def insert_documents(
self,
bank_id: str,
documents: List[MemoryBankDocument],
) -> None:
async with httpx.AsyncClient() as client:
r = await client.post(
f"{self.base_url}/memory_bank/insert",
json={
"bank_id": bank_id,
"documents": [d.dict() for d in documents],
},
headers={"Content-Type": "application/json"},
timeout=20,
)
r.raise_for_status()
async def query_documents(
self,
bank_id: str,
query: InterleavedTextMedia,
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse:
async with httpx.AsyncClient() as client:
r = await client.post(
f"{self.base_url}/memory_bank/query",
json={
"bank_id": bank_id,
"query": query,
"params": params,
},
headers={"Content-Type": "application/json"},
timeout=20,
)
r.raise_for_status()
return QueryDocumentsResponse(**r.json())
async def run_main(host: str, port: int, stream: bool):
client = MemoryClient(f"http://{host}:{port}")
# create a memory bank
bank = await client.create_memory_bank(
name="test_bank",
config=VectorMemoryBankConfig(
bank_id="test_bank",
embedding_model="dragon-roberta-query-2",
chunk_size_in_tokens=512,
overlap_size_in_tokens=64,
),
)
cprint(json.dumps(bank.dict(), indent=4), "green")
retrieved_bank = await client.get_memory_bank(bank.bank_id)
assert retrieved_bank is not None
assert retrieved_bank.config.embedding_model == "dragon-roberta-query-2"
urls = [
"memory_optimizations.rst",
"chat.rst",
"llama3.rst",
"datasets.rst",
"qat_finetune.rst",
"lora_finetune.rst",
]
documents = [
MemoryBankDocument(
document_id=f"num-{i}",
content=URL(
uri=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}"
),
mime_type="text/plain",
)
for i, url in enumerate(urls)
]
this_dir = os.path.dirname(__file__)
files = [Path(this_dir).parent.parent / "CONTRIBUTING.md"]
documents += [
MemoryBankDocument(
document_id=f"num-{i}",
content=data_url_from_file(path),
)
for i, path in enumerate(files)
]
# insert some documents
await client.insert_documents(
bank_id=bank.bank_id,
documents=documents,
)
# query the documents
response = await client.query_documents(
bank_id=bank.bank_id,
query=[
"How do I use Lora?",
],
)
for chunk, score in zip(response.chunks, response.scores):
print(f"Score: {score}")
print(f"Chunk:\n========\n{chunk}\n========\n")
response = await client.query_documents(
bank_id=bank.bank_id,
query=[
"Tell me more about llama3 and torchtune",
],
)
for chunk, score in zip(response.chunks, response.scores):
print(f"Score: {score}")
print(f"Chunk:\n========\n{chunk}\n========\n")
def main(host: str, port: int, stream: bool = True):
asyncio.run(run_main(host, port, stream))
if __name__ == "__main__":
fire.Fire(main)

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,26 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import base64
import mimetypes
import os
from llama_models.llama3.api.datatypes import URL
def data_url_from_file(file_path: str) -> URL:
if not os.path.exists(file_path):
raise FileNotFoundError(f"File not found: {file_path}")
with open(file_path, "rb") as file:
file_content = file.read()
base64_content = base64.b64encode(file_content).decode("utf-8")
mime_type, _ = mimetypes.guess_type(file_path)
data_url = f"data:{mime_type};base64,{base64_content}"
return URL(uri=data_url)

View file

@ -0,0 +1,180 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import base64
import io
import re
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
from urllib.parse import unquote
import chardet
import httpx
import numpy as np
from numpy.typing import NDArray
from pypdf import PdfReader
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.memory.api import * # noqa: F403
ALL_MINILM_L6_V2_DIMENSION = 384
EMBEDDING_MODEL = None
def get_embedding_model() -> "SentenceTransformer":
global EMBEDDING_MODEL
if EMBEDDING_MODEL is None:
print("Loading sentence transformer")
from sentence_transformers import SentenceTransformer
EMBEDDING_MODEL = SentenceTransformer("all-MiniLM-L6-v2")
return EMBEDDING_MODEL
def parse_data_url(data_url: str):
data_url_pattern = re.compile(
r"^"
r"data:"
r"(?P<mimetype>[\w/\-+.]+)"
r"(?P<charset>;charset=(?P<encoding>[\w-]+))?"
r"(?P<base64>;base64)?"
r",(?P<data>.*)"
r"$",
re.DOTALL,
)
match = data_url_pattern.match(data_url)
if not match:
raise ValueError("Invalid Data URL format")
parts = match.groupdict()
parts["is_base64"] = bool(parts["base64"])
return parts
def content_from_data(data_url: str) -> str:
parts = parse_data_url(data_url)
data = parts["data"]
if parts["is_base64"]:
data = base64.b64decode(data)
else:
data = unquote(data)
encoding = parts["encoding"] or "utf-8"
data = data.encode(encoding)
encoding = parts["encoding"]
if not encoding:
detected = chardet.detect(data)
encoding = detected["encoding"]
mime_type = parts["mimetype"]
mime_category = mime_type.split("/")[0]
if mime_category == "text":
# For text-based files (including CSV, MD)
return data.decode(encoding)
elif mime_type == "application/pdf":
# For PDF and DOC/DOCX files, we can't reliably convert to string)
pdf_bytes = io.BytesIO(data)
pdf_reader = PdfReader(pdf_bytes)
return "\n".join([page.extract_text() for page in pdf_reader.pages])
else:
cprint("Could not extract content from data_url properly.", color="red")
return ""
async def content_from_doc(doc: MemoryBankDocument) -> str:
if isinstance(doc.content, URL):
if doc.content.uri.startswith("data:"):
return content_from_data(doc.content.uri)
else:
async with httpx.AsyncClient() as client:
r = await client.get(doc.content.uri)
return r.text
return interleaved_text_media_as_str(doc.content)
def make_overlapped_chunks(
document_id: str, text: str, window_len: int, overlap_len: int
) -> List[Chunk]:
tokenizer = Tokenizer.get_instance()
tokens = tokenizer.encode(text, bos=False, eos=False)
chunks = []
for i in range(0, len(tokens), window_len - overlap_len):
toks = tokens[i : i + window_len]
chunk = tokenizer.decode(toks)
chunks.append(
Chunk(content=chunk, token_count=len(toks), document_id=document_id)
)
return chunks
class EmbeddingIndex(ABC):
@abstractmethod
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
raise NotImplementedError()
@abstractmethod
async def query(self, embedding: NDArray, k: int) -> QueryDocumentsResponse:
raise NotImplementedError()
@dataclass
class BankWithIndex:
bank: MemoryBank
index: EmbeddingIndex
async def insert_documents(
self,
documents: List[MemoryBankDocument],
) -> None:
model = get_embedding_model()
for doc in documents:
content = await content_from_doc(doc)
chunks = make_overlapped_chunks(
doc.document_id,
content,
self.bank.config.chunk_size_in_tokens,
self.bank.config.overlap_size_in_tokens
or (self.bank.config.chunk_size_in_tokens // 4),
)
embeddings = model.encode([x.content for x in chunks]).astype(np.float32)
await self.index.add_chunks(chunks, embeddings)
async def query_documents(
self,
query: InterleavedTextMedia,
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse:
if params is None:
params = {}
k = params.get("max_chunks", 3)
def _process(c) -> str:
if isinstance(c, str):
return c
else:
return "<media>"
if isinstance(query, list):
query_str = " ".join([_process(c) for c in query])
else:
query_str = _process(query)
model = get_embedding_model()
query_vector = model.encode([query_str])[0].astype(np.float32)
return await self.index.query(query_vector, k)

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,19 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import FaissImplConfig
async def get_provider_impl(config: FaissImplConfig, _deps):
from .faiss import FaissMemoryImpl
assert isinstance(
config, FaissImplConfig
), f"Unexpected config type: {type(config)}"
impl = FaissMemoryImpl(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,13 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel
@json_schema_type
class FaissImplConfig(BaseModel): ...

View file

@ -0,0 +1,124 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
import uuid
from typing import Any, Dict, List, Optional
import faiss
import numpy as np
from numpy.typing import NDArray
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.memory.api import * # noqa: F403
from llama_stack.memory.common.vector_store import (
ALL_MINILM_L6_V2_DIMENSION,
BankWithIndex,
EmbeddingIndex,
)
from llama_stack.telemetry import tracing
from .config import FaissImplConfig
logger = logging.getLogger(__name__)
class FaissIndex(EmbeddingIndex):
id_by_index: Dict[int, str]
chunk_by_index: Dict[int, str]
def __init__(self, dimension: int):
self.index = faiss.IndexFlatL2(dimension)
self.id_by_index = {}
self.chunk_by_index = {}
@tracing.span(name="add_chunks")
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
indexlen = len(self.id_by_index)
for i, chunk in enumerate(chunks):
self.chunk_by_index[indexlen + i] = chunk
logger.info(f"Adding chunk #{indexlen + i} tokens={chunk.token_count}")
self.id_by_index[indexlen + i] = chunk.document_id
self.index.add(np.array(embeddings).astype(np.float32))
async def query(self, embedding: NDArray, k: int) -> QueryDocumentsResponse:
distances, indices = self.index.search(
embedding.reshape(1, -1).astype(np.float32), k
)
chunks = []
scores = []
for d, i in zip(distances[0], indices[0]):
if i < 0:
continue
chunks.append(self.chunk_by_index[int(i)])
scores.append(1.0 / float(d))
return QueryDocumentsResponse(chunks=chunks, scores=scores)
class FaissMemoryImpl(Memory):
def __init__(self, config: FaissImplConfig) -> None:
self.config = config
self.cache = {}
async def initialize(self) -> None: ...
async def shutdown(self) -> None: ...
async def create_memory_bank(
self,
name: str,
config: MemoryBankConfig,
url: Optional[URL] = None,
) -> MemoryBank:
assert url is None, "URL is not supported for this implementation"
assert (
config.type == MemoryBankType.vector.value
), f"Only vector banks are supported {config.type}"
bank_id = str(uuid.uuid4())
bank = MemoryBank(
bank_id=bank_id,
name=name,
config=config,
url=url,
)
index = BankWithIndex(bank=bank, index=FaissIndex(ALL_MINILM_L6_V2_DIMENSION))
self.cache[bank_id] = index
return bank
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
index = self.cache.get(bank_id)
if index is None:
return None
return index.bank
async def insert_documents(
self,
bank_id: str,
documents: List[MemoryBankDocument],
ttl_seconds: Optional[int] = None,
) -> None:
index = self.cache.get(bank_id)
if index is None:
raise ValueError(f"Bank {bank_id} not found")
await index.insert_documents(documents)
async def query_documents(
self,
bank_id: str,
query: InterleavedTextMedia,
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse:
index = self.cache.get(bank_id)
if index is None:
raise ValueError(f"Bank {bank_id} not found")
return await index.query_documents(query, params)

View file

@ -0,0 +1,45 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List
from llama_stack.core.datatypes import * # noqa: F403
EMBEDDING_DEPS = [
"blobfile",
"chardet",
"pypdf",
"sentence-transformers",
]
def available_providers() -> List[ProviderSpec]:
return [
InlineProviderSpec(
api=Api.memory,
provider_id="meta-reference-faiss",
pip_packages=EMBEDDING_DEPS + ["faiss-cpu"],
module="llama_stack.memory.meta_reference.faiss",
config_class="llama_stack.memory.meta_reference.faiss.FaissImplConfig",
),
remote_provider_spec(
Api.memory,
AdapterSpec(
adapter_id="chromadb",
pip_packages=EMBEDDING_DEPS + ["chromadb-client"],
module="llama_stack.memory.adapters.chroma",
),
),
remote_provider_spec(
Api.memory,
AdapterSpec(
adapter_id="pgvector",
pip_packages=EMBEDDING_DEPS + ["psycopg2-binary"],
module="llama_stack.memory.adapters.pgvector",
config_class="llama_stack.memory.adapters.pgvector.PGVectorConfig",
),
),
]

View file

@ -0,0 +1,17 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, List, Tuple
from llama_stack.core.datatypes import Api
async def get_router_impl(inner_impls: List[Tuple[str, Any]], deps: List[Api]):
from .router import MemoryRouterImpl
impl = MemoryRouterImpl(inner_impls, deps)
await impl.initialize()
return impl

View file

@ -0,0 +1,91 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List, Tuple
from llama_stack.core.datatypes import Api
from llama_stack.memory.api import * # noqa: F403
class MemoryRouterImpl(Memory):
"""Routes to an provider based on the memory bank type"""
def __init__(
self,
inner_impls: List[Tuple[str, Any]],
deps: List[Api],
) -> None:
self.deps = deps
bank_types = [v.value for v in MemoryBankType]
self.providers = {}
for routing_key, provider_impl in inner_impls:
if routing_key not in bank_types:
raise ValueError(
f"Unknown routing key `{routing_key}` for memory bank type"
)
self.providers[routing_key] = provider_impl
self.bank_id_to_type = {}
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
for p in self.providers.values():
await p.shutdown()
def get_provider(self, bank_type):
if bank_type not in self.providers:
raise ValueError(f"Memory bank type {bank_type} not supported")
return self.providers[bank_type]
async def create_memory_bank(
self,
name: str,
config: MemoryBankConfig,
url: Optional[URL] = None,
) -> MemoryBank:
provider = self.get_provider(config.type)
bank = await provider.create_memory_bank(name, config, url)
self.bank_id_to_type[bank.bank_id] = config.type
return bank
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
bank_type = self.bank_id_to_type.get(bank_id)
if not bank_type:
raise ValueError(f"Could not find bank type for {bank_id}")
provider = self.get_provider(bank_type)
return await provider.get_memory_bank(bank_id)
async def insert_documents(
self,
bank_id: str,
documents: List[MemoryBankDocument],
ttl_seconds: Optional[int] = None,
) -> None:
bank_type = self.bank_id_to_type.get(bank_id)
if not bank_type:
raise ValueError(f"Could not find bank type for {bank_id}")
provider = self.get_provider(bank_type)
return await provider.insert_documents(bank_id, documents, ttl_seconds)
async def query_documents(
self,
bank_id: str,
query: InterleavedTextMedia,
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse:
bank_type = self.bank_id_to_type.get(bank_id)
if not bank_type:
raise ValueError(f"Could not find bank type for {bank_id}")
provider = self.get_provider(bank_type)
return await provider.query_documents(bank_id, query, params)