mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
faiss provider implementation
This commit is contained in:
parent
14637bea66
commit
a08958c000
9 changed files with 401 additions and 3 deletions
|
@ -623,7 +623,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
)
|
)
|
||||||
for a in attachments
|
for a in attachments
|
||||||
]
|
]
|
||||||
await self.memory_api.insert_documents(bank_id, documents)
|
await self.memory_api.insert_documents(bank.bank_id, documents)
|
||||||
|
|
||||||
assert len(bank_ids) > 0, "No memory banks configured?"
|
assert len(bank_ids) > 0, "No memory banks configured?"
|
||||||
|
|
||||||
|
|
|
@ -32,6 +32,7 @@ def available_distribution_specs() -> List[DistributionSpec]:
|
||||||
description="Use code from `llama_toolchain` itself to serve all llama stack APIs",
|
description="Use code from `llama_toolchain` itself to serve all llama stack APIs",
|
||||||
provider_specs={
|
provider_specs={
|
||||||
Api.inference: providers[Api.inference]["meta-reference"],
|
Api.inference: providers[Api.inference]["meta-reference"],
|
||||||
|
Api.memory: providers[Api.memory]["meta-reference-faiss"],
|
||||||
Api.safety: providers[Api.safety]["meta-reference"],
|
Api.safety: providers[Api.safety]["meta-reference"],
|
||||||
Api.agentic_system: providers[Api.agentic_system]["meta-reference"],
|
Api.agentic_system: providers[Api.agentic_system]["meta-reference"],
|
||||||
},
|
},
|
||||||
|
@ -50,6 +51,12 @@ def available_distribution_specs() -> List[DistributionSpec]:
|
||||||
Api.agentic_system: providers[Api.agentic_system]["meta-reference"],
|
Api.agentic_system: providers[Api.agentic_system]["meta-reference"],
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
|
DistributionSpec(
|
||||||
|
spec_id="test-memory",
|
||||||
|
provider_specs={
|
||||||
|
Api.memory: providers[Api.memory]["meta-reference-faiss"],
|
||||||
|
},
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import List, Protocol
|
from typing import List, Optional, Protocol
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
@ -35,6 +35,7 @@ class VectorMemoryBankConfig(BaseModel):
|
||||||
type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value
|
type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value
|
||||||
embedding_model: str
|
embedding_model: str
|
||||||
chunk_size_in_tokens: int
|
chunk_size_in_tokens: int
|
||||||
|
overlap_size_in_tokens: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
class KeyValueMemoryBankConfig(BaseModel):
|
class KeyValueMemoryBankConfig(BaseModel):
|
||||||
|
@ -103,7 +104,7 @@ class Memory(Protocol):
|
||||||
async def list_memory_banks(self) -> List[MemoryBank]: ...
|
async def list_memory_banks(self) -> List[MemoryBank]: ...
|
||||||
|
|
||||||
@webmethod(route="/memory_banks/get")
|
@webmethod(route="/memory_banks/get")
|
||||||
async def get_memory_bank(self, bank_id: str) -> MemoryBank: ...
|
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]: ...
|
||||||
|
|
||||||
@webmethod(route="/memory_banks/drop", method="DELETE")
|
@webmethod(route="/memory_banks/drop", method="DELETE")
|
||||||
async def drop_memory_bank(
|
async def drop_memory_bank(
|
||||||
|
|
161
llama_toolchain/memory/client.py
Normal file
161
llama_toolchain/memory/client.py
Normal file
|
@ -0,0 +1,161 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
# import json
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
import fire
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
# from termcolor import cprint
|
||||||
|
|
||||||
|
from .api import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
|
async def get_client_impl(base_url: str):
|
||||||
|
return MemoryClient(base_url)
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryClient(Memory):
|
||||||
|
def __init__(self, base_url: str):
|
||||||
|
print(f"Initializing client for {base_url}")
|
||||||
|
self.base_url = base_url
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
async with client.get(
|
||||||
|
f"{self.base_url}/memory_banks/get",
|
||||||
|
params={
|
||||||
|
"bank_id": bank_id,
|
||||||
|
},
|
||||||
|
headers={"Content-Type": "application/json"},
|
||||||
|
timeout=20,
|
||||||
|
) as r:
|
||||||
|
r.raise_for_status()
|
||||||
|
d = r.json()
|
||||||
|
if len(d) == 0:
|
||||||
|
return None
|
||||||
|
return MemoryBank(**d)
|
||||||
|
|
||||||
|
async def create_memory_bank(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
config: MemoryBankConfig,
|
||||||
|
url: Optional[URL] = None,
|
||||||
|
) -> MemoryBank:
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
async with client.post(
|
||||||
|
f"{self.base_url}/memory_banks/create",
|
||||||
|
data={
|
||||||
|
"name": name,
|
||||||
|
"config": config.dict(),
|
||||||
|
"url": url,
|
||||||
|
},
|
||||||
|
headers={"Content-Type": "application/json"},
|
||||||
|
timeout=20,
|
||||||
|
) as r:
|
||||||
|
r.raise_for_status()
|
||||||
|
d = r.json()
|
||||||
|
if len(d) == 0:
|
||||||
|
return None
|
||||||
|
return MemoryBank(**d)
|
||||||
|
|
||||||
|
async def insert_documents(
|
||||||
|
self,
|
||||||
|
bank_id: str,
|
||||||
|
documents: List[MemoryBankDocument],
|
||||||
|
) -> None:
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
async with client.post(
|
||||||
|
f"{self.base_url}/memory_bank/insert",
|
||||||
|
data={
|
||||||
|
"bank_id": bank_id,
|
||||||
|
"documents": documents,
|
||||||
|
},
|
||||||
|
headers={"Content-Type": "application/json"},
|
||||||
|
timeout=20,
|
||||||
|
) as r:
|
||||||
|
r.raise_for_status()
|
||||||
|
|
||||||
|
async def query_documents(
|
||||||
|
self,
|
||||||
|
bank_id: str,
|
||||||
|
query: InterleavedTextMedia,
|
||||||
|
params: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> QueryDocumentsResponse:
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
async with client.post(
|
||||||
|
f"{self.base_url}/memory_bank/query",
|
||||||
|
data={
|
||||||
|
"bank_id": bank_id,
|
||||||
|
"query": query,
|
||||||
|
"params": params,
|
||||||
|
},
|
||||||
|
headers={"Content-Type": "application/json"},
|
||||||
|
timeout=20,
|
||||||
|
) as r:
|
||||||
|
r.raise_for_status()
|
||||||
|
return QueryDocumentsResponse(**r.json())
|
||||||
|
|
||||||
|
|
||||||
|
async def run_main(host: str, port: int, stream: bool):
|
||||||
|
client = MemoryClient(f"http://{host}:{port}")
|
||||||
|
|
||||||
|
# create a memory bank
|
||||||
|
bank = await client.create_memory_bank(
|
||||||
|
name="test_bank",
|
||||||
|
config=VectorMemoryBankConfig(
|
||||||
|
bank_id="test_bank",
|
||||||
|
embedding_model="dragon-roberta-query-2",
|
||||||
|
chunk_size_in_tokens=512,
|
||||||
|
overlap_size_in_tokens=64,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
print(bank)
|
||||||
|
|
||||||
|
retrieved_bank = await client.get_memory_bank(bank.bank_id)
|
||||||
|
assert retrieved_bank is not None
|
||||||
|
assert retrieved_bank.embedding_model == "dragon-roberta-query-2"
|
||||||
|
|
||||||
|
# insert some documents
|
||||||
|
await client.insert_documents(
|
||||||
|
bank_id=bank.bank_id,
|
||||||
|
documents=[
|
||||||
|
MemoryBankDocument(
|
||||||
|
document_id="1",
|
||||||
|
content="hello world",
|
||||||
|
),
|
||||||
|
MemoryBankDocument(
|
||||||
|
document_id="2",
|
||||||
|
content="goodbye world",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# query the documents
|
||||||
|
response = await client.query_documents(
|
||||||
|
bank_id=bank.bank_id,
|
||||||
|
query=[
|
||||||
|
"hello world",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
print(response)
|
||||||
|
|
||||||
|
|
||||||
|
def main(host: str, port: int, stream: bool = True):
|
||||||
|
asyncio.run(run_main(host, port, stream))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
fire.Fire(main)
|
5
llama_toolchain/memory/meta_reference/__init__.py
Normal file
5
llama_toolchain/memory/meta_reference/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
8
llama_toolchain/memory/meta_reference/faiss/__init__.py
Normal file
8
llama_toolchain/memory/meta_reference/faiss/__init__.py
Normal file
|
@ -0,0 +1,8 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from .config import FaissImplConfig # noqa
|
||||||
|
from .memory import get_provider_impl # noqa
|
13
llama_toolchain/memory/meta_reference/faiss/config.py
Normal file
13
llama_toolchain/memory/meta_reference/faiss/config.py
Normal file
|
@ -0,0 +1,13 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from llama_models.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class FaissImplConfig(BaseModel): ...
|
179
llama_toolchain/memory/meta_reference/faiss/memory.py
Normal file
179
llama_toolchain/memory/meta_reference/faiss/memory.py
Normal file
|
@ -0,0 +1,179 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import faiss
|
||||||
|
import httpx
|
||||||
|
import numpy as np
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
|
||||||
|
|
||||||
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
|
|
||||||
|
from llama_toolchain.distribution.datatypes import Api, ProviderSpec
|
||||||
|
from llama_toolchain.memory.api import * # noqa: F403
|
||||||
|
from .config import FaissImplConfig
|
||||||
|
|
||||||
|
|
||||||
|
async def get_provider_impl(config: FaissImplConfig, _deps: Dict[Api, ProviderSpec]):
|
||||||
|
assert isinstance(
|
||||||
|
config, FaissImplConfig
|
||||||
|
), f"Unexpected config type: {type(config)}"
|
||||||
|
|
||||||
|
impl = FaissMemoryImpl(config)
|
||||||
|
await impl.initialize()
|
||||||
|
return impl
|
||||||
|
|
||||||
|
|
||||||
|
async def content_from_doc(doc: MemoryBankDocument) -> str:
|
||||||
|
if isinstance(doc.content, URL):
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
return await client.get(doc.content).text
|
||||||
|
|
||||||
|
def _process(c):
|
||||||
|
if isinstance(c, str):
|
||||||
|
return c
|
||||||
|
else:
|
||||||
|
return "<media>"
|
||||||
|
|
||||||
|
if isinstance(doc.content, list):
|
||||||
|
return " ".join([_process(c) for c in doc.content])
|
||||||
|
else:
|
||||||
|
return _process(doc.content)
|
||||||
|
|
||||||
|
|
||||||
|
def make_overlapped_chunks(
|
||||||
|
text: str, window_len: int, overlap_len: int
|
||||||
|
) -> List[Tuple[str, int]]:
|
||||||
|
tokenizer = Tokenizer.get_instance()
|
||||||
|
tokens = tokenizer.encode(text, bos=False, eos=False)
|
||||||
|
|
||||||
|
chunks = []
|
||||||
|
for i in range(0, len(tokens), window_len - overlap_len):
|
||||||
|
toks = tokens[i : i + window_len]
|
||||||
|
chunk = tokenizer.decode(toks)
|
||||||
|
chunks.append((chunk, len(toks)))
|
||||||
|
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
|
||||||
|
class BankState(BaseModel):
|
||||||
|
bank: MemoryBank
|
||||||
|
index: Optional[faiss.IndexFlatL2] = None
|
||||||
|
doc_by_id: Dict[str, MemoryBankDocument] = Field(default_factory=dict)
|
||||||
|
id_by_index: Dict[int, str] = Field(default_factory=dict)
|
||||||
|
chunk_by_index: Dict[int, str] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
async def insert_documents(
|
||||||
|
self,
|
||||||
|
model: SentenceTransformer,
|
||||||
|
documents: List[MemoryBankDocument],
|
||||||
|
) -> None:
|
||||||
|
tokenizer = Tokenizer.get_instance()
|
||||||
|
chunk_size = self.bank.config.chunk_size_in_tokens
|
||||||
|
|
||||||
|
for doc in documents:
|
||||||
|
indexlen = len(self.id_by_index)
|
||||||
|
self.doc_by_id[doc.document_id] = doc
|
||||||
|
|
||||||
|
content = await content_from_doc(doc)
|
||||||
|
chunks = make_overlapped_chunks(
|
||||||
|
content,
|
||||||
|
self.bank.config.chunk_size_in_tokens,
|
||||||
|
self.bank.config.overlap_size_in_tokens
|
||||||
|
or (self.bank.config.chunk_size_in_tokens // 4),
|
||||||
|
)
|
||||||
|
embeddings = model.encode([x[0] for x in chunks]).astype(np.float32)
|
||||||
|
await self._ensure_index(embeddings.shape[1])
|
||||||
|
|
||||||
|
self.index.add(embeddings)
|
||||||
|
for i, chunk in enumerate(chunks):
|
||||||
|
self.chunk_by_index[indexlen + i] = Chunk(
|
||||||
|
content=chunk[0],
|
||||||
|
token_count=chunk[1],
|
||||||
|
)
|
||||||
|
self.id_by_index[indexlen + i] = doc.document_id
|
||||||
|
|
||||||
|
async def query_documents(
|
||||||
|
self, model: SentenceTransformer, query: str, params: Dict[str, Any]
|
||||||
|
) -> Tuple[List[Chunk], List[float]]:
|
||||||
|
k = params.get("max_chunks", 3)
|
||||||
|
query_vector = model.encode([query])[0]
|
||||||
|
distances, indices = self.index.search(
|
||||||
|
query_vector.reshape(1, -1).astype(np.float32), k
|
||||||
|
)
|
||||||
|
|
||||||
|
chunks = [self.chunk_by_index[int(i)] for i in indices[0]]
|
||||||
|
scores = [1.0 / float(d) for d in distances[0]]
|
||||||
|
|
||||||
|
return chunks, scores
|
||||||
|
|
||||||
|
async def _ensure_index(self, dimension: int) -> faiss.IndexFlatL2:
|
||||||
|
if self.index is None:
|
||||||
|
self.index = faiss.IndexFlatL2(dimension)
|
||||||
|
return self.index
|
||||||
|
|
||||||
|
|
||||||
|
class FaissMemoryImpl(Memory):
|
||||||
|
def __init__(self, config: FaissImplConfig) -> None:
|
||||||
|
self.config = config
|
||||||
|
self.model = SentenceTransformer("all-MiniLM-L6-v2")
|
||||||
|
self.states = {}
|
||||||
|
|
||||||
|
async def initialize(self) -> None: ...
|
||||||
|
|
||||||
|
async def shutdown(self) -> None: ...
|
||||||
|
|
||||||
|
async def create_memory_bank(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
config: MemoryBankConfig,
|
||||||
|
url: Optional[URL] = None,
|
||||||
|
) -> MemoryBank:
|
||||||
|
assert url is None, "URL is not supported for this implementation"
|
||||||
|
assert (
|
||||||
|
config.type == MemoryBankType.vector.value
|
||||||
|
), f"Only vector banks are supported {config.type}"
|
||||||
|
|
||||||
|
id = str(uuid.uuid4())
|
||||||
|
bank = MemoryBank(
|
||||||
|
bank_id=id,
|
||||||
|
name=name,
|
||||||
|
config=config,
|
||||||
|
url=url,
|
||||||
|
)
|
||||||
|
state = BankState(bank=bank)
|
||||||
|
self.states[id] = state
|
||||||
|
return bank
|
||||||
|
|
||||||
|
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
|
||||||
|
if bank_id not in self.states:
|
||||||
|
return None
|
||||||
|
return self.states[bank_id].bank
|
||||||
|
|
||||||
|
async def insert_documents(
|
||||||
|
self,
|
||||||
|
bank_id: str,
|
||||||
|
documents: List[MemoryBankDocument],
|
||||||
|
) -> None:
|
||||||
|
assert bank_id in self.states, f"Bank {bank_id} not found"
|
||||||
|
state = self.states[bank_id]
|
||||||
|
|
||||||
|
await state.insert_documents(self.model, documents)
|
||||||
|
|
||||||
|
async def query_documents(
|
||||||
|
self,
|
||||||
|
bank_id: str,
|
||||||
|
query: InterleavedTextMedia,
|
||||||
|
params: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> QueryDocumentsResponse:
|
||||||
|
assert bank_id in self.states, f"Bank {bank_id} not found"
|
||||||
|
state = self.states[bank_id]
|
||||||
|
|
||||||
|
chunks, scores = await state.query_documents(self.model, query, params)
|
||||||
|
return QueryDocumentsResponse(chunk=chunks, scores=scores)
|
24
llama_toolchain/memory/providers.py
Normal file
24
llama_toolchain/memory/providers.py
Normal file
|
@ -0,0 +1,24 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from llama_toolchain.distribution.datatypes import Api, InlineProviderSpec, ProviderSpec
|
||||||
|
|
||||||
|
|
||||||
|
def available_inference_providers() -> List[ProviderSpec]:
|
||||||
|
return [
|
||||||
|
InlineProviderSpec(
|
||||||
|
api=Api.memory,
|
||||||
|
provider_id="meta-reference-faiss",
|
||||||
|
pip_packages=[
|
||||||
|
"faiss",
|
||||||
|
"sentence-transformers",
|
||||||
|
],
|
||||||
|
module="llama_toolchain.memory.meta_reference.faiss",
|
||||||
|
config_class="llama_toolchain.memory.meta_reference.faiss.FaissImplConfig",
|
||||||
|
),
|
||||||
|
]
|
Loading…
Add table
Add a link
Reference in a new issue