Fix broken pgvector provider and memory leaks (#947)

This PR fixes the broken pgvector provider as well as wraps all cursor
object creations with context manager to ensure that they get properly
closed to avoid potential memory leaks.

```
> pytest llama_stack/providers/tests/vector_io/test_vector_io.py   -m "pgvector" --env EMBEDDING_DIMENSION=384 --env PGVECTOR_PORT=7432 --env PGVECTOR_DB=db --env PGVECTOR_USER=user --env PGVECTOR_PASSWORD=pass   -v -s --tb=short --disable-warnings

llama_stack/providers/tests/vector_io/test_vector_io.py::TestVectorIO::test_banks_list[-pgvector] PASSED
llama_stack/providers/tests/vector_io/test_vector_io.py::TestVectorIO::test_banks_register[-pgvector] PASSED
llama_stack/providers/tests/vector_io/test_vector_io.py::TestVectorIO::test_query_documents[-pgvector] The scores are: [0.8168284974053789, 0.8080469278964486, 0.8050996198466661]
PASSED
```

---------

Signed-off-by: Yuan Tang <terrytangyuan@gmail.com>
This commit is contained in:
Yuan Tang 2025-02-05 12:32:05 -05:00 committed by GitHub
parent 5c8e35a9e2
commit a79a083e39
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 73 additions and 62 deletions

View file

@ -12,8 +12,8 @@ from .config import PGVectorConfig
async def get_adapter_impl(config: PGVectorConfig, deps: Dict[Api, ProviderSpec]): async def get_adapter_impl(config: PGVectorConfig, deps: Dict[Api, ProviderSpec]):
from .pgvector import PGVectorMemoryAdapter from .pgvector import PGVectorVectorDBAdapter
impl = PGVectorMemoryAdapter(config, deps[Api.inference]) impl = PGVectorVectorDBAdapter(config, deps[Api.inference])
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -35,18 +35,19 @@ def check_extension_version(cur):
return result[0] if result else None return result[0] if result else None
def upsert_models(cur, keys_models: List[Tuple[str, BaseModel]]): def upsert_models(conn, keys_models: List[Tuple[str, BaseModel]]):
query = sql.SQL( with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
query = sql.SQL(
"""
INSERT INTO metadata_store (key, data)
VALUES %s
ON CONFLICT (key) DO UPDATE
SET data = EXCLUDED.data
""" """
INSERT INTO metadata_store (key, data) )
VALUES %s
ON CONFLICT (key) DO UPDATE
SET data = EXCLUDED.data
"""
)
values = [(key, Json(model.model_dump())) for key, model in keys_models] values = [(key, Json(model.model_dump())) for key, model in keys_models]
execute_values(cur, query, values, template="(%s, %s)") execute_values(cur, query, values, template="(%s, %s)")
def load_models(cur, cls): def load_models(cur, cls):
@ -56,19 +57,20 @@ def load_models(cur, cls):
class PGVectorIndex(EmbeddingIndex): class PGVectorIndex(EmbeddingIndex):
def __init__(self, vector_db: VectorDB, dimension: int, cursor): def __init__(self, vector_db: VectorDB, dimension: int, conn):
self.cursor = cursor self.conn = conn
self.table_name = f"vector_store_{vector_db.identifier}" with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
self.table_name = f"vector_store_{vector_db.identifier}"
self.cursor.execute( cur.execute(
f""" f"""
CREATE TABLE IF NOT EXISTS {self.table_name} ( CREATE TABLE IF NOT EXISTS {self.table_name} (
id TEXT PRIMARY KEY, id TEXT PRIMARY KEY,
document JSONB, document JSONB,
embedding vector({dimension}) embedding vector({dimension})
)
"""
) )
"""
)
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray): async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
assert len(chunks) == len(embeddings), ( assert len(chunks) == len(embeddings), (
@ -79,7 +81,7 @@ class PGVectorIndex(EmbeddingIndex):
for i, chunk in enumerate(chunks): for i, chunk in enumerate(chunks):
values.append( values.append(
( (
f"{chunk.document_id}:chunk-{i}", f"{chunk.metadata['document_id']}:chunk-{i}",
Json(chunk.model_dump()), Json(chunk.model_dump()),
embeddings[i].tolist(), embeddings[i].tolist(),
) )
@ -92,37 +94,39 @@ class PGVectorIndex(EmbeddingIndex):
ON CONFLICT (id) DO UPDATE SET embedding = EXCLUDED.embedding, document = EXCLUDED.document ON CONFLICT (id) DO UPDATE SET embedding = EXCLUDED.embedding, document = EXCLUDED.document
""" """
) )
execute_values(self.cursor, query, values, template="(%s, %s, %s::vector)") with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
execute_values(cur, query, values, template="(%s, %s, %s::vector)")
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
self.cursor.execute( with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
f""" cur.execute(
SELECT document, embedding <-> %s::vector AS distance f"""
FROM {self.table_name} SELECT document, embedding <-> %s::vector AS distance
ORDER BY distance FROM {self.table_name}
LIMIT %s ORDER BY distance
""", LIMIT %s
(embedding.tolist(), k), """,
) (embedding.tolist(), k),
results = self.cursor.fetchall() )
results = cur.fetchall()
chunks = [] chunks = []
scores = [] scores = []
for doc, dist in results: for doc, dist in results:
chunks.append(Chunk(**doc)) chunks.append(Chunk(**doc))
scores.append(1.0 / float(dist)) scores.append(1.0 / float(dist))
return QueryChunksResponse(chunks=chunks, scores=scores) return QueryChunksResponse(chunks=chunks, scores=scores)
async def delete(self): async def delete(self):
self.cursor.execute(f"DROP TABLE IF EXISTS {self.table_name}") with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")
class PGVectorVectorDBAdapter(VectorIO, VectorDBsProtocolPrivate): class PGVectorVectorDBAdapter(VectorIO, VectorDBsProtocolPrivate):
def __init__(self, config: PGVectorConfig, inference_api: Api.inference) -> None: def __init__(self, config: PGVectorConfig, inference_api: Api.inference) -> None:
self.config = config self.config = config
self.inference_api = inference_api self.inference_api = inference_api
self.cursor = None
self.conn = None self.conn = None
self.cache = {} self.cache = {}
@ -137,22 +141,21 @@ class PGVectorVectorDBAdapter(VectorIO, VectorDBsProtocolPrivate):
password=self.config.password, password=self.config.password,
) )
self.conn.autocommit = True self.conn.autocommit = True
self.cursor = self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
version = check_extension_version(cur)
if version:
log.info(f"Vector extension version: {version}")
else:
raise RuntimeError("Vector extension is not installed.")
version = check_extension_version(self.cursor) cur.execute(
if version: """
log.info(f"Vector extension version: {version}") CREATE TABLE IF NOT EXISTS metadata_store (
else: key TEXT PRIMARY KEY,
raise RuntimeError("Vector extension is not installed.") data JSONB
)
self.cursor.execute(
""" """
CREATE TABLE IF NOT EXISTS metadata_store (
key TEXT PRIMARY KEY,
data JSONB
) )
"""
)
except Exception as e: except Exception as e:
log.exception("Could not connect to PGVector database server") log.exception("Could not connect to PGVector database server")
raise RuntimeError("Could not connect to PGVector database server") from e raise RuntimeError("Could not connect to PGVector database server") from e
@ -163,9 +166,9 @@ class PGVectorVectorDBAdapter(VectorIO, VectorDBsProtocolPrivate):
log.info("Connection to PGVector database server closed") log.info("Connection to PGVector database server closed")
async def register_vector_db(self, vector_db: VectorDB) -> None: async def register_vector_db(self, vector_db: VectorDB) -> None:
upsert_models(self.cursor, [(vector_db.identifier, vector_db)]) upsert_models(self.conn, [(vector_db.identifier, vector_db)])
index = PGVectorIndex(vector_db, vector_db.embedding_dimension, self.cursor) index = PGVectorIndex(vector_db, vector_db.embedding_dimension, self.conn)
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api) self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
async def unregister_vector_db(self, vector_db_id: str) -> None: async def unregister_vector_db(self, vector_db_id: str) -> None:
@ -195,6 +198,6 @@ class PGVectorVectorDBAdapter(VectorIO, VectorDBsProtocolPrivate):
return self.cache[vector_db_id] return self.cache[vector_db_id]
vector_db = await self.vector_db_store.get_vector_db(vector_db_id) vector_db = await self.vector_db_store.get_vector_db(vector_db_id)
index = PGVectorIndex(vector_db, vector_db.embedding_dimension, self.cursor) index = PGVectorIndex(vector_db, vector_db.embedding_dimension, self.conn)
self.cache[vector_db_id] = VectorDBWithIndex(vector_db, index, self.inference_api) self.cache[vector_db_id] = VectorDBWithIndex(vector_db, index, self.inference_api)
return self.cache[vector_db_id] return self.cache[vector_db_id]

View file

@ -25,6 +25,14 @@ DEFAULT_PROVIDER_COMBINATIONS = [
id="sentence_transformers", id="sentence_transformers",
marks=pytest.mark.sentence_transformers, marks=pytest.mark.sentence_transformers,
), ),
pytest.param(
{
"inference": "ollama",
"vector_io": "pgvector",
},
id="pgvector",
marks=pytest.mark.pgvector,
),
pytest.param( pytest.param(
{ {
"inference": "ollama", "inference": "ollama",
@ -77,7 +85,7 @@ def pytest_generate_tests(metafunc):
if model: if model:
params = [pytest.param(model, id="")] params = [pytest.param(model, id="")]
else: else:
params = [pytest.param("all-MiniLM-L6-v2", id="")] params = [pytest.param("all-minilm:l6-v2", id="")]
metafunc.parametrize("embedding_model", params, indirect=True) metafunc.parametrize("embedding_model", params, indirect=True)

View file

@ -17,8 +17,8 @@ from llama_stack.providers.utils.memory.vector_store import make_overlapped_chun
# How to run this test: # How to run this test:
# #
# pytest llama_stack/providers/tests/memory/test_memory.py # pytest llama_stack/providers/tests/vector_io/test_vector_io.py \
# -m "sentence_transformers" --env EMBEDDING_DIMENSION=384 # -m "pgvector" --env EMBEDDING_DIMENSION=384 PGVECTOR_PORT=7432 \
# -v -s --tb=short --disable-warnings # -v -s --tb=short --disable-warnings