mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-02 00:34:44 +00:00
feat (RAG): Implement configurable search mode in RAGQueryConfig
Signed-off-by: Varsha Prasad Narsing <varshaprasad96@gmail.com>
This commit is contained in:
parent
85b5f3172b
commit
e2a7022d3c
14 changed files with 210 additions and 43 deletions
7
docs/_static/llama-stack-spec.html
vendored
7
docs/_static/llama-stack-spec.html
vendored
|
@ -11601,6 +11601,7 @@
|
||||||
},
|
},
|
||||||
"max_chunks": {
|
"max_chunks": {
|
||||||
"type": "integer",
|
"type": "integer",
|
||||||
|
<<<<<<< HEAD
|
||||||
"default": 5,
|
"default": 5,
|
||||||
"description": "Maximum number of chunks to retrieve."
|
"description": "Maximum number of chunks to retrieve."
|
||||||
},
|
},
|
||||||
|
@ -11608,6 +11609,12 @@
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"default": "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n",
|
"default": "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n",
|
||||||
"description": "Template for formatting each retrieved chunk in the context. Available placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk content string), {metadata} (chunk metadata dict). Default: \"Result {index}\\nContent: {chunk.content}\\nMetadata: {metadata}\\n\""
|
"description": "Template for formatting each retrieved chunk in the context. Available placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk content string), {metadata} (chunk metadata dict). Default: \"Result {index}\\nContent: {chunk.content}\\nMetadata: {metadata}\\n\""
|
||||||
|
=======
|
||||||
|
"default": 5
|
||||||
|
},
|
||||||
|
"mode": {
|
||||||
|
"type": "string"
|
||||||
|
>>>>>>> 1a0433d2 (feat (RAG): Implement configurable search mode in RAGQueryConfig)
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
|
|
5
docs/_static/llama-stack-spec.yaml
vendored
5
docs/_static/llama-stack-spec.yaml
vendored
|
@ -8072,6 +8072,7 @@ components:
|
||||||
max_chunks:
|
max_chunks:
|
||||||
type: integer
|
type: integer
|
||||||
default: 5
|
default: 5
|
||||||
|
<<<<<<< HEAD
|
||||||
description: Maximum number of chunks to retrieve.
|
description: Maximum number of chunks to retrieve.
|
||||||
chunk_template:
|
chunk_template:
|
||||||
type: string
|
type: string
|
||||||
|
@ -8086,6 +8087,10 @@ components:
|
||||||
placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk
|
placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk
|
||||||
content string), {metadata} (chunk metadata dict). Default: "Result {index}\nContent:
|
content string), {metadata} (chunk metadata dict). Default: "Result {index}\nContent:
|
||||||
{chunk.content}\nMetadata: {metadata}\n"
|
{chunk.content}\nMetadata: {metadata}\n"
|
||||||
|
=======
|
||||||
|
mode:
|
||||||
|
type: string
|
||||||
|
>>>>>>> 1a0433d2 (feat (RAG): Implement configurable search mode in RAGQueryConfig)
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
required:
|
required:
|
||||||
- query_generator_config
|
- query_generator_config
|
||||||
|
|
|
@ -66,6 +66,25 @@ To use sqlite-vec in your Llama Stack project, follow these steps:
|
||||||
2. Configure your Llama Stack project to use SQLite-Vec.
|
2. Configure your Llama Stack project to use SQLite-Vec.
|
||||||
3. Start storing and querying vectors.
|
3. Start storing and querying vectors.
|
||||||
|
|
||||||
|
## Supported Search Modes
|
||||||
|
|
||||||
|
The sqlite-vec provider supports both vector-based and keyword-based (full-text) search modes.
|
||||||
|
|
||||||
|
When using the RAGTool interface, you can specify the desired search behavior via the search_mode parameter in
|
||||||
|
`RAGQueryConfig`. For example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from llama_stack.apis.tool_runtime.rag import RAGQueryConfig
|
||||||
|
|
||||||
|
query_config = RAGQueryConfig(max_chunks=6, mode="vector")
|
||||||
|
|
||||||
|
results = client.tool_runtime.rag_tool.query(
|
||||||
|
vector_db_ids=[vector_db_id],
|
||||||
|
content="what is torchtune",
|
||||||
|
query_config=query_config,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
|
|
||||||
You can install SQLite-Vec using pip:
|
You can install SQLite-Vec using pip:
|
||||||
|
|
|
@ -84,6 +84,7 @@ class RAGQueryConfig(BaseModel):
|
||||||
max_tokens_in_context: int = 4096
|
max_tokens_in_context: int = 4096
|
||||||
max_chunks: int = 5
|
max_chunks: int = 5
|
||||||
chunk_template: str = "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n"
|
chunk_template: str = "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n"
|
||||||
|
mode: str | None = None
|
||||||
|
|
||||||
@field_validator("chunk_template")
|
@field_validator("chunk_template")
|
||||||
def validate_chunk_template(cls, v: str) -> str:
|
def validate_chunk_template(cls, v: str) -> str:
|
||||||
|
|
|
@ -122,6 +122,7 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
||||||
query=query,
|
query=query,
|
||||||
params={
|
params={
|
||||||
"max_chunks": query_config.max_chunks,
|
"max_chunks": query_config.max_chunks,
|
||||||
|
"mode": query_config.mode,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
for vector_db_id in vector_db_ids
|
for vector_db_id in vector_db_ids
|
||||||
|
|
|
@ -99,9 +99,15 @@ class FaissIndex(EmbeddingIndex):
|
||||||
# Save updated index
|
# Save updated index
|
||||||
await self._save_index()
|
await self._save_index()
|
||||||
|
|
||||||
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
async def query(
|
||||||
|
self,
|
||||||
|
embedding: NDArray,
|
||||||
|
query_string: Optional[str],
|
||||||
|
k: int,
|
||||||
|
score_threshold: float,
|
||||||
|
mode: Optional[str],
|
||||||
|
) -> QueryChunksResponse:
|
||||||
distances, indices = await asyncio.to_thread(self.index.search, embedding.reshape(1, -1).astype(np.float32), k)
|
distances, indices = await asyncio.to_thread(self.index.search, embedding.reshape(1, -1).astype(np.float32), k)
|
||||||
|
|
||||||
chunks = []
|
chunks = []
|
||||||
scores = []
|
scores = []
|
||||||
for d, i in zip(distances[0], indices[0], strict=False):
|
for d, i in zip(distances[0], indices[0], strict=False):
|
||||||
|
|
|
@ -24,6 +24,11 @@ from llama_stack.providers.utils.memory.vector_store import EmbeddingIndex, Vect
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Specifying search mode is dependent on the VectorIO provider.
|
||||||
|
VECTOR_SEARCH = "vector"
|
||||||
|
KEYWORD_SEARCH = "keyword"
|
||||||
|
SEARCH_MODES = {VECTOR_SEARCH, KEYWORD_SEARCH}
|
||||||
|
|
||||||
|
|
||||||
def serialize_vector(vector: list[float]) -> bytes:
|
def serialize_vector(vector: list[float]) -> bytes:
|
||||||
"""Serialize a list of floats into a compact binary representation."""
|
"""Serialize a list of floats into a compact binary representation."""
|
||||||
|
@ -45,6 +50,7 @@ class SQLiteVecIndex(EmbeddingIndex):
|
||||||
Two tables are used:
|
Two tables are used:
|
||||||
- A metadata table (chunks_{bank_id}) that holds the chunk JSON.
|
- A metadata table (chunks_{bank_id}) that holds the chunk JSON.
|
||||||
- A virtual table (vec_chunks_{bank_id}) that holds the serialized vector.
|
- A virtual table (vec_chunks_{bank_id}) that holds the serialized vector.
|
||||||
|
- An FTS5 table (fts_chunks_{bank_id}) for full-text keyword search.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, dimension: int, db_path: str, bank_id: str):
|
def __init__(self, dimension: int, db_path: str, bank_id: str):
|
||||||
|
@ -53,6 +59,7 @@ class SQLiteVecIndex(EmbeddingIndex):
|
||||||
self.bank_id = bank_id
|
self.bank_id = bank_id
|
||||||
self.metadata_table = f"chunks_{bank_id}".replace("-", "_")
|
self.metadata_table = f"chunks_{bank_id}".replace("-", "_")
|
||||||
self.vector_table = f"vec_chunks_{bank_id}".replace("-", "_")
|
self.vector_table = f"vec_chunks_{bank_id}".replace("-", "_")
|
||||||
|
self.fts_table = f"fts_chunks_{bank_id}".replace("-", "_")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def create(cls, dimension: int, db_path: str, bank_id: str):
|
async def create(cls, dimension: int, db_path: str, bank_id: str):
|
||||||
|
@ -78,6 +85,14 @@ class SQLiteVecIndex(EmbeddingIndex):
|
||||||
USING vec0(embedding FLOAT[{self.dimension}], id TEXT);
|
USING vec0(embedding FLOAT[{self.dimension}], id TEXT);
|
||||||
""")
|
""")
|
||||||
connection.commit()
|
connection.commit()
|
||||||
|
# FTS5 table (for keyword search) - creating both the tables by default. Will use the relevant one
|
||||||
|
# based on query. Implementation of the change on client side will allow passing the search_mode option
|
||||||
|
# during initialization to make it easier to create the table that is required.
|
||||||
|
cur.execute(f"""
|
||||||
|
CREATE VIRTUAL TABLE IF NOT EXISTS {self.fts_table}
|
||||||
|
USING fts5(id, content);
|
||||||
|
""")
|
||||||
|
connection.commit()
|
||||||
finally:
|
finally:
|
||||||
cur.close()
|
cur.close()
|
||||||
connection.close()
|
connection.close()
|
||||||
|
@ -91,6 +106,7 @@ class SQLiteVecIndex(EmbeddingIndex):
|
||||||
try:
|
try:
|
||||||
cur.execute(f"DROP TABLE IF EXISTS {self.metadata_table};")
|
cur.execute(f"DROP TABLE IF EXISTS {self.metadata_table};")
|
||||||
cur.execute(f"DROP TABLE IF EXISTS {self.vector_table};")
|
cur.execute(f"DROP TABLE IF EXISTS {self.vector_table};")
|
||||||
|
cur.execute(f"DROP TABLE IF EXISTS {self.fts_table};")
|
||||||
connection.commit()
|
connection.commit()
|
||||||
finally:
|
finally:
|
||||||
cur.close()
|
cur.close()
|
||||||
|
@ -104,6 +120,7 @@ class SQLiteVecIndex(EmbeddingIndex):
|
||||||
For each chunk, we insert its JSON into the metadata table and then insert its
|
For each chunk, we insert its JSON into the metadata table and then insert its
|
||||||
embedding (serialized to raw bytes) into the virtual table using the assigned rowid.
|
embedding (serialized to raw bytes) into the virtual table using the assigned rowid.
|
||||||
If any insert fails, the transaction is rolled back to maintain consistency.
|
If any insert fails, the transaction is rolled back to maintain consistency.
|
||||||
|
Also inserts chunk content into FTS table for keyword search support.
|
||||||
"""
|
"""
|
||||||
assert all(isinstance(chunk.content, str) for chunk in chunks), "SQLiteVecIndex only supports text chunks"
|
assert all(isinstance(chunk.content, str) for chunk in chunks), "SQLiteVecIndex only supports text chunks"
|
||||||
|
|
||||||
|
@ -112,18 +129,16 @@ class SQLiteVecIndex(EmbeddingIndex):
|
||||||
cur = connection.cursor()
|
cur = connection.cursor()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Start transaction a single transcation for all batches
|
|
||||||
cur.execute("BEGIN TRANSACTION")
|
cur.execute("BEGIN TRANSACTION")
|
||||||
for i in range(0, len(chunks), batch_size):
|
for i in range(0, len(chunks), batch_size):
|
||||||
batch_chunks = chunks[i : i + batch_size]
|
batch_chunks = chunks[i : i + batch_size]
|
||||||
batch_embeddings = embeddings[i : i + batch_size]
|
batch_embeddings = embeddings[i : i + batch_size]
|
||||||
# Prepare metadata inserts
|
|
||||||
|
# Insert metadata
|
||||||
metadata_data = [
|
metadata_data = [
|
||||||
(generate_chunk_id(chunk.metadata["document_id"], chunk.content), chunk.model_dump_json())
|
(generate_chunk_id(chunk.metadata["document_id"], chunk.content), chunk.model_dump_json())
|
||||||
for chunk in batch_chunks
|
for chunk in batch_chunks
|
||||||
if isinstance(chunk.content, str)
|
|
||||||
]
|
]
|
||||||
# Insert metadata (ON CONFLICT to avoid duplicates)
|
|
||||||
cur.executemany(
|
cur.executemany(
|
||||||
f"""
|
f"""
|
||||||
INSERT INTO {self.metadata_table} (id, chunk)
|
INSERT INTO {self.metadata_table} (id, chunk)
|
||||||
|
@ -132,52 +147,108 @@ class SQLiteVecIndex(EmbeddingIndex):
|
||||||
""",
|
""",
|
||||||
metadata_data,
|
metadata_data,
|
||||||
)
|
)
|
||||||
# Prepare embeddings inserts
|
|
||||||
|
# Insert vector embeddings
|
||||||
embedding_data = [
|
embedding_data = [
|
||||||
(
|
(
|
||||||
generate_chunk_id(chunk.metadata["document_id"], chunk.content),
|
(
|
||||||
serialize_vector(emb.tolist()),
|
generate_chunk_id(chunk.metadata["document_id"], chunk.content),
|
||||||
|
serialize_vector(emb.tolist()),
|
||||||
|
)
|
||||||
)
|
)
|
||||||
for chunk, emb in zip(batch_chunks, batch_embeddings, strict=True)
|
for chunk, emb in zip(batch_chunks, batch_embeddings, strict=True)
|
||||||
if isinstance(chunk.content, str)
|
|
||||||
]
|
]
|
||||||
# Insert embeddings in batch
|
cur.executemany(
|
||||||
cur.executemany(f"INSERT INTO {self.vector_table} (id, embedding) VALUES (?, ?);", embedding_data)
|
f"INSERT INTO {self.vector_table} (id, embedding) VALUES (?, ?);",
|
||||||
|
embedding_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Insert FTS content
|
||||||
|
fts_data = [
|
||||||
|
(generate_chunk_id(chunk.metadata["document_id"], chunk.content), chunk.content)
|
||||||
|
for chunk in batch_chunks
|
||||||
|
]
|
||||||
|
# DELETE existing entries with same IDs (FTS5 doesn't support ON CONFLICT)
|
||||||
|
cur.executemany(
|
||||||
|
f"DELETE FROM {self.fts_table} WHERE id = ?;",
|
||||||
|
[(row[0],) for row in fts_data],
|
||||||
|
)
|
||||||
|
|
||||||
|
# INSERT new entries
|
||||||
|
cur.executemany(
|
||||||
|
f"INSERT INTO {self.fts_table} (id, content) VALUES (?, ?);",
|
||||||
|
fts_data,
|
||||||
|
)
|
||||||
|
|
||||||
connection.commit()
|
connection.commit()
|
||||||
|
|
||||||
except sqlite3.Error as e:
|
except sqlite3.Error as e:
|
||||||
connection.rollback() # Rollback on failure
|
connection.rollback()
|
||||||
logger.error(f"Error inserting into {self.vector_table}: {e}")
|
logger.error(f"Error inserting chunk batch: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
cur.close()
|
cur.close()
|
||||||
connection.close()
|
connection.close()
|
||||||
|
|
||||||
# Process all batches in a single thread
|
# Run batch insertion in a background thread
|
||||||
await asyncio.to_thread(_execute_all_batch_inserts)
|
await asyncio.to_thread(_execute_all_batch_inserts)
|
||||||
|
|
||||||
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
async def query(
|
||||||
|
self,
|
||||||
|
embedding: Optional[NDArray],
|
||||||
|
query_string: Optional[str],
|
||||||
|
k: int,
|
||||||
|
score_threshold: float,
|
||||||
|
mode: Optional[str],
|
||||||
|
) -> QueryChunksResponse:
|
||||||
"""
|
"""
|
||||||
Query for the k most similar chunks. We convert the query embedding to a blob and run a SQL query
|
Supports both vector-based and keyword-based searches.
|
||||||
against the virtual table. The SQL joins the metadata table to recover the chunk JSON.
|
|
||||||
|
1. Vector Search (`mode=VECTOR_SEARCH`):
|
||||||
|
Uses a virtual table for vector similarity, joined with metadata.
|
||||||
|
|
||||||
|
2. Keyword Search (`mode=KEYWORD_SEARCH`):
|
||||||
|
Uses SQLite FTS5 for relevance-ranked full-text search.
|
||||||
"""
|
"""
|
||||||
emb_list = embedding.tolist() if isinstance(embedding, np.ndarray) else list(embedding)
|
|
||||||
emb_blob = serialize_vector(emb_list)
|
|
||||||
|
|
||||||
def _execute_query():
|
def _execute_query():
|
||||||
connection = _create_sqlite_connection(self.db_path)
|
connection = _create_sqlite_connection(self.db_path)
|
||||||
cur = connection.cursor()
|
cur = connection.cursor()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
query_sql = f"""
|
if mode == VECTOR_SEARCH:
|
||||||
SELECT m.id, m.chunk, v.distance
|
if embedding is None:
|
||||||
FROM {self.vector_table} AS v
|
raise ValueError("embedding is required for vector search.")
|
||||||
JOIN {self.metadata_table} AS m ON m.id = v.id
|
emb_list = embedding.tolist() if isinstance(embedding, np.ndarray) else list(embedding)
|
||||||
WHERE v.embedding MATCH ? AND k = ?
|
emb_blob = serialize_vector(emb_list)
|
||||||
ORDER BY v.distance;
|
|
||||||
"""
|
query_sql = f"""
|
||||||
cur.execute(query_sql, (emb_blob, k))
|
SELECT m.id, m.chunk, v.distance
|
||||||
|
FROM {self.vector_table} AS v
|
||||||
|
JOIN {self.metadata_table} AS m ON m.id = v.id
|
||||||
|
WHERE v.embedding MATCH ? AND k = ?
|
||||||
|
ORDER BY v.distance;
|
||||||
|
"""
|
||||||
|
cur.execute(query_sql, (emb_blob, k))
|
||||||
|
|
||||||
|
elif mode == KEYWORD_SEARCH:
|
||||||
|
if query_string is None:
|
||||||
|
raise ValueError("query_string is required for keyword search.")
|
||||||
|
|
||||||
|
query_sql = f"""
|
||||||
|
SELECT DISTINCT m.id, m.chunk, bm25({self.fts_table}) AS score
|
||||||
|
FROM {self.fts_table} AS f
|
||||||
|
JOIN {self.metadata_table} AS m ON m.id = f.id
|
||||||
|
WHERE f.content MATCH ?
|
||||||
|
ORDER BY score ASC
|
||||||
|
LIMIT ?;
|
||||||
|
"""
|
||||||
|
cur.execute(query_sql, (query_string, k))
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid search_mode: {mode} please select from {SEARCH_MODES}")
|
||||||
|
|
||||||
return cur.fetchall()
|
return cur.fetchall()
|
||||||
finally:
|
finally:
|
||||||
cur.close()
|
cur.close()
|
||||||
|
@ -186,16 +257,25 @@ class SQLiteVecIndex(EmbeddingIndex):
|
||||||
rows = await asyncio.to_thread(_execute_query)
|
rows = await asyncio.to_thread(_execute_query)
|
||||||
|
|
||||||
chunks, scores = [], []
|
chunks, scores = [], []
|
||||||
for _id, chunk_json, distance in rows:
|
for row in rows:
|
||||||
|
if mode == VECTOR_SEARCH:
|
||||||
|
_id, chunk_json, distance = row
|
||||||
|
score = 1.0 / distance if distance != 0 else float("inf")
|
||||||
|
|
||||||
|
if score < score_threshold:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
_id, chunk_json, score = row
|
||||||
|
|
||||||
try:
|
try:
|
||||||
chunk = Chunk.model_validate_json(chunk_json)
|
chunk = Chunk.model_validate_json(chunk_json)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error parsing chunk JSON for id {_id}: {e}")
|
logger.error(f"Error parsing chunk JSON for id {_id}: {e}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
chunks.append(chunk)
|
chunks.append(chunk)
|
||||||
# Mimic the Faiss scoring: score = 1/distance (avoid division by zero)
|
|
||||||
score = 1.0 / distance if distance != 0 else float("inf")
|
|
||||||
scores.append(score)
|
scores.append(score)
|
||||||
|
|
||||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -55,7 +55,9 @@ class ChromaIndex(EmbeddingIndex):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
async def query(
|
||||||
|
self, embedding: NDArray, query_string: Optional[str], k: int, score_threshold: float, mode: str
|
||||||
|
) -> QueryChunksResponse:
|
||||||
results = await maybe_await(
|
results = await maybe_await(
|
||||||
self.collection.query(
|
self.collection.query(
|
||||||
query_embeddings=[embedding.tolist()],
|
query_embeddings=[embedding.tolist()],
|
||||||
|
|
|
@ -73,7 +73,9 @@ class MilvusIndex(EmbeddingIndex):
|
||||||
logger.error(f"Error inserting chunks into Milvus collection {self.collection_name}: {e}")
|
logger.error(f"Error inserting chunks into Milvus collection {self.collection_name}: {e}")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
async def query(
|
||||||
|
self, embedding: NDArray, query_str: Optional[str], k: int, score_threshold: float, mode: str
|
||||||
|
) -> QueryChunksResponse:
|
||||||
search_res = await asyncio.to_thread(
|
search_res = await asyncio.to_thread(
|
||||||
self.client.search,
|
self.client.search,
|
||||||
collection_name=self.collection_name,
|
collection_name=self.collection_name,
|
||||||
|
|
|
@ -99,7 +99,9 @@ class PGVectorIndex(EmbeddingIndex):
|
||||||
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
||||||
execute_values(cur, query, values, template="(%s, %s, %s::vector)")
|
execute_values(cur, query, values, template="(%s, %s, %s::vector)")
|
||||||
|
|
||||||
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
async def query(
|
||||||
|
self, embedding: NDArray, query_string: Optional[str], k: int, score_threshold: float, mode: str
|
||||||
|
) -> QueryChunksResponse:
|
||||||
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
||||||
cur.execute(
|
cur.execute(
|
||||||
f"""
|
f"""
|
||||||
|
|
|
@ -68,7 +68,9 @@ class QdrantIndex(EmbeddingIndex):
|
||||||
|
|
||||||
await self.client.upsert(collection_name=self.collection_name, points=points)
|
await self.client.upsert(collection_name=self.collection_name, points=points)
|
||||||
|
|
||||||
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
async def query(
|
||||||
|
self, embedding: NDArray, query_string: Optional[str], k: int, score_threshold: float, mode: str
|
||||||
|
) -> QueryChunksResponse:
|
||||||
results = (
|
results = (
|
||||||
await self.client.query_points(
|
await self.client.query_points(
|
||||||
collection_name=self.collection_name,
|
collection_name=self.collection_name,
|
||||||
|
|
|
@ -55,7 +55,9 @@ class WeaviateIndex(EmbeddingIndex):
|
||||||
# TODO: make this async friendly
|
# TODO: make this async friendly
|
||||||
collection.data.insert_many(data_objects)
|
collection.data.insert_many(data_objects)
|
||||||
|
|
||||||
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
async def query(
|
||||||
|
self, embedding: NDArray, query_string: Optional[str], k: int, score_threshold: float, mode: str
|
||||||
|
) -> QueryChunksResponse:
|
||||||
collection = self.client.collections.get(self.collection_name)
|
collection = self.client.collections.get(self.collection_name)
|
||||||
|
|
||||||
results = collection.query.near_vector(
|
results = collection.query.near_vector(
|
||||||
|
|
|
@ -177,7 +177,9 @@ class EmbeddingIndex(ABC):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
async def query(
|
||||||
|
self, embedding: NDArray, query_string: Optional[str], k: int, score_threshold: float, mode: Optional[str]
|
||||||
|
) -> QueryChunksResponse:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
@ -210,9 +212,9 @@ class VectorDBWithIndex:
|
||||||
if params is None:
|
if params is None:
|
||||||
params = {}
|
params = {}
|
||||||
k = params.get("max_chunks", 3)
|
k = params.get("max_chunks", 3)
|
||||||
|
mode = params.get("mode")
|
||||||
score_threshold = params.get("score_threshold", 0.0)
|
score_threshold = params.get("score_threshold", 0.0)
|
||||||
|
query_string = interleaved_content_as_str(query)
|
||||||
query_str = interleaved_content_as_str(query)
|
embeddings_response = await self.inference_api.embeddings(self.vector_db.embedding_model, [query_string])
|
||||||
embeddings_response = await self.inference_api.embeddings(self.vector_db.embedding_model, [query_str])
|
|
||||||
query_vector = np.array(embeddings_response.embeddings[0], dtype=np.float32)
|
query_vector = np.array(embeddings_response.embeddings[0], dtype=np.float32)
|
||||||
return await self.index.query(query_vector, k, score_threshold)
|
return await self.index.query(query_vector, query_string, k, score_threshold, mode)
|
||||||
|
|
|
@ -57,14 +57,50 @@ async def test_add_chunks(sqlite_vec_index, sample_chunks, sample_embeddings):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_query_chunks(sqlite_vec_index, sample_chunks, sample_embeddings, embedding_dimension):
|
async def test_query_chunks_vector(sqlite_vec_index, sample_chunks, sample_embeddings, embedding_dimension):
|
||||||
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
||||||
query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
|
query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
|
||||||
response = await sqlite_vec_index.query(query_embedding, k=2, score_threshold=0.0)
|
response = await sqlite_vec_index.query(query_embedding, query_string="", k=2, score_threshold=0.0, mode="vector")
|
||||||
assert isinstance(response, QueryChunksResponse)
|
assert isinstance(response, QueryChunksResponse)
|
||||||
assert len(response.chunks) == 2
|
assert len(response.chunks) == 2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_query_chunks_full_text_search(sqlite_vec_index, sample_chunks, sample_embeddings):
|
||||||
|
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
||||||
|
|
||||||
|
query_string = "Sentence 5"
|
||||||
|
response = await sqlite_vec_index.query(
|
||||||
|
embedding=None, k=3, score_threshold=0.0, query_string=query_string, mode="keyword"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(response, QueryChunksResponse)
|
||||||
|
assert len(response.chunks) == 3, f"Expected at least one result, but got {len(response.chunks)}"
|
||||||
|
|
||||||
|
non_existent_query_str = "blablabla"
|
||||||
|
response_no_results = await sqlite_vec_index.query(
|
||||||
|
embedding=None, query_string=non_existent_query_str, k=1, score_threshold=0.0, mode="keyword"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(response_no_results, QueryChunksResponse)
|
||||||
|
assert len(response_no_results.chunks) == 0, f"Expected 0 results, but got {len(response_no_results.chunks)}"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_query_chunks_full_text_search_k_greater_than_results(sqlite_vec_index, sample_chunks, sample_embeddings):
|
||||||
|
# Re-initialize with a clean index
|
||||||
|
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
||||||
|
|
||||||
|
query_str = "Sentence 1 from document 0" # Should match only one chunk
|
||||||
|
response = await sqlite_vec_index.query(
|
||||||
|
embedding=None, k=5, score_threshold=0.0, query_string=query_str, mode="keyword"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(response, QueryChunksResponse)
|
||||||
|
assert 0 < len(response.chunks) < 5, f"Expected <5 results but >0, got {len(response.chunks)}"
|
||||||
|
assert any("Sentence 1 from document 0" in chunk.content for chunk in response.chunks), "Expected chunk not found"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_chunk_id_conflict(sqlite_vec_index, sample_chunks, embedding_dimension):
|
async def test_chunk_id_conflict(sqlite_vec_index, sample_chunks, embedding_dimension):
|
||||||
"""Test that chunk IDs do not conflict across batches when inserting chunks."""
|
"""Test that chunk IDs do not conflict across batches when inserting chunks."""
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue