mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-12 13:00:39 +00:00
fix sqlite_vec by using local thread
This commit is contained in:
parent
f42dc48986
commit
6ff7ea127f
1 changed files with 21 additions and 5 deletions
|
@ -8,6 +8,7 @@ import hashlib
|
||||||
import logging
|
import logging
|
||||||
import sqlite3
|
import sqlite3
|
||||||
import struct
|
import struct
|
||||||
|
import threading
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
@ -158,11 +159,21 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.inference_api = inference_api
|
self.inference_api = inference_api
|
||||||
self.cache: Dict[str, VectorDBWithIndex] = {}
|
self.cache: Dict[str, VectorDBWithIndex] = {}
|
||||||
self.connection: Optional[sqlite3.Connection] = None
|
self._local = threading.local()
|
||||||
|
|
||||||
|
def _get_connection(self):
|
||||||
|
"""Get a thread-local database connection."""
|
||||||
|
if not hasattr(self._local, "conn"):
|
||||||
|
try:
|
||||||
|
self._local.conn = sqlite3.connect(self.config.db_path)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error connecting to SQLite database: {e}")
|
||||||
|
raise e
|
||||||
|
return self._local.conn
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
# Open a connection to the SQLite database (the file is specified in the config).
|
# Open a connection to the SQLite database (the file is specified in the config).
|
||||||
self.connection = sqlite3.connect(self.config.db_path, check_same_thread=False)
|
self.connection = self._get_connection()
|
||||||
self.connection.enable_load_extension(True)
|
self.connection.enable_load_extension(True)
|
||||||
sqlite_vec.load(self.connection)
|
sqlite_vec.load(self.connection)
|
||||||
self.connection.enable_load_extension(False)
|
self.connection.enable_load_extension(False)
|
||||||
|
@ -185,9 +196,14 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
||||||
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
|
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
if self.connection:
|
# We can't access other threads' connections, so we just close our own
|
||||||
self.connection.close()
|
if hasattr(self._local, "conn"):
|
||||||
self.connection = None
|
try:
|
||||||
|
self._local.conn.close()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error closing SQLite connection: {e}")
|
||||||
|
finally:
|
||||||
|
del self._local.conn
|
||||||
|
|
||||||
async def register_vector_db(self, vector_db: VectorDB) -> None:
|
async def register_vector_db(self, vector_db: VectorDB) -> None:
|
||||||
if self.connection is None:
|
if self.connection is None:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue