Fix broken pgvector provider and memory leaks (#947)

This PR fixes the broken pgvector provider as well as wraps all cursor
object creations with context manager to ensure that they get properly
closed to avoid potential memory leaks.

```
> pytest llama_stack/providers/tests/vector_io/test_vector_io.py   -m "pgvector" --env EMBEDDING_DIMENSION=384 --env PGVECTOR_PORT=7432 --env PGVECTOR_DB=db --env PGVECTOR_USER=user --env PGVECTOR_PASSWORD=pass   -v -s --tb=short --disable-warnings

llama_stack/providers/tests/vector_io/test_vector_io.py::TestVectorIO::test_banks_list[-pgvector] PASSED
llama_stack/providers/tests/vector_io/test_vector_io.py::TestVectorIO::test_banks_register[-pgvector] PASSED
llama_stack/providers/tests/vector_io/test_vector_io.py::TestVectorIO::test_query_documents[-pgvector] The scores are: [0.8168284974053789, 0.8080469278964486, 0.8050996198466661]
PASSED
```

---------

Signed-off-by: Yuan Tang <terrytangyuan@gmail.com>
This commit is contained in:
Yuan Tang 2025-02-05 12:32:05 -05:00 committed by GitHub
parent 5c8e35a9e2
commit a79a083e39
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 73 additions and 62 deletions

View file

@ -12,8 +12,8 @@ from .config import PGVectorConfig
async def get_adapter_impl(config: PGVectorConfig, deps: Dict[Api, ProviderSpec]):
from .pgvector import PGVectorMemoryAdapter
from .pgvector import PGVectorVectorDBAdapter
impl = PGVectorMemoryAdapter(config, deps[Api.inference])
impl = PGVectorVectorDBAdapter(config, deps[Api.inference])
await impl.initialize()
return impl

View file

@ -35,18 +35,19 @@ def check_extension_version(cur):
return result[0] if result else None
def upsert_models(cur, keys_models: List[Tuple[str, BaseModel]]):
query = sql.SQL(
def upsert_models(conn, keys_models: List[Tuple[str, BaseModel]]):
with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
query = sql.SQL(
"""
INSERT INTO metadata_store (key, data)
VALUES %s
ON CONFLICT (key) DO UPDATE
SET data = EXCLUDED.data
"""
INSERT INTO metadata_store (key, data)
VALUES %s
ON CONFLICT (key) DO UPDATE
SET data = EXCLUDED.data
"""
)
)
values = [(key, Json(model.model_dump())) for key, model in keys_models]
execute_values(cur, query, values, template="(%s, %s)")
values = [(key, Json(model.model_dump())) for key, model in keys_models]
execute_values(cur, query, values, template="(%s, %s)")
def load_models(cur, cls):
@ -56,19 +57,20 @@ def load_models(cur, cls):
class PGVectorIndex(EmbeddingIndex):
def __init__(self, vector_db: VectorDB, dimension: int, cursor):
self.cursor = cursor
self.table_name = f"vector_store_{vector_db.identifier}"
def __init__(self, vector_db: VectorDB, dimension: int, conn):
self.conn = conn
with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
self.table_name = f"vector_store_{vector_db.identifier}"
self.cursor.execute(
f"""
CREATE TABLE IF NOT EXISTS {self.table_name} (
id TEXT PRIMARY KEY,
document JSONB,
embedding vector({dimension})
cur.execute(
f"""
CREATE TABLE IF NOT EXISTS {self.table_name} (
id TEXT PRIMARY KEY,
document JSONB,
embedding vector({dimension})
)
"""
)
"""
)
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
assert len(chunks) == len(embeddings), (
@ -79,7 +81,7 @@ class PGVectorIndex(EmbeddingIndex):
for i, chunk in enumerate(chunks):
values.append(
(
f"{chunk.document_id}:chunk-{i}",
f"{chunk.metadata['document_id']}:chunk-{i}",
Json(chunk.model_dump()),
embeddings[i].tolist(),
)
@ -92,37 +94,39 @@ class PGVectorIndex(EmbeddingIndex):
ON CONFLICT (id) DO UPDATE SET embedding = EXCLUDED.embedding, document = EXCLUDED.document
"""
)
execute_values(self.cursor, query, values, template="(%s, %s, %s::vector)")
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
execute_values(cur, query, values, template="(%s, %s, %s::vector)")
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
self.cursor.execute(
f"""
SELECT document, embedding <-> %s::vector AS distance
FROM {self.table_name}
ORDER BY distance
LIMIT %s
""",
(embedding.tolist(), k),
)
results = self.cursor.fetchall()
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
cur.execute(
f"""
SELECT document, embedding <-> %s::vector AS distance
FROM {self.table_name}
ORDER BY distance
LIMIT %s
""",
(embedding.tolist(), k),
)
results = cur.fetchall()
chunks = []
scores = []
for doc, dist in results:
chunks.append(Chunk(**doc))
scores.append(1.0 / float(dist))
chunks = []
scores = []
for doc, dist in results:
chunks.append(Chunk(**doc))
scores.append(1.0 / float(dist))
return QueryChunksResponse(chunks=chunks, scores=scores)
return QueryChunksResponse(chunks=chunks, scores=scores)
async def delete(self):
self.cursor.execute(f"DROP TABLE IF EXISTS {self.table_name}")
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")
class PGVectorVectorDBAdapter(VectorIO, VectorDBsProtocolPrivate):
def __init__(self, config: PGVectorConfig, inference_api: Api.inference) -> None:
self.config = config
self.inference_api = inference_api
self.cursor = None
self.conn = None
self.cache = {}
@ -137,22 +141,21 @@ class PGVectorVectorDBAdapter(VectorIO, VectorDBsProtocolPrivate):
password=self.config.password,
)
self.conn.autocommit = True
self.cursor = self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor)
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
version = check_extension_version(cur)
if version:
log.info(f"Vector extension version: {version}")
else:
raise RuntimeError("Vector extension is not installed.")
version = check_extension_version(self.cursor)
if version:
log.info(f"Vector extension version: {version}")
else:
raise RuntimeError("Vector extension is not installed.")
self.cursor.execute(
cur.execute(
"""
CREATE TABLE IF NOT EXISTS metadata_store (
key TEXT PRIMARY KEY,
data JSONB
)
"""
CREATE TABLE IF NOT EXISTS metadata_store (
key TEXT PRIMARY KEY,
data JSONB
)
"""
)
except Exception as e:
log.exception("Could not connect to PGVector database server")
raise RuntimeError("Could not connect to PGVector database server") from e
@ -163,9 +166,9 @@ class PGVectorVectorDBAdapter(VectorIO, VectorDBsProtocolPrivate):
log.info("Connection to PGVector database server closed")
async def register_vector_db(self, vector_db: VectorDB) -> None:
upsert_models(self.cursor, [(vector_db.identifier, vector_db)])
upsert_models(self.conn, [(vector_db.identifier, vector_db)])
index = PGVectorIndex(vector_db, vector_db.embedding_dimension, self.cursor)
index = PGVectorIndex(vector_db, vector_db.embedding_dimension, self.conn)
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
async def unregister_vector_db(self, vector_db_id: str) -> None:
@ -195,6 +198,6 @@ class PGVectorVectorDBAdapter(VectorIO, VectorDBsProtocolPrivate):
return self.cache[vector_db_id]
vector_db = await self.vector_db_store.get_vector_db(vector_db_id)
index = PGVectorIndex(vector_db, vector_db.embedding_dimension, self.cursor)
index = PGVectorIndex(vector_db, vector_db.embedding_dimension, self.conn)
self.cache[vector_db_id] = VectorDBWithIndex(vector_db, index, self.inference_api)
return self.cache[vector_db_id]

View file

@ -25,6 +25,14 @@ DEFAULT_PROVIDER_COMBINATIONS = [
id="sentence_transformers",
marks=pytest.mark.sentence_transformers,
),
pytest.param(
{
"inference": "ollama",
"vector_io": "pgvector",
},
id="pgvector",
marks=pytest.mark.pgvector,
),
pytest.param(
{
"inference": "ollama",
@ -77,7 +85,7 @@ def pytest_generate_tests(metafunc):
if model:
params = [pytest.param(model, id="")]
else:
params = [pytest.param("all-MiniLM-L6-v2", id="")]
params = [pytest.param("all-minilm:l6-v2", id="")]
metafunc.parametrize("embedding_model", params, indirect=True)

View file

@ -17,8 +17,8 @@ from llama_stack.providers.utils.memory.vector_store import make_overlapped_chun
# How to run this test:
#
# pytest llama_stack/providers/tests/memory/test_memory.py
# -m "sentence_transformers" --env EMBEDDING_DIMENSION=384
# pytest llama_stack/providers/tests/vector_io/test_vector_io.py \
# -m "pgvector" --env EMBEDDING_DIMENSION=384 PGVECTOR_PORT=7432 \
# -v -s --tb=short --disable-warnings