diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index 3f2599a57..0735aa8b0 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -11601,7 +11601,6 @@ }, "max_chunks": { "type": "integer", -<<<<<<< HEAD "default": 5, "description": "Maximum number of chunks to retrieve." }, @@ -11609,12 +11608,10 @@ "type": "string", "default": "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n", "description": "Template for formatting each retrieved chunk in the context. Available placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk content string), {metadata} (chunk metadata dict). Default: \"Result {index}\\nContent: {chunk.content}\\nMetadata: {metadata}\\n\"" -======= - "default": 5 }, "mode": { - "type": "string" ->>>>>>> 1a0433d2 (feat (RAG): Implement configurable search mode in RAGQueryConfig) + "type": "string", + "description": "Search mode for retrieval—either \"vector\" or \"keyword\"." } }, "additionalProperties": false, diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index b7da34c23..8cd7bc5d8 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -8072,7 +8072,6 @@ components: max_chunks: type: integer default: 5 -<<<<<<< HEAD description: Maximum number of chunks to retrieve. chunk_template: type: string @@ -8087,10 +8086,10 @@ components: placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk content string), {metadata} (chunk metadata dict). Default: "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n" -======= mode: type: string ->>>>>>> 1a0433d2 (feat (RAG): Implement configurable search mode in RAGQueryConfig) + description: >- + Search mode for retrieval—either "vector" or "keyword". additionalProperties: false required: - query_generator_config diff --git a/docs/source/providers/vector_io/sqlite-vec.md b/docs/source/providers/vector_io/sqlite-vec.md index f2b0baf4e..49ba659f7 100644 --- a/docs/source/providers/vector_io/sqlite-vec.md +++ b/docs/source/providers/vector_io/sqlite-vec.md @@ -70,7 +70,7 @@ To use sqlite-vec in your Llama Stack project, follow these steps: The sqlite-vec provider supports both vector-based and keyword-based (full-text) search modes. -When using the RAGTool interface, you can specify the desired search behavior via the search_mode parameter in +When using the RAGTool interface, you can specify the desired search behavior via the `mode` parameter in `RAGQueryConfig`. For example: ```python diff --git a/llama_stack/apis/tools/rag_tool.py b/llama_stack/apis/tools/rag_tool.py index 5d5280205..0cc521baf 100644 --- a/llama_stack/apis/tools/rag_tool.py +++ b/llama_stack/apis/tools/rag_tool.py @@ -76,6 +76,7 @@ class RAGQueryConfig(BaseModel): :param chunk_template: Template for formatting each retrieved chunk in the context. Available placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk content string), {metadata} (chunk metadata dict). Default: "Result {index}\\nContent: {chunk.content}\\nMetadata: {metadata}\\n" + :param mode: Search mode for retrieval—either "vector" or "keyword". """ # This config defines how a query is generated using the messages diff --git a/llama_stack/providers/inline/vector_io/faiss/faiss.py b/llama_stack/providers/inline/vector_io/faiss/faiss.py index ef9ca2855..050605464 100644 --- a/llama_stack/providers/inline/vector_io/faiss/faiss.py +++ b/llama_stack/providers/inline/vector_io/faiss/faiss.py @@ -99,13 +99,11 @@ class FaissIndex(EmbeddingIndex): # Save updated index await self._save_index() - async def query( + async def query_vector( self, embedding: NDArray, - query_string: Optional[str], k: int, score_threshold: float, - mode: Optional[str], ) -> QueryChunksResponse: distances, indices = await asyncio.to_thread(self.index.search, embedding.reshape(1, -1).astype(np.float32), k) chunks = [] @@ -118,6 +116,14 @@ class FaissIndex(EmbeddingIndex): return QueryChunksResponse(chunks=chunks, scores=scores) + async def query_keyword( + self, + query_string: str | None, + k: int, + score_threshold: float, + ) -> QueryChunksResponse: + raise NotImplementedError("Keyword search is not supported in FAISS") + class FaissVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): def __init__(self, config: FaissVectorIOConfig, inference_api: Inference) -> None: diff --git a/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py b/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py index c351d7e2e..402cd5ffd 100644 --- a/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +++ b/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py @@ -184,7 +184,7 @@ class SQLiteVecIndex(EmbeddingIndex): except sqlite3.Error as e: connection.rollback() - logger.error(f"Error inserting chunk batch: {e}") + logger.error(f"Error inserting into {self.vector_table}: {e}") raise finally: @@ -194,88 +194,99 @@ class SQLiteVecIndex(EmbeddingIndex): # Run batch insertion in a background thread await asyncio.to_thread(_execute_all_batch_inserts) - async def query( + async def query_vector( self, - embedding: Optional[NDArray], - query_string: Optional[str], + embedding: NDArray, k: int, score_threshold: float, - mode: Optional[str], ) -> QueryChunksResponse: """ - Supports both vector-based and keyword-based searches. - - 1. Vector Search (`mode=VECTOR_SEARCH`): - Uses a virtual table for vector similarity, joined with metadata. - - 2. Keyword Search (`mode=KEYWORD_SEARCH`): - Uses SQLite FTS5 for relevance-ranked full-text search. + Performs vector-based search using a virtual table for vector similarity. """ + if embedding is None: + raise ValueError("embedding is required for vector search.") def _execute_query(): connection = _create_sqlite_connection(self.db_path) cur = connection.cursor() - try: - if mode == VECTOR_SEARCH: - if embedding is None: - raise ValueError("embedding is required for vector search.") - emb_list = embedding.tolist() if isinstance(embedding, np.ndarray) else list(embedding) - emb_blob = serialize_vector(emb_list) - - query_sql = f""" - SELECT m.id, m.chunk, v.distance - FROM {self.vector_table} AS v - JOIN {self.metadata_table} AS m ON m.id = v.id - WHERE v.embedding MATCH ? AND k = ? - ORDER BY v.distance; - """ - cur.execute(query_sql, (emb_blob, k)) - - elif mode == KEYWORD_SEARCH: - if query_string is None: - raise ValueError("query_string is required for keyword search.") - - query_sql = f""" - SELECT DISTINCT m.id, m.chunk, bm25({self.fts_table}) AS score - FROM {self.fts_table} AS f - JOIN {self.metadata_table} AS m ON m.id = f.id - WHERE f.content MATCH ? - ORDER BY score ASC - LIMIT ?; - """ - cur.execute(query_sql, (query_string, k)) - - else: - raise ValueError(f"Invalid search_mode: {mode} please select from {SEARCH_MODES}") - + emb_list = embedding.tolist() if isinstance(embedding, np.ndarray) else list(embedding) + emb_blob = serialize_vector(emb_list) + query_sql = f""" + SELECT m.id, m.chunk, v.distance + FROM {self.vector_table} AS v + JOIN {self.metadata_table} AS m ON m.id = v.id + WHERE v.embedding MATCH ? AND k = ? + ORDER BY v.distance; + """ + cur.execute(query_sql, (emb_blob, k)) return cur.fetchall() finally: cur.close() connection.close() rows = await asyncio.to_thread(_execute_query) - chunks, scores = [], [] for row in rows: - if mode == VECTOR_SEARCH: - _id, chunk_json, distance = row - score = 1.0 / distance if distance != 0 else float("inf") - - if score < score_threshold: - continue - else: - _id, chunk_json, score = row - + _id, chunk_json, distance = row + score = 1.0 / distance if distance != 0 else float("inf") + if score < score_threshold: + continue try: chunk = Chunk.model_validate_json(chunk_json) except Exception as e: logger.error(f"Error parsing chunk JSON for id {_id}: {e}") continue - chunks.append(chunk) scores.append(score) + return QueryChunksResponse(chunks=chunks, scores=scores) + async def query_keyword( + self, + query_string: str | None, + k: int, + score_threshold: float, + ) -> QueryChunksResponse: + """ + Performs keyword-based search using SQLite FTS5 for relevance-ranked full-text search. + """ + if query_string is None: + raise ValueError("query_string is required for keyword search.") + + def _execute_query(): + connection = _create_sqlite_connection(self.db_path) + cur = connection.cursor() + try: + query_sql = f""" + SELECT DISTINCT m.id, m.chunk, bm25({self.fts_table}) AS score + FROM {self.fts_table} AS f + JOIN {self.metadata_table} AS m ON m.id = f.id + WHERE f.content MATCH ? + ORDER BY score ASC + LIMIT ?; + """ + cur.execute(query_sql, (query_string, k)) + return cur.fetchall() + finally: + cur.close() + connection.close() + + rows = await asyncio.to_thread(_execute_query) + chunks, scores = [], [] + for row in rows: + _id, chunk_json, score = row + # BM25 scores returned by sqlite-vec are NEGATED (i.e., more relevant = more negative). + # This design is intentional to simplify sorting by ascending score. + # Reference: https://alexgarcia.xyz/blog/2024/sqlite-vec-hybrid-search/index.html + if score > -score_threshold: + continue + try: + chunk = Chunk.model_validate_json(chunk_json) + except Exception as e: + logger.error(f"Error parsing chunk JSON for id {_id}: {e}") + continue + chunks.append(chunk) + scores.append(score) return QueryChunksResponse(chunks=chunks, scores=scores) diff --git a/llama_stack/providers/remote/vector_io/chroma/chroma.py b/llama_stack/providers/remote/vector_io/chroma/chroma.py index 669adc8ca..52aacbe59 100644 --- a/llama_stack/providers/remote/vector_io/chroma/chroma.py +++ b/llama_stack/providers/remote/vector_io/chroma/chroma.py @@ -55,9 +55,7 @@ class ChromaIndex(EmbeddingIndex): ) ) - async def query( - self, embedding: NDArray, query_string: Optional[str], k: int, score_threshold: float, mode: str - ) -> QueryChunksResponse: + async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: results = await maybe_await( self.collection.query( query_embeddings=[embedding.tolist()], @@ -86,6 +84,14 @@ class ChromaIndex(EmbeddingIndex): async def delete(self): await maybe_await(self.client.delete_collection(self.collection.name)) + async def query_keyword( + self, + query_string: str | None, + k: int, + score_threshold: float, + ) -> QueryChunksResponse: + raise NotImplementedError("Keyword search is not supported in Chroma") + class ChromaVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): def __init__( diff --git a/llama_stack/providers/remote/vector_io/milvus/milvus.py b/llama_stack/providers/remote/vector_io/milvus/milvus.py index 8f186611d..67c5d4474 100644 --- a/llama_stack/providers/remote/vector_io/milvus/milvus.py +++ b/llama_stack/providers/remote/vector_io/milvus/milvus.py @@ -73,9 +73,7 @@ class MilvusIndex(EmbeddingIndex): logger.error(f"Error inserting chunks into Milvus collection {self.collection_name}: {e}") raise e - async def query( - self, embedding: NDArray, query_str: Optional[str], k: int, score_threshold: float, mode: str - ) -> QueryChunksResponse: + async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: search_res = await asyncio.to_thread( self.client.search, collection_name=self.collection_name, @@ -88,6 +86,14 @@ class MilvusIndex(EmbeddingIndex): scores = [res["distance"] for res in search_res[0]] return QueryChunksResponse(chunks=chunks, scores=scores) + async def query_keyword( + self, + query_string: str | None, + k: int, + score_threshold: float, + ) -> QueryChunksResponse: + raise NotImplementedError("Keyword search is not supported in Milvus") + class MilvusVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): def __init__( diff --git a/llama_stack/providers/remote/vector_io/pgvector/pgvector.py b/llama_stack/providers/remote/vector_io/pgvector/pgvector.py index cadf768e2..150129c5c 100644 --- a/llama_stack/providers/remote/vector_io/pgvector/pgvector.py +++ b/llama_stack/providers/remote/vector_io/pgvector/pgvector.py @@ -99,9 +99,7 @@ class PGVectorIndex(EmbeddingIndex): with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: execute_values(cur, query, values, template="(%s, %s, %s::vector)") - async def query( - self, embedding: NDArray, query_string: Optional[str], k: int, score_threshold: float, mode: str - ) -> QueryChunksResponse: + async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: cur.execute( f""" @@ -122,6 +120,14 @@ class PGVectorIndex(EmbeddingIndex): return QueryChunksResponse(chunks=chunks, scores=scores) + async def query_keyword( + self, + query_string: str | None, + k: int, + score_threshold: float, + ) -> QueryChunksResponse: + raise NotImplementedError("Keyword search is not supported in PGVector") + async def delete(self): with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: cur.execute(f"DROP TABLE IF EXISTS {self.table_name}") diff --git a/llama_stack/providers/remote/vector_io/qdrant/qdrant.py b/llama_stack/providers/remote/vector_io/qdrant/qdrant.py index 5638a2831..4357ec03a 100644 --- a/llama_stack/providers/remote/vector_io/qdrant/qdrant.py +++ b/llama_stack/providers/remote/vector_io/qdrant/qdrant.py @@ -68,9 +68,7 @@ class QdrantIndex(EmbeddingIndex): await self.client.upsert(collection_name=self.collection_name, points=points) - async def query( - self, embedding: NDArray, query_string: Optional[str], k: int, score_threshold: float, mode: str - ) -> QueryChunksResponse: + async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: results = ( await self.client.query_points( collection_name=self.collection_name, @@ -97,6 +95,14 @@ class QdrantIndex(EmbeddingIndex): return QueryChunksResponse(chunks=chunks, scores=scores) + async def query_keyword( + self, + query_string: str | None, + k: int, + score_threshold: float, + ) -> QueryChunksResponse: + raise NotImplementedError("Keyword search is not supported in Qdrant") + async def delete(self): await self.client.delete_collection(collection_name=self.collection_name) diff --git a/llama_stack/providers/remote/vector_io/weaviate/weaviate.py b/llama_stack/providers/remote/vector_io/weaviate/weaviate.py index a633f362b..f0d154b09 100644 --- a/llama_stack/providers/remote/vector_io/weaviate/weaviate.py +++ b/llama_stack/providers/remote/vector_io/weaviate/weaviate.py @@ -55,9 +55,7 @@ class WeaviateIndex(EmbeddingIndex): # TODO: make this async friendly collection.data.insert_many(data_objects) - async def query( - self, embedding: NDArray, query_string: Optional[str], k: int, score_threshold: float, mode: str - ) -> QueryChunksResponse: + async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: collection = self.client.collections.get(self.collection_name) results = collection.query.near_vector( @@ -86,6 +84,14 @@ class WeaviateIndex(EmbeddingIndex): collection = self.client.collections.get(self.collection_name) collection.data.delete_many(where=Filter.by_property("id").contains_any(chunk_ids)) + async def query_keyword( + self, + query_string: str | None, + k: int, + score_threshold: float, + ) -> QueryChunksResponse: + raise NotImplementedError("Keyword search is not supported in Weaviate") + class WeaviateVectorIOAdapter( VectorIO, diff --git a/llama_stack/providers/utils/memory/vector_store.py b/llama_stack/providers/utils/memory/vector_store.py index 89b2aac57..d915942be 100644 --- a/llama_stack/providers/utils/memory/vector_store.py +++ b/llama_stack/providers/utils/memory/vector_store.py @@ -177,9 +177,11 @@ class EmbeddingIndex(ABC): raise NotImplementedError() @abstractmethod - async def query( - self, embedding: NDArray, query_string: Optional[str], k: int, score_threshold: float, mode: Optional[str] - ) -> QueryChunksResponse: + async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: + raise NotImplementedError() + + @abstractmethod + async def query_keyword(self, query_string: str | None, k: int, score_threshold: float) -> QueryChunksResponse: raise NotImplementedError() @abstractmethod @@ -215,6 +217,9 @@ class VectorDBWithIndex: mode = params.get("mode") score_threshold = params.get("score_threshold", 0.0) query_string = interleaved_content_as_str(query) - embeddings_response = await self.inference_api.embeddings(self.vector_db.embedding_model, [query_string]) - query_vector = np.array(embeddings_response.embeddings[0], dtype=np.float32) - return await self.index.query(query_vector, query_string, k, score_threshold, mode) + if mode == "keyword": + return await self.index.query_keyword(query_string, k, score_threshold) + else: + embeddings_response = await self.inference_api.embeddings(self.vector_db.embedding_model, [query_string]) + query_vector = np.array(embeddings_response.embeddings[0], dtype=np.float32) + return await self.index.query_vector(query_vector, k, score_threshold) diff --git a/tests/unit/providers/vector_io/test_qdrant.py b/tests/unit/providers/vector_io/test_qdrant.py index bc97719c0..34df9b52f 100644 --- a/tests/unit/providers/vector_io/test_qdrant.py +++ b/tests/unit/providers/vector_io/test_qdrant.py @@ -98,7 +98,7 @@ async def test_qdrant_adapter_returns_expected_chunks( response = await qdrant_adapter.query_chunks( query=__QUERY, vector_db_id=vector_db_id, - params={"max_chunks": max_query_chunks}, + params={"max_chunks": max_query_chunks, "mode": "vector"}, ) assert isinstance(response, QueryChunksResponse) assert len(response.chunks) == expected_chunks diff --git a/tests/unit/providers/vector_io/test_sqlite_vec.py b/tests/unit/providers/vector_io/test_sqlite_vec.py index 282ab6cd0..010a0ca42 100644 --- a/tests/unit/providers/vector_io/test_sqlite_vec.py +++ b/tests/unit/providers/vector_io/test_sqlite_vec.py @@ -60,7 +60,7 @@ async def test_add_chunks(sqlite_vec_index, sample_chunks, sample_embeddings): async def test_query_chunks_vector(sqlite_vec_index, sample_chunks, sample_embeddings, embedding_dimension): await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings) query_embedding = np.random.rand(embedding_dimension).astype(np.float32) - response = await sqlite_vec_index.query(query_embedding, query_string="", k=2, score_threshold=0.0, mode="vector") + response = await sqlite_vec_index.query_vector(query_embedding, k=2, score_threshold=0.0) assert isinstance(response, QueryChunksResponse) assert len(response.chunks) == 2 @@ -70,16 +70,14 @@ async def test_query_chunks_full_text_search(sqlite_vec_index, sample_chunks, sa await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings) query_string = "Sentence 5" - response = await sqlite_vec_index.query( - embedding=None, k=3, score_threshold=0.0, query_string=query_string, mode="keyword" - ) + response = await sqlite_vec_index.query_keyword(k=3, score_threshold=0.0, query_string=query_string) assert isinstance(response, QueryChunksResponse) - assert len(response.chunks) == 3, f"Expected at least one result, but got {len(response.chunks)}" + assert len(response.chunks) == 3, f"Expected three chunks, but got {len(response.chunks)}" non_existent_query_str = "blablabla" - response_no_results = await sqlite_vec_index.query( - embedding=None, query_string=non_existent_query_str, k=1, score_threshold=0.0, mode="keyword" + response_no_results = await sqlite_vec_index.query_keyword( + query_string=non_existent_query_str, k=1, score_threshold=0.0 ) assert isinstance(response_no_results, QueryChunksResponse) @@ -92,12 +90,10 @@ async def test_query_chunks_full_text_search_k_greater_than_results(sqlite_vec_i await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings) query_str = "Sentence 1 from document 0" # Should match only one chunk - response = await sqlite_vec_index.query( - embedding=None, k=5, score_threshold=0.0, query_string=query_str, mode="keyword" - ) + response = await sqlite_vec_index.query_keyword(k=5, score_threshold=0.0, query_string=query_str) assert isinstance(response, QueryChunksResponse) - assert 0 < len(response.chunks) < 5, f"Expected <5 results but >0, got {len(response.chunks)}" + assert 0 < len(response.chunks) < 5, f"Expected results between [1, 4], got {len(response.chunks)}" assert any("Sentence 1 from document 0" in chunk.content for chunk in response.chunks), "Expected chunk not found"