From e2a7022d3c409aac8283f44f65230a5dcfd8ff71 Mon Sep 17 00:00:00 2001 From: Varsha Prasad Narsing Date: Mon, 14 Apr 2025 16:53:17 -0700 Subject: [PATCH] feat (RAG): Implement configurable search mode in RAGQueryConfig Signed-off-by: Varsha Prasad Narsing --- docs/_static/llama-stack-spec.html | 7 + docs/_static/llama-stack-spec.yaml | 5 + docs/source/providers/vector_io/sqlite-vec.md | 19 +++ llama_stack/apis/tools/rag_tool.py | 1 + .../inline/tool_runtime/rag/memory.py | 1 + .../providers/inline/vector_io/faiss/faiss.py | 10 +- .../inline/vector_io/sqlite_vec/sqlite_vec.py | 138 ++++++++++++++---- .../remote/vector_io/chroma/chroma.py | 4 +- .../remote/vector_io/milvus/milvus.py | 4 +- .../remote/vector_io/pgvector/pgvector.py | 4 +- .../remote/vector_io/qdrant/qdrant.py | 4 +- .../remote/vector_io/weaviate/weaviate.py | 4 +- .../providers/utils/memory/vector_store.py | 12 +- .../providers/vector_io/test_sqlite_vec.py | 40 ++++- 14 files changed, 210 insertions(+), 43 deletions(-) diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index 6adfe9b2b..3f2599a57 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -11601,6 +11601,7 @@ }, "max_chunks": { "type": "integer", +<<<<<<< HEAD "default": 5, "description": "Maximum number of chunks to retrieve." }, @@ -11608,6 +11609,12 @@ "type": "string", "default": "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n", "description": "Template for formatting each retrieved chunk in the context. Available placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk content string), {metadata} (chunk metadata dict). Default: \"Result {index}\\nContent: {chunk.content}\\nMetadata: {metadata}\\n\"" +======= + "default": 5 + }, + "mode": { + "type": "string" +>>>>>>> 1a0433d2 (feat (RAG): Implement configurable search mode in RAGQueryConfig) } }, "additionalProperties": false, diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index 31ca3f52a..b7da34c23 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -8072,6 +8072,7 @@ components: max_chunks: type: integer default: 5 +<<<<<<< HEAD description: Maximum number of chunks to retrieve. chunk_template: type: string @@ -8086,6 +8087,10 @@ components: placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk content string), {metadata} (chunk metadata dict). Default: "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n" +======= + mode: + type: string +>>>>>>> 1a0433d2 (feat (RAG): Implement configurable search mode in RAGQueryConfig) additionalProperties: false required: - query_generator_config diff --git a/docs/source/providers/vector_io/sqlite-vec.md b/docs/source/providers/vector_io/sqlite-vec.md index 43d10c751..f2b0baf4e 100644 --- a/docs/source/providers/vector_io/sqlite-vec.md +++ b/docs/source/providers/vector_io/sqlite-vec.md @@ -66,6 +66,25 @@ To use sqlite-vec in your Llama Stack project, follow these steps: 2. Configure your Llama Stack project to use SQLite-Vec. 3. Start storing and querying vectors. +## Supported Search Modes + +The sqlite-vec provider supports both vector-based and keyword-based (full-text) search modes. + +When using the RAGTool interface, you can specify the desired search behavior via the search_mode parameter in +`RAGQueryConfig`. For example: + +```python +from llama_stack.apis.tool_runtime.rag import RAGQueryConfig + +query_config = RAGQueryConfig(max_chunks=6, mode="vector") + +results = client.tool_runtime.rag_tool.query( + vector_db_ids=[vector_db_id], + content="what is torchtune", + query_config=query_config, +) +``` + ## Installation You can install SQLite-Vec using pip: diff --git a/llama_stack/apis/tools/rag_tool.py b/llama_stack/apis/tools/rag_tool.py index de3e4c62c..5d5280205 100644 --- a/llama_stack/apis/tools/rag_tool.py +++ b/llama_stack/apis/tools/rag_tool.py @@ -84,6 +84,7 @@ class RAGQueryConfig(BaseModel): max_tokens_in_context: int = 4096 max_chunks: int = 5 chunk_template: str = "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n" + mode: str | None = None @field_validator("chunk_template") def validate_chunk_template(cls, v: str) -> str: diff --git a/llama_stack/providers/inline/tool_runtime/rag/memory.py b/llama_stack/providers/inline/tool_runtime/rag/memory.py index c46960f75..fe16c76b8 100644 --- a/llama_stack/providers/inline/tool_runtime/rag/memory.py +++ b/llama_stack/providers/inline/tool_runtime/rag/memory.py @@ -122,6 +122,7 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime): query=query, params={ "max_chunks": query_config.max_chunks, + "mode": query_config.mode, }, ) for vector_db_id in vector_db_ids diff --git a/llama_stack/providers/inline/vector_io/faiss/faiss.py b/llama_stack/providers/inline/vector_io/faiss/faiss.py index d3dc7e694..ef9ca2855 100644 --- a/llama_stack/providers/inline/vector_io/faiss/faiss.py +++ b/llama_stack/providers/inline/vector_io/faiss/faiss.py @@ -99,9 +99,15 @@ class FaissIndex(EmbeddingIndex): # Save updated index await self._save_index() - async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: + async def query( + self, + embedding: NDArray, + query_string: Optional[str], + k: int, + score_threshold: float, + mode: Optional[str], + ) -> QueryChunksResponse: distances, indices = await asyncio.to_thread(self.index.search, embedding.reshape(1, -1).astype(np.float32), k) - chunks = [] scores = [] for d, i in zip(distances[0], indices[0], strict=False): diff --git a/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py b/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py index ab4384021..c351d7e2e 100644 --- a/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +++ b/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py @@ -24,6 +24,11 @@ from llama_stack.providers.utils.memory.vector_store import EmbeddingIndex, Vect logger = logging.getLogger(__name__) +# Specifying search mode is dependent on the VectorIO provider. +VECTOR_SEARCH = "vector" +KEYWORD_SEARCH = "keyword" +SEARCH_MODES = {VECTOR_SEARCH, KEYWORD_SEARCH} + def serialize_vector(vector: list[float]) -> bytes: """Serialize a list of floats into a compact binary representation.""" @@ -45,6 +50,7 @@ class SQLiteVecIndex(EmbeddingIndex): Two tables are used: - A metadata table (chunks_{bank_id}) that holds the chunk JSON. - A virtual table (vec_chunks_{bank_id}) that holds the serialized vector. + - An FTS5 table (fts_chunks_{bank_id}) for full-text keyword search. """ def __init__(self, dimension: int, db_path: str, bank_id: str): @@ -53,6 +59,7 @@ class SQLiteVecIndex(EmbeddingIndex): self.bank_id = bank_id self.metadata_table = f"chunks_{bank_id}".replace("-", "_") self.vector_table = f"vec_chunks_{bank_id}".replace("-", "_") + self.fts_table = f"fts_chunks_{bank_id}".replace("-", "_") @classmethod async def create(cls, dimension: int, db_path: str, bank_id: str): @@ -78,6 +85,14 @@ class SQLiteVecIndex(EmbeddingIndex): USING vec0(embedding FLOAT[{self.dimension}], id TEXT); """) connection.commit() + # FTS5 table (for keyword search) - creating both the tables by default. Will use the relevant one + # based on query. Implementation of the change on client side will allow passing the search_mode option + # during initialization to make it easier to create the table that is required. + cur.execute(f""" + CREATE VIRTUAL TABLE IF NOT EXISTS {self.fts_table} + USING fts5(id, content); + """) + connection.commit() finally: cur.close() connection.close() @@ -91,6 +106,7 @@ class SQLiteVecIndex(EmbeddingIndex): try: cur.execute(f"DROP TABLE IF EXISTS {self.metadata_table};") cur.execute(f"DROP TABLE IF EXISTS {self.vector_table};") + cur.execute(f"DROP TABLE IF EXISTS {self.fts_table};") connection.commit() finally: cur.close() @@ -104,6 +120,7 @@ class SQLiteVecIndex(EmbeddingIndex): For each chunk, we insert its JSON into the metadata table and then insert its embedding (serialized to raw bytes) into the virtual table using the assigned rowid. If any insert fails, the transaction is rolled back to maintain consistency. + Also inserts chunk content into FTS table for keyword search support. """ assert all(isinstance(chunk.content, str) for chunk in chunks), "SQLiteVecIndex only supports text chunks" @@ -112,18 +129,16 @@ class SQLiteVecIndex(EmbeddingIndex): cur = connection.cursor() try: - # Start transaction a single transcation for all batches cur.execute("BEGIN TRANSACTION") for i in range(0, len(chunks), batch_size): batch_chunks = chunks[i : i + batch_size] batch_embeddings = embeddings[i : i + batch_size] - # Prepare metadata inserts + + # Insert metadata metadata_data = [ (generate_chunk_id(chunk.metadata["document_id"], chunk.content), chunk.model_dump_json()) for chunk in batch_chunks - if isinstance(chunk.content, str) ] - # Insert metadata (ON CONFLICT to avoid duplicates) cur.executemany( f""" INSERT INTO {self.metadata_table} (id, chunk) @@ -132,52 +147,108 @@ class SQLiteVecIndex(EmbeddingIndex): """, metadata_data, ) - # Prepare embeddings inserts + + # Insert vector embeddings embedding_data = [ ( - generate_chunk_id(chunk.metadata["document_id"], chunk.content), - serialize_vector(emb.tolist()), + ( + generate_chunk_id(chunk.metadata["document_id"], chunk.content), + serialize_vector(emb.tolist()), + ) ) for chunk, emb in zip(batch_chunks, batch_embeddings, strict=True) - if isinstance(chunk.content, str) ] - # Insert embeddings in batch - cur.executemany(f"INSERT INTO {self.vector_table} (id, embedding) VALUES (?, ?);", embedding_data) + cur.executemany( + f"INSERT INTO {self.vector_table} (id, embedding) VALUES (?, ?);", + embedding_data, + ) + + # Insert FTS content + fts_data = [ + (generate_chunk_id(chunk.metadata["document_id"], chunk.content), chunk.content) + for chunk in batch_chunks + ] + # DELETE existing entries with same IDs (FTS5 doesn't support ON CONFLICT) + cur.executemany( + f"DELETE FROM {self.fts_table} WHERE id = ?;", + [(row[0],) for row in fts_data], + ) + + # INSERT new entries + cur.executemany( + f"INSERT INTO {self.fts_table} (id, content) VALUES (?, ?);", + fts_data, + ) + connection.commit() except sqlite3.Error as e: - connection.rollback() # Rollback on failure - logger.error(f"Error inserting into {self.vector_table}: {e}") + connection.rollback() + logger.error(f"Error inserting chunk batch: {e}") raise finally: cur.close() connection.close() - # Process all batches in a single thread + # Run batch insertion in a background thread await asyncio.to_thread(_execute_all_batch_inserts) - async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: + async def query( + self, + embedding: Optional[NDArray], + query_string: Optional[str], + k: int, + score_threshold: float, + mode: Optional[str], + ) -> QueryChunksResponse: """ - Query for the k most similar chunks. We convert the query embedding to a blob and run a SQL query - against the virtual table. The SQL joins the metadata table to recover the chunk JSON. + Supports both vector-based and keyword-based searches. + + 1. Vector Search (`mode=VECTOR_SEARCH`): + Uses a virtual table for vector similarity, joined with metadata. + + 2. Keyword Search (`mode=KEYWORD_SEARCH`): + Uses SQLite FTS5 for relevance-ranked full-text search. """ - emb_list = embedding.tolist() if isinstance(embedding, np.ndarray) else list(embedding) - emb_blob = serialize_vector(emb_list) def _execute_query(): connection = _create_sqlite_connection(self.db_path) cur = connection.cursor() try: - query_sql = f""" - SELECT m.id, m.chunk, v.distance - FROM {self.vector_table} AS v - JOIN {self.metadata_table} AS m ON m.id = v.id - WHERE v.embedding MATCH ? AND k = ? - ORDER BY v.distance; - """ - cur.execute(query_sql, (emb_blob, k)) + if mode == VECTOR_SEARCH: + if embedding is None: + raise ValueError("embedding is required for vector search.") + emb_list = embedding.tolist() if isinstance(embedding, np.ndarray) else list(embedding) + emb_blob = serialize_vector(emb_list) + + query_sql = f""" + SELECT m.id, m.chunk, v.distance + FROM {self.vector_table} AS v + JOIN {self.metadata_table} AS m ON m.id = v.id + WHERE v.embedding MATCH ? AND k = ? + ORDER BY v.distance; + """ + cur.execute(query_sql, (emb_blob, k)) + + elif mode == KEYWORD_SEARCH: + if query_string is None: + raise ValueError("query_string is required for keyword search.") + + query_sql = f""" + SELECT DISTINCT m.id, m.chunk, bm25({self.fts_table}) AS score + FROM {self.fts_table} AS f + JOIN {self.metadata_table} AS m ON m.id = f.id + WHERE f.content MATCH ? + ORDER BY score ASC + LIMIT ?; + """ + cur.execute(query_sql, (query_string, k)) + + else: + raise ValueError(f"Invalid search_mode: {mode} please select from {SEARCH_MODES}") + return cur.fetchall() finally: cur.close() @@ -186,16 +257,25 @@ class SQLiteVecIndex(EmbeddingIndex): rows = await asyncio.to_thread(_execute_query) chunks, scores = [], [] - for _id, chunk_json, distance in rows: + for row in rows: + if mode == VECTOR_SEARCH: + _id, chunk_json, distance = row + score = 1.0 / distance if distance != 0 else float("inf") + + if score < score_threshold: + continue + else: + _id, chunk_json, score = row + try: chunk = Chunk.model_validate_json(chunk_json) except Exception as e: logger.error(f"Error parsing chunk JSON for id {_id}: {e}") continue + chunks.append(chunk) - # Mimic the Faiss scoring: score = 1/distance (avoid division by zero) - score = 1.0 / distance if distance != 0 else float("inf") scores.append(score) + return QueryChunksResponse(chunks=chunks, scores=scores) diff --git a/llama_stack/providers/remote/vector_io/chroma/chroma.py b/llama_stack/providers/remote/vector_io/chroma/chroma.py index a919963ab..669adc8ca 100644 --- a/llama_stack/providers/remote/vector_io/chroma/chroma.py +++ b/llama_stack/providers/remote/vector_io/chroma/chroma.py @@ -55,7 +55,9 @@ class ChromaIndex(EmbeddingIndex): ) ) - async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: + async def query( + self, embedding: NDArray, query_string: Optional[str], k: int, score_threshold: float, mode: str + ) -> QueryChunksResponse: results = await maybe_await( self.collection.query( query_embeddings=[embedding.tolist()], diff --git a/llama_stack/providers/remote/vector_io/milvus/milvus.py b/llama_stack/providers/remote/vector_io/milvus/milvus.py index c98417b56..8f186611d 100644 --- a/llama_stack/providers/remote/vector_io/milvus/milvus.py +++ b/llama_stack/providers/remote/vector_io/milvus/milvus.py @@ -73,7 +73,9 @@ class MilvusIndex(EmbeddingIndex): logger.error(f"Error inserting chunks into Milvus collection {self.collection_name}: {e}") raise e - async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: + async def query( + self, embedding: NDArray, query_str: Optional[str], k: int, score_threshold: float, mode: str + ) -> QueryChunksResponse: search_res = await asyncio.to_thread( self.client.search, collection_name=self.collection_name, diff --git a/llama_stack/providers/remote/vector_io/pgvector/pgvector.py b/llama_stack/providers/remote/vector_io/pgvector/pgvector.py index 94546c6cf..cadf768e2 100644 --- a/llama_stack/providers/remote/vector_io/pgvector/pgvector.py +++ b/llama_stack/providers/remote/vector_io/pgvector/pgvector.py @@ -99,7 +99,9 @@ class PGVectorIndex(EmbeddingIndex): with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: execute_values(cur, query, values, template="(%s, %s, %s::vector)") - async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: + async def query( + self, embedding: NDArray, query_string: Optional[str], k: int, score_threshold: float, mode: str + ) -> QueryChunksResponse: with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: cur.execute( f""" diff --git a/llama_stack/providers/remote/vector_io/qdrant/qdrant.py b/llama_stack/providers/remote/vector_io/qdrant/qdrant.py index 514a6c70d..5638a2831 100644 --- a/llama_stack/providers/remote/vector_io/qdrant/qdrant.py +++ b/llama_stack/providers/remote/vector_io/qdrant/qdrant.py @@ -68,7 +68,9 @@ class QdrantIndex(EmbeddingIndex): await self.client.upsert(collection_name=self.collection_name, points=points) - async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: + async def query( + self, embedding: NDArray, query_string: Optional[str], k: int, score_threshold: float, mode: str + ) -> QueryChunksResponse: results = ( await self.client.query_points( collection_name=self.collection_name, diff --git a/llama_stack/providers/remote/vector_io/weaviate/weaviate.py b/llama_stack/providers/remote/vector_io/weaviate/weaviate.py index 308d2eb3d..a633f362b 100644 --- a/llama_stack/providers/remote/vector_io/weaviate/weaviate.py +++ b/llama_stack/providers/remote/vector_io/weaviate/weaviate.py @@ -55,7 +55,9 @@ class WeaviateIndex(EmbeddingIndex): # TODO: make this async friendly collection.data.insert_many(data_objects) - async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: + async def query( + self, embedding: NDArray, query_string: Optional[str], k: int, score_threshold: float, mode: str + ) -> QueryChunksResponse: collection = self.client.collections.get(self.collection_name) results = collection.query.near_vector( diff --git a/llama_stack/providers/utils/memory/vector_store.py b/llama_stack/providers/utils/memory/vector_store.py index e0e9d0679..89b2aac57 100644 --- a/llama_stack/providers/utils/memory/vector_store.py +++ b/llama_stack/providers/utils/memory/vector_store.py @@ -177,7 +177,9 @@ class EmbeddingIndex(ABC): raise NotImplementedError() @abstractmethod - async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: + async def query( + self, embedding: NDArray, query_string: Optional[str], k: int, score_threshold: float, mode: Optional[str] + ) -> QueryChunksResponse: raise NotImplementedError() @abstractmethod @@ -210,9 +212,9 @@ class VectorDBWithIndex: if params is None: params = {} k = params.get("max_chunks", 3) + mode = params.get("mode") score_threshold = params.get("score_threshold", 0.0) - - query_str = interleaved_content_as_str(query) - embeddings_response = await self.inference_api.embeddings(self.vector_db.embedding_model, [query_str]) + query_string = interleaved_content_as_str(query) + embeddings_response = await self.inference_api.embeddings(self.vector_db.embedding_model, [query_string]) query_vector = np.array(embeddings_response.embeddings[0], dtype=np.float32) - return await self.index.query(query_vector, k, score_threshold) + return await self.index.query(query_vector, query_string, k, score_threshold, mode) diff --git a/tests/unit/providers/vector_io/test_sqlite_vec.py b/tests/unit/providers/vector_io/test_sqlite_vec.py index 32b60ffa5..282ab6cd0 100644 --- a/tests/unit/providers/vector_io/test_sqlite_vec.py +++ b/tests/unit/providers/vector_io/test_sqlite_vec.py @@ -57,14 +57,50 @@ async def test_add_chunks(sqlite_vec_index, sample_chunks, sample_embeddings): @pytest.mark.asyncio -async def test_query_chunks(sqlite_vec_index, sample_chunks, sample_embeddings, embedding_dimension): +async def test_query_chunks_vector(sqlite_vec_index, sample_chunks, sample_embeddings, embedding_dimension): await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings) query_embedding = np.random.rand(embedding_dimension).astype(np.float32) - response = await sqlite_vec_index.query(query_embedding, k=2, score_threshold=0.0) + response = await sqlite_vec_index.query(query_embedding, query_string="", k=2, score_threshold=0.0, mode="vector") assert isinstance(response, QueryChunksResponse) assert len(response.chunks) == 2 +@pytest.mark.asyncio +async def test_query_chunks_full_text_search(sqlite_vec_index, sample_chunks, sample_embeddings): + await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings) + + query_string = "Sentence 5" + response = await sqlite_vec_index.query( + embedding=None, k=3, score_threshold=0.0, query_string=query_string, mode="keyword" + ) + + assert isinstance(response, QueryChunksResponse) + assert len(response.chunks) == 3, f"Expected at least one result, but got {len(response.chunks)}" + + non_existent_query_str = "blablabla" + response_no_results = await sqlite_vec_index.query( + embedding=None, query_string=non_existent_query_str, k=1, score_threshold=0.0, mode="keyword" + ) + + assert isinstance(response_no_results, QueryChunksResponse) + assert len(response_no_results.chunks) == 0, f"Expected 0 results, but got {len(response_no_results.chunks)}" + + +@pytest.mark.asyncio +async def test_query_chunks_full_text_search_k_greater_than_results(sqlite_vec_index, sample_chunks, sample_embeddings): + # Re-initialize with a clean index + await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings) + + query_str = "Sentence 1 from document 0" # Should match only one chunk + response = await sqlite_vec_index.query( + embedding=None, k=5, score_threshold=0.0, query_string=query_str, mode="keyword" + ) + + assert isinstance(response, QueryChunksResponse) + assert 0 < len(response.chunks) < 5, f"Expected <5 results but >0, got {len(response.chunks)}" + assert any("Sentence 1 from document 0" in chunk.content for chunk in response.chunks), "Expected chunk not found" + + @pytest.mark.asyncio async def test_chunk_id_conflict(sqlite_vec_index, sample_chunks, embedding_dimension): """Test that chunk IDs do not conflict across batches when inserting chunks."""