faiss provider implementation

This commit is contained in:
Ashwin Bharambe 2024-08-23 20:58:27 -07:00
parent 14637bea66
commit a08958c000
9 changed files with 401 additions and 3 deletions

View file

@ -623,7 +623,7 @@ class ChatAgent(ShieldRunnerMixin):
)
for a in attachments
]
await self.memory_api.insert_documents(bank_id, documents)
await self.memory_api.insert_documents(bank.bank_id, documents)
assert len(bank_ids) > 0, "No memory banks configured?"

View file

@ -32,6 +32,7 @@ def available_distribution_specs() -> List[DistributionSpec]:
description="Use code from `llama_toolchain` itself to serve all llama stack APIs",
provider_specs={
Api.inference: providers[Api.inference]["meta-reference"],
Api.memory: providers[Api.memory]["meta-reference-faiss"],
Api.safety: providers[Api.safety]["meta-reference"],
Api.agentic_system: providers[Api.agentic_system]["meta-reference"],
},
@ -50,6 +51,12 @@ def available_distribution_specs() -> List[DistributionSpec]:
Api.agentic_system: providers[Api.agentic_system]["meta-reference"],
},
),
DistributionSpec(
spec_id="test-memory",
provider_specs={
Api.memory: providers[Api.memory]["meta-reference-faiss"],
},
),
]

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List, Protocol
from typing import List, Optional, Protocol
from pydantic import BaseModel, Field
from typing_extensions import Annotated
@ -35,6 +35,7 @@ class VectorMemoryBankConfig(BaseModel):
type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value
embedding_model: str
chunk_size_in_tokens: int
overlap_size_in_tokens: Optional[int] = None
class KeyValueMemoryBankConfig(BaseModel):
@ -103,7 +104,7 @@ class Memory(Protocol):
async def list_memory_banks(self) -> List[MemoryBank]: ...
@webmethod(route="/memory_banks/get")
async def get_memory_bank(self, bank_id: str) -> MemoryBank: ...
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]: ...
@webmethod(route="/memory_banks/drop", method="DELETE")
async def drop_memory_bank(

View file

@ -0,0 +1,161 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
# import json
from typing import Dict, List, Optional
import fire
import httpx
# from termcolor import cprint
from .api import * # noqa: F403
async def get_client_impl(base_url: str):
return MemoryClient(base_url)
class MemoryClient(Memory):
def __init__(self, base_url: str):
print(f"Initializing client for {base_url}")
self.base_url = base_url
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
async with httpx.AsyncClient() as client:
async with client.get(
f"{self.base_url}/memory_banks/get",
params={
"bank_id": bank_id,
},
headers={"Content-Type": "application/json"},
timeout=20,
) as r:
r.raise_for_status()
d = r.json()
if len(d) == 0:
return None
return MemoryBank(**d)
async def create_memory_bank(
self,
name: str,
config: MemoryBankConfig,
url: Optional[URL] = None,
) -> MemoryBank:
async with httpx.AsyncClient() as client:
async with client.post(
f"{self.base_url}/memory_banks/create",
data={
"name": name,
"config": config.dict(),
"url": url,
},
headers={"Content-Type": "application/json"},
timeout=20,
) as r:
r.raise_for_status()
d = r.json()
if len(d) == 0:
return None
return MemoryBank(**d)
async def insert_documents(
self,
bank_id: str,
documents: List[MemoryBankDocument],
) -> None:
async with httpx.AsyncClient() as client:
async with client.post(
f"{self.base_url}/memory_bank/insert",
data={
"bank_id": bank_id,
"documents": documents,
},
headers={"Content-Type": "application/json"},
timeout=20,
) as r:
r.raise_for_status()
async def query_documents(
self,
bank_id: str,
query: InterleavedTextMedia,
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse:
async with httpx.AsyncClient() as client:
async with client.post(
f"{self.base_url}/memory_bank/query",
data={
"bank_id": bank_id,
"query": query,
"params": params,
},
headers={"Content-Type": "application/json"},
timeout=20,
) as r:
r.raise_for_status()
return QueryDocumentsResponse(**r.json())
async def run_main(host: str, port: int, stream: bool):
client = MemoryClient(f"http://{host}:{port}")
# create a memory bank
bank = await client.create_memory_bank(
name="test_bank",
config=VectorMemoryBankConfig(
bank_id="test_bank",
embedding_model="dragon-roberta-query-2",
chunk_size_in_tokens=512,
overlap_size_in_tokens=64,
),
)
print(bank)
retrieved_bank = await client.get_memory_bank(bank.bank_id)
assert retrieved_bank is not None
assert retrieved_bank.embedding_model == "dragon-roberta-query-2"
# insert some documents
await client.insert_documents(
bank_id=bank.bank_id,
documents=[
MemoryBankDocument(
document_id="1",
content="hello world",
),
MemoryBankDocument(
document_id="2",
content="goodbye world",
),
],
)
# query the documents
response = await client.query_documents(
bank_id=bank.bank_id,
query=[
"hello world",
],
)
print(response)
def main(host: str, port: int, stream: bool = True):
asyncio.run(run_main(host, port, stream))
if __name__ == "__main__":
fire.Fire(main)

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,8 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import FaissImplConfig # noqa
from .memory import get_provider_impl # noqa

View file

@ -0,0 +1,13 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel
@json_schema_type
class FaissImplConfig(BaseModel): ...

View file

@ -0,0 +1,179 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List, Optional, Tuple
import faiss
import httpx
import numpy as np
from sentence_transformers import SentenceTransformer
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_toolchain.distribution.datatypes import Api, ProviderSpec
from llama_toolchain.memory.api import * # noqa: F403
from .config import FaissImplConfig
async def get_provider_impl(config: FaissImplConfig, _deps: Dict[Api, ProviderSpec]):
assert isinstance(
config, FaissImplConfig
), f"Unexpected config type: {type(config)}"
impl = FaissMemoryImpl(config)
await impl.initialize()
return impl
async def content_from_doc(doc: MemoryBankDocument) -> str:
if isinstance(doc.content, URL):
async with httpx.AsyncClient() as client:
return await client.get(doc.content).text
def _process(c):
if isinstance(c, str):
return c
else:
return "<media>"
if isinstance(doc.content, list):
return " ".join([_process(c) for c in doc.content])
else:
return _process(doc.content)
def make_overlapped_chunks(
text: str, window_len: int, overlap_len: int
) -> List[Tuple[str, int]]:
tokenizer = Tokenizer.get_instance()
tokens = tokenizer.encode(text, bos=False, eos=False)
chunks = []
for i in range(0, len(tokens), window_len - overlap_len):
toks = tokens[i : i + window_len]
chunk = tokenizer.decode(toks)
chunks.append((chunk, len(toks)))
return chunks
class BankState(BaseModel):
bank: MemoryBank
index: Optional[faiss.IndexFlatL2] = None
doc_by_id: Dict[str, MemoryBankDocument] = Field(default_factory=dict)
id_by_index: Dict[int, str] = Field(default_factory=dict)
chunk_by_index: Dict[int, str] = Field(default_factory=dict)
async def insert_documents(
self,
model: SentenceTransformer,
documents: List[MemoryBankDocument],
) -> None:
tokenizer = Tokenizer.get_instance()
chunk_size = self.bank.config.chunk_size_in_tokens
for doc in documents:
indexlen = len(self.id_by_index)
self.doc_by_id[doc.document_id] = doc
content = await content_from_doc(doc)
chunks = make_overlapped_chunks(
content,
self.bank.config.chunk_size_in_tokens,
self.bank.config.overlap_size_in_tokens
or (self.bank.config.chunk_size_in_tokens // 4),
)
embeddings = model.encode([x[0] for x in chunks]).astype(np.float32)
await self._ensure_index(embeddings.shape[1])
self.index.add(embeddings)
for i, chunk in enumerate(chunks):
self.chunk_by_index[indexlen + i] = Chunk(
content=chunk[0],
token_count=chunk[1],
)
self.id_by_index[indexlen + i] = doc.document_id
async def query_documents(
self, model: SentenceTransformer, query: str, params: Dict[str, Any]
) -> Tuple[List[Chunk], List[float]]:
k = params.get("max_chunks", 3)
query_vector = model.encode([query])[0]
distances, indices = self.index.search(
query_vector.reshape(1, -1).astype(np.float32), k
)
chunks = [self.chunk_by_index[int(i)] for i in indices[0]]
scores = [1.0 / float(d) for d in distances[0]]
return chunks, scores
async def _ensure_index(self, dimension: int) -> faiss.IndexFlatL2:
if self.index is None:
self.index = faiss.IndexFlatL2(dimension)
return self.index
class FaissMemoryImpl(Memory):
def __init__(self, config: FaissImplConfig) -> None:
self.config = config
self.model = SentenceTransformer("all-MiniLM-L6-v2")
self.states = {}
async def initialize(self) -> None: ...
async def shutdown(self) -> None: ...
async def create_memory_bank(
self,
name: str,
config: MemoryBankConfig,
url: Optional[URL] = None,
) -> MemoryBank:
assert url is None, "URL is not supported for this implementation"
assert (
config.type == MemoryBankType.vector.value
), f"Only vector banks are supported {config.type}"
id = str(uuid.uuid4())
bank = MemoryBank(
bank_id=id,
name=name,
config=config,
url=url,
)
state = BankState(bank=bank)
self.states[id] = state
return bank
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
if bank_id not in self.states:
return None
return self.states[bank_id].bank
async def insert_documents(
self,
bank_id: str,
documents: List[MemoryBankDocument],
) -> None:
assert bank_id in self.states, f"Bank {bank_id} not found"
state = self.states[bank_id]
await state.insert_documents(self.model, documents)
async def query_documents(
self,
bank_id: str,
query: InterleavedTextMedia,
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse:
assert bank_id in self.states, f"Bank {bank_id} not found"
state = self.states[bank_id]
chunks, scores = await state.query_documents(self.model, query, params)
return QueryDocumentsResponse(chunk=chunks, scores=scores)

View file

@ -0,0 +1,24 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List
from llama_toolchain.distribution.datatypes import Api, InlineProviderSpec, ProviderSpec
def available_inference_providers() -> List[ProviderSpec]:
return [
InlineProviderSpec(
api=Api.memory,
provider_id="meta-reference-faiss",
pip_packages=[
"faiss",
"sentence-transformers",
],
module="llama_toolchain.memory.meta_reference.faiss",
config_class="llama_toolchain.memory.meta_reference.faiss.FaissImplConfig",
),
]