From 3e65c70b2d07acf3f73cb27edc9a19a7db0e7046 Mon Sep 17 00:00:00 2001 From: qifengleqifengle <2472846459@qq.com> Date: Mon, 14 Jul 2025 16:50:29 +0800 Subject: [PATCH] feat(vector-io): add OpenGauss vector database provider Implement OpenGauss vector database integration for Llama Stack with the following features: - Add OpenGaussVectorIOAdapter for vector storage and retrieval - Support native vector similarity search operations - Provide configuration template for easy setup - Add comprehensive unit tests - Align with the latest Llama Stack provider architecture, including KVStore and OpenAI Vector Store Mixin. The implementation allows Llama Stack users to leverage OpenGauss as an enterprise-grade vector database for RAG applications. --- .../workflows/integration-vector-io-tests.yml | 32 +- docs/source/providers/vector_io/index.md | 1 + .../providers/vector_io/remote_opengauss.md | 58 ++++ llama_stack/providers/registry/vector_io.py | 38 +++ .../remote/vector_io/opengauss/__init__.py | 18 ++ .../remote/vector_io/opengauss/config.py | 48 +++ .../remote/vector_io/opengauss/opengauss.py | 283 ++++++++++++++++++ .../vector_io/test_openai_vector_stores.py | 2 + tests/unit/providers/vector_io/conftest.py | 93 +++++- .../providers/vector_io/test_opengauss.py | 219 ++++++++++++++ 10 files changed, 790 insertions(+), 2 deletions(-) create mode 100644 docs/source/providers/vector_io/remote_opengauss.md create mode 100644 llama_stack/providers/remote/vector_io/opengauss/__init__.py create mode 100644 llama_stack/providers/remote/vector_io/opengauss/config.py create mode 100644 llama_stack/providers/remote/vector_io/opengauss/opengauss.py create mode 100644 tests/unit/providers/vector_io/test_opengauss.py diff --git a/.github/workflows/integration-vector-io-tests.yml b/.github/workflows/integration-vector-io-tests.yml index aa239572b..5e2ae0418 100644 --- a/.github/workflows/integration-vector-io-tests.yml +++ b/.github/workflows/integration-vector-io-tests.yml @@ -24,7 +24,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - vector-io-provider: ["inline::faiss", "inline::sqlite-vec", "inline::milvus", "remote::chromadb", "remote::pgvector", "remote::weaviate", "remote::qdrant"] + vector-io-provider: ["inline::faiss", "inline::sqlite-vec", "inline::milvus", "remote::chromadb", "remote::pgvector", "remote::weaviate", "remote::qdrant", "remote::opengauss"] python-version: ["3.12", "3.13"] fail-fast: false # we want to run all tests regardless of failure @@ -86,6 +86,30 @@ jobs: PGPASSWORD=llamastack psql -h localhost -U llamastack -d llamastack \ -c "CREATE EXTENSION IF NOT EXISTS vector;" + - name: Start OpenGauss DB + if: matrix.vector-io-provider == 'remote::opengauss' + run: | + docker run -d \ + --name opengauss \ + -e GS_PASSWORD=Enmo@123 \ + -e GS_DB=llamastack \ + -e GS_USER=llamastack \ + -p 5432:5432 \ + enmotech/opengauss:latest + + - name: Wait for OpenGauss to be ready + if: matrix.vector-io-provider == 'remote::opengauss' + run: | + echo "Waiting for OpenGauss to be ready..." + for i in {1..30}; do + if docker exec opengauss gsql -d llamastack -U llamastack -W Enmo@123 -c "SELECT version();" > /dev/null 2>&1; then + echo "OpenGauss is ready!" + break + fi + echo "Not ready yet... ($i)" + sleep 2 + done + - name: Setup Qdrant if: matrix.vector-io-provider == 'remote::qdrant' run: | @@ -163,6 +187,12 @@ jobs: QDRANT_URL: ${{ matrix.vector-io-provider == 'remote::qdrant' && 'http://localhost:6333' || '' }} ENABLE_WEAVIATE: ${{ matrix.vector-io-provider == 'remote::weaviate' && 'true' || '' }} WEAVIATE_CLUSTER_URL: ${{ matrix.vector-io-provider == 'remote::weaviate' && 'localhost:8080' || '' }} + ENABLE_OPENGAUSS: ${{ matrix.vector-io-provider == 'remote::opengauss' && 'true' || '' }} + OPENGAUSS_HOST: ${{ matrix.vector-io-provider == 'remote::opengauss' && 'localhost' || '' }} + OPENGAUSS_PORT: ${{ matrix.vector-io-provider == 'remote::opengauss' && '5432' || '' }} + OPENGAUSS_DB: ${{ matrix.vector-io-provider == 'remote::opengauss' && 'llamastack' || '' }} + OPENGAUSS_USER: ${{ matrix.vector-io-provider == 'remote::opengauss' && 'llamastack' || '' }} + OPENGAUSS_PASSWORD: ${{ matrix.vector-io-provider == 'remote::opengauss' && 'Enmo@123' || '' }} run: | uv run pytest -sv --stack-config="inference=inline::sentence-transformers,vector_io=${{ matrix.vector-io-provider }}" \ tests/integration/vector_io \ diff --git a/docs/source/providers/vector_io/index.md b/docs/source/providers/vector_io/index.md index 28ae523d7..acd90cb48 100644 --- a/docs/source/providers/vector_io/index.md +++ b/docs/source/providers/vector_io/index.md @@ -18,6 +18,7 @@ inline_sqlite-vec inline_sqlite_vec remote_chromadb remote_milvus +remote_opengauss remote_pgvector remote_qdrant remote_weaviate diff --git a/docs/source/providers/vector_io/remote_opengauss.md b/docs/source/providers/vector_io/remote_opengauss.md new file mode 100644 index 000000000..30291de79 --- /dev/null +++ b/docs/source/providers/vector_io/remote_opengauss.md @@ -0,0 +1,58 @@ +# remote::opengauss + +## Description + + +[OpenGauss](https://opengauss.org/en/) is a remote vector database provider for Llama Stack. It +allows you to store and query vectors directly in memory. +That means you'll get fast and efficient vector retrieval. + +## Features + +- Easy to use +- Fully integrated with Llama Stack + +## Usage + +To use OpenGauss in your Llama Stack project, follow these steps: + +1. Install the necessary dependencies. +2. Configure your Llama Stack project to use OpenGauss. +3. Start storing and querying vectors. + +## Installation + +You can install OpenGauss using docker: + +```bash +docker pull opengauss/opengauss:latest +``` +## Documentation +See [OpenGauss' documentation](https://docs.opengauss.org/en/docs/5.0.0/docs/GettingStarted/understanding-opengauss.html) for more details about OpenGauss in general. + + +## Configuration + +| Field | Type | Required | Default | Description | +|-------|------|----------|---------|-------------| +| `host` | `str \| None` | No | localhost | | +| `port` | `int \| None` | No | 5432 | | +| `db` | `str \| None` | No | postgres | | +| `user` | `str \| None` | No | postgres | | +| `password` | `str \| None` | No | mysecretpassword | | +| `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig, annotation=NoneType, required=False, default='sqlite', discriminator='type'` | No | | Config for KV store backend (SQLite only for now) | + +## Sample Configuration + +```yaml +host: ${env.OPENGAUSS_HOST:=localhost} +port: ${env.OPENGAUSS_PORT:=5432} +db: ${env.OPENGAUSS_DB} +user: ${env.OPENGAUSS_USER} +password: ${env.OPENGAUSS_PASSWORD} +kvstore: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/opengauss_registry.db + +``` + diff --git a/llama_stack/providers/registry/vector_io.py b/llama_stack/providers/registry/vector_io.py index ed170b508..bc82d5dbd 100644 --- a/llama_stack/providers/registry/vector_io.py +++ b/llama_stack/providers/registry/vector_io.py @@ -424,6 +424,44 @@ See [PGVector's documentation](https://github.com/pgvector/pgvector) for more de api_dependencies=[Api.inference], optional_api_dependencies=[Api.files], ), + remote_provider_spec( + Api.vector_io, + AdapterSpec( + adapter_type="opengauss", + pip_packages=["psycopg2-binary"], + module="llama_stack.providers.remote.vector_io.opengauss", + config_class="llama_stack.providers.remote.vector_io.opengauss.OpenGaussVectorIOConfig", + description=""" +[OpenGauss](https://opengauss.org/en/) is a remote vector database provider for Llama Stack. It +allows you to store and query vectors directly in memory. +That means you'll get fast and efficient vector retrieval. + +## Features + +- Easy to use +- Fully integrated with Llama Stack + +## Usage + +To use OpenGauss in your Llama Stack project, follow these steps: + +1. Install the necessary dependencies. +2. Configure your Llama Stack project to use OpenGauss. +3. Start storing and querying vectors. + +## Installation + +You can install OpenGauss using docker: + +```bash +docker pull opengauss/opengauss:latest +``` +## Documentation +See [OpenGauss' documentation](https://docs.opengauss.org/en/docs/5.0.0/docs/GettingStarted/understanding-opengauss.html) for more details about OpenGauss in general. +""", + ), + api_dependencies=[Api.inference], + ), remote_provider_spec( Api.vector_io, AdapterSpec( diff --git a/llama_stack/providers/remote/vector_io/opengauss/__init__.py b/llama_stack/providers/remote/vector_io/opengauss/__init__.py new file mode 100644 index 000000000..4a58cfdbc --- /dev/null +++ b/llama_stack/providers/remote/vector_io/opengauss/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_stack.providers.datatypes import Api + +from .config import OpenGaussVectorIOConfig + + +async def get_adapter_impl(config: OpenGaussVectorIOConfig, deps): + from .opengauss import OpenGaussVectorIOAdapter + + files_api = deps.get(Api.files) + impl = OpenGaussVectorIOAdapter(config, deps[Api.inference], files_api) + await impl.initialize() + return impl diff --git a/llama_stack/providers/remote/vector_io/opengauss/config.py b/llama_stack/providers/remote/vector_io/opengauss/config.py new file mode 100644 index 000000000..8c32002fe --- /dev/null +++ b/llama_stack/providers/remote/vector_io/opengauss/config.py @@ -0,0 +1,48 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any + +from pydantic import BaseModel, Field + +from llama_stack.providers.utils.kvstore.config import ( + KVStoreConfig, + SqliteKVStoreConfig, +) +from llama_stack.schema_utils import json_schema_type + + +@json_schema_type +class OpenGaussVectorIOConfig(BaseModel): + host: str | None = Field(default="localhost") + port: int | None = Field(default=5432) + db: str | None = Field(default="postgres") + user: str | None = Field(default="postgres") + password: str | None = Field(default="mysecretpassword") + kvstore: KVStoreConfig | None = Field(description="Config for KV store backend (SQLite only for now)", default=None) + + @classmethod + def sample_run_config( + cls, + __distro_dir__: str, + host: str = "${env.OPENGAUSS_HOST:=localhost}", + port: str = "${env.OPENGAUSS_PORT:=5432}", + db: str = "${env.OPENGAUSS_DB}", + user: str = "${env.OPENGAUSS_USER}", + password: str = "${env.OPENGAUSS_PASSWORD}", + **kwargs: Any, + ) -> dict[str, Any]: + return { + "host": host, + "port": port, + "db": db, + "user": user, + "password": password, + "kvstore": SqliteKVStoreConfig.sample_run_config( + __distro_dir__=__distro_dir__, + db_name="opengauss_registry.db", + ), + } diff --git a/llama_stack/providers/remote/vector_io/opengauss/opengauss.py b/llama_stack/providers/remote/vector_io/opengauss/opengauss.py new file mode 100644 index 000000000..a60d0a4a7 --- /dev/null +++ b/llama_stack/providers/remote/vector_io/opengauss/opengauss.py @@ -0,0 +1,283 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import logging +from typing import Any + +import psycopg2 +from numpy.typing import NDArray +from psycopg2 import sql +from psycopg2.extras import Json, execute_values +from pydantic import BaseModel, TypeAdapter + +from llama_stack.apis.common.errors import VectorStoreNotFoundError +from llama_stack.apis.files.files import Files +from llama_stack.apis.inference import InterleavedContent +from llama_stack.apis.vector_dbs import VectorDB +from llama_stack.apis.vector_io import ( + Chunk, + QueryChunksResponse, + VectorIO, +) +from llama_stack.providers.datatypes import VectorDBsProtocolPrivate +from llama_stack.providers.utils.kvstore import kvstore_impl +from llama_stack.providers.utils.kvstore.api import KVStore +from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin +from llama_stack.providers.utils.memory.vector_store import ( + EmbeddingIndex, + VectorDBWithIndex, +) + +from .config import OpenGaussVectorIOConfig + +log = logging.getLogger(__name__) + +VERSION = "v3" +VECTOR_DBS_PREFIX = f"vector_dbs:opengauss:{VERSION}::" +VECTOR_INDEX_PREFIX = f"vector_index:opengauss:{VERSION}::" +OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:opengauss:{VERSION}::" +OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:opengauss:{VERSION}::" +OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_contents:opengauss:{VERSION}::" + + +def upsert_models(conn, keys_models: list[tuple[str, BaseModel]]): + with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: + query = sql.SQL( + """ + MERGE INTO metadata_store AS target + USING (VALUES %s) AS source (key, data) + ON (target.key = source.key) + WHEN MATCHED THEN + UPDATE SET data = source.data + WHEN NOT MATCHED THEN + INSERT (key, data) VALUES (source.key, source.data); + """ + ) + + values = [(key, Json(model.model_dump())) for key, model in keys_models] + execute_values(cur, query, values, template="(%s, %s::JSONB)") + + +def load_models(cur, cls): + cur.execute("SELECT key, data FROM metadata_store") + rows = cur.fetchall() + return [TypeAdapter(cls).validate_python(row["data"]) for row in rows] + + +class OpenGaussIndex(EmbeddingIndex): + def __init__(self, vector_db: VectorDB, dimension: int, conn, kvstore: KVStore | None = None): + self.conn = conn + with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: + sanitized_identifier = vector_db.identifier.replace("-", "_") + self.table_name = f"vector_store_{sanitized_identifier}" + self.kvstore = kvstore + + cur.execute( + f""" + CREATE TABLE IF NOT EXISTS {self.table_name} ( + id TEXT PRIMARY KEY, + document JSONB, + embedding vector({dimension}) + ) + """ + ) + + async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray): + assert len(chunks) == len(embeddings), ( + f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}" + ) + + values = [] + for i, chunk in enumerate(chunks): + values.append( + ( + f"{chunk.chunk_id}", + Json(chunk.model_dump()), + embeddings[i].tolist(), + ) + ) + + query = sql.SQL( + f""" + MERGE INTO {self.table_name} AS target + USING (VALUES %s) AS source (id, document, embedding) + ON (target.id = source.id) + WHEN MATCHED THEN + UPDATE SET document = source.document, embedding = source.embedding + WHEN NOT MATCHED THEN + INSERT (id, document, embedding) VALUES (source.id, source.document, source.embedding); + """ + ) + with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: + execute_values(cur, query, values, template="(%s, %s::JSONB, %s::VECTOR)") + + async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: + with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: + cur.execute( + f""" + SELECT document, embedding <=> %s::VECTOR AS distance + FROM {self.table_name} + ORDER BY distance + LIMIT %s + """, + (embedding.tolist(), k), + ) + results = cur.fetchall() + + chunks = [] + scores = [] + for doc, dist in results: + score = 1.0 / float(dist) if dist != 0 else float("inf") + if score < score_threshold: + continue + chunks.append(Chunk(**doc)) + scores.append(score) + + return QueryChunksResponse(chunks=chunks, scores=scores) + + async def query_keyword( + self, + query_string: str, + k: int, + score_threshold: float, + ) -> QueryChunksResponse: + raise NotImplementedError("Keyword search is not supported in OpenGauss") + + async def query_hybrid( + self, + embedding: NDArray, + query_string: str, + k: int, + score_threshold: float, + reranker_type: str, + reranker_params: dict[str, Any] | None = None, + ) -> QueryChunksResponse: + raise NotImplementedError("Hybrid search is not supported in OpenGauss") + + async def delete(self): + with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: + cur.execute(f"DROP TABLE IF EXISTS {self.table_name}") + + async def delete_chunk(self, chunk_id: str) -> None: + with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: + cur.execute(f"DELETE FROM {self.table_name} WHERE id = %s", (chunk_id,)) + + +class OpenGaussVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate): + def __init__( + self, + config: OpenGaussVectorIOConfig, + inference_api: Any, + files_api: Files | None = None, + ) -> None: + self.config = config + self.inference_api = inference_api + self.conn = None + self.cache: dict[str, VectorDBWithIndex] = {} + self.files_api = files_api + self.kvstore: KVStore | None = None + self.vector_db_store = None + self.openai_vector_store: dict[str, dict[str, Any]] = {} + self.metadatadata_collection_name = "openai_vector_stores_metadata" + + async def initialize(self) -> None: + log.info(f"Initializing OpenGauss memory adapter with config: {self.config}") + if self.config.kvstore is not None: + self.kvstore = await kvstore_impl(self.config.kvstore) + await self.initialize_openai_vector_stores() + + try: + self.conn = psycopg2.connect( + host=self.config.host, + port=self.config.port, + database=self.config.db, + user=self.config.user, + password=self.config.password, + ) + if self.conn: + self.conn.autocommit = True + with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: + cur.execute("SELECT version();") + version = cur.fetchone()[0] + log.info(f"OpenGauss server version: {version}") + log.info("Assuming native vector support is enabled in this OpenGauss instance.") + + cur.execute( + """ + CREATE TABLE IF NOT EXISTS metadata_store ( + key TEXT PRIMARY KEY, + data JSONB + ) + """ + ) + except Exception as e: + log.exception("Could not connect to OpenGauss database server") + raise RuntimeError("Could not connect to OpenGauss database server") from e + + async def shutdown(self) -> None: + if self.conn is not None: + self.conn.close() + log.info("Connection to OpenGauss database server closed") + + async def register_vector_db(self, vector_db: VectorDB) -> None: + assert self.kvstore is not None + upsert_models(self.conn, [(vector_db.identifier, vector_db)]) + + index = VectorDBWithIndex( + vector_db, + index=OpenGaussIndex(vector_db, vector_db.embedding_dimension, self.conn, kvstore=self.kvstore), + inference_api=self.inference_api, + ) + self.cache[vector_db.identifier] = index + + async def unregister_vector_db(self, vector_db_id: str) -> None: + if vector_db_id in self.cache: + await self.cache[vector_db_id].index.delete() + del self.cache[vector_db_id] + + assert self.kvstore is not None + await self.kvstore.delete(key=f"{VECTOR_DBS_PREFIX}{vector_db_id}") + + async def insert_chunks( + self, + vector_db_id: str, + chunks: list[Chunk], + ttl_seconds: int | None = None, + ) -> None: + index = await self._get_and_cache_vector_db_index(vector_db_id) + await index.insert_chunks(chunks) + + async def query_chunks( + self, + vector_db_id: str, + query: InterleavedContent, + params: dict[str, Any] | None = None, + ) -> QueryChunksResponse: + index = await self._get_and_cache_vector_db_index(vector_db_id) + return await index.query_chunks(query, params) + + async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> VectorDBWithIndex: + if vector_db_id in self.cache: + return self.cache[vector_db_id] + + if self.vector_db_store is None: + raise RuntimeError("Vector DB store not initialized") + + vector_db = self.vector_db_store.get_vector_db(vector_db_id) + if vector_db is None: + raise VectorStoreNotFoundError(vector_db_id) + + index = OpenGaussIndex(vector_db, vector_db.embedding_dimension, self.conn) + self.cache[vector_db_id] = VectorDBWithIndex(vector_db, index, self.inference_api) + return self.cache[vector_db_id] + + async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None: + index = await self._get_and_cache_vector_db_index(store_id) + if not index: + raise VectorStoreNotFoundError(store_id) + + for chunk_id in chunk_ids: + await index.index.delete_chunk(chunk_id) diff --git a/tests/integration/vector_io/test_openai_vector_stores.py b/tests/integration/vector_io/test_openai_vector_stores.py index 3212a7568..942621e3a 100644 --- a/tests/integration/vector_io/test_openai_vector_stores.py +++ b/tests/integration/vector_io/test_openai_vector_stores.py @@ -26,6 +26,7 @@ def skip_if_provider_doesnt_support_openai_vector_stores(client_with_models): "inline::milvus", "inline::chromadb", "remote::pgvector", + "remote::opengauss", "remote::chromadb", "remote::qdrant", "inline::qdrant", @@ -47,6 +48,7 @@ def skip_if_provider_doesnt_support_openai_vector_stores_search(client_with_mode "inline::chromadb", "inline::qdrant", "remote::pgvector", + "remote::opengauss", "remote::chromadb", "remote::weaviate", "remote::qdrant", diff --git a/tests/unit/providers/vector_io/conftest.py b/tests/unit/providers/vector_io/conftest.py index f71073651..f83e92176 100644 --- a/tests/unit/providers/vector_io/conftest.py +++ b/tests/unit/providers/vector_io/conftest.py @@ -4,7 +4,9 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import os import random +from unittest.mock import AsyncMock import numpy as np import pytest @@ -22,6 +24,8 @@ from llama_stack.providers.inline.vector_io.sqlite_vec import SQLiteVectorIOConf from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import SQLiteVecIndex, SQLiteVecVectorIOAdapter from llama_stack.providers.remote.vector_io.chroma.chroma import ChromaIndex, ChromaVectorIOAdapter, maybe_await from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex, MilvusVectorIOAdapter +from llama_stack.providers.remote.vector_io.opengauss.config import OpenGaussVectorIOConfig +from llama_stack.providers.remote.vector_io.opengauss.opengauss import OpenGaussIndex, OpenGaussVectorIOAdapter from llama_stack.providers.remote.vector_io.qdrant.qdrant import QdrantVectorIOAdapter EMBEDDING_DIMENSION = 384 @@ -29,7 +33,7 @@ COLLECTION_PREFIX = "test_collection" MILVUS_ALIAS = "test_milvus" -@pytest.fixture(params=["milvus", "sqlite_vec", "faiss", "chroma"]) +@pytest.fixture(params=["milvus", "sqlite_vec", "faiss", "chroma", "opengauss"]) def vector_provider(request): return request.param @@ -333,6 +337,92 @@ async def qdrant_vec_index(qdrant_vec_db_path, embedding_dimension): await index.delete() +@pytest.fixture +def opengauss_vec_db_path(): + return { + "host": "localhost", + "port": 5432, + "db": "test_db", + "user": "test_user", + "password": "test_password", + } + + +@pytest.fixture +async def opengauss_vec_index(embedding_dimension, opengauss_vec_db_path): + mock_conn = AsyncMock() + mock_cursor = AsyncMock() + mock_conn.cursor.return_value.__enter__.return_value = mock_cursor + + vector_db = VectorDB( + identifier=f"test_opengauss_db_{np.random.randint(1e6)}", + provider_id="opengauss", + embedding_model="test_model", + embedding_dimension=embedding_dimension, + ) + + if all( + os.getenv(var) + for var in ["OPENGAUSS_HOST", "OPENGAUSS_PORT", "OPENGAUSS_DB", "OPENGAUSS_USER", "OPENGAUSS_PASSWORD"] + ): + import psycopg2 + + real_conn = psycopg2.connect(**opengauss_vec_db_path) + real_conn.autocommit = True + index = OpenGaussIndex(vector_db, embedding_dimension, real_conn) + yield index + await index.delete() + real_conn.close() + else: + index = OpenGaussIndex(vector_db, embedding_dimension, mock_conn) + yield index + + +@pytest.fixture +async def opengauss_vec_adapter(mock_inference_api, embedding_dimension, tmp_path_factory): + temp_dir = tmp_path_factory.getbasetemp() + kv_db_path = str(temp_dir / f"opengauss_kv_{np.random.randint(1e6)}.db") + + config = OpenGaussVectorIOConfig( + host=os.getenv("OPENGAUSS_HOST", "localhost"), + port=int(os.getenv("OPENGAUSS_PORT", "5432")), + db=os.getenv("OPENGAUSS_DB", "test_db"), + user=os.getenv("OPENGAUSS_USER", "test_user"), + password=os.getenv("OPENGAUSS_PASSWORD", "test_password"), + kvstore=SqliteKVStoreConfig(db_path=kv_db_path), + ) + + if all( + os.getenv(var) + for var in ["OPENGAUSS_HOST", "OPENGAUSS_PORT", "OPENGAUSS_DB", "OPENGAUSS_USER", "OPENGAUSS_PASSWORD"] + ): + adapter = OpenGaussVectorIOAdapter(config, mock_inference_api) + await adapter.initialize() + + collection_id = f"opengauss_test_collection_{np.random.randint(1e6)}" + await adapter.register_vector_db( + VectorDB( + identifier=collection_id, + provider_id="opengauss", + embedding_model="test_model", + embedding_dimension=embedding_dimension, + ) + ) + adapter.test_collection_id = collection_id + yield adapter + + try: + await adapter.unregister_vector_db(collection_id) + except Exception: + pass + await adapter.shutdown() + + if os.path.exists(kv_db_path): + os.remove(kv_db_path) + else: + pytest.skip("OpenGauss connection not available for integration testing") + + @pytest.fixture def vector_io_adapter(vector_provider, request): """Returns the appropriate vector IO adapter based on the provider parameter.""" @@ -342,6 +432,7 @@ def vector_io_adapter(vector_provider, request): "sqlite_vec": "sqlite_vec_adapter", "chroma": "chroma_vec_adapter", "qdrant": "qdrant_vec_adapter", + "opengauss": "opengauss_vec_adapter", } return request.getfixturevalue(vector_provider_dict[vector_provider]) diff --git a/tests/unit/providers/vector_io/test_opengauss.py b/tests/unit/providers/vector_io/test_opengauss.py new file mode 100644 index 000000000..5843eb91c --- /dev/null +++ b/tests/unit/providers/vector_io/test_opengauss.py @@ -0,0 +1,219 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import os +import random +from unittest.mock import AsyncMock + +import numpy as np +import pytest +import pytest_asyncio + +from llama_stack.apis.inference import EmbeddingsResponse, Inference +from llama_stack.apis.vector_dbs import VectorDB +from llama_stack.apis.vector_io import Chunk, QueryChunksResponse +from llama_stack.providers.remote.vector_io.opengauss.config import ( + OpenGaussVectorIOConfig, +) +from llama_stack.providers.remote.vector_io.opengauss.opengauss import ( + OpenGaussIndex, + OpenGaussVectorIOAdapter, +) +from llama_stack.providers.utils.kvstore.config import ( + SqliteKVStoreConfig, +) + +# Skip all tests in this file if the required environment variables are not set. +pytestmark = pytest.mark.skipif( + not all( + os.getenv(var) + for var in [ + "OPENGAUSS_HOST", + "OPENGAUSS_PORT", + "OPENGAUSS_DB", + "OPENGAUSS_USER", + "OPENGAUSS_PASSWORD", + ] + ), + reason="OpenGauss connection environment variables not set", +) + + +@pytest.fixture(scope="session") +def embedding_dimension() -> int: + return 128 + + +@pytest.fixture +def sample_chunks(): + """Provides a list of sample chunks for testing.""" + return [ + Chunk( + content="The sky is blue.", + metadata={"document_id": "doc1", "topic": "nature"}, + ), + Chunk( + content="An apple a day keeps the doctor away.", + metadata={"document_id": "doc2", "topic": "health"}, + ), + Chunk( + content="Quantum computing is a new frontier.", + metadata={"document_id": "doc3", "topic": "technology"}, + ), + ] + + +@pytest.fixture +def sample_embeddings(embedding_dimension, sample_chunks): + """Provides a deterministic set of embeddings for the sample chunks.""" + # Use a fixed seed for reproducibility + rng = np.random.default_rng(42) + return rng.random((len(sample_chunks), embedding_dimension), dtype=np.float32) + + +@pytest.fixture +def mock_inference_api(sample_embeddings): + """Mocks the inference API to return dummy embeddings.""" + mock_api = AsyncMock(spec=Inference) + mock_api.embeddings = AsyncMock(return_value=EmbeddingsResponse(embeddings=sample_embeddings.tolist())) + return mock_api + + +@pytest.fixture +def vector_db(embedding_dimension): + """Provides a sample VectorDB object for registration.""" + return VectorDB( + identifier=f"test_db_{random.randint(1, 10000)}", + embedding_model="test_embedding_model", + embedding_dimension=embedding_dimension, + provider_id="opengauss", + ) + + +@pytest_asyncio.fixture +async def opengauss_connection(): + """Creates and manages a connection to the OpenGauss database.""" + import psycopg2 + + conn = psycopg2.connect( + host=os.getenv("OPENGAUSS_HOST"), + port=int(os.getenv("OPENGAUSS_PORT")), + database=os.getenv("OPENGAUSS_DB"), + user=os.getenv("OPENGAUSS_USER"), + password=os.getenv("OPENGAUSS_PASSWORD"), + ) + conn.autocommit = True + yield conn + conn.close() + + +@pytest_asyncio.fixture +async def opengauss_index(opengauss_connection, vector_db): + """Fixture to create and clean up an OpenGaussIndex instance.""" + index = OpenGaussIndex(vector_db, vector_db.embedding_dimension, opengauss_connection) + yield index + await index.delete() + + +@pytest_asyncio.fixture +async def opengauss_adapter(mock_inference_api): + """Fixture to set up and tear down the OpenGaussVectorIOAdapter.""" + config = OpenGaussVectorIOConfig( + host=os.getenv("OPENGAUSS_HOST"), + port=int(os.getenv("OPENGAUSS_PORT")), + db=os.getenv("OPENGAUSS_DB"), + user=os.getenv("OPENGAUSS_USER"), + password=os.getenv("OPENGAUSS_PASSWORD"), + kvstore=SqliteKVStoreConfig(db_name="opengauss_test.db"), + ) + adapter = OpenGaussVectorIOAdapter(config, mock_inference_api) + await adapter.initialize() + yield adapter + if adapter.conn and not adapter.conn.closed: + for db_id in list(adapter.cache.keys()): + try: + await adapter.unregister_vector_db(db_id) + except Exception as e: + print(f"Error during cleanup of {db_id}: {e}") + await adapter.shutdown() + # Clean up the sqlite db file + if os.path.exists("opengauss_test.db"): + os.remove("opengauss_test.db") + + +@pytest.mark.asyncio +class TestOpenGaussIndex: + async def test_add_and_query_vector(self, opengauss_index, sample_chunks, sample_embeddings): + """Test adding chunks with embeddings and querying for the most similar one.""" + await opengauss_index.add_chunks(sample_chunks, sample_embeddings) + + # Query with the embedding of the first chunk + query_embedding = sample_embeddings[0] + response = await opengauss_index.query_vector(query_embedding, k=1, score_threshold=0.0) + + assert isinstance(response, QueryChunksResponse) + assert len(response.chunks) == 1 + assert response.chunks[0].content == sample_chunks[0].content + # The distance to itself should be 0, resulting in infinite score + assert response.scores[0] == float("inf") + + +@pytest.mark.asyncio +class TestOpenGaussVectorIOAdapter: + async def test_initialization(self, opengauss_adapter): + """Test that the adapter initializes and connects to the database.""" + assert opengauss_adapter.conn is not None + assert not opengauss_adapter.conn.closed + + async def test_register_and_unregister_vector_db(self, opengauss_adapter, vector_db): + """Test the registration and unregistration of a vector database.""" + await opengauss_adapter.register_vector_db(vector_db) + assert vector_db.identifier in opengauss_adapter.cache + + table_name = opengauss_adapter.cache[vector_db.identifier].index.table_name + with opengauss_adapter.conn.cursor() as cur: + cur.execute( + "SELECT EXISTS (SELECT 1 FROM pg_tables WHERE schemaname = 'public' AND tablename = %s);", + (table_name,), + ) + assert cur.fetchone()[0] + + await opengauss_adapter.unregister_vector_db(vector_db.identifier) + assert vector_db.identifier not in opengauss_adapter.cache + + with opengauss_adapter.conn.cursor() as cur: + cur.execute( + "SELECT EXISTS (SELECT 1 FROM pg_tables WHERE schemaname = 'public' AND tablename = %s);", + (table_name,), + ) + assert not cur.fetchone()[0] + + @pytest.mark.asyncio + async def test_adapter_end_to_end_query(self, opengauss_adapter, vector_db, sample_chunks): + """ + Tests the full adapter flow: text query -> embedding generation -> vector search. + """ + # 1. Register the DB and insert chunks. The adapter will use the mocked + # inference_api to generate embeddings for these chunks. + await opengauss_adapter.register_vector_db(vector_db) + await opengauss_adapter.insert_chunks(vector_db.identifier, sample_chunks) + + # 2. The user query is a text string. + query_text = "What is the color of the sky?" + + # 3. The adapter will now internally call the (mocked) inference_api + # to get an embedding for the query_text. + response = await opengauss_adapter.query_chunks(vector_db.identifier, query_text) + + # 4. Assertions + assert isinstance(response, QueryChunksResponse) + assert len(response.chunks) > 0 + + # Because the mocked inference_api returns random embeddings, we can't + # deterministically know which chunk is "closest". However, in a real + # integration test with a real model, this assertion would be more specific. + # For this unit test, we just confirm that the process completes and returns data. + assert response.chunks[0].content in [c.content for c in sample_chunks]