mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-28 06:51:59 +00:00
fix: sqlite_vec keyword implementation
Signed-off-by: Varsha Prasad Narsing <varshaprasad96@gmail.com>
This commit is contained in:
parent
e2a7022d3c
commit
2060fdba7f
14 changed files with 146 additions and 101 deletions
|
|
@ -99,13 +99,11 @@ class FaissIndex(EmbeddingIndex):
|
|||
# Save updated index
|
||||
await self._save_index()
|
||||
|
||||
async def query(
|
||||
async def query_vector(
|
||||
self,
|
||||
embedding: NDArray,
|
||||
query_string: Optional[str],
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
mode: Optional[str],
|
||||
) -> QueryChunksResponse:
|
||||
distances, indices = await asyncio.to_thread(self.index.search, embedding.reshape(1, -1).astype(np.float32), k)
|
||||
chunks = []
|
||||
|
|
@ -118,6 +116,14 @@ class FaissIndex(EmbeddingIndex):
|
|||
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
async def query_keyword(
|
||||
self,
|
||||
query_string: str | None,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
) -> QueryChunksResponse:
|
||||
raise NotImplementedError("Keyword search is not supported in FAISS")
|
||||
|
||||
|
||||
class FaissVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
||||
def __init__(self, config: FaissVectorIOConfig, inference_api: Inference) -> None:
|
||||
|
|
|
|||
|
|
@ -184,7 +184,7 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
|
||||
except sqlite3.Error as e:
|
||||
connection.rollback()
|
||||
logger.error(f"Error inserting chunk batch: {e}")
|
||||
logger.error(f"Error inserting into {self.vector_table}: {e}")
|
||||
raise
|
||||
|
||||
finally:
|
||||
|
|
@ -194,88 +194,99 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
# Run batch insertion in a background thread
|
||||
await asyncio.to_thread(_execute_all_batch_inserts)
|
||||
|
||||
async def query(
|
||||
async def query_vector(
|
||||
self,
|
||||
embedding: Optional[NDArray],
|
||||
query_string: Optional[str],
|
||||
embedding: NDArray,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
mode: Optional[str],
|
||||
) -> QueryChunksResponse:
|
||||
"""
|
||||
Supports both vector-based and keyword-based searches.
|
||||
|
||||
1. Vector Search (`mode=VECTOR_SEARCH`):
|
||||
Uses a virtual table for vector similarity, joined with metadata.
|
||||
|
||||
2. Keyword Search (`mode=KEYWORD_SEARCH`):
|
||||
Uses SQLite FTS5 for relevance-ranked full-text search.
|
||||
Performs vector-based search using a virtual table for vector similarity.
|
||||
"""
|
||||
if embedding is None:
|
||||
raise ValueError("embedding is required for vector search.")
|
||||
|
||||
def _execute_query():
|
||||
connection = _create_sqlite_connection(self.db_path)
|
||||
cur = connection.cursor()
|
||||
|
||||
try:
|
||||
if mode == VECTOR_SEARCH:
|
||||
if embedding is None:
|
||||
raise ValueError("embedding is required for vector search.")
|
||||
emb_list = embedding.tolist() if isinstance(embedding, np.ndarray) else list(embedding)
|
||||
emb_blob = serialize_vector(emb_list)
|
||||
|
||||
query_sql = f"""
|
||||
SELECT m.id, m.chunk, v.distance
|
||||
FROM {self.vector_table} AS v
|
||||
JOIN {self.metadata_table} AS m ON m.id = v.id
|
||||
WHERE v.embedding MATCH ? AND k = ?
|
||||
ORDER BY v.distance;
|
||||
"""
|
||||
cur.execute(query_sql, (emb_blob, k))
|
||||
|
||||
elif mode == KEYWORD_SEARCH:
|
||||
if query_string is None:
|
||||
raise ValueError("query_string is required for keyword search.")
|
||||
|
||||
query_sql = f"""
|
||||
SELECT DISTINCT m.id, m.chunk, bm25({self.fts_table}) AS score
|
||||
FROM {self.fts_table} AS f
|
||||
JOIN {self.metadata_table} AS m ON m.id = f.id
|
||||
WHERE f.content MATCH ?
|
||||
ORDER BY score ASC
|
||||
LIMIT ?;
|
||||
"""
|
||||
cur.execute(query_sql, (query_string, k))
|
||||
|
||||
else:
|
||||
raise ValueError(f"Invalid search_mode: {mode} please select from {SEARCH_MODES}")
|
||||
|
||||
emb_list = embedding.tolist() if isinstance(embedding, np.ndarray) else list(embedding)
|
||||
emb_blob = serialize_vector(emb_list)
|
||||
query_sql = f"""
|
||||
SELECT m.id, m.chunk, v.distance
|
||||
FROM {self.vector_table} AS v
|
||||
JOIN {self.metadata_table} AS m ON m.id = v.id
|
||||
WHERE v.embedding MATCH ? AND k = ?
|
||||
ORDER BY v.distance;
|
||||
"""
|
||||
cur.execute(query_sql, (emb_blob, k))
|
||||
return cur.fetchall()
|
||||
finally:
|
||||
cur.close()
|
||||
connection.close()
|
||||
|
||||
rows = await asyncio.to_thread(_execute_query)
|
||||
|
||||
chunks, scores = [], []
|
||||
for row in rows:
|
||||
if mode == VECTOR_SEARCH:
|
||||
_id, chunk_json, distance = row
|
||||
score = 1.0 / distance if distance != 0 else float("inf")
|
||||
|
||||
if score < score_threshold:
|
||||
continue
|
||||
else:
|
||||
_id, chunk_json, score = row
|
||||
|
||||
_id, chunk_json, distance = row
|
||||
score = 1.0 / distance if distance != 0 else float("inf")
|
||||
if score < score_threshold:
|
||||
continue
|
||||
try:
|
||||
chunk = Chunk.model_validate_json(chunk_json)
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing chunk JSON for id {_id}: {e}")
|
||||
continue
|
||||
|
||||
chunks.append(chunk)
|
||||
scores.append(score)
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
async def query_keyword(
|
||||
self,
|
||||
query_string: str | None,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
) -> QueryChunksResponse:
|
||||
"""
|
||||
Performs keyword-based search using SQLite FTS5 for relevance-ranked full-text search.
|
||||
"""
|
||||
if query_string is None:
|
||||
raise ValueError("query_string is required for keyword search.")
|
||||
|
||||
def _execute_query():
|
||||
connection = _create_sqlite_connection(self.db_path)
|
||||
cur = connection.cursor()
|
||||
try:
|
||||
query_sql = f"""
|
||||
SELECT DISTINCT m.id, m.chunk, bm25({self.fts_table}) AS score
|
||||
FROM {self.fts_table} AS f
|
||||
JOIN {self.metadata_table} AS m ON m.id = f.id
|
||||
WHERE f.content MATCH ?
|
||||
ORDER BY score ASC
|
||||
LIMIT ?;
|
||||
"""
|
||||
cur.execute(query_sql, (query_string, k))
|
||||
return cur.fetchall()
|
||||
finally:
|
||||
cur.close()
|
||||
connection.close()
|
||||
|
||||
rows = await asyncio.to_thread(_execute_query)
|
||||
chunks, scores = [], []
|
||||
for row in rows:
|
||||
_id, chunk_json, score = row
|
||||
# BM25 scores returned by sqlite-vec are NEGATED (i.e., more relevant = more negative).
|
||||
# This design is intentional to simplify sorting by ascending score.
|
||||
# Reference: https://alexgarcia.xyz/blog/2024/sqlite-vec-hybrid-search/index.html
|
||||
if score > -score_threshold:
|
||||
continue
|
||||
try:
|
||||
chunk = Chunk.model_validate_json(chunk_json)
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing chunk JSON for id {_id}: {e}")
|
||||
continue
|
||||
chunks.append(chunk)
|
||||
scores.append(score)
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -55,9 +55,7 @@ class ChromaIndex(EmbeddingIndex):
|
|||
)
|
||||
)
|
||||
|
||||
async def query(
|
||||
self, embedding: NDArray, query_string: Optional[str], k: int, score_threshold: float, mode: str
|
||||
) -> QueryChunksResponse:
|
||||
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
results = await maybe_await(
|
||||
self.collection.query(
|
||||
query_embeddings=[embedding.tolist()],
|
||||
|
|
@ -86,6 +84,14 @@ class ChromaIndex(EmbeddingIndex):
|
|||
async def delete(self):
|
||||
await maybe_await(self.client.delete_collection(self.collection.name))
|
||||
|
||||
async def query_keyword(
|
||||
self,
|
||||
query_string: str | None,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
) -> QueryChunksResponse:
|
||||
raise NotImplementedError("Keyword search is not supported in Chroma")
|
||||
|
||||
|
||||
class ChromaVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
||||
def __init__(
|
||||
|
|
|
|||
|
|
@ -73,9 +73,7 @@ class MilvusIndex(EmbeddingIndex):
|
|||
logger.error(f"Error inserting chunks into Milvus collection {self.collection_name}: {e}")
|
||||
raise e
|
||||
|
||||
async def query(
|
||||
self, embedding: NDArray, query_str: Optional[str], k: int, score_threshold: float, mode: str
|
||||
) -> QueryChunksResponse:
|
||||
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
search_res = await asyncio.to_thread(
|
||||
self.client.search,
|
||||
collection_name=self.collection_name,
|
||||
|
|
@ -88,6 +86,14 @@ class MilvusIndex(EmbeddingIndex):
|
|||
scores = [res["distance"] for res in search_res[0]]
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
async def query_keyword(
|
||||
self,
|
||||
query_string: str | None,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
) -> QueryChunksResponse:
|
||||
raise NotImplementedError("Keyword search is not supported in Milvus")
|
||||
|
||||
|
||||
class MilvusVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
||||
def __init__(
|
||||
|
|
|
|||
|
|
@ -99,9 +99,7 @@ class PGVectorIndex(EmbeddingIndex):
|
|||
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
||||
execute_values(cur, query, values, template="(%s, %s, %s::vector)")
|
||||
|
||||
async def query(
|
||||
self, embedding: NDArray, query_string: Optional[str], k: int, score_threshold: float, mode: str
|
||||
) -> QueryChunksResponse:
|
||||
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
||||
cur.execute(
|
||||
f"""
|
||||
|
|
@ -122,6 +120,14 @@ class PGVectorIndex(EmbeddingIndex):
|
|||
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
async def query_keyword(
|
||||
self,
|
||||
query_string: str | None,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
) -> QueryChunksResponse:
|
||||
raise NotImplementedError("Keyword search is not supported in PGVector")
|
||||
|
||||
async def delete(self):
|
||||
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
||||
cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")
|
||||
|
|
|
|||
|
|
@ -68,9 +68,7 @@ class QdrantIndex(EmbeddingIndex):
|
|||
|
||||
await self.client.upsert(collection_name=self.collection_name, points=points)
|
||||
|
||||
async def query(
|
||||
self, embedding: NDArray, query_string: Optional[str], k: int, score_threshold: float, mode: str
|
||||
) -> QueryChunksResponse:
|
||||
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
results = (
|
||||
await self.client.query_points(
|
||||
collection_name=self.collection_name,
|
||||
|
|
@ -97,6 +95,14 @@ class QdrantIndex(EmbeddingIndex):
|
|||
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
async def query_keyword(
|
||||
self,
|
||||
query_string: str | None,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
) -> QueryChunksResponse:
|
||||
raise NotImplementedError("Keyword search is not supported in Qdrant")
|
||||
|
||||
async def delete(self):
|
||||
await self.client.delete_collection(collection_name=self.collection_name)
|
||||
|
||||
|
|
|
|||
|
|
@ -55,9 +55,7 @@ class WeaviateIndex(EmbeddingIndex):
|
|||
# TODO: make this async friendly
|
||||
collection.data.insert_many(data_objects)
|
||||
|
||||
async def query(
|
||||
self, embedding: NDArray, query_string: Optional[str], k: int, score_threshold: float, mode: str
|
||||
) -> QueryChunksResponse:
|
||||
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
collection = self.client.collections.get(self.collection_name)
|
||||
|
||||
results = collection.query.near_vector(
|
||||
|
|
@ -86,6 +84,14 @@ class WeaviateIndex(EmbeddingIndex):
|
|||
collection = self.client.collections.get(self.collection_name)
|
||||
collection.data.delete_many(where=Filter.by_property("id").contains_any(chunk_ids))
|
||||
|
||||
async def query_keyword(
|
||||
self,
|
||||
query_string: str | None,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
) -> QueryChunksResponse:
|
||||
raise NotImplementedError("Keyword search is not supported in Weaviate")
|
||||
|
||||
|
||||
class WeaviateVectorIOAdapter(
|
||||
VectorIO,
|
||||
|
|
|
|||
|
|
@ -177,9 +177,11 @@ class EmbeddingIndex(ABC):
|
|||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
async def query(
|
||||
self, embedding: NDArray, query_string: Optional[str], k: int, score_threshold: float, mode: Optional[str]
|
||||
) -> QueryChunksResponse:
|
||||
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
async def query_keyword(self, query_string: str | None, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
|
|
@ -215,6 +217,9 @@ class VectorDBWithIndex:
|
|||
mode = params.get("mode")
|
||||
score_threshold = params.get("score_threshold", 0.0)
|
||||
query_string = interleaved_content_as_str(query)
|
||||
embeddings_response = await self.inference_api.embeddings(self.vector_db.embedding_model, [query_string])
|
||||
query_vector = np.array(embeddings_response.embeddings[0], dtype=np.float32)
|
||||
return await self.index.query(query_vector, query_string, k, score_threshold, mode)
|
||||
if mode == "keyword":
|
||||
return await self.index.query_keyword(query_string, k, score_threshold)
|
||||
else:
|
||||
embeddings_response = await self.inference_api.embeddings(self.vector_db.embedding_model, [query_string])
|
||||
query_vector = np.array(embeddings_response.embeddings[0], dtype=np.float32)
|
||||
return await self.index.query_vector(query_vector, k, score_threshold)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue