diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html
index 3f2599a57..0735aa8b0 100644
--- a/docs/_static/llama-stack-spec.html
+++ b/docs/_static/llama-stack-spec.html
@@ -11601,7 +11601,6 @@
},
"max_chunks": {
"type": "integer",
-<<<<<<< HEAD
"default": 5,
"description": "Maximum number of chunks to retrieve."
},
@@ -11609,12 +11608,10 @@
"type": "string",
"default": "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n",
"description": "Template for formatting each retrieved chunk in the context. Available placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk content string), {metadata} (chunk metadata dict). Default: \"Result {index}\\nContent: {chunk.content}\\nMetadata: {metadata}\\n\""
-=======
- "default": 5
},
"mode": {
- "type": "string"
->>>>>>> 1a0433d2 (feat (RAG): Implement configurable search mode in RAGQueryConfig)
+ "type": "string",
+ "description": "Search mode for retrieval—either \"vector\" or \"keyword\"."
}
},
"additionalProperties": false,
diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml
index b7da34c23..8cd7bc5d8 100644
--- a/docs/_static/llama-stack-spec.yaml
+++ b/docs/_static/llama-stack-spec.yaml
@@ -8072,7 +8072,6 @@ components:
max_chunks:
type: integer
default: 5
-<<<<<<< HEAD
description: Maximum number of chunks to retrieve.
chunk_template:
type: string
@@ -8087,10 +8086,10 @@ components:
placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk
content string), {metadata} (chunk metadata dict). Default: "Result {index}\nContent:
{chunk.content}\nMetadata: {metadata}\n"
-=======
mode:
type: string
->>>>>>> 1a0433d2 (feat (RAG): Implement configurable search mode in RAGQueryConfig)
+ description: >-
+ Search mode for retrieval—either "vector" or "keyword".
additionalProperties: false
required:
- query_generator_config
diff --git a/docs/source/providers/vector_io/sqlite-vec.md b/docs/source/providers/vector_io/sqlite-vec.md
index f2b0baf4e..49ba659f7 100644
--- a/docs/source/providers/vector_io/sqlite-vec.md
+++ b/docs/source/providers/vector_io/sqlite-vec.md
@@ -70,7 +70,7 @@ To use sqlite-vec in your Llama Stack project, follow these steps:
The sqlite-vec provider supports both vector-based and keyword-based (full-text) search modes.
-When using the RAGTool interface, you can specify the desired search behavior via the search_mode parameter in
+When using the RAGTool interface, you can specify the desired search behavior via the `mode` parameter in
`RAGQueryConfig`. For example:
```python
diff --git a/llama_stack/apis/tools/rag_tool.py b/llama_stack/apis/tools/rag_tool.py
index 5d5280205..0cc521baf 100644
--- a/llama_stack/apis/tools/rag_tool.py
+++ b/llama_stack/apis/tools/rag_tool.py
@@ -76,6 +76,7 @@ class RAGQueryConfig(BaseModel):
:param chunk_template: Template for formatting each retrieved chunk in the context.
Available placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk content string), {metadata} (chunk metadata dict).
Default: "Result {index}\\nContent: {chunk.content}\\nMetadata: {metadata}\\n"
+ :param mode: Search mode for retrieval—either "vector" or "keyword".
"""
# This config defines how a query is generated using the messages
diff --git a/llama_stack/providers/inline/vector_io/faiss/faiss.py b/llama_stack/providers/inline/vector_io/faiss/faiss.py
index ef9ca2855..050605464 100644
--- a/llama_stack/providers/inline/vector_io/faiss/faiss.py
+++ b/llama_stack/providers/inline/vector_io/faiss/faiss.py
@@ -99,13 +99,11 @@ class FaissIndex(EmbeddingIndex):
# Save updated index
await self._save_index()
- async def query(
+ async def query_vector(
self,
embedding: NDArray,
- query_string: Optional[str],
k: int,
score_threshold: float,
- mode: Optional[str],
) -> QueryChunksResponse:
distances, indices = await asyncio.to_thread(self.index.search, embedding.reshape(1, -1).astype(np.float32), k)
chunks = []
@@ -118,6 +116,14 @@ class FaissIndex(EmbeddingIndex):
return QueryChunksResponse(chunks=chunks, scores=scores)
+ async def query_keyword(
+ self,
+ query_string: str | None,
+ k: int,
+ score_threshold: float,
+ ) -> QueryChunksResponse:
+ raise NotImplementedError("Keyword search is not supported in FAISS")
+
class FaissVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
def __init__(self, config: FaissVectorIOConfig, inference_api: Inference) -> None:
diff --git a/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py b/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py
index c351d7e2e..402cd5ffd 100644
--- a/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py
+++ b/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py
@@ -184,7 +184,7 @@ class SQLiteVecIndex(EmbeddingIndex):
except sqlite3.Error as e:
connection.rollback()
- logger.error(f"Error inserting chunk batch: {e}")
+ logger.error(f"Error inserting into {self.vector_table}: {e}")
raise
finally:
@@ -194,88 +194,99 @@ class SQLiteVecIndex(EmbeddingIndex):
# Run batch insertion in a background thread
await asyncio.to_thread(_execute_all_batch_inserts)
- async def query(
+ async def query_vector(
self,
- embedding: Optional[NDArray],
- query_string: Optional[str],
+ embedding: NDArray,
k: int,
score_threshold: float,
- mode: Optional[str],
) -> QueryChunksResponse:
"""
- Supports both vector-based and keyword-based searches.
-
- 1. Vector Search (`mode=VECTOR_SEARCH`):
- Uses a virtual table for vector similarity, joined with metadata.
-
- 2. Keyword Search (`mode=KEYWORD_SEARCH`):
- Uses SQLite FTS5 for relevance-ranked full-text search.
+ Performs vector-based search using a virtual table for vector similarity.
"""
+ if embedding is None:
+ raise ValueError("embedding is required for vector search.")
def _execute_query():
connection = _create_sqlite_connection(self.db_path)
cur = connection.cursor()
-
try:
- if mode == VECTOR_SEARCH:
- if embedding is None:
- raise ValueError("embedding is required for vector search.")
- emb_list = embedding.tolist() if isinstance(embedding, np.ndarray) else list(embedding)
- emb_blob = serialize_vector(emb_list)
-
- query_sql = f"""
- SELECT m.id, m.chunk, v.distance
- FROM {self.vector_table} AS v
- JOIN {self.metadata_table} AS m ON m.id = v.id
- WHERE v.embedding MATCH ? AND k = ?
- ORDER BY v.distance;
- """
- cur.execute(query_sql, (emb_blob, k))
-
- elif mode == KEYWORD_SEARCH:
- if query_string is None:
- raise ValueError("query_string is required for keyword search.")
-
- query_sql = f"""
- SELECT DISTINCT m.id, m.chunk, bm25({self.fts_table}) AS score
- FROM {self.fts_table} AS f
- JOIN {self.metadata_table} AS m ON m.id = f.id
- WHERE f.content MATCH ?
- ORDER BY score ASC
- LIMIT ?;
- """
- cur.execute(query_sql, (query_string, k))
-
- else:
- raise ValueError(f"Invalid search_mode: {mode} please select from {SEARCH_MODES}")
-
+ emb_list = embedding.tolist() if isinstance(embedding, np.ndarray) else list(embedding)
+ emb_blob = serialize_vector(emb_list)
+ query_sql = f"""
+ SELECT m.id, m.chunk, v.distance
+ FROM {self.vector_table} AS v
+ JOIN {self.metadata_table} AS m ON m.id = v.id
+ WHERE v.embedding MATCH ? AND k = ?
+ ORDER BY v.distance;
+ """
+ cur.execute(query_sql, (emb_blob, k))
return cur.fetchall()
finally:
cur.close()
connection.close()
rows = await asyncio.to_thread(_execute_query)
-
chunks, scores = [], []
for row in rows:
- if mode == VECTOR_SEARCH:
- _id, chunk_json, distance = row
- score = 1.0 / distance if distance != 0 else float("inf")
-
- if score < score_threshold:
- continue
- else:
- _id, chunk_json, score = row
-
+ _id, chunk_json, distance = row
+ score = 1.0 / distance if distance != 0 else float("inf")
+ if score < score_threshold:
+ continue
try:
chunk = Chunk.model_validate_json(chunk_json)
except Exception as e:
logger.error(f"Error parsing chunk JSON for id {_id}: {e}")
continue
-
chunks.append(chunk)
scores.append(score)
+ return QueryChunksResponse(chunks=chunks, scores=scores)
+ async def query_keyword(
+ self,
+ query_string: str | None,
+ k: int,
+ score_threshold: float,
+ ) -> QueryChunksResponse:
+ """
+ Performs keyword-based search using SQLite FTS5 for relevance-ranked full-text search.
+ """
+ if query_string is None:
+ raise ValueError("query_string is required for keyword search.")
+
+ def _execute_query():
+ connection = _create_sqlite_connection(self.db_path)
+ cur = connection.cursor()
+ try:
+ query_sql = f"""
+ SELECT DISTINCT m.id, m.chunk, bm25({self.fts_table}) AS score
+ FROM {self.fts_table} AS f
+ JOIN {self.metadata_table} AS m ON m.id = f.id
+ WHERE f.content MATCH ?
+ ORDER BY score ASC
+ LIMIT ?;
+ """
+ cur.execute(query_sql, (query_string, k))
+ return cur.fetchall()
+ finally:
+ cur.close()
+ connection.close()
+
+ rows = await asyncio.to_thread(_execute_query)
+ chunks, scores = [], []
+ for row in rows:
+ _id, chunk_json, score = row
+ # BM25 scores returned by sqlite-vec are NEGATED (i.e., more relevant = more negative).
+ # This design is intentional to simplify sorting by ascending score.
+ # Reference: https://alexgarcia.xyz/blog/2024/sqlite-vec-hybrid-search/index.html
+ if score > -score_threshold:
+ continue
+ try:
+ chunk = Chunk.model_validate_json(chunk_json)
+ except Exception as e:
+ logger.error(f"Error parsing chunk JSON for id {_id}: {e}")
+ continue
+ chunks.append(chunk)
+ scores.append(score)
return QueryChunksResponse(chunks=chunks, scores=scores)
diff --git a/llama_stack/providers/remote/vector_io/chroma/chroma.py b/llama_stack/providers/remote/vector_io/chroma/chroma.py
index 669adc8ca..52aacbe59 100644
--- a/llama_stack/providers/remote/vector_io/chroma/chroma.py
+++ b/llama_stack/providers/remote/vector_io/chroma/chroma.py
@@ -55,9 +55,7 @@ class ChromaIndex(EmbeddingIndex):
)
)
- async def query(
- self, embedding: NDArray, query_string: Optional[str], k: int, score_threshold: float, mode: str
- ) -> QueryChunksResponse:
+ async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
results = await maybe_await(
self.collection.query(
query_embeddings=[embedding.tolist()],
@@ -86,6 +84,14 @@ class ChromaIndex(EmbeddingIndex):
async def delete(self):
await maybe_await(self.client.delete_collection(self.collection.name))
+ async def query_keyword(
+ self,
+ query_string: str | None,
+ k: int,
+ score_threshold: float,
+ ) -> QueryChunksResponse:
+ raise NotImplementedError("Keyword search is not supported in Chroma")
+
class ChromaVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
def __init__(
diff --git a/llama_stack/providers/remote/vector_io/milvus/milvus.py b/llama_stack/providers/remote/vector_io/milvus/milvus.py
index 8f186611d..67c5d4474 100644
--- a/llama_stack/providers/remote/vector_io/milvus/milvus.py
+++ b/llama_stack/providers/remote/vector_io/milvus/milvus.py
@@ -73,9 +73,7 @@ class MilvusIndex(EmbeddingIndex):
logger.error(f"Error inserting chunks into Milvus collection {self.collection_name}: {e}")
raise e
- async def query(
- self, embedding: NDArray, query_str: Optional[str], k: int, score_threshold: float, mode: str
- ) -> QueryChunksResponse:
+ async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
search_res = await asyncio.to_thread(
self.client.search,
collection_name=self.collection_name,
@@ -88,6 +86,14 @@ class MilvusIndex(EmbeddingIndex):
scores = [res["distance"] for res in search_res[0]]
return QueryChunksResponse(chunks=chunks, scores=scores)
+ async def query_keyword(
+ self,
+ query_string: str | None,
+ k: int,
+ score_threshold: float,
+ ) -> QueryChunksResponse:
+ raise NotImplementedError("Keyword search is not supported in Milvus")
+
class MilvusVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
def __init__(
diff --git a/llama_stack/providers/remote/vector_io/pgvector/pgvector.py b/llama_stack/providers/remote/vector_io/pgvector/pgvector.py
index cadf768e2..150129c5c 100644
--- a/llama_stack/providers/remote/vector_io/pgvector/pgvector.py
+++ b/llama_stack/providers/remote/vector_io/pgvector/pgvector.py
@@ -99,9 +99,7 @@ class PGVectorIndex(EmbeddingIndex):
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
execute_values(cur, query, values, template="(%s, %s, %s::vector)")
- async def query(
- self, embedding: NDArray, query_string: Optional[str], k: int, score_threshold: float, mode: str
- ) -> QueryChunksResponse:
+ async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
cur.execute(
f"""
@@ -122,6 +120,14 @@ class PGVectorIndex(EmbeddingIndex):
return QueryChunksResponse(chunks=chunks, scores=scores)
+ async def query_keyword(
+ self,
+ query_string: str | None,
+ k: int,
+ score_threshold: float,
+ ) -> QueryChunksResponse:
+ raise NotImplementedError("Keyword search is not supported in PGVector")
+
async def delete(self):
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")
diff --git a/llama_stack/providers/remote/vector_io/qdrant/qdrant.py b/llama_stack/providers/remote/vector_io/qdrant/qdrant.py
index 5638a2831..4357ec03a 100644
--- a/llama_stack/providers/remote/vector_io/qdrant/qdrant.py
+++ b/llama_stack/providers/remote/vector_io/qdrant/qdrant.py
@@ -68,9 +68,7 @@ class QdrantIndex(EmbeddingIndex):
await self.client.upsert(collection_name=self.collection_name, points=points)
- async def query(
- self, embedding: NDArray, query_string: Optional[str], k: int, score_threshold: float, mode: str
- ) -> QueryChunksResponse:
+ async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
results = (
await self.client.query_points(
collection_name=self.collection_name,
@@ -97,6 +95,14 @@ class QdrantIndex(EmbeddingIndex):
return QueryChunksResponse(chunks=chunks, scores=scores)
+ async def query_keyword(
+ self,
+ query_string: str | None,
+ k: int,
+ score_threshold: float,
+ ) -> QueryChunksResponse:
+ raise NotImplementedError("Keyword search is not supported in Qdrant")
+
async def delete(self):
await self.client.delete_collection(collection_name=self.collection_name)
diff --git a/llama_stack/providers/remote/vector_io/weaviate/weaviate.py b/llama_stack/providers/remote/vector_io/weaviate/weaviate.py
index a633f362b..f0d154b09 100644
--- a/llama_stack/providers/remote/vector_io/weaviate/weaviate.py
+++ b/llama_stack/providers/remote/vector_io/weaviate/weaviate.py
@@ -55,9 +55,7 @@ class WeaviateIndex(EmbeddingIndex):
# TODO: make this async friendly
collection.data.insert_many(data_objects)
- async def query(
- self, embedding: NDArray, query_string: Optional[str], k: int, score_threshold: float, mode: str
- ) -> QueryChunksResponse:
+ async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
collection = self.client.collections.get(self.collection_name)
results = collection.query.near_vector(
@@ -86,6 +84,14 @@ class WeaviateIndex(EmbeddingIndex):
collection = self.client.collections.get(self.collection_name)
collection.data.delete_many(where=Filter.by_property("id").contains_any(chunk_ids))
+ async def query_keyword(
+ self,
+ query_string: str | None,
+ k: int,
+ score_threshold: float,
+ ) -> QueryChunksResponse:
+ raise NotImplementedError("Keyword search is not supported in Weaviate")
+
class WeaviateVectorIOAdapter(
VectorIO,
diff --git a/llama_stack/providers/utils/memory/vector_store.py b/llama_stack/providers/utils/memory/vector_store.py
index 89b2aac57..d915942be 100644
--- a/llama_stack/providers/utils/memory/vector_store.py
+++ b/llama_stack/providers/utils/memory/vector_store.py
@@ -177,9 +177,11 @@ class EmbeddingIndex(ABC):
raise NotImplementedError()
@abstractmethod
- async def query(
- self, embedding: NDArray, query_string: Optional[str], k: int, score_threshold: float, mode: Optional[str]
- ) -> QueryChunksResponse:
+ async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
+ raise NotImplementedError()
+
+ @abstractmethod
+ async def query_keyword(self, query_string: str | None, k: int, score_threshold: float) -> QueryChunksResponse:
raise NotImplementedError()
@abstractmethod
@@ -215,6 +217,9 @@ class VectorDBWithIndex:
mode = params.get("mode")
score_threshold = params.get("score_threshold", 0.0)
query_string = interleaved_content_as_str(query)
- embeddings_response = await self.inference_api.embeddings(self.vector_db.embedding_model, [query_string])
- query_vector = np.array(embeddings_response.embeddings[0], dtype=np.float32)
- return await self.index.query(query_vector, query_string, k, score_threshold, mode)
+ if mode == "keyword":
+ return await self.index.query_keyword(query_string, k, score_threshold)
+ else:
+ embeddings_response = await self.inference_api.embeddings(self.vector_db.embedding_model, [query_string])
+ query_vector = np.array(embeddings_response.embeddings[0], dtype=np.float32)
+ return await self.index.query_vector(query_vector, k, score_threshold)
diff --git a/tests/unit/providers/vector_io/test_qdrant.py b/tests/unit/providers/vector_io/test_qdrant.py
index bc97719c0..34df9b52f 100644
--- a/tests/unit/providers/vector_io/test_qdrant.py
+++ b/tests/unit/providers/vector_io/test_qdrant.py
@@ -98,7 +98,7 @@ async def test_qdrant_adapter_returns_expected_chunks(
response = await qdrant_adapter.query_chunks(
query=__QUERY,
vector_db_id=vector_db_id,
- params={"max_chunks": max_query_chunks},
+ params={"max_chunks": max_query_chunks, "mode": "vector"},
)
assert isinstance(response, QueryChunksResponse)
assert len(response.chunks) == expected_chunks
diff --git a/tests/unit/providers/vector_io/test_sqlite_vec.py b/tests/unit/providers/vector_io/test_sqlite_vec.py
index 282ab6cd0..010a0ca42 100644
--- a/tests/unit/providers/vector_io/test_sqlite_vec.py
+++ b/tests/unit/providers/vector_io/test_sqlite_vec.py
@@ -60,7 +60,7 @@ async def test_add_chunks(sqlite_vec_index, sample_chunks, sample_embeddings):
async def test_query_chunks_vector(sqlite_vec_index, sample_chunks, sample_embeddings, embedding_dimension):
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
- response = await sqlite_vec_index.query(query_embedding, query_string="", k=2, score_threshold=0.0, mode="vector")
+ response = await sqlite_vec_index.query_vector(query_embedding, k=2, score_threshold=0.0)
assert isinstance(response, QueryChunksResponse)
assert len(response.chunks) == 2
@@ -70,16 +70,14 @@ async def test_query_chunks_full_text_search(sqlite_vec_index, sample_chunks, sa
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
query_string = "Sentence 5"
- response = await sqlite_vec_index.query(
- embedding=None, k=3, score_threshold=0.0, query_string=query_string, mode="keyword"
- )
+ response = await sqlite_vec_index.query_keyword(k=3, score_threshold=0.0, query_string=query_string)
assert isinstance(response, QueryChunksResponse)
- assert len(response.chunks) == 3, f"Expected at least one result, but got {len(response.chunks)}"
+ assert len(response.chunks) == 3, f"Expected three chunks, but got {len(response.chunks)}"
non_existent_query_str = "blablabla"
- response_no_results = await sqlite_vec_index.query(
- embedding=None, query_string=non_existent_query_str, k=1, score_threshold=0.0, mode="keyword"
+ response_no_results = await sqlite_vec_index.query_keyword(
+ query_string=non_existent_query_str, k=1, score_threshold=0.0
)
assert isinstance(response_no_results, QueryChunksResponse)
@@ -92,12 +90,10 @@ async def test_query_chunks_full_text_search_k_greater_than_results(sqlite_vec_i
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
query_str = "Sentence 1 from document 0" # Should match only one chunk
- response = await sqlite_vec_index.query(
- embedding=None, k=5, score_threshold=0.0, query_string=query_str, mode="keyword"
- )
+ response = await sqlite_vec_index.query_keyword(k=5, score_threshold=0.0, query_string=query_str)
assert isinstance(response, QueryChunksResponse)
- assert 0 < len(response.chunks) < 5, f"Expected <5 results but >0, got {len(response.chunks)}"
+ assert 0 < len(response.chunks) < 5, f"Expected results between [1, 4], got {len(response.chunks)}"
assert any("Sentence 1 from document 0" in chunk.content for chunk in response.chunks), "Expected chunk not found"