forked from phoenix-oss/llama-stack-mirror
# What does this PR do? This PR introduces support for keyword based FTS5 search with BM25 relevance scoring. It makes changes to the existing EmbeddingIndex base class in order to support a search_mode and query_str parameter, that can be used for keyword based search implementations. [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) ## Test Plan run ``` pytest llama_stack/providers/tests/vector_io/test_sqlite_vec.py -v -s --tb=short --disable-warnings --asyncio-mode=auto ``` Output: ``` pytest llama_stack/providers/tests/vector_io/test_sqlite_vec.py -v -s --tb=short --disable-warnings --asyncio-mode=auto /Users/vnarsing/miniconda3/envs/stack-client/lib/python3.10/site-packages/pytest_asyncio/plugin.py:207: PytestDeprecationWarning: The configuration option "asyncio_default_fixture_loop_scope" is unset. The event loop scope for asynchronous fixtures will default to the fixture caching scope. Future versions of pytest-asyncio will default the loop scope for asynchronous fixtures to function scope. Set the default fixture loop scope explicitly in order to avoid unexpected behavior in the future. Valid fixture loop scopes are: "function", "class", "module", "package", "session" warnings.warn(PytestDeprecationWarning(_DEFAULT_FIXTURE_LOOP_SCOPE_UNSET)) ====================================================== test session starts ======================================================= platform darwin -- Python 3.10.16, pytest-8.3.4, pluggy-1.5.0 -- /Users/vnarsing/miniconda3/envs/stack-client/bin/python cachedir: .pytest_cache metadata: {'Python': '3.10.16', 'Platform': 'macOS-14.7.4-arm64-arm-64bit', 'Packages': {'pytest': '8.3.4', 'pluggy': '1.5.0'}, 'Plugins': {'html': '4.1.1', 'metadata': '3.1.1', 'asyncio': '0.25.3', 'anyio': '4.8.0'}} rootdir: /Users/vnarsing/go/src/github/meta-llama/llama-stack configfile: pyproject.toml plugins: html-4.1.1, metadata-3.1.1, asyncio-0.25.3, anyio-4.8.0 asyncio: mode=auto, asyncio_default_fixture_loop_scope=None collected 7 items llama_stack/providers/tests/vector_io/test_sqlite_vec.py::test_add_chunks PASSED llama_stack/providers/tests/vector_io/test_sqlite_vec.py::test_query_chunks_vector PASSED llama_stack/providers/tests/vector_io/test_sqlite_vec.py::test_query_chunks_fts PASSED llama_stack/providers/tests/vector_io/test_sqlite_vec.py::test_chunk_id_conflict PASSED llama_stack/providers/tests/vector_io/test_sqlite_vec.py::test_register_vector_db PASSED llama_stack/providers/tests/vector_io/test_sqlite_vec.py::test_unregister_vector_db PASSED llama_stack/providers/tests/vector_io/test_sqlite_vec.py::test_generate_chunk_id PASSED ``` For reference, with the implementation, the fts table looks like below: ``` Chunk ID: 9fbc39ce-c729-64a2-260f-c5ec9bb2a33e, Content: Sentence 0 from document 0 Chunk ID: 94062914-3e23-44cf-1e50-9e25821ba882, Content: Sentence 1 from document 0 Chunk ID: e6cfd559-4641-33ba-6ce1-7038226495eb, Content: Sentence 2 from document 0 Chunk ID: 1383af9b-f1f0-f417-4de5-65fe9456cc20, Content: Sentence 3 from document 0 Chunk ID: 2db19b1a-de14-353b-f4e1-085e8463361c, Content: Sentence 4 from document 0 Chunk ID: 9faf986a-f028-7714-068a-1c795e8f2598, Content: Sentence 5 from document 0 Chunk ID: ef593ead-5a4a-392f-7ad8-471a50f033e8, Content: Sentence 6 from document 0 Chunk ID: e161950f-021f-7300-4d05-3166738b94cf, Content: Sentence 7 from document 0 Chunk ID: 90610fc4-67c1-e740-f043-709c5978867a, Content: Sentence 8 from document 0 Chunk ID: 97712879-6fff-98ad-0558-e9f42e6b81d3, Content: Sentence 9 from document 0 Chunk ID: aea70411-51df-61ba-d2f0-cb2b5972c210, Content: Sentence 0 from document 1 Chunk ID: b678a463-7b84-92b8-abb2-27e9a1977e3c, Content: Sentence 1 from document 1 Chunk ID: 27bd63da-909c-1606-a109-75bdb9479882, Content: Sentence 2 from document 1 Chunk ID: a2ad49ad-f9be-5372-e0c7-7b0221d0b53e, Content: Sentence 3 from document 1 Chunk ID: cac53bcd-1965-082a-c0f4-ceee7323fc70, Content: Sentence 4 from document 1 ``` Query results: Result 1: Sentence 5 from document 0 Result 2: Sentence 5 from document 1 Result 3: Sentence 5 from document 2 [//]: # (## Documentation) --------- Signed-off-by: Varsha Prasad Narsing <varshaprasad96@gmail.com>
180 lines
6.2 KiB
Python
180 lines
6.2 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import logging
|
|
import uuid
|
|
from typing import Any
|
|
|
|
from numpy.typing import NDArray
|
|
from qdrant_client import AsyncQdrantClient, models
|
|
from qdrant_client.models import PointStruct
|
|
|
|
from llama_stack.apis.inference import InterleavedContent
|
|
from llama_stack.apis.vector_dbs import VectorDB
|
|
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
|
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
|
|
from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig as InlineQdrantVectorIOConfig
|
|
from llama_stack.providers.utils.memory.vector_store import (
|
|
EmbeddingIndex,
|
|
VectorDBWithIndex,
|
|
)
|
|
|
|
from .config import QdrantVectorIOConfig as RemoteQdrantVectorIOConfig
|
|
|
|
log = logging.getLogger(__name__)
|
|
CHUNK_ID_KEY = "_chunk_id"
|
|
|
|
|
|
def convert_id(_id: str) -> str:
|
|
"""
|
|
Converts any string into a UUID string based on a seed.
|
|
|
|
Qdrant accepts UUID strings and unsigned integers as point ID.
|
|
We use a seed to convert each string into a UUID string deterministically.
|
|
This allows us to overwrite the same point with the original ID.
|
|
"""
|
|
return str(uuid.uuid5(uuid.NAMESPACE_DNS, _id))
|
|
|
|
|
|
class QdrantIndex(EmbeddingIndex):
|
|
def __init__(self, client: AsyncQdrantClient, collection_name: str):
|
|
self.client = client
|
|
self.collection_name = collection_name
|
|
|
|
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray):
|
|
assert len(chunks) == len(embeddings), (
|
|
f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
|
|
)
|
|
|
|
if not await self.client.collection_exists(self.collection_name):
|
|
await self.client.create_collection(
|
|
self.collection_name,
|
|
vectors_config=models.VectorParams(size=len(embeddings[0]), distance=models.Distance.COSINE),
|
|
)
|
|
|
|
points = []
|
|
for i, (chunk, embedding) in enumerate(zip(chunks, embeddings, strict=False)):
|
|
chunk_id = f"{chunk.metadata['document_id']}:chunk-{i}"
|
|
points.append(
|
|
PointStruct(
|
|
id=convert_id(chunk_id),
|
|
vector=embedding,
|
|
payload={"chunk_content": chunk.model_dump()} | {CHUNK_ID_KEY: chunk_id},
|
|
)
|
|
)
|
|
|
|
await self.client.upsert(collection_name=self.collection_name, points=points)
|
|
|
|
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
|
results = (
|
|
await self.client.query_points(
|
|
collection_name=self.collection_name,
|
|
query=embedding.tolist(),
|
|
limit=k,
|
|
with_payload=True,
|
|
score_threshold=score_threshold,
|
|
)
|
|
).points
|
|
|
|
chunks, scores = [], []
|
|
for point in results:
|
|
assert isinstance(point, models.ScoredPoint)
|
|
assert point.payload is not None
|
|
|
|
try:
|
|
chunk = Chunk(**point.payload["chunk_content"])
|
|
except Exception:
|
|
log.exception("Failed to parse chunk")
|
|
continue
|
|
|
|
chunks.append(chunk)
|
|
scores.append(point.score)
|
|
|
|
return QueryChunksResponse(chunks=chunks, scores=scores)
|
|
|
|
async def query_keyword(
|
|
self,
|
|
query_string: str,
|
|
k: int,
|
|
score_threshold: float,
|
|
) -> QueryChunksResponse:
|
|
raise NotImplementedError("Keyword search is not supported in Qdrant")
|
|
|
|
async def delete(self):
|
|
await self.client.delete_collection(collection_name=self.collection_name)
|
|
|
|
|
|
class QdrantVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
|
def __init__(
|
|
self, config: RemoteQdrantVectorIOConfig | InlineQdrantVectorIOConfig, inference_api: Api.inference
|
|
) -> None:
|
|
self.config = config
|
|
self.client: AsyncQdrantClient = None
|
|
self.cache = {}
|
|
self.inference_api = inference_api
|
|
|
|
async def initialize(self) -> None:
|
|
self.client = AsyncQdrantClient(**self.config.model_dump(exclude_none=True))
|
|
|
|
async def shutdown(self) -> None:
|
|
await self.client.close()
|
|
|
|
async def register_vector_db(
|
|
self,
|
|
vector_db: VectorDB,
|
|
) -> None:
|
|
index = VectorDBWithIndex(
|
|
vector_db=vector_db,
|
|
index=QdrantIndex(self.client, vector_db.identifier),
|
|
inference_api=self.inference_api,
|
|
)
|
|
|
|
self.cache[vector_db.identifier] = index
|
|
|
|
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
|
if vector_db_id in self.cache:
|
|
await self.cache[vector_db_id].index.delete()
|
|
del self.cache[vector_db_id]
|
|
|
|
async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> VectorDBWithIndex | None:
|
|
if vector_db_id in self.cache:
|
|
return self.cache[vector_db_id]
|
|
|
|
vector_db = await self.vector_db_store.get_vector_db(vector_db_id)
|
|
if not vector_db:
|
|
raise ValueError(f"Vector DB {vector_db_id} not found")
|
|
|
|
index = VectorDBWithIndex(
|
|
vector_db=vector_db,
|
|
index=QdrantIndex(client=self.client, collection_name=vector_db.identifier),
|
|
inference_api=self.inference_api,
|
|
)
|
|
self.cache[vector_db_id] = index
|
|
return index
|
|
|
|
async def insert_chunks(
|
|
self,
|
|
vector_db_id: str,
|
|
chunks: list[Chunk],
|
|
ttl_seconds: int | None = None,
|
|
) -> None:
|
|
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
|
if not index:
|
|
raise ValueError(f"Vector DB {vector_db_id} not found")
|
|
|
|
await index.insert_chunks(chunks)
|
|
|
|
async def query_chunks(
|
|
self,
|
|
vector_db_id: str,
|
|
query: InterleavedContent,
|
|
params: dict[str, Any] | None = None,
|
|
) -> QueryChunksResponse:
|
|
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
|
if not index:
|
|
raise ValueError(f"Vector DB {vector_db_id} not found")
|
|
|
|
return await index.query_chunks(query, params)
|