[memory refactor][2/n] Update faiss and make it pass tests (#830)

See https://github.com/meta-llama/llama-stack/issues/827 for the broader
design.

Second part:

- updates routing table / router code 
- updates the faiss implementation


## Test Plan

```
pytest -s -v -k sentence test_vector_io.py --env EMBEDDING_DIMENSION=384
```
This commit is contained in:
Ashwin Bharambe 2025-01-22 10:02:15 -08:00 committed by GitHub
parent 3ae8585b65
commit 78a481bb22
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 343 additions and 353 deletions

View file

@ -17,35 +17,28 @@ import numpy as np
from numpy.typing import NDArray
from llama_stack.apis.inference import InterleavedContent
from llama_stack.apis.memory import (
Chunk,
Memory,
MemoryBankDocument,
QueryDocumentsResponse,
)
from llama_stack.apis.memory_banks import MemoryBank, MemoryBankType, VectorMemoryBank
from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.memory.vector_store import (
BankWithIndex,
EmbeddingIndex,
VectorDBWithIndex,
)
from .config import FaissImplConfig
logger = logging.getLogger(__name__)
MEMORY_BANKS_PREFIX = "memory_banks:v2::"
VECTOR_DBS_PREFIX = "vector_dbs:v2::"
FAISS_INDEX_PREFIX = "faiss_index:v2::"
class FaissIndex(EmbeddingIndex):
id_by_index: Dict[int, str]
chunk_by_index: Dict[int, str]
def __init__(self, dimension: int, kvstore=None, bank_id: str = None):
self.index = faiss.IndexFlatL2(dimension)
self.id_by_index = {}
self.chunk_by_index = {}
self.kvstore = kvstore
self.bank_id = bank_id
@ -65,7 +58,6 @@ class FaissIndex(EmbeddingIndex):
if stored_data:
data = json.loads(stored_data)
self.id_by_index = {int(k): v for k, v in data["id_by_index"].items()}
self.chunk_by_index = {
int(k): Chunk.model_validate_json(v)
for k, v in data["chunk_by_index"].items()
@ -82,7 +74,6 @@ class FaissIndex(EmbeddingIndex):
buffer = io.BytesIO()
np.savetxt(buffer, np_index)
data = {
"id_by_index": self.id_by_index,
"chunk_by_index": {
k: v.model_dump_json() for k, v in self.chunk_by_index.items()
},
@ -108,10 +99,9 @@ class FaissIndex(EmbeddingIndex):
f"Embedding dimension mismatch. Expected {self.index.d}, got {embedding_dim}"
)
indexlen = len(self.id_by_index)
indexlen = len(self.chunk_by_index)
for i, chunk in enumerate(chunks):
self.chunk_by_index[indexlen + i] = chunk
self.id_by_index[indexlen + i] = chunk.document_id
self.index.add(np.array(embeddings).astype(np.float32))
@ -120,7 +110,7 @@ class FaissIndex(EmbeddingIndex):
async def query(
self, embedding: NDArray, k: int, score_threshold: float
) -> QueryDocumentsResponse:
) -> QueryChunksResponse:
distances, indices = self.index.search(
embedding.reshape(1, -1).astype(np.float32), k
)
@ -133,10 +123,10 @@ class FaissIndex(EmbeddingIndex):
chunks.append(self.chunk_by_index[int(i)])
scores.append(1.0 / float(d))
return QueryDocumentsResponse(chunks=chunks, scores=scores)
return QueryChunksResponse(chunks=chunks, scores=scores)
class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
class FaissVectorIOImpl(VectorIO, VectorDBsProtocolPrivate):
def __init__(self, config: FaissImplConfig, inference_api: Api.inference) -> None:
self.config = config
self.inference_api = inference_api
@ -146,77 +136,74 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
async def initialize(self) -> None:
self.kvstore = await kvstore_impl(self.config.kvstore)
# Load existing banks from kvstore
start_key = MEMORY_BANKS_PREFIX
end_key = f"{MEMORY_BANKS_PREFIX}\xff"
stored_banks = await self.kvstore.range(start_key, end_key)
start_key = VECTOR_DBS_PREFIX
end_key = f"{VECTOR_DBS_PREFIX}\xff"
stored_vector_dbs = await self.kvstore.range(start_key, end_key)
for bank_data in stored_banks:
bank = VectorMemoryBank.model_validate_json(bank_data)
index = BankWithIndex(
bank,
for vector_db_data in stored_vector_dbs:
vector_db = VectorDB.model_validate_json(vector_db_data)
index = VectorDBWithIndex(
vector_db,
await FaissIndex.create(
bank.embedding_dimension, self.kvstore, bank.identifier
vector_db.embedding_dimension, self.kvstore, vector_db.identifier
),
self.inference_api,
)
self.cache[bank.identifier] = index
self.cache[vector_db.identifier] = index
async def shutdown(self) -> None:
# Cleanup if needed
pass
async def register_memory_bank(
async def register_vector_db(
self,
memory_bank: MemoryBank,
vector_db: VectorDB,
) -> None:
assert (
memory_bank.memory_bank_type == MemoryBankType.vector.value
), f"Only vector banks are supported {memory_bank.type}"
# Store in kvstore
key = f"{MEMORY_BANKS_PREFIX}{memory_bank.identifier}"
key = f"{VECTOR_DBS_PREFIX}{vector_db.identifier}"
await self.kvstore.set(
key=key,
value=memory_bank.model_dump_json(),
value=vector_db.model_dump_json(),
)
# Store in cache
self.cache[memory_bank.identifier] = BankWithIndex(
memory_bank,
await FaissIndex.create(
memory_bank.embedding_dimension, self.kvstore, memory_bank.identifier
self.cache[vector_db.identifier] = VectorDBWithIndex(
vector_db=vector_db,
index=await FaissIndex.create(
vector_db.embedding_dimension, self.kvstore, vector_db.identifier
),
self.inference_api,
inference_api=self.inference_api,
)
async def list_memory_banks(self) -> List[MemoryBank]:
return [i.bank for i in self.cache.values()]
async def list_vector_dbs(self) -> List[VectorDB]:
return [i.vector_db for i in self.cache.values()]
async def unregister_memory_bank(self, memory_bank_id: str) -> None:
await self.cache[memory_bank_id].index.delete()
del self.cache[memory_bank_id]
await self.kvstore.delete(f"{MEMORY_BANKS_PREFIX}{memory_bank_id}")
async def unregister_vector_db(self, vector_db_id: str) -> None:
await self.cache[vector_db_id].index.delete()
del self.cache[vector_db_id]
await self.kvstore.delete(f"{VECTOR_DBS_PREFIX}{vector_db_id}")
async def insert_documents(
async def insert_chunks(
self,
bank_id: str,
documents: List[MemoryBankDocument],
vector_db_id: str,
chunks: List[Chunk],
ttl_seconds: Optional[int] = None,
) -> None:
index = self.cache.get(bank_id)
index = self.cache.get(vector_db_id)
if index is None:
raise ValueError(f"Bank {bank_id} not found. found: {self.cache.keys()}")
raise ValueError(
f"Vector DB {vector_db_id} not found. found: {self.cache.keys()}"
)
await index.insert_documents(documents)
await index.insert_chunks(chunks)
async def query_documents(
async def query_chunks(
self,
bank_id: str,
vector_db_id: str,
query: InterleavedContent,
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse:
index = self.cache.get(bank_id)
) -> QueryChunksResponse:
index = self.cache.get(vector_db_id)
if index is None:
raise ValueError(f"Bank {bank_id} not found")
raise ValueError(f"Vector DB {vector_db_id} not found")
return await index.query_documents(query, params)
return await index.query_chunks(query, params)