forked from phoenix-oss/llama-stack-mirror
Fix broken pgvector provider and memory leaks (#947)
This PR fixes the broken pgvector provider as well as wraps all cursor object creations with context manager to ensure that they get properly closed to avoid potential memory leaks. ``` > pytest llama_stack/providers/tests/vector_io/test_vector_io.py -m "pgvector" --env EMBEDDING_DIMENSION=384 --env PGVECTOR_PORT=7432 --env PGVECTOR_DB=db --env PGVECTOR_USER=user --env PGVECTOR_PASSWORD=pass -v -s --tb=short --disable-warnings llama_stack/providers/tests/vector_io/test_vector_io.py::TestVectorIO::test_banks_list[-pgvector] PASSED llama_stack/providers/tests/vector_io/test_vector_io.py::TestVectorIO::test_banks_register[-pgvector] PASSED llama_stack/providers/tests/vector_io/test_vector_io.py::TestVectorIO::test_query_documents[-pgvector] The scores are: [0.8168284974053789, 0.8080469278964486, 0.8050996198466661] PASSED ``` --------- Signed-off-by: Yuan Tang <terrytangyuan@gmail.com>
This commit is contained in:
parent
5c8e35a9e2
commit
a79a083e39
4 changed files with 73 additions and 62 deletions
|
@ -12,8 +12,8 @@ from .config import PGVectorConfig
|
|||
|
||||
|
||||
async def get_adapter_impl(config: PGVectorConfig, deps: Dict[Api, ProviderSpec]):
|
||||
from .pgvector import PGVectorMemoryAdapter
|
||||
from .pgvector import PGVectorVectorDBAdapter
|
||||
|
||||
impl = PGVectorMemoryAdapter(config, deps[Api.inference])
|
||||
impl = PGVectorVectorDBAdapter(config, deps[Api.inference])
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
@ -35,18 +35,19 @@ def check_extension_version(cur):
|
|||
return result[0] if result else None
|
||||
|
||||
|
||||
def upsert_models(cur, keys_models: List[Tuple[str, BaseModel]]):
|
||||
query = sql.SQL(
|
||||
def upsert_models(conn, keys_models: List[Tuple[str, BaseModel]]):
|
||||
with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
||||
query = sql.SQL(
|
||||
"""
|
||||
INSERT INTO metadata_store (key, data)
|
||||
VALUES %s
|
||||
ON CONFLICT (key) DO UPDATE
|
||||
SET data = EXCLUDED.data
|
||||
"""
|
||||
INSERT INTO metadata_store (key, data)
|
||||
VALUES %s
|
||||
ON CONFLICT (key) DO UPDATE
|
||||
SET data = EXCLUDED.data
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
values = [(key, Json(model.model_dump())) for key, model in keys_models]
|
||||
execute_values(cur, query, values, template="(%s, %s)")
|
||||
values = [(key, Json(model.model_dump())) for key, model in keys_models]
|
||||
execute_values(cur, query, values, template="(%s, %s)")
|
||||
|
||||
|
||||
def load_models(cur, cls):
|
||||
|
@ -56,19 +57,20 @@ def load_models(cur, cls):
|
|||
|
||||
|
||||
class PGVectorIndex(EmbeddingIndex):
|
||||
def __init__(self, vector_db: VectorDB, dimension: int, cursor):
|
||||
self.cursor = cursor
|
||||
self.table_name = f"vector_store_{vector_db.identifier}"
|
||||
def __init__(self, vector_db: VectorDB, dimension: int, conn):
|
||||
self.conn = conn
|
||||
with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
||||
self.table_name = f"vector_store_{vector_db.identifier}"
|
||||
|
||||
self.cursor.execute(
|
||||
f"""
|
||||
CREATE TABLE IF NOT EXISTS {self.table_name} (
|
||||
id TEXT PRIMARY KEY,
|
||||
document JSONB,
|
||||
embedding vector({dimension})
|
||||
cur.execute(
|
||||
f"""
|
||||
CREATE TABLE IF NOT EXISTS {self.table_name} (
|
||||
id TEXT PRIMARY KEY,
|
||||
document JSONB,
|
||||
embedding vector({dimension})
|
||||
)
|
||||
"""
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
|
||||
assert len(chunks) == len(embeddings), (
|
||||
|
@ -79,7 +81,7 @@ class PGVectorIndex(EmbeddingIndex):
|
|||
for i, chunk in enumerate(chunks):
|
||||
values.append(
|
||||
(
|
||||
f"{chunk.document_id}:chunk-{i}",
|
||||
f"{chunk.metadata['document_id']}:chunk-{i}",
|
||||
Json(chunk.model_dump()),
|
||||
embeddings[i].tolist(),
|
||||
)
|
||||
|
@ -92,37 +94,39 @@ class PGVectorIndex(EmbeddingIndex):
|
|||
ON CONFLICT (id) DO UPDATE SET embedding = EXCLUDED.embedding, document = EXCLUDED.document
|
||||
"""
|
||||
)
|
||||
execute_values(self.cursor, query, values, template="(%s, %s, %s::vector)")
|
||||
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
||||
execute_values(cur, query, values, template="(%s, %s, %s::vector)")
|
||||
|
||||
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
self.cursor.execute(
|
||||
f"""
|
||||
SELECT document, embedding <-> %s::vector AS distance
|
||||
FROM {self.table_name}
|
||||
ORDER BY distance
|
||||
LIMIT %s
|
||||
""",
|
||||
(embedding.tolist(), k),
|
||||
)
|
||||
results = self.cursor.fetchall()
|
||||
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
||||
cur.execute(
|
||||
f"""
|
||||
SELECT document, embedding <-> %s::vector AS distance
|
||||
FROM {self.table_name}
|
||||
ORDER BY distance
|
||||
LIMIT %s
|
||||
""",
|
||||
(embedding.tolist(), k),
|
||||
)
|
||||
results = cur.fetchall()
|
||||
|
||||
chunks = []
|
||||
scores = []
|
||||
for doc, dist in results:
|
||||
chunks.append(Chunk(**doc))
|
||||
scores.append(1.0 / float(dist))
|
||||
chunks = []
|
||||
scores = []
|
||||
for doc, dist in results:
|
||||
chunks.append(Chunk(**doc))
|
||||
scores.append(1.0 / float(dist))
|
||||
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
async def delete(self):
|
||||
self.cursor.execute(f"DROP TABLE IF EXISTS {self.table_name}")
|
||||
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
||||
cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")
|
||||
|
||||
|
||||
class PGVectorVectorDBAdapter(VectorIO, VectorDBsProtocolPrivate):
|
||||
def __init__(self, config: PGVectorConfig, inference_api: Api.inference) -> None:
|
||||
self.config = config
|
||||
self.inference_api = inference_api
|
||||
self.cursor = None
|
||||
self.conn = None
|
||||
self.cache = {}
|
||||
|
||||
|
@ -137,22 +141,21 @@ class PGVectorVectorDBAdapter(VectorIO, VectorDBsProtocolPrivate):
|
|||
password=self.config.password,
|
||||
)
|
||||
self.conn.autocommit = True
|
||||
self.cursor = self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor)
|
||||
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
||||
version = check_extension_version(cur)
|
||||
if version:
|
||||
log.info(f"Vector extension version: {version}")
|
||||
else:
|
||||
raise RuntimeError("Vector extension is not installed.")
|
||||
|
||||
version = check_extension_version(self.cursor)
|
||||
if version:
|
||||
log.info(f"Vector extension version: {version}")
|
||||
else:
|
||||
raise RuntimeError("Vector extension is not installed.")
|
||||
|
||||
self.cursor.execute(
|
||||
cur.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS metadata_store (
|
||||
key TEXT PRIMARY KEY,
|
||||
data JSONB
|
||||
)
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS metadata_store (
|
||||
key TEXT PRIMARY KEY,
|
||||
data JSONB
|
||||
)
|
||||
"""
|
||||
)
|
||||
except Exception as e:
|
||||
log.exception("Could not connect to PGVector database server")
|
||||
raise RuntimeError("Could not connect to PGVector database server") from e
|
||||
|
@ -163,9 +166,9 @@ class PGVectorVectorDBAdapter(VectorIO, VectorDBsProtocolPrivate):
|
|||
log.info("Connection to PGVector database server closed")
|
||||
|
||||
async def register_vector_db(self, vector_db: VectorDB) -> None:
|
||||
upsert_models(self.cursor, [(vector_db.identifier, vector_db)])
|
||||
upsert_models(self.conn, [(vector_db.identifier, vector_db)])
|
||||
|
||||
index = PGVectorIndex(vector_db, vector_db.embedding_dimension, self.cursor)
|
||||
index = PGVectorIndex(vector_db, vector_db.embedding_dimension, self.conn)
|
||||
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
|
||||
|
||||
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
||||
|
@ -195,6 +198,6 @@ class PGVectorVectorDBAdapter(VectorIO, VectorDBsProtocolPrivate):
|
|||
return self.cache[vector_db_id]
|
||||
|
||||
vector_db = await self.vector_db_store.get_vector_db(vector_db_id)
|
||||
index = PGVectorIndex(vector_db, vector_db.embedding_dimension, self.cursor)
|
||||
index = PGVectorIndex(vector_db, vector_db.embedding_dimension, self.conn)
|
||||
self.cache[vector_db_id] = VectorDBWithIndex(vector_db, index, self.inference_api)
|
||||
return self.cache[vector_db_id]
|
||||
|
|
|
@ -25,6 +25,14 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
|||
id="sentence_transformers",
|
||||
marks=pytest.mark.sentence_transformers,
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
"inference": "ollama",
|
||||
"vector_io": "pgvector",
|
||||
},
|
||||
id="pgvector",
|
||||
marks=pytest.mark.pgvector,
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
"inference": "ollama",
|
||||
|
@ -77,7 +85,7 @@ def pytest_generate_tests(metafunc):
|
|||
if model:
|
||||
params = [pytest.param(model, id="")]
|
||||
else:
|
||||
params = [pytest.param("all-MiniLM-L6-v2", id="")]
|
||||
params = [pytest.param("all-minilm:l6-v2", id="")]
|
||||
|
||||
metafunc.parametrize("embedding_model", params, indirect=True)
|
||||
|
||||
|
|
|
@ -17,8 +17,8 @@ from llama_stack.providers.utils.memory.vector_store import make_overlapped_chun
|
|||
|
||||
# How to run this test:
|
||||
#
|
||||
# pytest llama_stack/providers/tests/memory/test_memory.py
|
||||
# -m "sentence_transformers" --env EMBEDDING_DIMENSION=384
|
||||
# pytest llama_stack/providers/tests/vector_io/test_vector_io.py \
|
||||
# -m "pgvector" --env EMBEDDING_DIMENSION=384 PGVECTOR_PORT=7432 \
|
||||
# -v -s --tb=short --disable-warnings
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue