mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-17 18:38:11 +00:00
fix: SQLiteVecIndex.create(..., bank_id="test_bank.123") - bank_id with a dot - leads to sqlite3.OperationalError (#2770) (#2771)
# What does this PR do? Resolves https://github.com/meta-llama/llama-stack/issues/2770. It replaces characters in SQLite table names that are not alphanumeric or underscores with underscores and quotes the table names with square brackets in SQL statements. Closes #[2770] ## Test Plan I added a ".123" suffix to the bank_id on the following line ``` index = await SQLiteVecIndex.create(dimension=embedding_dimension, db_path=db_path, bank_id="test_bank.123") ``` in tests/unit/providers/vector_io/test_sqlite_vec.py, which, without the fix in place, demonstrates the issue.
This commit is contained in:
parent
72e606355d
commit
30be1fd8b7
2 changed files with 25 additions and 20 deletions
|
@ -7,6 +7,7 @@
|
|||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import sqlite3
|
||||
import struct
|
||||
from typing import Any
|
||||
|
@ -117,6 +118,10 @@ def _rrf_rerank(
|
|||
return rrf_scores
|
||||
|
||||
|
||||
def _make_sql_identifier(name: str) -> str:
|
||||
return re.sub(r"[^a-zA-Z0-9_]", "_", name)
|
||||
|
||||
|
||||
class SQLiteVecIndex(EmbeddingIndex):
|
||||
"""
|
||||
An index implementation that stores embeddings in a SQLite virtual table using sqlite-vec.
|
||||
|
@ -130,9 +135,9 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
self.dimension = dimension
|
||||
self.db_path = db_path
|
||||
self.bank_id = bank_id
|
||||
self.metadata_table = f"chunks_{bank_id}".replace("-", "_")
|
||||
self.vector_table = f"vec_chunks_{bank_id}".replace("-", "_")
|
||||
self.fts_table = f"fts_chunks_{bank_id}".replace("-", "_")
|
||||
self.metadata_table = _make_sql_identifier(f"chunks_{bank_id}")
|
||||
self.vector_table = _make_sql_identifier(f"vec_chunks_{bank_id}")
|
||||
self.fts_table = _make_sql_identifier(f"fts_chunks_{bank_id}")
|
||||
self.kvstore = kvstore
|
||||
|
||||
@classmethod
|
||||
|
@ -148,14 +153,14 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
try:
|
||||
# Create the table to store chunk metadata.
|
||||
cur.execute(f"""
|
||||
CREATE TABLE IF NOT EXISTS {self.metadata_table} (
|
||||
CREATE TABLE IF NOT EXISTS [{self.metadata_table}] (
|
||||
id TEXT PRIMARY KEY,
|
||||
chunk TEXT
|
||||
);
|
||||
""")
|
||||
# Create the virtual table for embeddings.
|
||||
cur.execute(f"""
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS {self.vector_table}
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS [{self.vector_table}]
|
||||
USING vec0(embedding FLOAT[{self.dimension}], id TEXT);
|
||||
""")
|
||||
connection.commit()
|
||||
|
@ -163,7 +168,7 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
# based on query. Implementation of the change on client side will allow passing the search_mode option
|
||||
# during initialization to make it easier to create the table that is required.
|
||||
cur.execute(f"""
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS {self.fts_table}
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS [{self.fts_table}]
|
||||
USING fts5(id, content);
|
||||
""")
|
||||
connection.commit()
|
||||
|
@ -178,9 +183,9 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
connection = _create_sqlite_connection(self.db_path)
|
||||
cur = connection.cursor()
|
||||
try:
|
||||
cur.execute(f"DROP TABLE IF EXISTS {self.metadata_table};")
|
||||
cur.execute(f"DROP TABLE IF EXISTS {self.vector_table};")
|
||||
cur.execute(f"DROP TABLE IF EXISTS {self.fts_table};")
|
||||
cur.execute(f"DROP TABLE IF EXISTS [{self.metadata_table}];")
|
||||
cur.execute(f"DROP TABLE IF EXISTS [{self.vector_table}];")
|
||||
cur.execute(f"DROP TABLE IF EXISTS [{self.fts_table}];")
|
||||
connection.commit()
|
||||
finally:
|
||||
cur.close()
|
||||
|
@ -212,7 +217,7 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
metadata_data = [(chunk.chunk_id, chunk.model_dump_json()) for chunk in batch_chunks]
|
||||
cur.executemany(
|
||||
f"""
|
||||
INSERT INTO {self.metadata_table} (id, chunk)
|
||||
INSERT INTO [{self.metadata_table}] (id, chunk)
|
||||
VALUES (?, ?)
|
||||
ON CONFLICT(id) DO UPDATE SET chunk = excluded.chunk;
|
||||
""",
|
||||
|
@ -230,7 +235,7 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
for chunk, emb in zip(batch_chunks, batch_embeddings, strict=True)
|
||||
]
|
||||
cur.executemany(
|
||||
f"INSERT INTO {self.vector_table} (id, embedding) VALUES (?, ?);",
|
||||
f"INSERT INTO [{self.vector_table}] (id, embedding) VALUES (?, ?);",
|
||||
embedding_data,
|
||||
)
|
||||
|
||||
|
@ -238,13 +243,13 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
fts_data = [(chunk.chunk_id, chunk.content) for chunk in batch_chunks]
|
||||
# DELETE existing entries with same IDs (FTS5 doesn't support ON CONFLICT)
|
||||
cur.executemany(
|
||||
f"DELETE FROM {self.fts_table} WHERE id = ?;",
|
||||
f"DELETE FROM [{self.fts_table}] WHERE id = ?;",
|
||||
[(row[0],) for row in fts_data],
|
||||
)
|
||||
|
||||
# INSERT new entries
|
||||
cur.executemany(
|
||||
f"INSERT INTO {self.fts_table} (id, content) VALUES (?, ?);",
|
||||
f"INSERT INTO [{self.fts_table}] (id, content) VALUES (?, ?);",
|
||||
fts_data,
|
||||
)
|
||||
|
||||
|
@ -280,8 +285,8 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
emb_blob = serialize_vector(emb_list)
|
||||
query_sql = f"""
|
||||
SELECT m.id, m.chunk, v.distance
|
||||
FROM {self.vector_table} AS v
|
||||
JOIN {self.metadata_table} AS m ON m.id = v.id
|
||||
FROM [{self.vector_table}] AS v
|
||||
JOIN [{self.metadata_table}] AS m ON m.id = v.id
|
||||
WHERE v.embedding MATCH ? AND k = ?
|
||||
ORDER BY v.distance;
|
||||
"""
|
||||
|
@ -322,9 +327,9 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
cur = connection.cursor()
|
||||
try:
|
||||
query_sql = f"""
|
||||
SELECT DISTINCT m.id, m.chunk, bm25({self.fts_table}) AS score
|
||||
FROM {self.fts_table} AS f
|
||||
JOIN {self.metadata_table} AS m ON m.id = f.id
|
||||
SELECT DISTINCT m.id, m.chunk, bm25([{self.fts_table}]) AS score
|
||||
FROM [{self.fts_table}] AS f
|
||||
JOIN [{self.metadata_table}] AS m ON m.id = f.id
|
||||
WHERE f.content MATCH ?
|
||||
ORDER BY score ASC
|
||||
LIMIT ?;
|
||||
|
|
|
@ -37,7 +37,7 @@ def loop():
|
|||
async def sqlite_vec_index(embedding_dimension, tmp_path_factory):
|
||||
temp_dir = tmp_path_factory.getbasetemp()
|
||||
db_path = str(temp_dir / "test_sqlite.db")
|
||||
index = await SQLiteVecIndex.create(dimension=embedding_dimension, db_path=db_path, bank_id="test_bank")
|
||||
index = await SQLiteVecIndex.create(dimension=embedding_dimension, db_path=db_path, bank_id="test_bank.123")
|
||||
yield index
|
||||
await index.delete()
|
||||
|
||||
|
@ -110,7 +110,7 @@ async def test_chunk_id_conflict(sqlite_vec_index, sample_chunks, embedding_dime
|
|||
cur = connection.cursor()
|
||||
|
||||
# Retrieve all chunk IDs to check for duplicates
|
||||
cur.execute(f"SELECT id FROM {sqlite_vec_index.metadata_table}")
|
||||
cur.execute(f"SELECT id FROM [{sqlite_vec_index.metadata_table}]")
|
||||
chunk_ids = [row[0] for row in cur.fetchall()]
|
||||
cur.close()
|
||||
connection.close()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue