From 5e9c394500e4abeec7857ccecbc00afd993c4674 Mon Sep 17 00:00:00 2001 From: qifengleqifengle <2472846459@qq.com> Date: Mon, 14 Jul 2025 16:50:29 +0800 Subject: [PATCH] feat(vector-io): add OpenGauss vector database provider Implement OpenGauss vector database integration for Llama Stack with the following features: - Add OpenGaussVectorIOAdapter for vector storage and retrieval - Support native vector similarity search operations - Implement connection and query management with psycopg2 - Provide configuration template for easy setup - Add comprehensive unit tests The implementation allows Llama Stack users to leverage OpenGauss as an enterprise-grade vector database for RAG applications. Users can configure their environment through a simple YAML configuration and environment variables. --- docs/source/providers/vector_io/index.md | 1 + .../providers/vector_io/remote_opengauss.md | 54 +++ llama_stack/providers/registry/vector_io.py | 38 ++ .../remote/vector_io/opengauss/__init__.py | 17 + .../remote/vector_io/opengauss/config.py | 32 ++ .../remote/vector_io/opengauss/opengauss.py | 351 ++++++++++++++++++ .../templates/opengauss-demo/__init__.py | 9 + .../templates/opengauss-demo/build.yaml | 27 ++ llama_stack/templates/opengauss-demo/run.yaml | 96 +++++ .../providers/vector_io/test_opengauss.py | 229 ++++++++++++ 10 files changed, 854 insertions(+) create mode 100644 docs/source/providers/vector_io/remote_opengauss.md create mode 100644 llama_stack/providers/remote/vector_io/opengauss/__init__.py create mode 100644 llama_stack/providers/remote/vector_io/opengauss/config.py create mode 100644 llama_stack/providers/remote/vector_io/opengauss/opengauss.py create mode 100644 llama_stack/templates/opengauss-demo/__init__.py create mode 100644 llama_stack/templates/opengauss-demo/build.yaml create mode 100644 llama_stack/templates/opengauss-demo/run.yaml create mode 100644 tests/unit/providers/vector_io/test_opengauss.py diff --git a/docs/source/providers/vector_io/index.md b/docs/source/providers/vector_io/index.md index 870d04401..49706458b 100644 --- a/docs/source/providers/vector_io/index.md +++ b/docs/source/providers/vector_io/index.md @@ -11,6 +11,7 @@ This section contains documentation for all available providers for the **vector - [inline::sqlite_vec](inline_sqlite_vec.md) - [remote::chromadb](remote_chromadb.md) - [remote::milvus](remote_milvus.md) +- [remote::opengauss](remote_opengauss.md) - [remote::pgvector](remote_pgvector.md) - [remote::qdrant](remote_qdrant.md) - [remote::weaviate](remote_weaviate.md) \ No newline at end of file diff --git a/docs/source/providers/vector_io/remote_opengauss.md b/docs/source/providers/vector_io/remote_opengauss.md new file mode 100644 index 000000000..7f6576b50 --- /dev/null +++ b/docs/source/providers/vector_io/remote_opengauss.md @@ -0,0 +1,54 @@ +# remote::opengauss + +## Description + + +[OpenGauss](https://opengauss.org/en/) is a remote vector database provider for Llama Stack. It +allows you to store and query vectors directly in memory. +That means you'll get fast and efficient vector retrieval. + +## Features + +- Easy to use +- Fully integrated with Llama Stack + +## Usage + +To use OpenGauss in your Llama Stack project, follow these steps: + +1. Install the necessary dependencies. +2. Configure your Llama Stack project to use OpenGauss. +3. Start storing and querying vectors. + +## Installation + +You can install OpenGauss using docker: + +```bash +docker pull opengauss/opengauss:latest +``` +## Documentation +See [OpenGauss' documentation](https://docs.opengauss.org/en/docs/5.0.0/docs/GettingStarted/understanding-opengauss.html) for more details about OpenGauss in general. + + +## Configuration + +| Field | Type | Required | Default | Description | +|-------|------|----------|---------|-------------| +| `host` | `str \| None` | No | localhost | | +| `port` | `int \| None` | No | 5432 | | +| `db` | `str \| None` | No | postgres | | +| `user` | `str \| None` | No | postgres | | +| `password` | `str \| None` | No | mysecretpassword | | + +## Sample Configuration + +```yaml +host: ${env.OPENGAUSS_HOST:=localhost} +port: ${env.OPENGAUSS_PORT:=5432} +db: ${env.OPENGAUSS_DB} +user: ${env.OPENGAUSS_USER} +password: ${env.OPENGAUSS_PASSWORD} + +``` + diff --git a/llama_stack/providers/registry/vector_io.py b/llama_stack/providers/registry/vector_io.py index c13e65bbc..97f5280f0 100644 --- a/llama_stack/providers/registry/vector_io.py +++ b/llama_stack/providers/registry/vector_io.py @@ -407,6 +407,44 @@ docker pull pgvector/pgvector:pg17 ``` ## Documentation See [PGVector's documentation](https://github.com/pgvector/pgvector) for more details about PGVector in general. +""", + ), + api_dependencies=[Api.inference], + ), + remote_provider_spec( + Api.vector_io, + AdapterSpec( + adapter_type="opengauss", + pip_packages=["psycopg2-binary"], + module="llama_stack.providers.remote.vector_io.opengauss", + config_class="llama_stack.providers.remote.vector_io.opengauss.OpenGaussVectorIOConfig", + description=""" +[OpenGauss](https://opengauss.org/en/) is a remote vector database provider for Llama Stack. It +allows you to store and query vectors directly in memory. +That means you'll get fast and efficient vector retrieval. + +## Features + +- Easy to use +- Fully integrated with Llama Stack + +## Usage + +To use OpenGauss in your Llama Stack project, follow these steps: + +1. Install the necessary dependencies. +2. Configure your Llama Stack project to use OpenGauss. +3. Start storing and querying vectors. + +## Installation + +You can install OpenGauss using docker: + +```bash +docker pull opengauss/opengauss:latest +``` +## Documentation +See [OpenGauss' documentation](https://docs.opengauss.org/en/docs/5.0.0/docs/GettingStarted/understanding-opengauss.html) for more details about OpenGauss in general. """, ), api_dependencies=[Api.inference], diff --git a/llama_stack/providers/remote/vector_io/opengauss/__init__.py b/llama_stack/providers/remote/vector_io/opengauss/__init__.py new file mode 100644 index 000000000..b87eee54b --- /dev/null +++ b/llama_stack/providers/remote/vector_io/opengauss/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_stack.providers.datatypes import Api, ProviderSpec + +from .config import OpenGaussVectorIOConfig + + +async def get_adapter_impl(config: OpenGaussVectorIOConfig, deps: dict[Api, ProviderSpec]): + from .opengauss import OpenGaussVectorIOAdapter + + impl = OpenGaussVectorIOAdapter(config, deps[Api.inference]) + await impl.initialize() + return impl diff --git a/llama_stack/providers/remote/vector_io/opengauss/config.py b/llama_stack/providers/remote/vector_io/opengauss/config.py new file mode 100644 index 000000000..967e21f89 --- /dev/null +++ b/llama_stack/providers/remote/vector_io/opengauss/config.py @@ -0,0 +1,32 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any + +from pydantic import BaseModel, Field + +from llama_stack.schema_utils import json_schema_type + + +@json_schema_type +class OpenGaussVectorIOConfig(BaseModel): + host: str | None = Field(default="localhost") + port: int | None = Field(default=5432) + db: str | None = Field(default="postgres") + user: str | None = Field(default="postgres") + password: str | None = Field(default="mysecretpassword") + + @classmethod + def sample_run_config( + cls, + host: str = "${env.OPENGAUSS_HOST:=localhost}", + port: str = "${env.OPENGAUSS_PORT:=5432}", + db: str = "${env.OPENGAUSS_DB}", + user: str = "${env.OPENGAUSS_USER}", + password: str = "${env.OPENGAUSS_PASSWORD}", + **kwargs: Any, + ) -> dict[str, Any]: + return {"host": host, "port": port, "db": db, "user": user, "password": password} diff --git a/llama_stack/providers/remote/vector_io/opengauss/opengauss.py b/llama_stack/providers/remote/vector_io/opengauss/opengauss.py new file mode 100644 index 000000000..b019cf92f --- /dev/null +++ b/llama_stack/providers/remote/vector_io/opengauss/opengauss.py @@ -0,0 +1,351 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import logging +from typing import Any + +import psycopg2 +from numpy.typing import NDArray +from psycopg2 import sql +from psycopg2.extras import Json, execute_values +from pydantic import BaseModel, TypeAdapter + +from llama_stack.apis.inference import InterleavedContent +from llama_stack.apis.vector_dbs import VectorDB +from llama_stack.apis.vector_io import ( + Chunk, + QueryChunksResponse, + SearchRankingOptions, + VectorIO, + VectorStoreChunkingStrategy, + VectorStoreDeleteResponse, + VectorStoreFileContentsResponse, + VectorStoreFileDeleteResponse, + VectorStoreFileObject, + VectorStoreFileStatus, + VectorStoreListFilesResponse, + VectorStoreListResponse, + VectorStoreObject, + VectorStoreSearchResponsePage, +) +from llama_stack.providers.datatypes import VectorDBsProtocolPrivate +from llama_stack.providers.utils.memory.vector_store import ( + EmbeddingIndex, + VectorDBWithIndex, +) + +from .config import OpenGaussVectorIOConfig + +log = logging.getLogger(__name__) + + +def upsert_models(conn, keys_models: list[tuple[str, BaseModel]]): + with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: + query = sql.SQL( + """ + MERGE INTO metadata_store AS target + USING (VALUES %s) AS source (key, data) + ON (target.key = source.key) + WHEN MATCHED THEN + UPDATE SET data = source.data + WHEN NOT MATCHED THEN + INSERT (key, data) VALUES (source.key, source.data); + """ + ) + + values = [(key, Json(model.model_dump())) for key, model in keys_models] + execute_values(cur, query, values, template="(%s, %s::JSONB)") + + +def load_models(cur, cls): + cur.execute("SELECT key, data FROM metadata_store") + rows = cur.fetchall() + return [TypeAdapter(cls).validate_python(row["data"]) for row in rows] + + +class OpenGaussIndex(EmbeddingIndex): + def __init__(self, vector_db: VectorDB, dimension: int, conn): + self.conn = conn + with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: + sanitized_identifier = vector_db.identifier.replace("-", "_") + self.table_name = f"vector_store_{sanitized_identifier}" + + log.info( + f"Creating table '{self.table_name}' for vector store '{vector_db.identifier}' if it does not exist." + ) + cur.execute( + f""" + CREATE TABLE IF NOT EXISTS {self.table_name} ( + id TEXT PRIMARY KEY, + document JSONB, + embedding vector({dimension}) + ) + """ + ) + + async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray): + assert len(chunks) == len(embeddings), ( + f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}" + ) + + values = [] + for i, chunk in enumerate(chunks): + values.append( + ( + f"{chunk.metadata['document_id']}:chunk-{i}", + Json(chunk.model_dump()), + embeddings[i].tolist(), + ) + ) + + query = sql.SQL( + f""" + MERGE INTO {self.table_name} AS target + USING (VALUES %s) AS source (id, document, embedding) + ON (target.id = source.id) + WHEN MATCHED THEN + UPDATE SET document = source.document, embedding = source.embedding + WHEN NOT MATCHED THEN + INSERT (id, document, embedding) VALUES (source.id, source.document, source.embedding); + """ + ) + with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: + execute_values(cur, query, values, template="(%s, %s::JSONB, %s::VECTOR)") + + async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: + with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: + cur.execute( + f""" + SELECT document, embedding <=> %s::VECTOR AS distance + FROM {self.table_name} + ORDER BY distance + LIMIT %s + """, + (embedding.tolist(), k), + ) + results = cur.fetchall() + + chunks = [] + scores = [] + for doc, dist in results: + chunks.append(Chunk(**doc)) + scores.append(1.0 / float(dist) if dist != 0 else float("inf")) + + return QueryChunksResponse(chunks=chunks, scores=scores) + + async def query_keyword(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse: + raise NotImplementedError("Keyword search is not supported in this OpenGauss provider") + + async def query_hybrid( + self, + embedding: NDArray, + query_string: str, + k: int, + score_threshold: float, + reranker_type: str, + reranker_params: dict[str, Any] | None = None, + ) -> QueryChunksResponse: + raise NotImplementedError("Hybrid search is not supported in this OpenGauss provider") + + async def delete(self): + with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: + cur.execute(f"DROP TABLE IF EXISTS {self.table_name}") + + +class OpenGaussVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): + def __init__(self, config: OpenGaussVectorIOConfig, inference_api: Any) -> None: + self.config = config + self.inference_api = inference_api + self.conn = None + self.cache: dict[str, VectorDBWithIndex] = {} + + async def initialize(self) -> None: + log.info(f"Initializing OpenGauss native vector adapter with config: {self.config}") + try: + self.conn = psycopg2.connect( + host=self.config.host, + port=self.config.port, + database=self.config.db, + user=self.config.user, + password=self.config.password, + ) + if self.conn: + self.conn.autocommit = True + with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: + cur.execute("SELECT version();") + version = cur.fetchone()[0] + log.info(f"Successfully connected to OpenGauss. Server version: {version}") + log.info("Assuming native vector support is enabled in this OpenGauss instance.") + + cur.execute( + """ + CREATE TABLE IF NOT EXISTS metadata_store ( + key TEXT PRIMARY KEY, + data JSONB + ) + """ + ) + else: + raise RuntimeError("Failed to establish a connection to the database.") + except Exception as e: + log.exception("Could not connect to OpenGauss database server") + raise RuntimeError("Could not connect to OpenGauss database server") from e + + async def shutdown(self) -> None: + if self.conn is not None: + self.conn.close() + log.info("Connection to OpenGauss database server closed") + + async def register_vector_db(self, vector_db: VectorDB) -> None: + upsert_models(self.conn, [(vector_db.identifier, vector_db)]) + + index = OpenGaussIndex(vector_db, vector_db.embedding_dimension, self.conn) + self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api) + + async def unregister_vector_db(self, vector_db_id: str) -> None: + await self.cache[vector_db_id].index.delete() + del self.cache[vector_db_id] + + async def insert_chunks( + self, + vector_db_id: str, + chunks: list[Chunk], + ttl_seconds: int | None = None, + ) -> None: + index = await self._get_and_cache_vector_db_index(vector_db_id) + await index.insert_chunks(chunks) + + async def query_chunks( + self, + vector_db_id: str, + query: InterleavedContent, + params: dict[str, Any] | None = None, + ) -> QueryChunksResponse: + index = await self._get_and_cache_vector_db_index(vector_db_id) + return await index.query_chunks(query, params) + + async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> VectorDBWithIndex: + if vector_db_id in self.cache: + return self.cache[vector_db_id] + + if self.vector_db_store is None: + raise RuntimeError("Vector DB store has not been initialized.") + + vector_db = self.vector_db_store.get_vector_db(vector_db_id) + if not vector_db: + raise ValueError(f"Vector DB with id {vector_db_id} not found.") + + if not self.conn: + raise RuntimeError("Database connection not initialized.") + + index = OpenGaussIndex(vector_db, vector_db.embedding_dimension, self.conn) + self.cache[vector_db_id] = VectorDBWithIndex(vector_db, index, self.inference_api) + return self.cache[vector_db_id] + + async def openai_create_vector_store( + self, + name: str, + file_ids: list[str] | None = None, + expires_after: dict[str, Any] | None = None, + chunking_strategy: dict[str, Any] | None = None, + metadata: dict[str, Any] | None = None, + embedding_model: str | None = None, + embedding_dimension: int | None = 384, + provider_id: str | None = None, + provider_vector_db_id: str | None = None, + ) -> VectorStoreObject: + raise NotImplementedError("OpenAI Vector Stores API is not supported in OpenGauss") + + async def openai_list_vector_stores( + self, + limit: int | None = 20, + order: str | None = "desc", + after: str | None = None, + before: str | None = None, + ) -> VectorStoreListResponse: + raise NotImplementedError("OpenAI Vector Stores API is not supported in OpenGauss") + + async def openai_retrieve_vector_store( + self, + vector_store_id: str, + ) -> VectorStoreObject: + raise NotImplementedError("OpenAI Vector Stores API is not supported in OpenGauss") + + async def openai_update_vector_store( + self, + vector_store_id: str, + name: str | None = None, + expires_after: dict[str, Any] | None = None, + metadata: dict[str, Any] | None = None, + ) -> VectorStoreObject: + raise NotImplementedError("OpenAI Vector Stores API is not supported in OpenGauss") + + async def openai_delete_vector_store( + self, + vector_store_id: str, + ) -> VectorStoreDeleteResponse: + raise NotImplementedError("OpenAI Vector Stores API is not supported in OpenGauss") + + async def openai_search_vector_store( + self, + vector_store_id: str, + query: str | list[str], + filters: dict[str, Any] | None = None, + max_num_results: int | None = 10, + ranking_options: SearchRankingOptions | None = None, + rewrite_query: bool | None = False, + search_mode: str | None = "vector", + ) -> VectorStoreSearchResponsePage: + raise NotImplementedError("OpenAI Vector Stores API is not supported in OpenGauss") + + async def openai_attach_file_to_vector_store( + self, + vector_store_id: str, + file_id: str, + attributes: dict[str, Any] | None = None, + chunking_strategy: VectorStoreChunkingStrategy | None = None, + ) -> VectorStoreFileObject: + raise NotImplementedError("OpenAI Vector Stores API is not supported in OpenGauss") + + async def openai_list_files_in_vector_store( + self, + vector_store_id: str, + limit: int | None = 20, + order: str | None = "desc", + after: str | None = None, + before: str | None = None, + filter: VectorStoreFileStatus | None = None, + ) -> VectorStoreListFilesResponse: + raise NotImplementedError("OpenAI Vector Stores API is not supported in OpenGauss") + + async def openai_retrieve_vector_store_file( + self, + vector_store_id: str, + file_id: str, + ) -> VectorStoreFileObject: + raise NotImplementedError("OpenAI Vector Stores API is not supported in OpenGauss") + + async def openai_retrieve_vector_store_file_contents( + self, + vector_store_id: str, + file_id: str, + ) -> VectorStoreFileContentsResponse: + raise NotImplementedError("OpenAI Vector Stores API is not supported in OpenGauss") + + async def openai_update_vector_store_file( + self, + vector_store_id: str, + file_id: str, + attributes: dict[str, Any] | None = None, + ) -> VectorStoreFileObject: + raise NotImplementedError("OpenAI Vector Stores API is not supported in OpenGauss") + + async def openai_delete_vector_store_file( + self, + vector_store_id: str, + file_id: str, + ) -> VectorStoreFileDeleteResponse: + raise NotImplementedError("OpenAI Vector Stores API is not supported in OpenGauss") diff --git a/llama_stack/templates/opengauss-demo/__init__.py b/llama_stack/templates/opengauss-demo/__init__.py new file mode 100644 index 000000000..996f7ed8a --- /dev/null +++ b/llama_stack/templates/opengauss-demo/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# from .mymilvus_demo import get_distribution_template + +# __all__ = ["get_distribution_template"] diff --git a/llama_stack/templates/opengauss-demo/build.yaml b/llama_stack/templates/opengauss-demo/build.yaml new file mode 100644 index 000000000..e80b51db6 --- /dev/null +++ b/llama_stack/templates/opengauss-demo/build.yaml @@ -0,0 +1,27 @@ +version: 2 +distribution_spec: + description: "Custom configuration using Opengauss for vector storage and metadata" + providers: + inference: + - remote::together-openai-compat + vector_io: + - remote::opengauss + agents: + - inline::meta-reference + tool_runtime: + - inline::rag-runtime + files: + - inline::localfs + telemetry: + - inline::meta-reference + safety: + - inline::llama-guard + +image_type: venv +additional_pip_packages: + - psycopg2-binary>=2.9.3 + - pgvector>=0.2.0 + - asyncpg>=0.27.0 +# rm -rf ~/.llama/distributions/opengauss-demo +# uv run --with llama-stack llama stack build --template opengauss-demo --image-type venv +# uv run --env-file .env --with llama-stack llama stack run /home/gt/.llama/distributions/opengauss-demo/opengauss-demo-run.yaml diff --git a/llama_stack/templates/opengauss-demo/run.yaml b/llama_stack/templates/opengauss-demo/run.yaml new file mode 100644 index 000000000..037242076 --- /dev/null +++ b/llama_stack/templates/opengauss-demo/run.yaml @@ -0,0 +1,96 @@ +version: 2 +image_name: opengauss-demo +apis: + - agents + - inference + - vector_io + - tool_runtime + - files + - models + - telemetry + - safety + +providers: + inference: + - provider_id: together + provider_type: remote::together + config: + api_key: ${env.TOGETHER_API_KEY} + url: ${env.TOGETHER_API_BASE_URL} + - provider_id: sentence-transformers + provider_type: inline::sentence-transformers + config: {} + + vector_io: + - provider_id: opengauss + provider_type: remote::opengauss + config: + host: ${env.OPENGAUSS_HOST} + port: ${env.OPENGAUSS_PORT:=5432} + db: ${env.OPENGAUSS_DB} + user: ${env.OPENGAUSS_USER} + password: ${env.OPENGAUSS_PASSWORD} + + agents: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: + persistence_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/opengauss-demo}/agent_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/opengauss-demo}/agent_responses.db + tool_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/opengauss-demo}/agent_tools.db + + tool_runtime: + - provider_id: rag-runtime + provider_type: inline::rag-runtime + config: {} + + files: + - provider_id: localfs + provider_type: inline::localfs + config: + storage_dir: ${env.LOCALFS_BASE_PATH:=~/.llama/distributions/opengauss-demo/files} + metadata_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/opengauss-demo}/files_metadata.db + + telemetry: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: {} + + safety: + - provider_id: llama-guard + provider_type: inline::llama-guard + config: {} + +models: + - model_id: meta-llama/Llama-3.3-70B-Instruct-Turbo + model_type: llm + provider_id: together + provider_model_id: meta-llama/Llama-3.3-70B-Instruct-Turbo + metadata: {} + + - model_id: all-MiniLM-L6-v2 + model_type: embedding + provider_id: sentence-transformers + provider_model_id: all-MiniLM-L6-v2 + metadata: + embedding_dimension: 384 + +metadata_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/opengauss-demo}/registry.db + +inference_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/opengauss-demo}/inference_store.db + +tool_groups: + - toolgroup_id: builtin::rag + provider_id: rag-runtime diff --git a/tests/unit/providers/vector_io/test_opengauss.py b/tests/unit/providers/vector_io/test_opengauss.py new file mode 100644 index 000000000..3583f25f2 --- /dev/null +++ b/tests/unit/providers/vector_io/test_opengauss.py @@ -0,0 +1,229 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import os +import random +from unittest.mock import AsyncMock + +import numpy as np +import pytest +import pytest_asyncio + +from llama_stack.apis.inference import EmbeddingsResponse, Inference +from llama_stack.apis.vector_dbs import VectorDB +from llama_stack.apis.vector_io import Chunk, QueryChunksResponse +from llama_stack.providers.remote.vector_io.opengauss.config import ( + OpenGaussVectorIOConfig, +) +from llama_stack.providers.remote.vector_io.opengauss.opengauss import ( + OpenGaussIndex, + OpenGaussVectorIOAdapter, +) + +# Skip all tests in this file if the required environment variables are not set. +pytestmark = pytest.mark.skipif( + not all( + os.getenv(var) + for var in [ + "OPENGAUSS_HOST", + "OPENGAUSS_PORT", + "OPENGAUSS_DB", + "OPENGAUSS_USER", + "OPENGAUSS_PASSWORD", + ] + ), + reason="OpenGauss connection environment variables not set", +) + + +@pytest.fixture(scope="session") +def embedding_dimension() -> int: + return 128 + + +@pytest.fixture +def sample_chunks(): + """Provides a list of sample chunks for testing.""" + return [ + Chunk( + content="The sky is blue.", + metadata={"document_id": "doc1", "topic": "nature"}, + ), + Chunk( + content="An apple a day keeps the doctor away.", + metadata={"document_id": "doc2", "topic": "health"}, + ), + Chunk( + content="Quantum computing is a new frontier.", + metadata={"document_id": "doc3", "topic": "technology"}, + ), + ] + + +@pytest.fixture +def sample_embeddings(embedding_dimension, sample_chunks): + """Provides a deterministic set of embeddings for the sample chunks.""" + # Use a fixed seed for reproducibility + rng = np.random.default_rng(42) + return rng.random((len(sample_chunks), embedding_dimension), dtype=np.float32) + + +@pytest.fixture +def mock_inference_api(sample_embeddings): + """Mocks the inference API to return dummy embeddings.""" + mock_api = AsyncMock(spec=Inference) + mock_api.embeddings = AsyncMock(return_value=EmbeddingsResponse(embeddings=sample_embeddings.tolist())) + return mock_api + + +@pytest.fixture +def vector_db(embedding_dimension): + """Provides a sample VectorDB object for registration.""" + return VectorDB( + identifier=f"test_db_{random.randint(1, 10000)}", + embedding_model="test_embedding_model", + embedding_dimension=embedding_dimension, + provider_id="opengauss", + ) + + +@pytest_asyncio.fixture +async def opengauss_connection(): + """Creates and manages a connection to the OpenGauss database.""" + import psycopg2 + + conn = psycopg2.connect( + host=os.getenv("OPENGAUSS_HOST"), + port=int(os.getenv("OPENGAUSS_PORT")), + database=os.getenv("OPENGAUSS_DB"), + user=os.getenv("OPENGAUSS_USER"), + password=os.getenv("OPENGAUSS_PASSWORD"), + ) + conn.autocommit = True + yield conn + conn.close() + + +@pytest_asyncio.fixture +async def opengauss_index(opengauss_connection, vector_db): + """Fixture to create and clean up an OpenGaussIndex instance.""" + index = OpenGaussIndex(vector_db, vector_db.embedding_dimension, opengauss_connection) + yield index + await index.delete() + + +@pytest_asyncio.fixture +async def opengauss_adapter(mock_inference_api): + """Fixture to set up and tear down the OpenGaussVectorIOAdapter.""" + config = OpenGaussVectorIOConfig( + host=os.getenv("OPENGAUSS_HOST"), + port=int(os.getenv("OPENGAUSS_PORT")), + db=os.getenv("OPENGAUSS_DB"), + user=os.getenv("OPENGAUSS_USER"), + password=os.getenv("OPENGAUSS_PASSWORD"), + ) + adapter = OpenGaussVectorIOAdapter(config, mock_inference_api) + await adapter.initialize() + yield adapter + if adapter.conn and not adapter.conn.closed: + for db_id in list(adapter.cache.keys()): + try: + await adapter.unregister_vector_db(db_id) + except Exception as e: + print(f"Error during cleanup of {db_id}: {e}") + await adapter.shutdown() + + +@pytest.mark.asyncio +class TestOpenGaussIndex: + async def test_add_and_query_vector(self, opengauss_index, sample_chunks, sample_embeddings): + """Test adding chunks with embeddings and querying for the most similar one.""" + await opengauss_index.add_chunks(sample_chunks, sample_embeddings) + + # Query with the embedding of the first chunk + query_embedding = sample_embeddings[0] + response = await opengauss_index.query_vector(query_embedding, k=1, score_threshold=0.0) + + assert isinstance(response, QueryChunksResponse) + assert len(response.chunks) == 1 + assert response.chunks[0].content == sample_chunks[0].content + # The distance to itself should be 0, resulting in infinite score + assert response.scores[0] == float("inf") + + +@pytest.mark.asyncio +class TestOpenGaussVectorIOAdapter: + async def test_initialization(self, opengauss_adapter): + """Test that the adapter initializes and connects to the database.""" + assert opengauss_adapter.conn is not None + assert not opengauss_adapter.conn.closed + + async def test_register_and_unregister_vector_db(self, opengauss_adapter, vector_db): + """Test the registration and unregistration of a vector database.""" + await opengauss_adapter.register_vector_db(vector_db) + assert vector_db.identifier in opengauss_adapter.cache + + table_name = opengauss_adapter.cache[vector_db.identifier].index.table_name + with opengauss_adapter.conn.cursor() as cur: + cur.execute( + "SELECT EXISTS (SELECT 1 FROM pg_tables WHERE schemaname = 'public' AND tablename = %s);", + (table_name,), + ) + assert cur.fetchone()[0] + + await opengauss_adapter.unregister_vector_db(vector_db.identifier) + assert vector_db.identifier not in opengauss_adapter.cache + + with opengauss_adapter.conn.cursor() as cur: + cur.execute( + "SELECT EXISTS (SELECT 1 FROM pg_tables WHERE schemaname = 'public' AND tablename = %s);", + (table_name,), + ) + assert not cur.fetchone()[0] + + async def test_insert_and_query_chunks(self, opengauss_adapter, vector_db, sample_chunks): + """Test the full adapter flow of inserting and querying.""" + await opengauss_adapter.register_vector_db(vector_db) + await opengauss_adapter.insert_chunks(vector_db.identifier, sample_chunks) + + # Query for something semantically similar to the first chunk + response = await opengauss_adapter.query_chunks(vector_db.identifier, "What color is the sky?") + + assert isinstance(response, QueryChunksResponse) + assert len(response.chunks) > 0 + assert response.chunks[0].content == "The sky is blue." + + async def test_not_implemented_errors(self, opengauss_adapter): + """Test that unsupported methods raise NotImplementedError.""" + with pytest.raises(NotImplementedError): + await opengauss_adapter.openai_create_vector_store(name="test") + + @pytest.mark.asyncio + async def test_adapter_end_to_end_query(self, opengauss_adapter, vector_db, sample_chunks): + """ + Tests the full adapter flow: text query -> embedding generation -> vector search. + """ + # 1. Register the DB and insert chunks. The adapter will use the mocked + # inference_api to generate embeddings for these chunks. + await opengauss_adapter.register_vector_db(vector_db) + await opengauss_adapter.insert_chunks(vector_db.identifier, sample_chunks) + + # 2. The user query is a text string. + query_text = "What is the color of the sky?" + + # 3. The adapter will now internally call the (mocked) inference_api + # to get an embedding for the query_text. + response = await opengauss_adapter.query_chunks(vector_db.identifier, query_text) + + # 4. Assertions + assert isinstance(response, QueryChunksResponse) + assert len(response.chunks) > 0 + + # Because the mocked inference_api returns random embeddings, we can't + # deterministically know which chunk is "closest". However, in a real + # integration test with a real model, this assertion would be more specific. + # For this unit test, we just confirm that the process completes and returns data. + assert response.chunks[0].content in [c.content for c in sample_chunks]