fix: SQLiteVecIndex.create(..., bank_id="test_bank.123") - bank_id with a dot - leads to sqlite3.OperationalError (#2770) (#2771)

# What does this PR do?
Resolves https://github.com/meta-llama/llama-stack/issues/2770. It
replaces characters in SQLite table names that are not alphanumeric or
underscores with underscores and quotes the table names with square
brackets in SQL statements.

Closes #[2770]

## Test Plan
I added a ".123" suffix to the bank_id on the following line
```
    index = await SQLiteVecIndex.create(dimension=embedding_dimension, db_path=db_path, bank_id="test_bank.123")
```
in tests/unit/providers/vector_io/test_sqlite_vec.py, which, without the
fix in place, demonstrates the issue.
This commit is contained in:
Sergey Yedrikov 2025-07-16 11:25:44 -04:00 committed by GitHub
parent 72e606355d
commit 30be1fd8b7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 25 additions and 20 deletions

View file

@ -7,6 +7,7 @@
import asyncio
import json
import logging
import re
import sqlite3
import struct
from typing import Any
@ -117,6 +118,10 @@ def _rrf_rerank(
return rrf_scores
def _make_sql_identifier(name: str) -> str:
return re.sub(r"[^a-zA-Z0-9_]", "_", name)
class SQLiteVecIndex(EmbeddingIndex):
"""
An index implementation that stores embeddings in a SQLite virtual table using sqlite-vec.
@ -130,9 +135,9 @@ class SQLiteVecIndex(EmbeddingIndex):
self.dimension = dimension
self.db_path = db_path
self.bank_id = bank_id
self.metadata_table = f"chunks_{bank_id}".replace("-", "_")
self.vector_table = f"vec_chunks_{bank_id}".replace("-", "_")
self.fts_table = f"fts_chunks_{bank_id}".replace("-", "_")
self.metadata_table = _make_sql_identifier(f"chunks_{bank_id}")
self.vector_table = _make_sql_identifier(f"vec_chunks_{bank_id}")
self.fts_table = _make_sql_identifier(f"fts_chunks_{bank_id}")
self.kvstore = kvstore
@classmethod
@ -148,14 +153,14 @@ class SQLiteVecIndex(EmbeddingIndex):
try:
# Create the table to store chunk metadata.
cur.execute(f"""
CREATE TABLE IF NOT EXISTS {self.metadata_table} (
CREATE TABLE IF NOT EXISTS [{self.metadata_table}] (
id TEXT PRIMARY KEY,
chunk TEXT
);
""")
# Create the virtual table for embeddings.
cur.execute(f"""
CREATE VIRTUAL TABLE IF NOT EXISTS {self.vector_table}
CREATE VIRTUAL TABLE IF NOT EXISTS [{self.vector_table}]
USING vec0(embedding FLOAT[{self.dimension}], id TEXT);
""")
connection.commit()
@ -163,7 +168,7 @@ class SQLiteVecIndex(EmbeddingIndex):
# based on query. Implementation of the change on client side will allow passing the search_mode option
# during initialization to make it easier to create the table that is required.
cur.execute(f"""
CREATE VIRTUAL TABLE IF NOT EXISTS {self.fts_table}
CREATE VIRTUAL TABLE IF NOT EXISTS [{self.fts_table}]
USING fts5(id, content);
""")
connection.commit()
@ -178,9 +183,9 @@ class SQLiteVecIndex(EmbeddingIndex):
connection = _create_sqlite_connection(self.db_path)
cur = connection.cursor()
try:
cur.execute(f"DROP TABLE IF EXISTS {self.metadata_table};")
cur.execute(f"DROP TABLE IF EXISTS {self.vector_table};")
cur.execute(f"DROP TABLE IF EXISTS {self.fts_table};")
cur.execute(f"DROP TABLE IF EXISTS [{self.metadata_table}];")
cur.execute(f"DROP TABLE IF EXISTS [{self.vector_table}];")
cur.execute(f"DROP TABLE IF EXISTS [{self.fts_table}];")
connection.commit()
finally:
cur.close()
@ -212,7 +217,7 @@ class SQLiteVecIndex(EmbeddingIndex):
metadata_data = [(chunk.chunk_id, chunk.model_dump_json()) for chunk in batch_chunks]
cur.executemany(
f"""
INSERT INTO {self.metadata_table} (id, chunk)
INSERT INTO [{self.metadata_table}] (id, chunk)
VALUES (?, ?)
ON CONFLICT(id) DO UPDATE SET chunk = excluded.chunk;
""",
@ -230,7 +235,7 @@ class SQLiteVecIndex(EmbeddingIndex):
for chunk, emb in zip(batch_chunks, batch_embeddings, strict=True)
]
cur.executemany(
f"INSERT INTO {self.vector_table} (id, embedding) VALUES (?, ?);",
f"INSERT INTO [{self.vector_table}] (id, embedding) VALUES (?, ?);",
embedding_data,
)
@ -238,13 +243,13 @@ class SQLiteVecIndex(EmbeddingIndex):
fts_data = [(chunk.chunk_id, chunk.content) for chunk in batch_chunks]
# DELETE existing entries with same IDs (FTS5 doesn't support ON CONFLICT)
cur.executemany(
f"DELETE FROM {self.fts_table} WHERE id = ?;",
f"DELETE FROM [{self.fts_table}] WHERE id = ?;",
[(row[0],) for row in fts_data],
)
# INSERT new entries
cur.executemany(
f"INSERT INTO {self.fts_table} (id, content) VALUES (?, ?);",
f"INSERT INTO [{self.fts_table}] (id, content) VALUES (?, ?);",
fts_data,
)
@ -280,8 +285,8 @@ class SQLiteVecIndex(EmbeddingIndex):
emb_blob = serialize_vector(emb_list)
query_sql = f"""
SELECT m.id, m.chunk, v.distance
FROM {self.vector_table} AS v
JOIN {self.metadata_table} AS m ON m.id = v.id
FROM [{self.vector_table}] AS v
JOIN [{self.metadata_table}] AS m ON m.id = v.id
WHERE v.embedding MATCH ? AND k = ?
ORDER BY v.distance;
"""
@ -322,9 +327,9 @@ class SQLiteVecIndex(EmbeddingIndex):
cur = connection.cursor()
try:
query_sql = f"""
SELECT DISTINCT m.id, m.chunk, bm25({self.fts_table}) AS score
FROM {self.fts_table} AS f
JOIN {self.metadata_table} AS m ON m.id = f.id
SELECT DISTINCT m.id, m.chunk, bm25([{self.fts_table}]) AS score
FROM [{self.fts_table}] AS f
JOIN [{self.metadata_table}] AS m ON m.id = f.id
WHERE f.content MATCH ?
ORDER BY score ASC
LIMIT ?;

View file

@ -37,7 +37,7 @@ def loop():
async def sqlite_vec_index(embedding_dimension, tmp_path_factory):
temp_dir = tmp_path_factory.getbasetemp()
db_path = str(temp_dir / "test_sqlite.db")
index = await SQLiteVecIndex.create(dimension=embedding_dimension, db_path=db_path, bank_id="test_bank")
index = await SQLiteVecIndex.create(dimension=embedding_dimension, db_path=db_path, bank_id="test_bank.123")
yield index
await index.delete()
@ -110,7 +110,7 @@ async def test_chunk_id_conflict(sqlite_vec_index, sample_chunks, embedding_dime
cur = connection.cursor()
# Retrieve all chunk IDs to check for duplicates
cur.execute(f"SELECT id FROM {sqlite_vec_index.metadata_table}")
cur.execute(f"SELECT id FROM [{sqlite_vec_index.metadata_table}]")
chunk_ids = [row[0] for row in cur.fetchall()]
cur.close()
connection.close()