From 35a0a6cb7b4ea9bbf15c61f5b894a309854ba26b Mon Sep 17 00:00:00 2001 From: qifengleqifengle <2472846459@qq.com> Date: Mon, 14 Jul 2025 16:50:29 +0800 Subject: [PATCH] feat(vector-io): add OpenGauss vector database provider Implement OpenGauss vector database integration for Llama Stack with the following features: - Add OpenGaussVectorIOAdapter for vector storage and retrieval - Support native vector similarity search operations - Provide configuration template for easy setup - Add comprehensive unit tests - Align with the latest Llama Stack provider architecture, including KVStore and OpenAI Vector Store Mixin. The implementation allows Llama Stack users to leverage OpenGauss as an enterprise-grade vector database for RAG applications. --- .../workflows/integration-vector-io-tests.yml | 32 +- .../distributions/k8s-benchmark/benchmark.py | 0 docs/source/providers/agents/index.md | 12 +- docs/source/providers/batches/index.md | 8 +- docs/source/providers/inference/index.md | 6 +- docs/source/providers/vector_io/index.md | 1 + .../providers/vector_io/remote_opengauss.md | 58 ++++ llama_stack/providers/registry/vector_io.py | 38 +++ .../remote/vector_io/opengauss/__init__.py | 18 ++ .../remote/vector_io/opengauss/config.py | 48 +++ .../remote/vector_io/opengauss/opengauss.py | 286 ++++++++++++++++++ .../vector_io/test_openai_vector_stores.py | 2 + tests/unit/providers/vector_io/conftest.py | 93 +++++- .../providers/vector_io/test_opengauss.py | 215 +++++++++++++ 14 files changed, 802 insertions(+), 15 deletions(-) mode change 100644 => 100755 docs/source/distributions/k8s-benchmark/benchmark.py create mode 100644 docs/source/providers/vector_io/remote_opengauss.md create mode 100644 llama_stack/providers/remote/vector_io/opengauss/__init__.py create mode 100644 llama_stack/providers/remote/vector_io/opengauss/config.py create mode 100644 llama_stack/providers/remote/vector_io/opengauss/opengauss.py create mode 100644 tests/unit/providers/vector_io/test_opengauss.py diff --git a/.github/workflows/integration-vector-io-tests.yml b/.github/workflows/integration-vector-io-tests.yml index 61b8e004e..c915aa3e7 100644 --- a/.github/workflows/integration-vector-io-tests.yml +++ b/.github/workflows/integration-vector-io-tests.yml @@ -27,7 +27,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - vector-io-provider: ["inline::faiss", "inline::sqlite-vec", "inline::milvus", "remote::chromadb", "remote::pgvector", "remote::weaviate", "remote::qdrant"] + vector-io-provider: ["inline::faiss", "inline::sqlite-vec", "inline::milvus", "remote::chromadb", "remote::pgvector", "remote::weaviate", "remote::qdrant", "remote::opengauss"] python-version: ${{ github.event.schedule == '0 0 * * *' && fromJSON('["3.12", "3.13"]') || fromJSON('["3.12"]') }} fail-fast: false # we want to run all tests regardless of failure @@ -89,6 +89,30 @@ jobs: PGPASSWORD=llamastack psql -h localhost -U llamastack -d llamastack \ -c "CREATE EXTENSION IF NOT EXISTS vector;" + - name: Start OpenGauss DB + if: matrix.vector-io-provider == 'remote::opengauss' + run: | + docker run -d \ + --name opengauss \ + -e GS_PASSWORD=Enmo@123 \ + -e GS_DB=llamastack \ + -e GS_USER=llamastack \ + -p 5432:5432 \ + enmotech/opengauss:latest + + - name: Wait for OpenGauss to be ready + if: matrix.vector-io-provider == 'remote::opengauss' + run: | + echo "Waiting for OpenGauss to be ready..." + for i in {1..30}; do + if docker exec opengauss gsql -d llamastack -U llamastack -W Enmo@123 -c "SELECT version();" > /dev/null 2>&1; then + echo "OpenGauss is ready!" + break + fi + echo "Not ready yet... ($i)" + sleep 2 + done + - name: Setup Qdrant if: matrix.vector-io-provider == 'remote::qdrant' run: | @@ -166,6 +190,12 @@ jobs: QDRANT_URL: ${{ matrix.vector-io-provider == 'remote::qdrant' && 'http://localhost:6333' || '' }} ENABLE_WEAVIATE: ${{ matrix.vector-io-provider == 'remote::weaviate' && 'true' || '' }} WEAVIATE_CLUSTER_URL: ${{ matrix.vector-io-provider == 'remote::weaviate' && 'localhost:8080' || '' }} + ENABLE_OPENGAUSS: ${{ matrix.vector-io-provider == 'remote::opengauss' && 'true' || '' }} + OPENGAUSS_HOST: ${{ matrix.vector-io-provider == 'remote::opengauss' && 'localhost' || '' }} + OPENGAUSS_PORT: ${{ matrix.vector-io-provider == 'remote::opengauss' && '5432' || '' }} + OPENGAUSS_DB: ${{ matrix.vector-io-provider == 'remote::opengauss' && 'llamastack' || '' }} + OPENGAUSS_USER: ${{ matrix.vector-io-provider == 'remote::opengauss' && 'llamastack' || '' }} + OPENGAUSS_PASSWORD: ${{ matrix.vector-io-provider == 'remote::opengauss' && 'Enmo@123' || '' }} run: | uv run --no-sync \ pytest -sv --stack-config="files=inline::localfs,inference=inline::sentence-transformers,vector_io=${{ matrix.vector-io-provider }}" \ diff --git a/docs/source/distributions/k8s-benchmark/benchmark.py b/docs/source/distributions/k8s-benchmark/benchmark.py old mode 100644 new mode 100755 diff --git a/docs/source/providers/agents/index.md b/docs/source/providers/agents/index.md index a2c48d4b9..046db6bff 100644 --- a/docs/source/providers/agents/index.md +++ b/docs/source/providers/agents/index.md @@ -4,12 +4,12 @@ Agents API for creating and interacting with agentic systems. - Main functionalities provided by this API: - - Create agents with specific instructions and ability to use tools. - - Interactions with agents are grouped into sessions ("threads"), and each interaction is called a "turn". - - Agents can be provided with various tools (see the ToolGroups and ToolRuntime APIs for more details). - - Agents can be provided with various shields (see the Safety API for more details). - - Agents can also use Memory to retrieve information from knowledge bases. See the RAG Tool and Vector IO APIs for more details. +Main functionalities provided by this API: +- Create agents with specific instructions and ability to use tools. +- Interactions with agents are grouped into sessions ("threads"), and each interaction is called a "turn". +- Agents can be provided with various tools (see the ToolGroups and ToolRuntime APIs for more details). +- Agents can be provided with various shields (see the Safety API for more details). +- Agents can also use Memory to retrieve information from knowledge bases. See the RAG Tool and Vector IO APIs for more details. This section contains documentation for all available providers for the **agents** API. diff --git a/docs/source/providers/batches/index.md b/docs/source/providers/batches/index.md index 2a39a626c..f427a599b 100644 --- a/docs/source/providers/batches/index.md +++ b/docs/source/providers/batches/index.md @@ -4,11 +4,11 @@ Protocol for batch processing API operations. - The Batches API enables efficient processing of multiple requests in a single operation, - particularly useful for processing large datasets, batch evaluation workflows, and - cost-effective inference at scale. +The Batches API enables efficient processing of multiple requests in a single operation, +particularly useful for processing large datasets, batch evaluation workflows, and +cost-effective inference at scale. - Note: This API is currently under active development and may undergo changes. +Note: This API is currently under active development and may undergo changes. This section contains documentation for all available providers for the **batches** API. diff --git a/docs/source/providers/inference/index.md b/docs/source/providers/inference/index.md index b6d215474..291e8e525 100644 --- a/docs/source/providers/inference/index.md +++ b/docs/source/providers/inference/index.md @@ -4,9 +4,9 @@ Llama Stack Inference API for generating completions, chat completions, and embeddings. - This API provides the raw interface to the underlying models. Two kinds of models are supported: - - LLM models: these models generate "raw" and "chat" (conversational) completions. - - Embedding models: these models generate embeddings to be used for semantic search. +This API provides the raw interface to the underlying models. Two kinds of models are supported: +- LLM models: these models generate "raw" and "chat" (conversational) completions. +- Embedding models: these models generate embeddings to be used for semantic search. This section contains documentation for all available providers for the **inference** API. diff --git a/docs/source/providers/vector_io/index.md b/docs/source/providers/vector_io/index.md index 28ae523d7..acd90cb48 100644 --- a/docs/source/providers/vector_io/index.md +++ b/docs/source/providers/vector_io/index.md @@ -18,6 +18,7 @@ inline_sqlite-vec inline_sqlite_vec remote_chromadb remote_milvus +remote_opengauss remote_pgvector remote_qdrant remote_weaviate diff --git a/docs/source/providers/vector_io/remote_opengauss.md b/docs/source/providers/vector_io/remote_opengauss.md new file mode 100644 index 000000000..30291de79 --- /dev/null +++ b/docs/source/providers/vector_io/remote_opengauss.md @@ -0,0 +1,58 @@ +# remote::opengauss + +## Description + + +[OpenGauss](https://opengauss.org/en/) is a remote vector database provider for Llama Stack. It +allows you to store and query vectors directly in memory. +That means you'll get fast and efficient vector retrieval. + +## Features + +- Easy to use +- Fully integrated with Llama Stack + +## Usage + +To use OpenGauss in your Llama Stack project, follow these steps: + +1. Install the necessary dependencies. +2. Configure your Llama Stack project to use OpenGauss. +3. Start storing and querying vectors. + +## Installation + +You can install OpenGauss using docker: + +```bash +docker pull opengauss/opengauss:latest +``` +## Documentation +See [OpenGauss' documentation](https://docs.opengauss.org/en/docs/5.0.0/docs/GettingStarted/understanding-opengauss.html) for more details about OpenGauss in general. + + +## Configuration + +| Field | Type | Required | Default | Description | +|-------|------|----------|---------|-------------| +| `host` | `str \| None` | No | localhost | | +| `port` | `int \| None` | No | 5432 | | +| `db` | `str \| None` | No | postgres | | +| `user` | `str \| None` | No | postgres | | +| `password` | `str \| None` | No | mysecretpassword | | +| `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig, annotation=NoneType, required=False, default='sqlite', discriminator='type'` | No | | Config for KV store backend (SQLite only for now) | + +## Sample Configuration + +```yaml +host: ${env.OPENGAUSS_HOST:=localhost} +port: ${env.OPENGAUSS_PORT:=5432} +db: ${env.OPENGAUSS_DB} +user: ${env.OPENGAUSS_USER} +password: ${env.OPENGAUSS_PASSWORD} +kvstore: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/opengauss_registry.db + +``` + diff --git a/llama_stack/providers/registry/vector_io.py b/llama_stack/providers/registry/vector_io.py index 70148eb15..d59ff4aa4 100644 --- a/llama_stack/providers/registry/vector_io.py +++ b/llama_stack/providers/registry/vector_io.py @@ -426,6 +426,44 @@ See [PGVector's documentation](https://github.com/pgvector/pgvector) for more de api_dependencies=[Api.inference], optional_api_dependencies=[Api.files], ), + remote_provider_spec( + Api.vector_io, + AdapterSpec( + adapter_type="opengauss", + pip_packages=["psycopg2-binary"], + module="llama_stack.providers.remote.vector_io.opengauss", + config_class="llama_stack.providers.remote.vector_io.opengauss.OpenGaussVectorIOConfig", + description=""" +[OpenGauss](https://opengauss.org/en/) is a remote vector database provider for Llama Stack. It +allows you to store and query vectors directly in memory. +That means you'll get fast and efficient vector retrieval. + +## Features + +- Easy to use +- Fully integrated with Llama Stack + +## Usage + +To use OpenGauss in your Llama Stack project, follow these steps: + +1. Install the necessary dependencies. +2. Configure your Llama Stack project to use OpenGauss. +3. Start storing and querying vectors. + +## Installation + +You can install OpenGauss using docker: + +```bash +docker pull opengauss/opengauss:latest +``` +## Documentation +See [OpenGauss' documentation](https://docs.opengauss.org/en/docs/5.0.0/docs/GettingStarted/understanding-opengauss.html) for more details about OpenGauss in general. +""", + ), + api_dependencies=[Api.inference], + ), remote_provider_spec( Api.vector_io, AdapterSpec( diff --git a/llama_stack/providers/remote/vector_io/opengauss/__init__.py b/llama_stack/providers/remote/vector_io/opengauss/__init__.py new file mode 100644 index 000000000..4a58cfdbc --- /dev/null +++ b/llama_stack/providers/remote/vector_io/opengauss/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_stack.providers.datatypes import Api + +from .config import OpenGaussVectorIOConfig + + +async def get_adapter_impl(config: OpenGaussVectorIOConfig, deps): + from .opengauss import OpenGaussVectorIOAdapter + + files_api = deps.get(Api.files) + impl = OpenGaussVectorIOAdapter(config, deps[Api.inference], files_api) + await impl.initialize() + return impl diff --git a/llama_stack/providers/remote/vector_io/opengauss/config.py b/llama_stack/providers/remote/vector_io/opengauss/config.py new file mode 100644 index 000000000..8c32002fe --- /dev/null +++ b/llama_stack/providers/remote/vector_io/opengauss/config.py @@ -0,0 +1,48 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any + +from pydantic import BaseModel, Field + +from llama_stack.providers.utils.kvstore.config import ( + KVStoreConfig, + SqliteKVStoreConfig, +) +from llama_stack.schema_utils import json_schema_type + + +@json_schema_type +class OpenGaussVectorIOConfig(BaseModel): + host: str | None = Field(default="localhost") + port: int | None = Field(default=5432) + db: str | None = Field(default="postgres") + user: str | None = Field(default="postgres") + password: str | None = Field(default="mysecretpassword") + kvstore: KVStoreConfig | None = Field(description="Config for KV store backend (SQLite only for now)", default=None) + + @classmethod + def sample_run_config( + cls, + __distro_dir__: str, + host: str = "${env.OPENGAUSS_HOST:=localhost}", + port: str = "${env.OPENGAUSS_PORT:=5432}", + db: str = "${env.OPENGAUSS_DB}", + user: str = "${env.OPENGAUSS_USER}", + password: str = "${env.OPENGAUSS_PASSWORD}", + **kwargs: Any, + ) -> dict[str, Any]: + return { + "host": host, + "port": port, + "db": db, + "user": user, + "password": password, + "kvstore": SqliteKVStoreConfig.sample_run_config( + __distro_dir__=__distro_dir__, + db_name="opengauss_registry.db", + ), + } diff --git a/llama_stack/providers/remote/vector_io/opengauss/opengauss.py b/llama_stack/providers/remote/vector_io/opengauss/opengauss.py new file mode 100644 index 000000000..5773c9191 --- /dev/null +++ b/llama_stack/providers/remote/vector_io/opengauss/opengauss.py @@ -0,0 +1,286 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import logging +from typing import Any + +import psycopg2 +from numpy.typing import NDArray +from psycopg2 import sql +from psycopg2.extras import Json, execute_values +from pydantic import BaseModel, TypeAdapter + +from llama_stack.apis.common.errors import VectorStoreNotFoundError +from llama_stack.apis.files.files import Files +from llama_stack.apis.inference import InterleavedContent +from llama_stack.apis.vector_dbs import VectorDB +from llama_stack.apis.vector_io import ( + Chunk, + QueryChunksResponse, + VectorIO, +) +from llama_stack.providers.datatypes import VectorDBsProtocolPrivate +from llama_stack.providers.utils.kvstore import kvstore_impl +from llama_stack.providers.utils.kvstore.api import KVStore +from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin +from llama_stack.providers.utils.memory.vector_store import ( + ChunkForDeletion, + EmbeddingIndex, + VectorDBWithIndex, +) + +from .config import OpenGaussVectorIOConfig + +log = logging.getLogger(__name__) + +VERSION = "v3" +VECTOR_DBS_PREFIX = f"vector_dbs:opengauss:{VERSION}::" +VECTOR_INDEX_PREFIX = f"vector_index:opengauss:{VERSION}::" +OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:opengauss:{VERSION}::" +OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:opengauss:{VERSION}::" +OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_contents:opengauss:{VERSION}::" + + +def upsert_models(conn, keys_models: list[tuple[str, BaseModel]]): + with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: + query = sql.SQL( + """ + MERGE INTO metadata_store AS target + USING (VALUES %s) AS source (key, data) + ON (target.key = source.key) + WHEN MATCHED THEN + UPDATE SET data = source.data + WHEN NOT MATCHED THEN + INSERT (key, data) VALUES (source.key, source.data); + """ + ) + + values = [(key, Json(model.model_dump())) for key, model in keys_models] + execute_values(cur, query, values, template="(%s, %s::JSONB)") + + +def load_models(cur, cls): + cur.execute("SELECT key, data FROM metadata_store") + rows = cur.fetchall() + return [TypeAdapter(cls).validate_python(row["data"]) for row in rows] + + +class OpenGaussIndex(EmbeddingIndex): + def __init__(self, vector_db: VectorDB, dimension: int, conn, kvstore: KVStore | None = None): + self.conn = conn + with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: + sanitized_identifier = vector_db.identifier.replace("-", "_") + self.table_name = f"vector_store_{sanitized_identifier}" + self.kvstore = kvstore + + cur.execute( + f""" + CREATE TABLE IF NOT EXISTS {self.table_name} ( + id TEXT PRIMARY KEY, + document JSONB, + embedding vector({dimension}) + ) + """ + ) + + async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray): + assert len(chunks) == len(embeddings), ( + f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}" + ) + + values = [] + for i, chunk in enumerate(chunks): + values.append( + ( + f"{chunk.chunk_id}", + Json(chunk.model_dump()), + embeddings[i].tolist(), + ) + ) + + query = sql.SQL( + f""" + MERGE INTO {self.table_name} AS target + USING (VALUES %s) AS source (id, document, embedding) + ON (target.id = source.id) + WHEN MATCHED THEN + UPDATE SET document = source.document, embedding = source.embedding + WHEN NOT MATCHED THEN + INSERT (id, document, embedding) VALUES (source.id, source.document, source.embedding); + """ + ) + with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: + execute_values(cur, query, values, template="(%s, %s::JSONB, %s::VECTOR)") + + async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: + with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: + cur.execute( + f""" + SELECT document, embedding <=> %s::VECTOR AS distance + FROM {self.table_name} + ORDER BY distance + LIMIT %s + """, + (embedding.tolist(), k), + ) + results = cur.fetchall() + + chunks = [] + scores = [] + for doc, dist in results: + score = 1.0 / float(dist) if dist != 0 else float("inf") + if score < score_threshold: + continue + chunks.append(Chunk(**doc)) + scores.append(score) + + return QueryChunksResponse(chunks=chunks, scores=scores) + + async def query_keyword( + self, + query_string: str, + k: int, + score_threshold: float, + ) -> QueryChunksResponse: + raise NotImplementedError("Keyword search is not supported in OpenGauss") + + async def query_hybrid( + self, + embedding: NDArray, + query_string: str, + k: int, + score_threshold: float, + reranker_type: str, + reranker_params: dict[str, Any] | None = None, + ) -> QueryChunksResponse: + raise NotImplementedError("Hybrid search is not supported in OpenGauss") + + async def delete(self): + with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: + cur.execute(f"DROP TABLE IF EXISTS {self.table_name}") + + async def delete_chunks(self, chunks_for_deletion: list[ChunkForDeletion]) -> None: + """Remove chunks from the OpenGauss table.""" + chunk_ids = [c.chunk_id for c in chunks_for_deletion] + with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: + cur.execute(f"DELETE FROM {self.table_name} WHERE id = ANY(%s)", (chunk_ids,)) + + +class OpenGaussVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate): + def __init__( + self, + config: OpenGaussVectorIOConfig, + inference_api: Any, + files_api: Files | None = None, + ) -> None: + self.config = config + self.inference_api = inference_api + self.conn = None + self.cache: dict[str, VectorDBWithIndex] = {} + self.files_api = files_api + self.kvstore: KVStore | None = None + self.vector_db_store = None + self.openai_vector_store: dict[str, dict[str, Any]] = {} + self.metadatadata_collection_name = "openai_vector_stores_metadata" + + async def initialize(self) -> None: + log.info(f"Initializing OpenGauss memory adapter with config: {self.config}") + if self.config.kvstore is not None: + self.kvstore = await kvstore_impl(self.config.kvstore) + await self.initialize_openai_vector_stores() + + try: + self.conn = psycopg2.connect( + host=self.config.host, + port=self.config.port, + database=self.config.db, + user=self.config.user, + password=self.config.password, + ) + if self.conn: + self.conn.autocommit = True + with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: + cur.execute("SELECT version();") + version = cur.fetchone()[0] + log.info(f"OpenGauss server version: {version}") + log.info("Assuming native vector support is enabled in this OpenGauss instance.") + + cur.execute( + """ + CREATE TABLE IF NOT EXISTS metadata_store ( + key TEXT PRIMARY KEY, + data JSONB + ) + """ + ) + except Exception as e: + log.exception("Could not connect to OpenGauss database server") + raise RuntimeError("Could not connect to OpenGauss database server") from e + + async def shutdown(self) -> None: + if self.conn is not None: + self.conn.close() + log.info("Connection to OpenGauss database server closed") + + async def register_vector_db(self, vector_db: VectorDB) -> None: + assert self.kvstore is not None + upsert_models(self.conn, [(vector_db.identifier, vector_db)]) + + index = VectorDBWithIndex( + vector_db, + index=OpenGaussIndex(vector_db, vector_db.embedding_dimension, self.conn, kvstore=self.kvstore), + inference_api=self.inference_api, + ) + self.cache[vector_db.identifier] = index + + async def unregister_vector_db(self, vector_db_id: str) -> None: + if vector_db_id in self.cache: + await self.cache[vector_db_id].index.delete() + del self.cache[vector_db_id] + + assert self.kvstore is not None + await self.kvstore.delete(key=f"{VECTOR_DBS_PREFIX}{vector_db_id}") + + async def insert_chunks( + self, + vector_db_id: str, + chunks: list[Chunk], + ttl_seconds: int | None = None, + ) -> None: + index = await self._get_and_cache_vector_db_index(vector_db_id) + await index.insert_chunks(chunks) + + async def query_chunks( + self, + vector_db_id: str, + query: InterleavedContent, + params: dict[str, Any] | None = None, + ) -> QueryChunksResponse: + index = await self._get_and_cache_vector_db_index(vector_db_id) + return await index.query_chunks(query, params) + + async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> VectorDBWithIndex: + if vector_db_id in self.cache: + return self.cache[vector_db_id] + + if self.vector_db_store is None: + raise RuntimeError("Vector DB store not initialized") + + vector_db = self.vector_db_store.get_vector_db(vector_db_id) + if vector_db is None: + raise VectorStoreNotFoundError(vector_db_id) + + index = OpenGaussIndex(vector_db, vector_db.embedding_dimension, self.conn) + self.cache[vector_db_id] = VectorDBWithIndex(vector_db, index, self.inference_api) + return self.cache[vector_db_id] + + async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None: + """Delete chunks from an OpenGauss vector store.""" + index = await self._get_and_cache_vector_db_index(store_id) + if not index: + raise VectorStoreNotFoundError(store_id) + + await index.index.delete_chunks(chunks_for_deletion) diff --git a/tests/integration/vector_io/test_openai_vector_stores.py b/tests/integration/vector_io/test_openai_vector_stores.py index 7ccca9077..f8739dfcc 100644 --- a/tests/integration/vector_io/test_openai_vector_stores.py +++ b/tests/integration/vector_io/test_openai_vector_stores.py @@ -27,6 +27,7 @@ def skip_if_provider_doesnt_support_openai_vector_stores(client_with_models): "inline::milvus", "inline::chromadb", "remote::pgvector", + "remote::opengauss", "remote::chromadb", "remote::qdrant", "inline::qdrant", @@ -48,6 +49,7 @@ def skip_if_provider_doesnt_support_openai_vector_stores_search(client_with_mode "inline::chromadb", "inline::qdrant", "remote::pgvector", + "remote::opengauss", "remote::chromadb", "remote::weaviate", "remote::qdrant", diff --git a/tests/unit/providers/vector_io/conftest.py b/tests/unit/providers/vector_io/conftest.py index f71073651..f83e92176 100644 --- a/tests/unit/providers/vector_io/conftest.py +++ b/tests/unit/providers/vector_io/conftest.py @@ -4,7 +4,9 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import os import random +from unittest.mock import AsyncMock import numpy as np import pytest @@ -22,6 +24,8 @@ from llama_stack.providers.inline.vector_io.sqlite_vec import SQLiteVectorIOConf from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import SQLiteVecIndex, SQLiteVecVectorIOAdapter from llama_stack.providers.remote.vector_io.chroma.chroma import ChromaIndex, ChromaVectorIOAdapter, maybe_await from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex, MilvusVectorIOAdapter +from llama_stack.providers.remote.vector_io.opengauss.config import OpenGaussVectorIOConfig +from llama_stack.providers.remote.vector_io.opengauss.opengauss import OpenGaussIndex, OpenGaussVectorIOAdapter from llama_stack.providers.remote.vector_io.qdrant.qdrant import QdrantVectorIOAdapter EMBEDDING_DIMENSION = 384 @@ -29,7 +33,7 @@ COLLECTION_PREFIX = "test_collection" MILVUS_ALIAS = "test_milvus" -@pytest.fixture(params=["milvus", "sqlite_vec", "faiss", "chroma"]) +@pytest.fixture(params=["milvus", "sqlite_vec", "faiss", "chroma", "opengauss"]) def vector_provider(request): return request.param @@ -333,6 +337,92 @@ async def qdrant_vec_index(qdrant_vec_db_path, embedding_dimension): await index.delete() +@pytest.fixture +def opengauss_vec_db_path(): + return { + "host": "localhost", + "port": 5432, + "db": "test_db", + "user": "test_user", + "password": "test_password", + } + + +@pytest.fixture +async def opengauss_vec_index(embedding_dimension, opengauss_vec_db_path): + mock_conn = AsyncMock() + mock_cursor = AsyncMock() + mock_conn.cursor.return_value.__enter__.return_value = mock_cursor + + vector_db = VectorDB( + identifier=f"test_opengauss_db_{np.random.randint(1e6)}", + provider_id="opengauss", + embedding_model="test_model", + embedding_dimension=embedding_dimension, + ) + + if all( + os.getenv(var) + for var in ["OPENGAUSS_HOST", "OPENGAUSS_PORT", "OPENGAUSS_DB", "OPENGAUSS_USER", "OPENGAUSS_PASSWORD"] + ): + import psycopg2 + + real_conn = psycopg2.connect(**opengauss_vec_db_path) + real_conn.autocommit = True + index = OpenGaussIndex(vector_db, embedding_dimension, real_conn) + yield index + await index.delete() + real_conn.close() + else: + index = OpenGaussIndex(vector_db, embedding_dimension, mock_conn) + yield index + + +@pytest.fixture +async def opengauss_vec_adapter(mock_inference_api, embedding_dimension, tmp_path_factory): + temp_dir = tmp_path_factory.getbasetemp() + kv_db_path = str(temp_dir / f"opengauss_kv_{np.random.randint(1e6)}.db") + + config = OpenGaussVectorIOConfig( + host=os.getenv("OPENGAUSS_HOST", "localhost"), + port=int(os.getenv("OPENGAUSS_PORT", "5432")), + db=os.getenv("OPENGAUSS_DB", "test_db"), + user=os.getenv("OPENGAUSS_USER", "test_user"), + password=os.getenv("OPENGAUSS_PASSWORD", "test_password"), + kvstore=SqliteKVStoreConfig(db_path=kv_db_path), + ) + + if all( + os.getenv(var) + for var in ["OPENGAUSS_HOST", "OPENGAUSS_PORT", "OPENGAUSS_DB", "OPENGAUSS_USER", "OPENGAUSS_PASSWORD"] + ): + adapter = OpenGaussVectorIOAdapter(config, mock_inference_api) + await adapter.initialize() + + collection_id = f"opengauss_test_collection_{np.random.randint(1e6)}" + await adapter.register_vector_db( + VectorDB( + identifier=collection_id, + provider_id="opengauss", + embedding_model="test_model", + embedding_dimension=embedding_dimension, + ) + ) + adapter.test_collection_id = collection_id + yield adapter + + try: + await adapter.unregister_vector_db(collection_id) + except Exception: + pass + await adapter.shutdown() + + if os.path.exists(kv_db_path): + os.remove(kv_db_path) + else: + pytest.skip("OpenGauss connection not available for integration testing") + + @pytest.fixture def vector_io_adapter(vector_provider, request): """Returns the appropriate vector IO adapter based on the provider parameter.""" @@ -342,6 +432,7 @@ def vector_io_adapter(vector_provider, request): "sqlite_vec": "sqlite_vec_adapter", "chroma": "chroma_vec_adapter", "qdrant": "qdrant_vec_adapter", + "opengauss": "opengauss_vec_adapter", } return request.getfixturevalue(vector_provider_dict[vector_provider]) diff --git a/tests/unit/providers/vector_io/test_opengauss.py b/tests/unit/providers/vector_io/test_opengauss.py new file mode 100644 index 000000000..a6319f114 --- /dev/null +++ b/tests/unit/providers/vector_io/test_opengauss.py @@ -0,0 +1,215 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import os +import random +from unittest.mock import AsyncMock + +import numpy as np +import pytest + +from llama_stack.apis.inference import EmbeddingsResponse, Inference +from llama_stack.apis.vector_dbs import VectorDB +from llama_stack.apis.vector_io import Chunk, QueryChunksResponse +from llama_stack.providers.remote.vector_io.opengauss.config import ( + OpenGaussVectorIOConfig, +) +from llama_stack.providers.remote.vector_io.opengauss.opengauss import ( + OpenGaussIndex, + OpenGaussVectorIOAdapter, +) +from llama_stack.providers.utils.kvstore.config import ( + SqliteKVStoreConfig, +) + +# Skip all tests in this file if the required environment variables are not set. +pytestmark = pytest.mark.skipif( + not all( + os.getenv(var) + for var in [ + "OPENGAUSS_HOST", + "OPENGAUSS_PORT", + "OPENGAUSS_DB", + "OPENGAUSS_USER", + "OPENGAUSS_PASSWORD", + ] + ), + reason="OpenGauss connection environment variables not set", +) + + +@pytest.fixture(scope="session") +def embedding_dimension() -> int: + return 128 + + +@pytest.fixture +def sample_chunks(): + """Provides a list of sample chunks for testing.""" + return [ + Chunk( + content="The sky is blue.", + metadata={"document_id": "doc1", "topic": "nature"}, + ), + Chunk( + content="An apple a day keeps the doctor away.", + metadata={"document_id": "doc2", "topic": "health"}, + ), + Chunk( + content="Quantum computing is a new frontier.", + metadata={"document_id": "doc3", "topic": "technology"}, + ), + ] + + +@pytest.fixture +def sample_embeddings(embedding_dimension, sample_chunks): + """Provides a deterministic set of embeddings for the sample chunks.""" + # Use a fixed seed for reproducibility + rng = np.random.default_rng(42) + return rng.random((len(sample_chunks), embedding_dimension), dtype=np.float32) + + +@pytest.fixture +def mock_inference_api(sample_embeddings): + """Mocks the inference API to return dummy embeddings.""" + mock_api = AsyncMock(spec=Inference) + mock_api.embeddings = AsyncMock(return_value=EmbeddingsResponse(embeddings=sample_embeddings.tolist())) + return mock_api + + +@pytest.fixture +def vector_db(embedding_dimension): + """Provides a sample VectorDB object for registration.""" + return VectorDB( + identifier=f"test_db_{random.randint(1, 10000)}", + embedding_model="test_embedding_model", + embedding_dimension=embedding_dimension, + provider_id="opengauss", + ) + + +@pytest.fixture +async def opengauss_connection(): + """Creates and manages a connection to the OpenGauss database.""" + import psycopg2 + + conn = psycopg2.connect( + host=os.getenv("OPENGAUSS_HOST"), + port=int(os.getenv("OPENGAUSS_PORT")), + database=os.getenv("OPENGAUSS_DB"), + user=os.getenv("OPENGAUSS_USER"), + password=os.getenv("OPENGAUSS_PASSWORD"), + ) + conn.autocommit = True + yield conn + conn.close() + + +@pytest.fixture +async def opengauss_index(opengauss_connection, vector_db): + """Fixture to create and clean up an OpenGaussIndex instance.""" + index = OpenGaussIndex(vector_db, vector_db.embedding_dimension, opengauss_connection) + yield index + await index.delete() + + +@pytest.fixture +async def opengauss_adapter(mock_inference_api): + """Fixture to set up and tear down the OpenGaussVectorIOAdapter.""" + config = OpenGaussVectorIOConfig( + host=os.getenv("OPENGAUSS_HOST"), + port=int(os.getenv("OPENGAUSS_PORT")), + db=os.getenv("OPENGAUSS_DB"), + user=os.getenv("OPENGAUSS_USER"), + password=os.getenv("OPENGAUSS_PASSWORD"), + kvstore=SqliteKVStoreConfig(db_name="opengauss_test.db"), + ) + adapter = OpenGaussVectorIOAdapter(config, mock_inference_api) + await adapter.initialize() + yield adapter + if adapter.conn and not adapter.conn.closed: + for db_id in list(adapter.cache.keys()): + try: + await adapter.unregister_vector_db(db_id) + except Exception as e: + print(f"Error during cleanup of {db_id}: {e}") + await adapter.shutdown() + # Clean up the sqlite db file + if os.path.exists("opengauss_test.db"): + os.remove("opengauss_test.db") + + +class TestOpenGaussIndex: + async def test_add_and_query_vector(self, opengauss_index, sample_chunks, sample_embeddings): + """Test adding chunks with embeddings and querying for the most similar one.""" + await opengauss_index.add_chunks(sample_chunks, sample_embeddings) + + # Query with the embedding of the first chunk + query_embedding = sample_embeddings[0] + response = await opengauss_index.query_vector(query_embedding, k=1, score_threshold=0.0) + + assert isinstance(response, QueryChunksResponse) + assert len(response.chunks) == 1 + assert response.chunks[0].content == sample_chunks[0].content + # The distance to itself should be 0, resulting in infinite score + assert response.scores[0] == float("inf") + + +class TestOpenGaussVectorIOAdapter: + async def test_initialization(self, opengauss_adapter): + """Test that the adapter initializes and connects to the database.""" + assert opengauss_adapter.conn is not None + assert not opengauss_adapter.conn.closed + + async def test_register_and_unregister_vector_db(self, opengauss_adapter, vector_db): + """Test the registration and unregistration of a vector database.""" + await opengauss_adapter.register_vector_db(vector_db) + assert vector_db.identifier in opengauss_adapter.cache + + table_name = opengauss_adapter.cache[vector_db.identifier].index.table_name + with opengauss_adapter.conn.cursor() as cur: + cur.execute( + "SELECT EXISTS (SELECT 1 FROM pg_tables WHERE schemaname = 'public' AND tablename = %s);", + (table_name,), + ) + assert cur.fetchone()[0] + + await opengauss_adapter.unregister_vector_db(vector_db.identifier) + assert vector_db.identifier not in opengauss_adapter.cache + + with opengauss_adapter.conn.cursor() as cur: + cur.execute( + "SELECT EXISTS (SELECT 1 FROM pg_tables WHERE schemaname = 'public' AND tablename = %s);", + (table_name,), + ) + assert not cur.fetchone()[0] + + async def test_adapter_end_to_end_query(self, opengauss_adapter, vector_db, sample_chunks): + """ + Tests the full adapter flow: text query -> embedding generation -> vector search. + """ + # 1. Register the DB and insert chunks. The adapter will use the mocked + # inference_api to generate embeddings for these chunks. + await opengauss_adapter.register_vector_db(vector_db) + await opengauss_adapter.insert_chunks(vector_db.identifier, sample_chunks) + + # 2. The user query is a text string. + query_text = "What is the color of the sky?" + + # 3. The adapter will now internally call the (mocked) inference_api + # to get an embedding for the query_text. + response = await opengauss_adapter.query_chunks(vector_db.identifier, query_text) + + # 4. Assertions + assert isinstance(response, QueryChunksResponse) + assert len(response.chunks) > 0 + + # Because the mocked inference_api returns random embeddings, we can't + # deterministically know which chunk is "closest". However, in a real + # integration test with a real model, this assertion would be more specific. + # For this unit test, we just confirm that the process completes and returns data. + assert response.chunks[0].content in [c.content for c in sample_chunks]