forked from phoenix-oss/llama-stack-mirror
feat(sqlite-vec): enable keyword search for sqlite-vec (#1439)
# What does this PR do? This PR introduces support for keyword based FTS5 search with BM25 relevance scoring. It makes changes to the existing EmbeddingIndex base class in order to support a search_mode and query_str parameter, that can be used for keyword based search implementations. [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) ## Test Plan run ``` pytest llama_stack/providers/tests/vector_io/test_sqlite_vec.py -v -s --tb=short --disable-warnings --asyncio-mode=auto ``` Output: ``` pytest llama_stack/providers/tests/vector_io/test_sqlite_vec.py -v -s --tb=short --disable-warnings --asyncio-mode=auto /Users/vnarsing/miniconda3/envs/stack-client/lib/python3.10/site-packages/pytest_asyncio/plugin.py:207: PytestDeprecationWarning: The configuration option "asyncio_default_fixture_loop_scope" is unset. The event loop scope for asynchronous fixtures will default to the fixture caching scope. Future versions of pytest-asyncio will default the loop scope for asynchronous fixtures to function scope. Set the default fixture loop scope explicitly in order to avoid unexpected behavior in the future. Valid fixture loop scopes are: "function", "class", "module", "package", "session" warnings.warn(PytestDeprecationWarning(_DEFAULT_FIXTURE_LOOP_SCOPE_UNSET)) ====================================================== test session starts ======================================================= platform darwin -- Python 3.10.16, pytest-8.3.4, pluggy-1.5.0 -- /Users/vnarsing/miniconda3/envs/stack-client/bin/python cachedir: .pytest_cache metadata: {'Python': '3.10.16', 'Platform': 'macOS-14.7.4-arm64-arm-64bit', 'Packages': {'pytest': '8.3.4', 'pluggy': '1.5.0'}, 'Plugins': {'html': '4.1.1', 'metadata': '3.1.1', 'asyncio': '0.25.3', 'anyio': '4.8.0'}} rootdir: /Users/vnarsing/go/src/github/meta-llama/llama-stack configfile: pyproject.toml plugins: html-4.1.1, metadata-3.1.1, asyncio-0.25.3, anyio-4.8.0 asyncio: mode=auto, asyncio_default_fixture_loop_scope=None collected 7 items llama_stack/providers/tests/vector_io/test_sqlite_vec.py::test_add_chunks PASSED llama_stack/providers/tests/vector_io/test_sqlite_vec.py::test_query_chunks_vector PASSED llama_stack/providers/tests/vector_io/test_sqlite_vec.py::test_query_chunks_fts PASSED llama_stack/providers/tests/vector_io/test_sqlite_vec.py::test_chunk_id_conflict PASSED llama_stack/providers/tests/vector_io/test_sqlite_vec.py::test_register_vector_db PASSED llama_stack/providers/tests/vector_io/test_sqlite_vec.py::test_unregister_vector_db PASSED llama_stack/providers/tests/vector_io/test_sqlite_vec.py::test_generate_chunk_id PASSED ``` For reference, with the implementation, the fts table looks like below: ``` Chunk ID: 9fbc39ce-c729-64a2-260f-c5ec9bb2a33e, Content: Sentence 0 from document 0 Chunk ID: 94062914-3e23-44cf-1e50-9e25821ba882, Content: Sentence 1 from document 0 Chunk ID: e6cfd559-4641-33ba-6ce1-7038226495eb, Content: Sentence 2 from document 0 Chunk ID: 1383af9b-f1f0-f417-4de5-65fe9456cc20, Content: Sentence 3 from document 0 Chunk ID: 2db19b1a-de14-353b-f4e1-085e8463361c, Content: Sentence 4 from document 0 Chunk ID: 9faf986a-f028-7714-068a-1c795e8f2598, Content: Sentence 5 from document 0 Chunk ID: ef593ead-5a4a-392f-7ad8-471a50f033e8, Content: Sentence 6 from document 0 Chunk ID: e161950f-021f-7300-4d05-3166738b94cf, Content: Sentence 7 from document 0 Chunk ID: 90610fc4-67c1-e740-f043-709c5978867a, Content: Sentence 8 from document 0 Chunk ID: 97712879-6fff-98ad-0558-e9f42e6b81d3, Content: Sentence 9 from document 0 Chunk ID: aea70411-51df-61ba-d2f0-cb2b5972c210, Content: Sentence 0 from document 1 Chunk ID: b678a463-7b84-92b8-abb2-27e9a1977e3c, Content: Sentence 1 from document 1 Chunk ID: 27bd63da-909c-1606-a109-75bdb9479882, Content: Sentence 2 from document 1 Chunk ID: a2ad49ad-f9be-5372-e0c7-7b0221d0b53e, Content: Sentence 3 from document 1 Chunk ID: cac53bcd-1965-082a-c0f4-ceee7323fc70, Content: Sentence 4 from document 1 ``` Query results: Result 1: Sentence 5 from document 0 Result 2: Sentence 5 from document 1 Result 3: Sentence 5 from document 2 [//]: # (## Documentation) --------- Signed-off-by: Varsha Prasad Narsing <varshaprasad96@gmail.com>
This commit is contained in:
parent
85b5f3172b
commit
e92301f2d7
15 changed files with 247 additions and 37 deletions
4
docs/_static/llama-stack-spec.html
vendored
4
docs/_static/llama-stack-spec.html
vendored
|
@ -11608,6 +11608,10 @@
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"default": "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n",
|
"default": "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n",
|
||||||
"description": "Template for formatting each retrieved chunk in the context. Available placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk content string), {metadata} (chunk metadata dict). Default: \"Result {index}\\nContent: {chunk.content}\\nMetadata: {metadata}\\n\""
|
"description": "Template for formatting each retrieved chunk in the context. Available placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk content string), {metadata} (chunk metadata dict). Default: \"Result {index}\\nContent: {chunk.content}\\nMetadata: {metadata}\\n\""
|
||||||
|
},
|
||||||
|
"mode": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Search mode for retrieval—either \"vector\" or \"keyword\". Default \"vector\"."
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
|
|
4
docs/_static/llama-stack-spec.yaml
vendored
4
docs/_static/llama-stack-spec.yaml
vendored
|
@ -8086,6 +8086,10 @@ components:
|
||||||
placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk
|
placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk
|
||||||
content string), {metadata} (chunk metadata dict). Default: "Result {index}\nContent:
|
content string), {metadata} (chunk metadata dict). Default: "Result {index}\nContent:
|
||||||
{chunk.content}\nMetadata: {metadata}\n"
|
{chunk.content}\nMetadata: {metadata}\n"
|
||||||
|
mode:
|
||||||
|
type: string
|
||||||
|
description: >-
|
||||||
|
Search mode for retrieval—either "vector" or "keyword". Default "vector".
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
required:
|
required:
|
||||||
- query_generator_config
|
- query_generator_config
|
||||||
|
|
|
@ -66,6 +66,25 @@ To use sqlite-vec in your Llama Stack project, follow these steps:
|
||||||
2. Configure your Llama Stack project to use SQLite-Vec.
|
2. Configure your Llama Stack project to use SQLite-Vec.
|
||||||
3. Start storing and querying vectors.
|
3. Start storing and querying vectors.
|
||||||
|
|
||||||
|
## Supported Search Modes
|
||||||
|
|
||||||
|
The sqlite-vec provider supports both vector-based and keyword-based (full-text) search modes.
|
||||||
|
|
||||||
|
When using the RAGTool interface, you can specify the desired search behavior via the `mode` parameter in
|
||||||
|
`RAGQueryConfig`. For example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from llama_stack.apis.tool_runtime.rag import RAGQueryConfig
|
||||||
|
|
||||||
|
query_config = RAGQueryConfig(max_chunks=6, mode="vector")
|
||||||
|
|
||||||
|
results = client.tool_runtime.rag_tool.query(
|
||||||
|
vector_db_ids=[vector_db_id],
|
||||||
|
content="what is torchtune",
|
||||||
|
query_config=query_config,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
|
|
||||||
You can install SQLite-Vec using pip:
|
You can install SQLite-Vec using pip:
|
||||||
|
|
|
@ -76,6 +76,7 @@ class RAGQueryConfig(BaseModel):
|
||||||
:param chunk_template: Template for formatting each retrieved chunk in the context.
|
:param chunk_template: Template for formatting each retrieved chunk in the context.
|
||||||
Available placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk content string), {metadata} (chunk metadata dict).
|
Available placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk content string), {metadata} (chunk metadata dict).
|
||||||
Default: "Result {index}\\nContent: {chunk.content}\\nMetadata: {metadata}\\n"
|
Default: "Result {index}\\nContent: {chunk.content}\\nMetadata: {metadata}\\n"
|
||||||
|
:param mode: Search mode for retrieval—either "vector" or "keyword". Default "vector".
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# This config defines how a query is generated using the messages
|
# This config defines how a query is generated using the messages
|
||||||
|
@ -84,6 +85,7 @@ class RAGQueryConfig(BaseModel):
|
||||||
max_tokens_in_context: int = 4096
|
max_tokens_in_context: int = 4096
|
||||||
max_chunks: int = 5
|
max_chunks: int = 5
|
||||||
chunk_template: str = "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n"
|
chunk_template: str = "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n"
|
||||||
|
mode: str | None = None
|
||||||
|
|
||||||
@field_validator("chunk_template")
|
@field_validator("chunk_template")
|
||||||
def validate_chunk_template(cls, v: str) -> str:
|
def validate_chunk_template(cls, v: str) -> str:
|
||||||
|
|
|
@ -122,6 +122,7 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
||||||
query=query,
|
query=query,
|
||||||
params={
|
params={
|
||||||
"max_chunks": query_config.max_chunks,
|
"max_chunks": query_config.max_chunks,
|
||||||
|
"mode": query_config.mode,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
for vector_db_id in vector_db_ids
|
for vector_db_id in vector_db_ids
|
||||||
|
|
|
@ -99,9 +99,13 @@ class FaissIndex(EmbeddingIndex):
|
||||||
# Save updated index
|
# Save updated index
|
||||||
await self._save_index()
|
await self._save_index()
|
||||||
|
|
||||||
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
async def query_vector(
|
||||||
|
self,
|
||||||
|
embedding: NDArray,
|
||||||
|
k: int,
|
||||||
|
score_threshold: float,
|
||||||
|
) -> QueryChunksResponse:
|
||||||
distances, indices = await asyncio.to_thread(self.index.search, embedding.reshape(1, -1).astype(np.float32), k)
|
distances, indices = await asyncio.to_thread(self.index.search, embedding.reshape(1, -1).astype(np.float32), k)
|
||||||
|
|
||||||
chunks = []
|
chunks = []
|
||||||
scores = []
|
scores = []
|
||||||
for d, i in zip(distances[0], indices[0], strict=False):
|
for d, i in zip(distances[0], indices[0], strict=False):
|
||||||
|
@ -112,6 +116,14 @@ class FaissIndex(EmbeddingIndex):
|
||||||
|
|
||||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||||
|
|
||||||
|
async def query_keyword(
|
||||||
|
self,
|
||||||
|
query_string: str,
|
||||||
|
k: int,
|
||||||
|
score_threshold: float,
|
||||||
|
) -> QueryChunksResponse:
|
||||||
|
raise NotImplementedError("Keyword search is not supported in FAISS")
|
||||||
|
|
||||||
|
|
||||||
class FaissVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
class FaissVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
||||||
def __init__(self, config: FaissVectorIOConfig, inference_api: Inference) -> None:
|
def __init__(self, config: FaissVectorIOConfig, inference_api: Inference) -> None:
|
||||||
|
|
|
@ -24,6 +24,11 @@ from llama_stack.providers.utils.memory.vector_store import EmbeddingIndex, Vect
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Specifying search mode is dependent on the VectorIO provider.
|
||||||
|
VECTOR_SEARCH = "vector"
|
||||||
|
KEYWORD_SEARCH = "keyword"
|
||||||
|
SEARCH_MODES = {VECTOR_SEARCH, KEYWORD_SEARCH}
|
||||||
|
|
||||||
|
|
||||||
def serialize_vector(vector: list[float]) -> bytes:
|
def serialize_vector(vector: list[float]) -> bytes:
|
||||||
"""Serialize a list of floats into a compact binary representation."""
|
"""Serialize a list of floats into a compact binary representation."""
|
||||||
|
@ -45,6 +50,7 @@ class SQLiteVecIndex(EmbeddingIndex):
|
||||||
Two tables are used:
|
Two tables are used:
|
||||||
- A metadata table (chunks_{bank_id}) that holds the chunk JSON.
|
- A metadata table (chunks_{bank_id}) that holds the chunk JSON.
|
||||||
- A virtual table (vec_chunks_{bank_id}) that holds the serialized vector.
|
- A virtual table (vec_chunks_{bank_id}) that holds the serialized vector.
|
||||||
|
- An FTS5 table (fts_chunks_{bank_id}) for full-text keyword search.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, dimension: int, db_path: str, bank_id: str):
|
def __init__(self, dimension: int, db_path: str, bank_id: str):
|
||||||
|
@ -53,6 +59,7 @@ class SQLiteVecIndex(EmbeddingIndex):
|
||||||
self.bank_id = bank_id
|
self.bank_id = bank_id
|
||||||
self.metadata_table = f"chunks_{bank_id}".replace("-", "_")
|
self.metadata_table = f"chunks_{bank_id}".replace("-", "_")
|
||||||
self.vector_table = f"vec_chunks_{bank_id}".replace("-", "_")
|
self.vector_table = f"vec_chunks_{bank_id}".replace("-", "_")
|
||||||
|
self.fts_table = f"fts_chunks_{bank_id}".replace("-", "_")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def create(cls, dimension: int, db_path: str, bank_id: str):
|
async def create(cls, dimension: int, db_path: str, bank_id: str):
|
||||||
|
@ -78,6 +85,14 @@ class SQLiteVecIndex(EmbeddingIndex):
|
||||||
USING vec0(embedding FLOAT[{self.dimension}], id TEXT);
|
USING vec0(embedding FLOAT[{self.dimension}], id TEXT);
|
||||||
""")
|
""")
|
||||||
connection.commit()
|
connection.commit()
|
||||||
|
# FTS5 table (for keyword search) - creating both the tables by default. Will use the relevant one
|
||||||
|
# based on query. Implementation of the change on client side will allow passing the search_mode option
|
||||||
|
# during initialization to make it easier to create the table that is required.
|
||||||
|
cur.execute(f"""
|
||||||
|
CREATE VIRTUAL TABLE IF NOT EXISTS {self.fts_table}
|
||||||
|
USING fts5(id, content);
|
||||||
|
""")
|
||||||
|
connection.commit()
|
||||||
finally:
|
finally:
|
||||||
cur.close()
|
cur.close()
|
||||||
connection.close()
|
connection.close()
|
||||||
|
@ -91,6 +106,7 @@ class SQLiteVecIndex(EmbeddingIndex):
|
||||||
try:
|
try:
|
||||||
cur.execute(f"DROP TABLE IF EXISTS {self.metadata_table};")
|
cur.execute(f"DROP TABLE IF EXISTS {self.metadata_table};")
|
||||||
cur.execute(f"DROP TABLE IF EXISTS {self.vector_table};")
|
cur.execute(f"DROP TABLE IF EXISTS {self.vector_table};")
|
||||||
|
cur.execute(f"DROP TABLE IF EXISTS {self.fts_table};")
|
||||||
connection.commit()
|
connection.commit()
|
||||||
finally:
|
finally:
|
||||||
cur.close()
|
cur.close()
|
||||||
|
@ -104,6 +120,7 @@ class SQLiteVecIndex(EmbeddingIndex):
|
||||||
For each chunk, we insert its JSON into the metadata table and then insert its
|
For each chunk, we insert its JSON into the metadata table and then insert its
|
||||||
embedding (serialized to raw bytes) into the virtual table using the assigned rowid.
|
embedding (serialized to raw bytes) into the virtual table using the assigned rowid.
|
||||||
If any insert fails, the transaction is rolled back to maintain consistency.
|
If any insert fails, the transaction is rolled back to maintain consistency.
|
||||||
|
Also inserts chunk content into FTS table for keyword search support.
|
||||||
"""
|
"""
|
||||||
assert all(isinstance(chunk.content, str) for chunk in chunks), "SQLiteVecIndex only supports text chunks"
|
assert all(isinstance(chunk.content, str) for chunk in chunks), "SQLiteVecIndex only supports text chunks"
|
||||||
|
|
||||||
|
@ -112,18 +129,16 @@ class SQLiteVecIndex(EmbeddingIndex):
|
||||||
cur = connection.cursor()
|
cur = connection.cursor()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Start transaction a single transcation for all batches
|
|
||||||
cur.execute("BEGIN TRANSACTION")
|
cur.execute("BEGIN TRANSACTION")
|
||||||
for i in range(0, len(chunks), batch_size):
|
for i in range(0, len(chunks), batch_size):
|
||||||
batch_chunks = chunks[i : i + batch_size]
|
batch_chunks = chunks[i : i + batch_size]
|
||||||
batch_embeddings = embeddings[i : i + batch_size]
|
batch_embeddings = embeddings[i : i + batch_size]
|
||||||
# Prepare metadata inserts
|
|
||||||
|
# Insert metadata
|
||||||
metadata_data = [
|
metadata_data = [
|
||||||
(generate_chunk_id(chunk.metadata["document_id"], chunk.content), chunk.model_dump_json())
|
(generate_chunk_id(chunk.metadata["document_id"], chunk.content), chunk.model_dump_json())
|
||||||
for chunk in batch_chunks
|
for chunk in batch_chunks
|
||||||
if isinstance(chunk.content, str)
|
|
||||||
]
|
]
|
||||||
# Insert metadata (ON CONFLICT to avoid duplicates)
|
|
||||||
cur.executemany(
|
cur.executemany(
|
||||||
f"""
|
f"""
|
||||||
INSERT INTO {self.metadata_table} (id, chunk)
|
INSERT INTO {self.metadata_table} (id, chunk)
|
||||||
|
@ -132,21 +147,43 @@ class SQLiteVecIndex(EmbeddingIndex):
|
||||||
""",
|
""",
|
||||||
metadata_data,
|
metadata_data,
|
||||||
)
|
)
|
||||||
# Prepare embeddings inserts
|
|
||||||
|
# Insert vector embeddings
|
||||||
embedding_data = [
|
embedding_data = [
|
||||||
(
|
(
|
||||||
generate_chunk_id(chunk.metadata["document_id"], chunk.content),
|
(
|
||||||
serialize_vector(emb.tolist()),
|
generate_chunk_id(chunk.metadata["document_id"], chunk.content),
|
||||||
|
serialize_vector(emb.tolist()),
|
||||||
|
)
|
||||||
)
|
)
|
||||||
for chunk, emb in zip(batch_chunks, batch_embeddings, strict=True)
|
for chunk, emb in zip(batch_chunks, batch_embeddings, strict=True)
|
||||||
if isinstance(chunk.content, str)
|
|
||||||
]
|
]
|
||||||
# Insert embeddings in batch
|
cur.executemany(
|
||||||
cur.executemany(f"INSERT INTO {self.vector_table} (id, embedding) VALUES (?, ?);", embedding_data)
|
f"INSERT INTO {self.vector_table} (id, embedding) VALUES (?, ?);",
|
||||||
|
embedding_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Insert FTS content
|
||||||
|
fts_data = [
|
||||||
|
(generate_chunk_id(chunk.metadata["document_id"], chunk.content), chunk.content)
|
||||||
|
for chunk in batch_chunks
|
||||||
|
]
|
||||||
|
# DELETE existing entries with same IDs (FTS5 doesn't support ON CONFLICT)
|
||||||
|
cur.executemany(
|
||||||
|
f"DELETE FROM {self.fts_table} WHERE id = ?;",
|
||||||
|
[(row[0],) for row in fts_data],
|
||||||
|
)
|
||||||
|
|
||||||
|
# INSERT new entries
|
||||||
|
cur.executemany(
|
||||||
|
f"INSERT INTO {self.fts_table} (id, content) VALUES (?, ?);",
|
||||||
|
fts_data,
|
||||||
|
)
|
||||||
|
|
||||||
connection.commit()
|
connection.commit()
|
||||||
|
|
||||||
except sqlite3.Error as e:
|
except sqlite3.Error as e:
|
||||||
connection.rollback() # Rollback on failure
|
connection.rollback()
|
||||||
logger.error(f"Error inserting into {self.vector_table}: {e}")
|
logger.error(f"Error inserting into {self.vector_table}: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
@ -154,22 +191,25 @@ class SQLiteVecIndex(EmbeddingIndex):
|
||||||
cur.close()
|
cur.close()
|
||||||
connection.close()
|
connection.close()
|
||||||
|
|
||||||
# Process all batches in a single thread
|
# Run batch insertion in a background thread
|
||||||
await asyncio.to_thread(_execute_all_batch_inserts)
|
await asyncio.to_thread(_execute_all_batch_inserts)
|
||||||
|
|
||||||
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
async def query_vector(
|
||||||
|
self,
|
||||||
|
embedding: NDArray,
|
||||||
|
k: int,
|
||||||
|
score_threshold: float,
|
||||||
|
) -> QueryChunksResponse:
|
||||||
"""
|
"""
|
||||||
Query for the k most similar chunks. We convert the query embedding to a blob and run a SQL query
|
Performs vector-based search using a virtual table for vector similarity.
|
||||||
against the virtual table. The SQL joins the metadata table to recover the chunk JSON.
|
|
||||||
"""
|
"""
|
||||||
emb_list = embedding.tolist() if isinstance(embedding, np.ndarray) else list(embedding)
|
|
||||||
emb_blob = serialize_vector(emb_list)
|
|
||||||
|
|
||||||
def _execute_query():
|
def _execute_query():
|
||||||
connection = _create_sqlite_connection(self.db_path)
|
connection = _create_sqlite_connection(self.db_path)
|
||||||
cur = connection.cursor()
|
cur = connection.cursor()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
emb_list = embedding.tolist() if isinstance(embedding, np.ndarray) else list(embedding)
|
||||||
|
emb_blob = serialize_vector(emb_list)
|
||||||
query_sql = f"""
|
query_sql = f"""
|
||||||
SELECT m.id, m.chunk, v.distance
|
SELECT m.id, m.chunk, v.distance
|
||||||
FROM {self.vector_table} AS v
|
FROM {self.vector_table} AS v
|
||||||
|
@ -184,17 +224,66 @@ class SQLiteVecIndex(EmbeddingIndex):
|
||||||
connection.close()
|
connection.close()
|
||||||
|
|
||||||
rows = await asyncio.to_thread(_execute_query)
|
rows = await asyncio.to_thread(_execute_query)
|
||||||
|
|
||||||
chunks, scores = [], []
|
chunks, scores = [], []
|
||||||
for _id, chunk_json, distance in rows:
|
for row in rows:
|
||||||
|
_id, chunk_json, distance = row
|
||||||
|
score = 1.0 / distance if distance != 0 else float("inf")
|
||||||
|
if score < score_threshold:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
chunk = Chunk.model_validate_json(chunk_json)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error parsing chunk JSON for id {_id}: {e}")
|
||||||
|
continue
|
||||||
|
chunks.append(chunk)
|
||||||
|
scores.append(score)
|
||||||
|
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||||
|
|
||||||
|
async def query_keyword(
|
||||||
|
self,
|
||||||
|
query_string: str,
|
||||||
|
k: int,
|
||||||
|
score_threshold: float,
|
||||||
|
) -> QueryChunksResponse:
|
||||||
|
"""
|
||||||
|
Performs keyword-based search using SQLite FTS5 for relevance-ranked full-text search.
|
||||||
|
"""
|
||||||
|
if query_string is None:
|
||||||
|
raise ValueError("query_string is required for keyword search.")
|
||||||
|
|
||||||
|
def _execute_query():
|
||||||
|
connection = _create_sqlite_connection(self.db_path)
|
||||||
|
cur = connection.cursor()
|
||||||
|
try:
|
||||||
|
query_sql = f"""
|
||||||
|
SELECT DISTINCT m.id, m.chunk, bm25({self.fts_table}) AS score
|
||||||
|
FROM {self.fts_table} AS f
|
||||||
|
JOIN {self.metadata_table} AS m ON m.id = f.id
|
||||||
|
WHERE f.content MATCH ?
|
||||||
|
ORDER BY score ASC
|
||||||
|
LIMIT ?;
|
||||||
|
"""
|
||||||
|
cur.execute(query_sql, (query_string, k))
|
||||||
|
return cur.fetchall()
|
||||||
|
finally:
|
||||||
|
cur.close()
|
||||||
|
connection.close()
|
||||||
|
|
||||||
|
rows = await asyncio.to_thread(_execute_query)
|
||||||
|
chunks, scores = [], []
|
||||||
|
for row in rows:
|
||||||
|
_id, chunk_json, score = row
|
||||||
|
# BM25 scores returned by sqlite-vec are NEGATED (i.e., more relevant = more negative).
|
||||||
|
# This design is intentional to simplify sorting by ascending score.
|
||||||
|
# Reference: https://alexgarcia.xyz/blog/2024/sqlite-vec-hybrid-search/index.html
|
||||||
|
if score > -score_threshold:
|
||||||
|
continue
|
||||||
try:
|
try:
|
||||||
chunk = Chunk.model_validate_json(chunk_json)
|
chunk = Chunk.model_validate_json(chunk_json)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error parsing chunk JSON for id {_id}: {e}")
|
logger.error(f"Error parsing chunk JSON for id {_id}: {e}")
|
||||||
continue
|
continue
|
||||||
chunks.append(chunk)
|
chunks.append(chunk)
|
||||||
# Mimic the Faiss scoring: score = 1/distance (avoid division by zero)
|
|
||||||
score = 1.0 / distance if distance != 0 else float("inf")
|
|
||||||
scores.append(score)
|
scores.append(score)
|
||||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||||
|
|
||||||
|
|
|
@ -84,6 +84,14 @@ class ChromaIndex(EmbeddingIndex):
|
||||||
async def delete(self):
|
async def delete(self):
|
||||||
await maybe_await(self.client.delete_collection(self.collection.name))
|
await maybe_await(self.client.delete_collection(self.collection.name))
|
||||||
|
|
||||||
|
async def query_keyword(
|
||||||
|
self,
|
||||||
|
query_string: str,
|
||||||
|
k: int,
|
||||||
|
score_threshold: float,
|
||||||
|
) -> QueryChunksResponse:
|
||||||
|
raise NotImplementedError("Keyword search is not supported in Chroma")
|
||||||
|
|
||||||
|
|
||||||
class ChromaVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
class ChromaVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
|
@ -73,7 +73,7 @@ class MilvusIndex(EmbeddingIndex):
|
||||||
logger.error(f"Error inserting chunks into Milvus collection {self.collection_name}: {e}")
|
logger.error(f"Error inserting chunks into Milvus collection {self.collection_name}: {e}")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||||
search_res = await asyncio.to_thread(
|
search_res = await asyncio.to_thread(
|
||||||
self.client.search,
|
self.client.search,
|
||||||
collection_name=self.collection_name,
|
collection_name=self.collection_name,
|
||||||
|
@ -86,6 +86,14 @@ class MilvusIndex(EmbeddingIndex):
|
||||||
scores = [res["distance"] for res in search_res[0]]
|
scores = [res["distance"] for res in search_res[0]]
|
||||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||||
|
|
||||||
|
async def query_keyword(
|
||||||
|
self,
|
||||||
|
query_string: str,
|
||||||
|
k: int,
|
||||||
|
score_threshold: float,
|
||||||
|
) -> QueryChunksResponse:
|
||||||
|
raise NotImplementedError("Keyword search is not supported in Milvus")
|
||||||
|
|
||||||
|
|
||||||
class MilvusVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
class MilvusVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
|
@ -99,7 +99,7 @@ class PGVectorIndex(EmbeddingIndex):
|
||||||
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
||||||
execute_values(cur, query, values, template="(%s, %s, %s::vector)")
|
execute_values(cur, query, values, template="(%s, %s, %s::vector)")
|
||||||
|
|
||||||
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||||
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
||||||
cur.execute(
|
cur.execute(
|
||||||
f"""
|
f"""
|
||||||
|
@ -120,6 +120,14 @@ class PGVectorIndex(EmbeddingIndex):
|
||||||
|
|
||||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||||
|
|
||||||
|
async def query_keyword(
|
||||||
|
self,
|
||||||
|
query_string: str,
|
||||||
|
k: int,
|
||||||
|
score_threshold: float,
|
||||||
|
) -> QueryChunksResponse:
|
||||||
|
raise NotImplementedError("Keyword search is not supported in PGVector")
|
||||||
|
|
||||||
async def delete(self):
|
async def delete(self):
|
||||||
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
||||||
cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")
|
cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")
|
||||||
|
|
|
@ -68,7 +68,7 @@ class QdrantIndex(EmbeddingIndex):
|
||||||
|
|
||||||
await self.client.upsert(collection_name=self.collection_name, points=points)
|
await self.client.upsert(collection_name=self.collection_name, points=points)
|
||||||
|
|
||||||
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||||
results = (
|
results = (
|
||||||
await self.client.query_points(
|
await self.client.query_points(
|
||||||
collection_name=self.collection_name,
|
collection_name=self.collection_name,
|
||||||
|
@ -95,6 +95,14 @@ class QdrantIndex(EmbeddingIndex):
|
||||||
|
|
||||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||||
|
|
||||||
|
async def query_keyword(
|
||||||
|
self,
|
||||||
|
query_string: str,
|
||||||
|
k: int,
|
||||||
|
score_threshold: float,
|
||||||
|
) -> QueryChunksResponse:
|
||||||
|
raise NotImplementedError("Keyword search is not supported in Qdrant")
|
||||||
|
|
||||||
async def delete(self):
|
async def delete(self):
|
||||||
await self.client.delete_collection(collection_name=self.collection_name)
|
await self.client.delete_collection(collection_name=self.collection_name)
|
||||||
|
|
||||||
|
|
|
@ -55,7 +55,7 @@ class WeaviateIndex(EmbeddingIndex):
|
||||||
# TODO: make this async friendly
|
# TODO: make this async friendly
|
||||||
collection.data.insert_many(data_objects)
|
collection.data.insert_many(data_objects)
|
||||||
|
|
||||||
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||||
collection = self.client.collections.get(self.collection_name)
|
collection = self.client.collections.get(self.collection_name)
|
||||||
|
|
||||||
results = collection.query.near_vector(
|
results = collection.query.near_vector(
|
||||||
|
@ -84,6 +84,14 @@ class WeaviateIndex(EmbeddingIndex):
|
||||||
collection = self.client.collections.get(self.collection_name)
|
collection = self.client.collections.get(self.collection_name)
|
||||||
collection.data.delete_many(where=Filter.by_property("id").contains_any(chunk_ids))
|
collection.data.delete_many(where=Filter.by_property("id").contains_any(chunk_ids))
|
||||||
|
|
||||||
|
async def query_keyword(
|
||||||
|
self,
|
||||||
|
query_string: str,
|
||||||
|
k: int,
|
||||||
|
score_threshold: float,
|
||||||
|
) -> QueryChunksResponse:
|
||||||
|
raise NotImplementedError("Keyword search is not supported in Weaviate")
|
||||||
|
|
||||||
|
|
||||||
class WeaviateVectorIOAdapter(
|
class WeaviateVectorIOAdapter(
|
||||||
VectorIO,
|
VectorIO,
|
||||||
|
|
|
@ -177,7 +177,11 @@ class EmbeddingIndex(ABC):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def query_keyword(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
@ -210,9 +214,12 @@ class VectorDBWithIndex:
|
||||||
if params is None:
|
if params is None:
|
||||||
params = {}
|
params = {}
|
||||||
k = params.get("max_chunks", 3)
|
k = params.get("max_chunks", 3)
|
||||||
|
mode = params.get("mode")
|
||||||
score_threshold = params.get("score_threshold", 0.0)
|
score_threshold = params.get("score_threshold", 0.0)
|
||||||
|
query_string = interleaved_content_as_str(query)
|
||||||
query_str = interleaved_content_as_str(query)
|
if mode == "keyword":
|
||||||
embeddings_response = await self.inference_api.embeddings(self.vector_db.embedding_model, [query_str])
|
return await self.index.query_keyword(query_string, k, score_threshold)
|
||||||
query_vector = np.array(embeddings_response.embeddings[0], dtype=np.float32)
|
else:
|
||||||
return await self.index.query(query_vector, k, score_threshold)
|
embeddings_response = await self.inference_api.embeddings(self.vector_db.embedding_model, [query_string])
|
||||||
|
query_vector = np.array(embeddings_response.embeddings[0], dtype=np.float32)
|
||||||
|
return await self.index.query_vector(query_vector, k, score_threshold)
|
||||||
|
|
|
@ -98,7 +98,7 @@ async def test_qdrant_adapter_returns_expected_chunks(
|
||||||
response = await qdrant_adapter.query_chunks(
|
response = await qdrant_adapter.query_chunks(
|
||||||
query=__QUERY,
|
query=__QUERY,
|
||||||
vector_db_id=vector_db_id,
|
vector_db_id=vector_db_id,
|
||||||
params={"max_chunks": max_query_chunks},
|
params={"max_chunks": max_query_chunks, "mode": "vector"},
|
||||||
)
|
)
|
||||||
assert isinstance(response, QueryChunksResponse)
|
assert isinstance(response, QueryChunksResponse)
|
||||||
assert len(response.chunks) == expected_chunks
|
assert len(response.chunks) == expected_chunks
|
||||||
|
|
|
@ -57,14 +57,46 @@ async def test_add_chunks(sqlite_vec_index, sample_chunks, sample_embeddings):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_query_chunks(sqlite_vec_index, sample_chunks, sample_embeddings, embedding_dimension):
|
async def test_query_chunks_vector(sqlite_vec_index, sample_chunks, sample_embeddings, embedding_dimension):
|
||||||
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
||||||
query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
|
query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
|
||||||
response = await sqlite_vec_index.query(query_embedding, k=2, score_threshold=0.0)
|
response = await sqlite_vec_index.query_vector(query_embedding, k=2, score_threshold=0.0)
|
||||||
assert isinstance(response, QueryChunksResponse)
|
assert isinstance(response, QueryChunksResponse)
|
||||||
assert len(response.chunks) == 2
|
assert len(response.chunks) == 2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_query_chunks_full_text_search(sqlite_vec_index, sample_chunks, sample_embeddings):
|
||||||
|
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
||||||
|
|
||||||
|
query_string = "Sentence 5"
|
||||||
|
response = await sqlite_vec_index.query_keyword(k=3, score_threshold=0.0, query_string=query_string)
|
||||||
|
|
||||||
|
assert isinstance(response, QueryChunksResponse)
|
||||||
|
assert len(response.chunks) == 3, f"Expected three chunks, but got {len(response.chunks)}"
|
||||||
|
|
||||||
|
non_existent_query_str = "blablabla"
|
||||||
|
response_no_results = await sqlite_vec_index.query_keyword(
|
||||||
|
query_string=non_existent_query_str, k=1, score_threshold=0.0
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(response_no_results, QueryChunksResponse)
|
||||||
|
assert len(response_no_results.chunks) == 0, f"Expected 0 results, but got {len(response_no_results.chunks)}"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_query_chunks_full_text_search_k_greater_than_results(sqlite_vec_index, sample_chunks, sample_embeddings):
|
||||||
|
# Re-initialize with a clean index
|
||||||
|
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
||||||
|
|
||||||
|
query_str = "Sentence 1 from document 0" # Should match only one chunk
|
||||||
|
response = await sqlite_vec_index.query_keyword(k=5, score_threshold=0.0, query_string=query_str)
|
||||||
|
|
||||||
|
assert isinstance(response, QueryChunksResponse)
|
||||||
|
assert 0 < len(response.chunks) < 5, f"Expected results between [1, 4], got {len(response.chunks)}"
|
||||||
|
assert any("Sentence 1 from document 0" in chunk.content for chunk in response.chunks), "Expected chunk not found"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_chunk_id_conflict(sqlite_vec_index, sample_chunks, embedding_dimension):
|
async def test_chunk_id_conflict(sqlite_vec_index, sample_chunks, embedding_dimension):
|
||||||
"""Test that chunk IDs do not conflict across batches when inserting chunks."""
|
"""Test that chunk IDs do not conflict across batches when inserting chunks."""
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue