From a08958c0001355f8756b3d2d9adcdd4f2a5c0baa Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Fri, 23 Aug 2024 20:58:27 -0700 Subject: [PATCH] faiss provider implementation --- .../meta_reference/agent_instance.py | 2 +- llama_toolchain/distribution/registry.py | 7 + llama_toolchain/memory/api/endpoints.py | 5 +- llama_toolchain/memory/client.py | 161 ++++++++++++++++ .../memory/meta_reference/__init__.py | 5 + .../memory/meta_reference/faiss/__init__.py | 8 + .../memory/meta_reference/faiss/config.py | 13 ++ .../memory/meta_reference/faiss/memory.py | 179 ++++++++++++++++++ llama_toolchain/memory/providers.py | 24 +++ 9 files changed, 401 insertions(+), 3 deletions(-) create mode 100644 llama_toolchain/memory/client.py create mode 100644 llama_toolchain/memory/meta_reference/__init__.py create mode 100644 llama_toolchain/memory/meta_reference/faiss/__init__.py create mode 100644 llama_toolchain/memory/meta_reference/faiss/config.py create mode 100644 llama_toolchain/memory/meta_reference/faiss/memory.py create mode 100644 llama_toolchain/memory/providers.py diff --git a/llama_toolchain/agentic_system/meta_reference/agent_instance.py b/llama_toolchain/agentic_system/meta_reference/agent_instance.py index 0cb5c3a0e..37d05e8a2 100644 --- a/llama_toolchain/agentic_system/meta_reference/agent_instance.py +++ b/llama_toolchain/agentic_system/meta_reference/agent_instance.py @@ -623,7 +623,7 @@ class ChatAgent(ShieldRunnerMixin): ) for a in attachments ] - await self.memory_api.insert_documents(bank_id, documents) + await self.memory_api.insert_documents(bank.bank_id, documents) assert len(bank_ids) > 0, "No memory banks configured?" diff --git a/llama_toolchain/distribution/registry.py b/llama_toolchain/distribution/registry.py index b208abf9c..acba4e874 100644 --- a/llama_toolchain/distribution/registry.py +++ b/llama_toolchain/distribution/registry.py @@ -32,6 +32,7 @@ def available_distribution_specs() -> List[DistributionSpec]: description="Use code from `llama_toolchain` itself to serve all llama stack APIs", provider_specs={ Api.inference: providers[Api.inference]["meta-reference"], + Api.memory: providers[Api.memory]["meta-reference-faiss"], Api.safety: providers[Api.safety]["meta-reference"], Api.agentic_system: providers[Api.agentic_system]["meta-reference"], }, @@ -50,6 +51,12 @@ def available_distribution_specs() -> List[DistributionSpec]: Api.agentic_system: providers[Api.agentic_system]["meta-reference"], }, ), + DistributionSpec( + spec_id="test-memory", + provider_specs={ + Api.memory: providers[Api.memory]["meta-reference-faiss"], + }, + ), ] diff --git a/llama_toolchain/memory/api/endpoints.py b/llama_toolchain/memory/api/endpoints.py index 615014b55..d4f1d5e20 100644 --- a/llama_toolchain/memory/api/endpoints.py +++ b/llama_toolchain/memory/api/endpoints.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import List, Protocol +from typing import List, Optional, Protocol from pydantic import BaseModel, Field from typing_extensions import Annotated @@ -35,6 +35,7 @@ class VectorMemoryBankConfig(BaseModel): type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value embedding_model: str chunk_size_in_tokens: int + overlap_size_in_tokens: Optional[int] = None class KeyValueMemoryBankConfig(BaseModel): @@ -103,7 +104,7 @@ class Memory(Protocol): async def list_memory_banks(self) -> List[MemoryBank]: ... @webmethod(route="/memory_banks/get") - async def get_memory_bank(self, bank_id: str) -> MemoryBank: ... + async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]: ... @webmethod(route="/memory_banks/drop", method="DELETE") async def drop_memory_bank( diff --git a/llama_toolchain/memory/client.py b/llama_toolchain/memory/client.py new file mode 100644 index 000000000..128d7cdd7 --- /dev/null +++ b/llama_toolchain/memory/client.py @@ -0,0 +1,161 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import asyncio + +# import json +from typing import Dict, List, Optional + +import fire +import httpx + +# from termcolor import cprint + +from .api import * # noqa: F403 + + +async def get_client_impl(base_url: str): + return MemoryClient(base_url) + + +class MemoryClient(Memory): + def __init__(self, base_url: str): + print(f"Initializing client for {base_url}") + self.base_url = base_url + + async def initialize(self) -> None: + pass + + async def shutdown(self) -> None: + pass + + async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]: + async with httpx.AsyncClient() as client: + async with client.get( + f"{self.base_url}/memory_banks/get", + params={ + "bank_id": bank_id, + }, + headers={"Content-Type": "application/json"}, + timeout=20, + ) as r: + r.raise_for_status() + d = r.json() + if len(d) == 0: + return None + return MemoryBank(**d) + + async def create_memory_bank( + self, + name: str, + config: MemoryBankConfig, + url: Optional[URL] = None, + ) -> MemoryBank: + async with httpx.AsyncClient() as client: + async with client.post( + f"{self.base_url}/memory_banks/create", + data={ + "name": name, + "config": config.dict(), + "url": url, + }, + headers={"Content-Type": "application/json"}, + timeout=20, + ) as r: + r.raise_for_status() + d = r.json() + if len(d) == 0: + return None + return MemoryBank(**d) + + async def insert_documents( + self, + bank_id: str, + documents: List[MemoryBankDocument], + ) -> None: + async with httpx.AsyncClient() as client: + async with client.post( + f"{self.base_url}/memory_bank/insert", + data={ + "bank_id": bank_id, + "documents": documents, + }, + headers={"Content-Type": "application/json"}, + timeout=20, + ) as r: + r.raise_for_status() + + async def query_documents( + self, + bank_id: str, + query: InterleavedTextMedia, + params: Optional[Dict[str, Any]] = None, + ) -> QueryDocumentsResponse: + async with httpx.AsyncClient() as client: + async with client.post( + f"{self.base_url}/memory_bank/query", + data={ + "bank_id": bank_id, + "query": query, + "params": params, + }, + headers={"Content-Type": "application/json"}, + timeout=20, + ) as r: + r.raise_for_status() + return QueryDocumentsResponse(**r.json()) + + +async def run_main(host: str, port: int, stream: bool): + client = MemoryClient(f"http://{host}:{port}") + + # create a memory bank + bank = await client.create_memory_bank( + name="test_bank", + config=VectorMemoryBankConfig( + bank_id="test_bank", + embedding_model="dragon-roberta-query-2", + chunk_size_in_tokens=512, + overlap_size_in_tokens=64, + ), + ) + print(bank) + + retrieved_bank = await client.get_memory_bank(bank.bank_id) + assert retrieved_bank is not None + assert retrieved_bank.embedding_model == "dragon-roberta-query-2" + + # insert some documents + await client.insert_documents( + bank_id=bank.bank_id, + documents=[ + MemoryBankDocument( + document_id="1", + content="hello world", + ), + MemoryBankDocument( + document_id="2", + content="goodbye world", + ), + ], + ) + + # query the documents + response = await client.query_documents( + bank_id=bank.bank_id, + query=[ + "hello world", + ], + ) + print(response) + + +def main(host: str, port: int, stream: bool = True): + asyncio.run(run_main(host, port, stream)) + + +if __name__ == "__main__": + fire.Fire(main) diff --git a/llama_toolchain/memory/meta_reference/__init__.py b/llama_toolchain/memory/meta_reference/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/llama_toolchain/memory/meta_reference/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/llama_toolchain/memory/meta_reference/faiss/__init__.py b/llama_toolchain/memory/meta_reference/faiss/__init__.py new file mode 100644 index 000000000..dda96f370 --- /dev/null +++ b/llama_toolchain/memory/meta_reference/faiss/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .config import FaissImplConfig # noqa +from .memory import get_provider_impl # noqa diff --git a/llama_toolchain/memory/meta_reference/faiss/config.py b/llama_toolchain/memory/meta_reference/faiss/config.py new file mode 100644 index 000000000..b1c94c889 --- /dev/null +++ b/llama_toolchain/memory/meta_reference/faiss/config.py @@ -0,0 +1,13 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_models.schema_utils import json_schema_type + +from pydantic import BaseModel + + +@json_schema_type +class FaissImplConfig(BaseModel): ... diff --git a/llama_toolchain/memory/meta_reference/faiss/memory.py b/llama_toolchain/memory/meta_reference/faiss/memory.py new file mode 100644 index 000000000..2322b8519 --- /dev/null +++ b/llama_toolchain/memory/meta_reference/faiss/memory.py @@ -0,0 +1,179 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any, Dict, List, Optional, Tuple + +import faiss +import httpx +import numpy as np +from sentence_transformers import SentenceTransformer + + +from llama_models.llama3.api.datatypes import * # noqa: F403 +from llama_models.llama3.api.tokenizer import Tokenizer + +from llama_toolchain.distribution.datatypes import Api, ProviderSpec +from llama_toolchain.memory.api import * # noqa: F403 +from .config import FaissImplConfig + + +async def get_provider_impl(config: FaissImplConfig, _deps: Dict[Api, ProviderSpec]): + assert isinstance( + config, FaissImplConfig + ), f"Unexpected config type: {type(config)}" + + impl = FaissMemoryImpl(config) + await impl.initialize() + return impl + + +async def content_from_doc(doc: MemoryBankDocument) -> str: + if isinstance(doc.content, URL): + async with httpx.AsyncClient() as client: + return await client.get(doc.content).text + + def _process(c): + if isinstance(c, str): + return c + else: + return "" + + if isinstance(doc.content, list): + return " ".join([_process(c) for c in doc.content]) + else: + return _process(doc.content) + + +def make_overlapped_chunks( + text: str, window_len: int, overlap_len: int +) -> List[Tuple[str, int]]: + tokenizer = Tokenizer.get_instance() + tokens = tokenizer.encode(text, bos=False, eos=False) + + chunks = [] + for i in range(0, len(tokens), window_len - overlap_len): + toks = tokens[i : i + window_len] + chunk = tokenizer.decode(toks) + chunks.append((chunk, len(toks))) + + return chunks + + +class BankState(BaseModel): + bank: MemoryBank + index: Optional[faiss.IndexFlatL2] = None + doc_by_id: Dict[str, MemoryBankDocument] = Field(default_factory=dict) + id_by_index: Dict[int, str] = Field(default_factory=dict) + chunk_by_index: Dict[int, str] = Field(default_factory=dict) + + async def insert_documents( + self, + model: SentenceTransformer, + documents: List[MemoryBankDocument], + ) -> None: + tokenizer = Tokenizer.get_instance() + chunk_size = self.bank.config.chunk_size_in_tokens + + for doc in documents: + indexlen = len(self.id_by_index) + self.doc_by_id[doc.document_id] = doc + + content = await content_from_doc(doc) + chunks = make_overlapped_chunks( + content, + self.bank.config.chunk_size_in_tokens, + self.bank.config.overlap_size_in_tokens + or (self.bank.config.chunk_size_in_tokens // 4), + ) + embeddings = model.encode([x[0] for x in chunks]).astype(np.float32) + await self._ensure_index(embeddings.shape[1]) + + self.index.add(embeddings) + for i, chunk in enumerate(chunks): + self.chunk_by_index[indexlen + i] = Chunk( + content=chunk[0], + token_count=chunk[1], + ) + self.id_by_index[indexlen + i] = doc.document_id + + async def query_documents( + self, model: SentenceTransformer, query: str, params: Dict[str, Any] + ) -> Tuple[List[Chunk], List[float]]: + k = params.get("max_chunks", 3) + query_vector = model.encode([query])[0] + distances, indices = self.index.search( + query_vector.reshape(1, -1).astype(np.float32), k + ) + + chunks = [self.chunk_by_index[int(i)] for i in indices[0]] + scores = [1.0 / float(d) for d in distances[0]] + + return chunks, scores + + async def _ensure_index(self, dimension: int) -> faiss.IndexFlatL2: + if self.index is None: + self.index = faiss.IndexFlatL2(dimension) + return self.index + + +class FaissMemoryImpl(Memory): + def __init__(self, config: FaissImplConfig) -> None: + self.config = config + self.model = SentenceTransformer("all-MiniLM-L6-v2") + self.states = {} + + async def initialize(self) -> None: ... + + async def shutdown(self) -> None: ... + + async def create_memory_bank( + self, + name: str, + config: MemoryBankConfig, + url: Optional[URL] = None, + ) -> MemoryBank: + assert url is None, "URL is not supported for this implementation" + assert ( + config.type == MemoryBankType.vector.value + ), f"Only vector banks are supported {config.type}" + + id = str(uuid.uuid4()) + bank = MemoryBank( + bank_id=id, + name=name, + config=config, + url=url, + ) + state = BankState(bank=bank) + self.states[id] = state + return bank + + async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]: + if bank_id not in self.states: + return None + return self.states[bank_id].bank + + async def insert_documents( + self, + bank_id: str, + documents: List[MemoryBankDocument], + ) -> None: + assert bank_id in self.states, f"Bank {bank_id} not found" + state = self.states[bank_id] + + await state.insert_documents(self.model, documents) + + async def query_documents( + self, + bank_id: str, + query: InterleavedTextMedia, + params: Optional[Dict[str, Any]] = None, + ) -> QueryDocumentsResponse: + assert bank_id in self.states, f"Bank {bank_id} not found" + state = self.states[bank_id] + + chunks, scores = await state.query_documents(self.model, query, params) + return QueryDocumentsResponse(chunk=chunks, scores=scores) diff --git a/llama_toolchain/memory/providers.py b/llama_toolchain/memory/providers.py new file mode 100644 index 000000000..bfa098d36 --- /dev/null +++ b/llama_toolchain/memory/providers.py @@ -0,0 +1,24 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import List + +from llama_toolchain.distribution.datatypes import Api, InlineProviderSpec, ProviderSpec + + +def available_inference_providers() -> List[ProviderSpec]: + return [ + InlineProviderSpec( + api=Api.memory, + provider_id="meta-reference-faiss", + pip_packages=[ + "faiss", + "sentence-transformers", + ], + module="llama_toolchain.memory.meta_reference.faiss", + config_class="llama_toolchain.memory.meta_reference.faiss.FaissImplConfig", + ), + ]