diff --git a/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py b/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py index 7e8230ff9..347d7ab2b 100644 --- a/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +++ b/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py @@ -8,6 +8,7 @@ import hashlib import logging import sqlite3 import struct +import threading import uuid from typing import Any, Dict, List, Optional @@ -158,11 +159,21 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): self.config = config self.inference_api = inference_api self.cache: Dict[str, VectorDBWithIndex] = {} - self.connection: Optional[sqlite3.Connection] = None + self._local = threading.local() + + def _get_connection(self): + """Get a thread-local database connection.""" + if not hasattr(self._local, "conn"): + try: + self._local.conn = sqlite3.connect(self.config.db_path) + except Exception as e: + print(f"Error connecting to SQLite database: {e}") + raise e + return self._local.conn async def initialize(self) -> None: # Open a connection to the SQLite database (the file is specified in the config). - self.connection = sqlite3.connect(self.config.db_path, check_same_thread=False) + self.connection = self._get_connection() self.connection.enable_load_extension(True) sqlite_vec.load(self.connection) self.connection.enable_load_extension(False) @@ -185,9 +196,14 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api) async def shutdown(self) -> None: - if self.connection: - self.connection.close() - self.connection = None + # We can't access other threads' connections, so we just close our own + if hasattr(self._local, "conn"): + try: + self._local.conn.close() + except Exception as e: + print(f"Error closing SQLite connection: {e}") + finally: + del self._local.conn async def register_vector_db(self, vector_db: VectorDB) -> None: if self.connection is None: