mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-18 02:42:31 +00:00
fix: SQLiteVecIndex.create(..., bank_id="test_bank.123") - bank_id with a dot - leads to sqlite3.OperationalError (#2770) (#2771)
# What does this PR do? Resolves https://github.com/meta-llama/llama-stack/issues/2770. It replaces characters in SQLite table names that are not alphanumeric or underscores with underscores and quotes the table names with square brackets in SQL statements. Closes #[2770] ## Test Plan I added a ".123" suffix to the bank_id on the following line ``` index = await SQLiteVecIndex.create(dimension=embedding_dimension, db_path=db_path, bank_id="test_bank.123") ``` in tests/unit/providers/vector_io/test_sqlite_vec.py, which, without the fix in place, demonstrates the issue.
This commit is contained in:
parent
72e606355d
commit
30be1fd8b7
2 changed files with 25 additions and 20 deletions
|
@ -7,6 +7,7 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
import sqlite3
|
import sqlite3
|
||||||
import struct
|
import struct
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
@ -117,6 +118,10 @@ def _rrf_rerank(
|
||||||
return rrf_scores
|
return rrf_scores
|
||||||
|
|
||||||
|
|
||||||
|
def _make_sql_identifier(name: str) -> str:
|
||||||
|
return re.sub(r"[^a-zA-Z0-9_]", "_", name)
|
||||||
|
|
||||||
|
|
||||||
class SQLiteVecIndex(EmbeddingIndex):
|
class SQLiteVecIndex(EmbeddingIndex):
|
||||||
"""
|
"""
|
||||||
An index implementation that stores embeddings in a SQLite virtual table using sqlite-vec.
|
An index implementation that stores embeddings in a SQLite virtual table using sqlite-vec.
|
||||||
|
@ -130,9 +135,9 @@ class SQLiteVecIndex(EmbeddingIndex):
|
||||||
self.dimension = dimension
|
self.dimension = dimension
|
||||||
self.db_path = db_path
|
self.db_path = db_path
|
||||||
self.bank_id = bank_id
|
self.bank_id = bank_id
|
||||||
self.metadata_table = f"chunks_{bank_id}".replace("-", "_")
|
self.metadata_table = _make_sql_identifier(f"chunks_{bank_id}")
|
||||||
self.vector_table = f"vec_chunks_{bank_id}".replace("-", "_")
|
self.vector_table = _make_sql_identifier(f"vec_chunks_{bank_id}")
|
||||||
self.fts_table = f"fts_chunks_{bank_id}".replace("-", "_")
|
self.fts_table = _make_sql_identifier(f"fts_chunks_{bank_id}")
|
||||||
self.kvstore = kvstore
|
self.kvstore = kvstore
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -148,14 +153,14 @@ class SQLiteVecIndex(EmbeddingIndex):
|
||||||
try:
|
try:
|
||||||
# Create the table to store chunk metadata.
|
# Create the table to store chunk metadata.
|
||||||
cur.execute(f"""
|
cur.execute(f"""
|
||||||
CREATE TABLE IF NOT EXISTS {self.metadata_table} (
|
CREATE TABLE IF NOT EXISTS [{self.metadata_table}] (
|
||||||
id TEXT PRIMARY KEY,
|
id TEXT PRIMARY KEY,
|
||||||
chunk TEXT
|
chunk TEXT
|
||||||
);
|
);
|
||||||
""")
|
""")
|
||||||
# Create the virtual table for embeddings.
|
# Create the virtual table for embeddings.
|
||||||
cur.execute(f"""
|
cur.execute(f"""
|
||||||
CREATE VIRTUAL TABLE IF NOT EXISTS {self.vector_table}
|
CREATE VIRTUAL TABLE IF NOT EXISTS [{self.vector_table}]
|
||||||
USING vec0(embedding FLOAT[{self.dimension}], id TEXT);
|
USING vec0(embedding FLOAT[{self.dimension}], id TEXT);
|
||||||
""")
|
""")
|
||||||
connection.commit()
|
connection.commit()
|
||||||
|
@ -163,7 +168,7 @@ class SQLiteVecIndex(EmbeddingIndex):
|
||||||
# based on query. Implementation of the change on client side will allow passing the search_mode option
|
# based on query. Implementation of the change on client side will allow passing the search_mode option
|
||||||
# during initialization to make it easier to create the table that is required.
|
# during initialization to make it easier to create the table that is required.
|
||||||
cur.execute(f"""
|
cur.execute(f"""
|
||||||
CREATE VIRTUAL TABLE IF NOT EXISTS {self.fts_table}
|
CREATE VIRTUAL TABLE IF NOT EXISTS [{self.fts_table}]
|
||||||
USING fts5(id, content);
|
USING fts5(id, content);
|
||||||
""")
|
""")
|
||||||
connection.commit()
|
connection.commit()
|
||||||
|
@ -178,9 +183,9 @@ class SQLiteVecIndex(EmbeddingIndex):
|
||||||
connection = _create_sqlite_connection(self.db_path)
|
connection = _create_sqlite_connection(self.db_path)
|
||||||
cur = connection.cursor()
|
cur = connection.cursor()
|
||||||
try:
|
try:
|
||||||
cur.execute(f"DROP TABLE IF EXISTS {self.metadata_table};")
|
cur.execute(f"DROP TABLE IF EXISTS [{self.metadata_table}];")
|
||||||
cur.execute(f"DROP TABLE IF EXISTS {self.vector_table};")
|
cur.execute(f"DROP TABLE IF EXISTS [{self.vector_table}];")
|
||||||
cur.execute(f"DROP TABLE IF EXISTS {self.fts_table};")
|
cur.execute(f"DROP TABLE IF EXISTS [{self.fts_table}];")
|
||||||
connection.commit()
|
connection.commit()
|
||||||
finally:
|
finally:
|
||||||
cur.close()
|
cur.close()
|
||||||
|
@ -212,7 +217,7 @@ class SQLiteVecIndex(EmbeddingIndex):
|
||||||
metadata_data = [(chunk.chunk_id, chunk.model_dump_json()) for chunk in batch_chunks]
|
metadata_data = [(chunk.chunk_id, chunk.model_dump_json()) for chunk in batch_chunks]
|
||||||
cur.executemany(
|
cur.executemany(
|
||||||
f"""
|
f"""
|
||||||
INSERT INTO {self.metadata_table} (id, chunk)
|
INSERT INTO [{self.metadata_table}] (id, chunk)
|
||||||
VALUES (?, ?)
|
VALUES (?, ?)
|
||||||
ON CONFLICT(id) DO UPDATE SET chunk = excluded.chunk;
|
ON CONFLICT(id) DO UPDATE SET chunk = excluded.chunk;
|
||||||
""",
|
""",
|
||||||
|
@ -230,7 +235,7 @@ class SQLiteVecIndex(EmbeddingIndex):
|
||||||
for chunk, emb in zip(batch_chunks, batch_embeddings, strict=True)
|
for chunk, emb in zip(batch_chunks, batch_embeddings, strict=True)
|
||||||
]
|
]
|
||||||
cur.executemany(
|
cur.executemany(
|
||||||
f"INSERT INTO {self.vector_table} (id, embedding) VALUES (?, ?);",
|
f"INSERT INTO [{self.vector_table}] (id, embedding) VALUES (?, ?);",
|
||||||
embedding_data,
|
embedding_data,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -238,13 +243,13 @@ class SQLiteVecIndex(EmbeddingIndex):
|
||||||
fts_data = [(chunk.chunk_id, chunk.content) for chunk in batch_chunks]
|
fts_data = [(chunk.chunk_id, chunk.content) for chunk in batch_chunks]
|
||||||
# DELETE existing entries with same IDs (FTS5 doesn't support ON CONFLICT)
|
# DELETE existing entries with same IDs (FTS5 doesn't support ON CONFLICT)
|
||||||
cur.executemany(
|
cur.executemany(
|
||||||
f"DELETE FROM {self.fts_table} WHERE id = ?;",
|
f"DELETE FROM [{self.fts_table}] WHERE id = ?;",
|
||||||
[(row[0],) for row in fts_data],
|
[(row[0],) for row in fts_data],
|
||||||
)
|
)
|
||||||
|
|
||||||
# INSERT new entries
|
# INSERT new entries
|
||||||
cur.executemany(
|
cur.executemany(
|
||||||
f"INSERT INTO {self.fts_table} (id, content) VALUES (?, ?);",
|
f"INSERT INTO [{self.fts_table}] (id, content) VALUES (?, ?);",
|
||||||
fts_data,
|
fts_data,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -280,8 +285,8 @@ class SQLiteVecIndex(EmbeddingIndex):
|
||||||
emb_blob = serialize_vector(emb_list)
|
emb_blob = serialize_vector(emb_list)
|
||||||
query_sql = f"""
|
query_sql = f"""
|
||||||
SELECT m.id, m.chunk, v.distance
|
SELECT m.id, m.chunk, v.distance
|
||||||
FROM {self.vector_table} AS v
|
FROM [{self.vector_table}] AS v
|
||||||
JOIN {self.metadata_table} AS m ON m.id = v.id
|
JOIN [{self.metadata_table}] AS m ON m.id = v.id
|
||||||
WHERE v.embedding MATCH ? AND k = ?
|
WHERE v.embedding MATCH ? AND k = ?
|
||||||
ORDER BY v.distance;
|
ORDER BY v.distance;
|
||||||
"""
|
"""
|
||||||
|
@ -322,9 +327,9 @@ class SQLiteVecIndex(EmbeddingIndex):
|
||||||
cur = connection.cursor()
|
cur = connection.cursor()
|
||||||
try:
|
try:
|
||||||
query_sql = f"""
|
query_sql = f"""
|
||||||
SELECT DISTINCT m.id, m.chunk, bm25({self.fts_table}) AS score
|
SELECT DISTINCT m.id, m.chunk, bm25([{self.fts_table}]) AS score
|
||||||
FROM {self.fts_table} AS f
|
FROM [{self.fts_table}] AS f
|
||||||
JOIN {self.metadata_table} AS m ON m.id = f.id
|
JOIN [{self.metadata_table}] AS m ON m.id = f.id
|
||||||
WHERE f.content MATCH ?
|
WHERE f.content MATCH ?
|
||||||
ORDER BY score ASC
|
ORDER BY score ASC
|
||||||
LIMIT ?;
|
LIMIT ?;
|
||||||
|
|
|
@ -37,7 +37,7 @@ def loop():
|
||||||
async def sqlite_vec_index(embedding_dimension, tmp_path_factory):
|
async def sqlite_vec_index(embedding_dimension, tmp_path_factory):
|
||||||
temp_dir = tmp_path_factory.getbasetemp()
|
temp_dir = tmp_path_factory.getbasetemp()
|
||||||
db_path = str(temp_dir / "test_sqlite.db")
|
db_path = str(temp_dir / "test_sqlite.db")
|
||||||
index = await SQLiteVecIndex.create(dimension=embedding_dimension, db_path=db_path, bank_id="test_bank")
|
index = await SQLiteVecIndex.create(dimension=embedding_dimension, db_path=db_path, bank_id="test_bank.123")
|
||||||
yield index
|
yield index
|
||||||
await index.delete()
|
await index.delete()
|
||||||
|
|
||||||
|
@ -110,7 +110,7 @@ async def test_chunk_id_conflict(sqlite_vec_index, sample_chunks, embedding_dime
|
||||||
cur = connection.cursor()
|
cur = connection.cursor()
|
||||||
|
|
||||||
# Retrieve all chunk IDs to check for duplicates
|
# Retrieve all chunk IDs to check for duplicates
|
||||||
cur.execute(f"SELECT id FROM {sqlite_vec_index.metadata_table}")
|
cur.execute(f"SELECT id FROM [{sqlite_vec_index.metadata_table}]")
|
||||||
chunk_ids = [row[0] for row in cur.fetchall()]
|
chunk_ids = [row[0] for row in cur.fetchall()]
|
||||||
cur.close()
|
cur.close()
|
||||||
connection.close()
|
connection.close()
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue