forked from phoenix-oss/llama-stack-mirror
# What does this PR do? This PR introduces support for keyword based FTS5 search with BM25 relevance scoring. It makes changes to the existing EmbeddingIndex base class in order to support a search_mode and query_str parameter, that can be used for keyword based search implementations. [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) ## Test Plan run ``` pytest llama_stack/providers/tests/vector_io/test_sqlite_vec.py -v -s --tb=short --disable-warnings --asyncio-mode=auto ``` Output: ``` pytest llama_stack/providers/tests/vector_io/test_sqlite_vec.py -v -s --tb=short --disable-warnings --asyncio-mode=auto /Users/vnarsing/miniconda3/envs/stack-client/lib/python3.10/site-packages/pytest_asyncio/plugin.py:207: PytestDeprecationWarning: The configuration option "asyncio_default_fixture_loop_scope" is unset. The event loop scope for asynchronous fixtures will default to the fixture caching scope. Future versions of pytest-asyncio will default the loop scope for asynchronous fixtures to function scope. Set the default fixture loop scope explicitly in order to avoid unexpected behavior in the future. Valid fixture loop scopes are: "function", "class", "module", "package", "session" warnings.warn(PytestDeprecationWarning(_DEFAULT_FIXTURE_LOOP_SCOPE_UNSET)) ====================================================== test session starts ======================================================= platform darwin -- Python 3.10.16, pytest-8.3.4, pluggy-1.5.0 -- /Users/vnarsing/miniconda3/envs/stack-client/bin/python cachedir: .pytest_cache metadata: {'Python': '3.10.16', 'Platform': 'macOS-14.7.4-arm64-arm-64bit', 'Packages': {'pytest': '8.3.4', 'pluggy': '1.5.0'}, 'Plugins': {'html': '4.1.1', 'metadata': '3.1.1', 'asyncio': '0.25.3', 'anyio': '4.8.0'}} rootdir: /Users/vnarsing/go/src/github/meta-llama/llama-stack configfile: pyproject.toml plugins: html-4.1.1, metadata-3.1.1, asyncio-0.25.3, anyio-4.8.0 asyncio: mode=auto, asyncio_default_fixture_loop_scope=None collected 7 items llama_stack/providers/tests/vector_io/test_sqlite_vec.py::test_add_chunks PASSED llama_stack/providers/tests/vector_io/test_sqlite_vec.py::test_query_chunks_vector PASSED llama_stack/providers/tests/vector_io/test_sqlite_vec.py::test_query_chunks_fts PASSED llama_stack/providers/tests/vector_io/test_sqlite_vec.py::test_chunk_id_conflict PASSED llama_stack/providers/tests/vector_io/test_sqlite_vec.py::test_register_vector_db PASSED llama_stack/providers/tests/vector_io/test_sqlite_vec.py::test_unregister_vector_db PASSED llama_stack/providers/tests/vector_io/test_sqlite_vec.py::test_generate_chunk_id PASSED ``` For reference, with the implementation, the fts table looks like below: ``` Chunk ID: 9fbc39ce-c729-64a2-260f-c5ec9bb2a33e, Content: Sentence 0 from document 0 Chunk ID: 94062914-3e23-44cf-1e50-9e25821ba882, Content: Sentence 1 from document 0 Chunk ID: e6cfd559-4641-33ba-6ce1-7038226495eb, Content: Sentence 2 from document 0 Chunk ID: 1383af9b-f1f0-f417-4de5-65fe9456cc20, Content: Sentence 3 from document 0 Chunk ID: 2db19b1a-de14-353b-f4e1-085e8463361c, Content: Sentence 4 from document 0 Chunk ID: 9faf986a-f028-7714-068a-1c795e8f2598, Content: Sentence 5 from document 0 Chunk ID: ef593ead-5a4a-392f-7ad8-471a50f033e8, Content: Sentence 6 from document 0 Chunk ID: e161950f-021f-7300-4d05-3166738b94cf, Content: Sentence 7 from document 0 Chunk ID: 90610fc4-67c1-e740-f043-709c5978867a, Content: Sentence 8 from document 0 Chunk ID: 97712879-6fff-98ad-0558-e9f42e6b81d3, Content: Sentence 9 from document 0 Chunk ID: aea70411-51df-61ba-d2f0-cb2b5972c210, Content: Sentence 0 from document 1 Chunk ID: b678a463-7b84-92b8-abb2-27e9a1977e3c, Content: Sentence 1 from document 1 Chunk ID: 27bd63da-909c-1606-a109-75bdb9479882, Content: Sentence 2 from document 1 Chunk ID: a2ad49ad-f9be-5372-e0c7-7b0221d0b53e, Content: Sentence 3 from document 1 Chunk ID: cac53bcd-1965-082a-c0f4-ceee7323fc70, Content: Sentence 4 from document 1 ``` Query results: Result 1: Sentence 5 from document 0 Result 2: Sentence 5 from document 1 Result 3: Sentence 5 from document 2 [//]: # (## Documentation) --------- Signed-off-by: Varsha Prasad Narsing <varshaprasad96@gmail.com>
210 lines
7.1 KiB
Python
210 lines
7.1 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import asyncio
|
|
import base64
|
|
import io
|
|
import json
|
|
import logging
|
|
from typing import Any
|
|
|
|
import faiss
|
|
import numpy as np
|
|
from numpy.typing import NDArray
|
|
|
|
from llama_stack.apis.common.content_types import InterleavedContent
|
|
from llama_stack.apis.inference.inference import Inference
|
|
from llama_stack.apis.vector_dbs import VectorDB
|
|
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
|
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
|
|
from llama_stack.providers.utils.kvstore import kvstore_impl
|
|
from llama_stack.providers.utils.kvstore.api import KVStore
|
|
from llama_stack.providers.utils.memory.vector_store import (
|
|
EmbeddingIndex,
|
|
VectorDBWithIndex,
|
|
)
|
|
|
|
from .config import FaissVectorIOConfig
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
VERSION = "v3"
|
|
VECTOR_DBS_PREFIX = f"vector_dbs:{VERSION}::"
|
|
FAISS_INDEX_PREFIX = f"faiss_index:{VERSION}::"
|
|
|
|
|
|
class FaissIndex(EmbeddingIndex):
|
|
def __init__(self, dimension: int, kvstore: KVStore | None = None, bank_id: str | None = None):
|
|
self.index = faiss.IndexFlatL2(dimension)
|
|
self.chunk_by_index: dict[int, Chunk] = {}
|
|
self.kvstore = kvstore
|
|
self.bank_id = bank_id
|
|
|
|
@classmethod
|
|
async def create(cls, dimension: int, kvstore: KVStore | None = None, bank_id: str | None = None):
|
|
instance = cls(dimension, kvstore, bank_id)
|
|
await instance.initialize()
|
|
return instance
|
|
|
|
async def initialize(self) -> None:
|
|
if not self.kvstore:
|
|
return
|
|
|
|
index_key = f"{FAISS_INDEX_PREFIX}{self.bank_id}"
|
|
stored_data = await self.kvstore.get(index_key)
|
|
|
|
if stored_data:
|
|
data = json.loads(stored_data)
|
|
self.chunk_by_index = {int(k): Chunk.model_validate_json(v) for k, v in data["chunk_by_index"].items()}
|
|
|
|
buffer = io.BytesIO(base64.b64decode(data["faiss_index"]))
|
|
self.index = faiss.deserialize_index(np.loadtxt(buffer, dtype=np.uint8))
|
|
|
|
async def _save_index(self):
|
|
if not self.kvstore or not self.bank_id:
|
|
return
|
|
|
|
np_index = faiss.serialize_index(self.index)
|
|
buffer = io.BytesIO()
|
|
np.savetxt(buffer, np_index)
|
|
data = {
|
|
"chunk_by_index": {k: v.model_dump_json() for k, v in self.chunk_by_index.items()},
|
|
"faiss_index": base64.b64encode(buffer.getvalue()).decode("utf-8"),
|
|
}
|
|
|
|
index_key = f"{FAISS_INDEX_PREFIX}{self.bank_id}"
|
|
await self.kvstore.set(key=index_key, value=json.dumps(data))
|
|
|
|
async def delete(self):
|
|
if not self.kvstore or not self.bank_id:
|
|
return
|
|
|
|
await self.kvstore.delete(f"{FAISS_INDEX_PREFIX}{self.bank_id}")
|
|
|
|
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray):
|
|
# Add dimension check
|
|
embedding_dim = embeddings.shape[1] if len(embeddings.shape) > 1 else embeddings.shape[0]
|
|
if embedding_dim != self.index.d:
|
|
raise ValueError(f"Embedding dimension mismatch. Expected {self.index.d}, got {embedding_dim}")
|
|
|
|
indexlen = len(self.chunk_by_index)
|
|
for i, chunk in enumerate(chunks):
|
|
self.chunk_by_index[indexlen + i] = chunk
|
|
|
|
self.index.add(np.array(embeddings).astype(np.float32))
|
|
|
|
# Save updated index
|
|
await self._save_index()
|
|
|
|
async def query_vector(
|
|
self,
|
|
embedding: NDArray,
|
|
k: int,
|
|
score_threshold: float,
|
|
) -> QueryChunksResponse:
|
|
distances, indices = await asyncio.to_thread(self.index.search, embedding.reshape(1, -1).astype(np.float32), k)
|
|
chunks = []
|
|
scores = []
|
|
for d, i in zip(distances[0], indices[0], strict=False):
|
|
if i < 0:
|
|
continue
|
|
chunks.append(self.chunk_by_index[int(i)])
|
|
scores.append(1.0 / float(d))
|
|
|
|
return QueryChunksResponse(chunks=chunks, scores=scores)
|
|
|
|
async def query_keyword(
|
|
self,
|
|
query_string: str,
|
|
k: int,
|
|
score_threshold: float,
|
|
) -> QueryChunksResponse:
|
|
raise NotImplementedError("Keyword search is not supported in FAISS")
|
|
|
|
|
|
class FaissVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
|
def __init__(self, config: FaissVectorIOConfig, inference_api: Inference) -> None:
|
|
self.config = config
|
|
self.inference_api = inference_api
|
|
self.cache: dict[str, VectorDBWithIndex] = {}
|
|
self.kvstore: KVStore | None = None
|
|
|
|
async def initialize(self) -> None:
|
|
self.kvstore = await kvstore_impl(self.config.kvstore)
|
|
# Load existing banks from kvstore
|
|
start_key = VECTOR_DBS_PREFIX
|
|
end_key = f"{VECTOR_DBS_PREFIX}\xff"
|
|
stored_vector_dbs = await self.kvstore.values_in_range(start_key, end_key)
|
|
|
|
for vector_db_data in stored_vector_dbs:
|
|
vector_db = VectorDB.model_validate_json(vector_db_data)
|
|
index = VectorDBWithIndex(
|
|
vector_db,
|
|
await FaissIndex.create(vector_db.embedding_dimension, self.kvstore, vector_db.identifier),
|
|
self.inference_api,
|
|
)
|
|
self.cache[vector_db.identifier] = index
|
|
|
|
async def shutdown(self) -> None:
|
|
# Cleanup if needed
|
|
pass
|
|
|
|
async def register_vector_db(
|
|
self,
|
|
vector_db: VectorDB,
|
|
) -> None:
|
|
assert self.kvstore is not None
|
|
|
|
key = f"{VECTOR_DBS_PREFIX}{vector_db.identifier}"
|
|
await self.kvstore.set(
|
|
key=key,
|
|
value=vector_db.model_dump_json(),
|
|
)
|
|
|
|
# Store in cache
|
|
self.cache[vector_db.identifier] = VectorDBWithIndex(
|
|
vector_db=vector_db,
|
|
index=await FaissIndex.create(vector_db.embedding_dimension, self.kvstore, vector_db.identifier),
|
|
inference_api=self.inference_api,
|
|
)
|
|
|
|
async def list_vector_dbs(self) -> list[VectorDB]:
|
|
return [i.vector_db for i in self.cache.values()]
|
|
|
|
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
|
assert self.kvstore is not None
|
|
|
|
if vector_db_id not in self.cache:
|
|
logger.warning(f"Vector DB {vector_db_id} not found")
|
|
return
|
|
|
|
await self.cache[vector_db_id].index.delete()
|
|
del self.cache[vector_db_id]
|
|
await self.kvstore.delete(f"{VECTOR_DBS_PREFIX}{vector_db_id}")
|
|
|
|
async def insert_chunks(
|
|
self,
|
|
vector_db_id: str,
|
|
chunks: list[Chunk],
|
|
ttl_seconds: int | None = None,
|
|
) -> None:
|
|
index = self.cache.get(vector_db_id)
|
|
if index is None:
|
|
raise ValueError(f"Vector DB {vector_db_id} not found. found: {self.cache.keys()}")
|
|
|
|
await index.insert_chunks(chunks)
|
|
|
|
async def query_chunks(
|
|
self,
|
|
vector_db_id: str,
|
|
query: InterleavedContent,
|
|
params: dict[str, Any] | None = None,
|
|
) -> QueryChunksResponse:
|
|
index = self.cache.get(vector_db_id)
|
|
if index is None:
|
|
raise ValueError(f"Vector DB {vector_db_id} not found")
|
|
|
|
return await index.query_chunks(query, params)
|