feat(vector-io): add OpenGauss vector database provider

Implement OpenGauss vector database integration for Llama Stack with the following features:
- Add OpenGaussVectorIOAdapter for vector storage and retrieval
- Support native vector similarity search operations
- Provide configuration template for easy setup
- Add comprehensive unit tests
- Align with the latest Llama Stack provider architecture, including KVStore and OpenAI Vector Store Mixin.

The implementation allows Llama Stack users to leverage OpenGauss as an
enterprise-grade vector database for RAG applications.
This commit is contained in:
qifengleqifengle 2025-07-14 16:50:29 +08:00
parent eb07a0f86a
commit 35a0a6cb7b
14 changed files with 802 additions and 15 deletions

View file

@ -27,7 +27,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
strategy: strategy:
matrix: matrix:
vector-io-provider: ["inline::faiss", "inline::sqlite-vec", "inline::milvus", "remote::chromadb", "remote::pgvector", "remote::weaviate", "remote::qdrant"] vector-io-provider: ["inline::faiss", "inline::sqlite-vec", "inline::milvus", "remote::chromadb", "remote::pgvector", "remote::weaviate", "remote::qdrant", "remote::opengauss"]
python-version: ${{ github.event.schedule == '0 0 * * *' && fromJSON('["3.12", "3.13"]') || fromJSON('["3.12"]') }} python-version: ${{ github.event.schedule == '0 0 * * *' && fromJSON('["3.12", "3.13"]') || fromJSON('["3.12"]') }}
fail-fast: false # we want to run all tests regardless of failure fail-fast: false # we want to run all tests regardless of failure
@ -89,6 +89,30 @@ jobs:
PGPASSWORD=llamastack psql -h localhost -U llamastack -d llamastack \ PGPASSWORD=llamastack psql -h localhost -U llamastack -d llamastack \
-c "CREATE EXTENSION IF NOT EXISTS vector;" -c "CREATE EXTENSION IF NOT EXISTS vector;"
- name: Start OpenGauss DB
if: matrix.vector-io-provider == 'remote::opengauss'
run: |
docker run -d \
--name opengauss \
-e GS_PASSWORD=Enmo@123 \
-e GS_DB=llamastack \
-e GS_USER=llamastack \
-p 5432:5432 \
enmotech/opengauss:latest
- name: Wait for OpenGauss to be ready
if: matrix.vector-io-provider == 'remote::opengauss'
run: |
echo "Waiting for OpenGauss to be ready..."
for i in {1..30}; do
if docker exec opengauss gsql -d llamastack -U llamastack -W Enmo@123 -c "SELECT version();" > /dev/null 2>&1; then
echo "OpenGauss is ready!"
break
fi
echo "Not ready yet... ($i)"
sleep 2
done
- name: Setup Qdrant - name: Setup Qdrant
if: matrix.vector-io-provider == 'remote::qdrant' if: matrix.vector-io-provider == 'remote::qdrant'
run: | run: |
@ -166,6 +190,12 @@ jobs:
QDRANT_URL: ${{ matrix.vector-io-provider == 'remote::qdrant' && 'http://localhost:6333' || '' }} QDRANT_URL: ${{ matrix.vector-io-provider == 'remote::qdrant' && 'http://localhost:6333' || '' }}
ENABLE_WEAVIATE: ${{ matrix.vector-io-provider == 'remote::weaviate' && 'true' || '' }} ENABLE_WEAVIATE: ${{ matrix.vector-io-provider == 'remote::weaviate' && 'true' || '' }}
WEAVIATE_CLUSTER_URL: ${{ matrix.vector-io-provider == 'remote::weaviate' && 'localhost:8080' || '' }} WEAVIATE_CLUSTER_URL: ${{ matrix.vector-io-provider == 'remote::weaviate' && 'localhost:8080' || '' }}
ENABLE_OPENGAUSS: ${{ matrix.vector-io-provider == 'remote::opengauss' && 'true' || '' }}
OPENGAUSS_HOST: ${{ matrix.vector-io-provider == 'remote::opengauss' && 'localhost' || '' }}
OPENGAUSS_PORT: ${{ matrix.vector-io-provider == 'remote::opengauss' && '5432' || '' }}
OPENGAUSS_DB: ${{ matrix.vector-io-provider == 'remote::opengauss' && 'llamastack' || '' }}
OPENGAUSS_USER: ${{ matrix.vector-io-provider == 'remote::opengauss' && 'llamastack' || '' }}
OPENGAUSS_PASSWORD: ${{ matrix.vector-io-provider == 'remote::opengauss' && 'Enmo@123' || '' }}
run: | run: |
uv run --no-sync \ uv run --no-sync \
pytest -sv --stack-config="files=inline::localfs,inference=inline::sentence-transformers,vector_io=${{ matrix.vector-io-provider }}" \ pytest -sv --stack-config="files=inline::localfs,inference=inline::sentence-transformers,vector_io=${{ matrix.vector-io-provider }}" \

0
docs/source/distributions/k8s-benchmark/benchmark.py Normal file → Executable file
View file

View file

@ -4,12 +4,12 @@
Agents API for creating and interacting with agentic systems. Agents API for creating and interacting with agentic systems.
Main functionalities provided by this API: Main functionalities provided by this API:
- Create agents with specific instructions and ability to use tools. - Create agents with specific instructions and ability to use tools.
- Interactions with agents are grouped into sessions ("threads"), and each interaction is called a "turn". - Interactions with agents are grouped into sessions ("threads"), and each interaction is called a "turn".
- Agents can be provided with various tools (see the ToolGroups and ToolRuntime APIs for more details). - Agents can be provided with various tools (see the ToolGroups and ToolRuntime APIs for more details).
- Agents can be provided with various shields (see the Safety API for more details). - Agents can be provided with various shields (see the Safety API for more details).
- Agents can also use Memory to retrieve information from knowledge bases. See the RAG Tool and Vector IO APIs for more details. - Agents can also use Memory to retrieve information from knowledge bases. See the RAG Tool and Vector IO APIs for more details.
This section contains documentation for all available providers for the **agents** API. This section contains documentation for all available providers for the **agents** API.

View file

@ -4,11 +4,11 @@
Protocol for batch processing API operations. Protocol for batch processing API operations.
The Batches API enables efficient processing of multiple requests in a single operation, The Batches API enables efficient processing of multiple requests in a single operation,
particularly useful for processing large datasets, batch evaluation workflows, and particularly useful for processing large datasets, batch evaluation workflows, and
cost-effective inference at scale. cost-effective inference at scale.
Note: This API is currently under active development and may undergo changes. Note: This API is currently under active development and may undergo changes.
This section contains documentation for all available providers for the **batches** API. This section contains documentation for all available providers for the **batches** API.

View file

@ -4,9 +4,9 @@
Llama Stack Inference API for generating completions, chat completions, and embeddings. Llama Stack Inference API for generating completions, chat completions, and embeddings.
This API provides the raw interface to the underlying models. Two kinds of models are supported: This API provides the raw interface to the underlying models. Two kinds of models are supported:
- LLM models: these models generate "raw" and "chat" (conversational) completions. - LLM models: these models generate "raw" and "chat" (conversational) completions.
- Embedding models: these models generate embeddings to be used for semantic search. - Embedding models: these models generate embeddings to be used for semantic search.
This section contains documentation for all available providers for the **inference** API. This section contains documentation for all available providers for the **inference** API.

View file

@ -18,6 +18,7 @@ inline_sqlite-vec
inline_sqlite_vec inline_sqlite_vec
remote_chromadb remote_chromadb
remote_milvus remote_milvus
remote_opengauss
remote_pgvector remote_pgvector
remote_qdrant remote_qdrant
remote_weaviate remote_weaviate

View file

@ -0,0 +1,58 @@
# remote::opengauss
## Description
[OpenGauss](https://opengauss.org/en/) is a remote vector database provider for Llama Stack. It
allows you to store and query vectors directly in memory.
That means you'll get fast and efficient vector retrieval.
## Features
- Easy to use
- Fully integrated with Llama Stack
## Usage
To use OpenGauss in your Llama Stack project, follow these steps:
1. Install the necessary dependencies.
2. Configure your Llama Stack project to use OpenGauss.
3. Start storing and querying vectors.
## Installation
You can install OpenGauss using docker:
```bash
docker pull opengauss/opengauss:latest
```
## Documentation
See [OpenGauss' documentation](https://docs.opengauss.org/en/docs/5.0.0/docs/GettingStarted/understanding-opengauss.html) for more details about OpenGauss in general.
## Configuration
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `host` | `str \| None` | No | localhost | |
| `port` | `int \| None` | No | 5432 | |
| `db` | `str \| None` | No | postgres | |
| `user` | `str \| None` | No | postgres | |
| `password` | `str \| None` | No | mysecretpassword | |
| `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig, annotation=NoneType, required=False, default='sqlite', discriminator='type'` | No | | Config for KV store backend (SQLite only for now) |
## Sample Configuration
```yaml
host: ${env.OPENGAUSS_HOST:=localhost}
port: ${env.OPENGAUSS_PORT:=5432}
db: ${env.OPENGAUSS_DB}
user: ${env.OPENGAUSS_USER}
password: ${env.OPENGAUSS_PASSWORD}
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/opengauss_registry.db
```

View file

@ -426,6 +426,44 @@ See [PGVector's documentation](https://github.com/pgvector/pgvector) for more de
api_dependencies=[Api.inference], api_dependencies=[Api.inference],
optional_api_dependencies=[Api.files], optional_api_dependencies=[Api.files],
), ),
remote_provider_spec(
Api.vector_io,
AdapterSpec(
adapter_type="opengauss",
pip_packages=["psycopg2-binary"],
module="llama_stack.providers.remote.vector_io.opengauss",
config_class="llama_stack.providers.remote.vector_io.opengauss.OpenGaussVectorIOConfig",
description="""
[OpenGauss](https://opengauss.org/en/) is a remote vector database provider for Llama Stack. It
allows you to store and query vectors directly in memory.
That means you'll get fast and efficient vector retrieval.
## Features
- Easy to use
- Fully integrated with Llama Stack
## Usage
To use OpenGauss in your Llama Stack project, follow these steps:
1. Install the necessary dependencies.
2. Configure your Llama Stack project to use OpenGauss.
3. Start storing and querying vectors.
## Installation
You can install OpenGauss using docker:
```bash
docker pull opengauss/opengauss:latest
```
## Documentation
See [OpenGauss' documentation](https://docs.opengauss.org/en/docs/5.0.0/docs/GettingStarted/understanding-opengauss.html) for more details about OpenGauss in general.
""",
),
api_dependencies=[Api.inference],
),
remote_provider_spec( remote_provider_spec(
Api.vector_io, Api.vector_io,
AdapterSpec( AdapterSpec(

View file

@ -0,0 +1,18 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.providers.datatypes import Api
from .config import OpenGaussVectorIOConfig
async def get_adapter_impl(config: OpenGaussVectorIOConfig, deps):
from .opengauss import OpenGaussVectorIOAdapter
files_api = deps.get(Api.files)
impl = OpenGaussVectorIOAdapter(config, deps[Api.inference], files_api)
await impl.initialize()
return impl

View file

@ -0,0 +1,48 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from pydantic import BaseModel, Field
from llama_stack.providers.utils.kvstore.config import (
KVStoreConfig,
SqliteKVStoreConfig,
)
from llama_stack.schema_utils import json_schema_type
@json_schema_type
class OpenGaussVectorIOConfig(BaseModel):
host: str | None = Field(default="localhost")
port: int | None = Field(default=5432)
db: str | None = Field(default="postgres")
user: str | None = Field(default="postgres")
password: str | None = Field(default="mysecretpassword")
kvstore: KVStoreConfig | None = Field(description="Config for KV store backend (SQLite only for now)", default=None)
@classmethod
def sample_run_config(
cls,
__distro_dir__: str,
host: str = "${env.OPENGAUSS_HOST:=localhost}",
port: str = "${env.OPENGAUSS_PORT:=5432}",
db: str = "${env.OPENGAUSS_DB}",
user: str = "${env.OPENGAUSS_USER}",
password: str = "${env.OPENGAUSS_PASSWORD}",
**kwargs: Any,
) -> dict[str, Any]:
return {
"host": host,
"port": port,
"db": db,
"user": user,
"password": password,
"kvstore": SqliteKVStoreConfig.sample_run_config(
__distro_dir__=__distro_dir__,
db_name="opengauss_registry.db",
),
}

View file

@ -0,0 +1,286 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
from typing import Any
import psycopg2
from numpy.typing import NDArray
from psycopg2 import sql
from psycopg2.extras import Json, execute_values
from pydantic import BaseModel, TypeAdapter
from llama_stack.apis.common.errors import VectorStoreNotFoundError
from llama_stack.apis.files.files import Files
from llama_stack.apis.inference import InterleavedContent
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import (
Chunk,
QueryChunksResponse,
VectorIO,
)
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.kvstore.api import KVStore
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
from llama_stack.providers.utils.memory.vector_store import (
ChunkForDeletion,
EmbeddingIndex,
VectorDBWithIndex,
)
from .config import OpenGaussVectorIOConfig
log = logging.getLogger(__name__)
VERSION = "v3"
VECTOR_DBS_PREFIX = f"vector_dbs:opengauss:{VERSION}::"
VECTOR_INDEX_PREFIX = f"vector_index:opengauss:{VERSION}::"
OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:opengauss:{VERSION}::"
OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:opengauss:{VERSION}::"
OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_contents:opengauss:{VERSION}::"
def upsert_models(conn, keys_models: list[tuple[str, BaseModel]]):
with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
query = sql.SQL(
"""
MERGE INTO metadata_store AS target
USING (VALUES %s) AS source (key, data)
ON (target.key = source.key)
WHEN MATCHED THEN
UPDATE SET data = source.data
WHEN NOT MATCHED THEN
INSERT (key, data) VALUES (source.key, source.data);
"""
)
values = [(key, Json(model.model_dump())) for key, model in keys_models]
execute_values(cur, query, values, template="(%s, %s::JSONB)")
def load_models(cur, cls):
cur.execute("SELECT key, data FROM metadata_store")
rows = cur.fetchall()
return [TypeAdapter(cls).validate_python(row["data"]) for row in rows]
class OpenGaussIndex(EmbeddingIndex):
def __init__(self, vector_db: VectorDB, dimension: int, conn, kvstore: KVStore | None = None):
self.conn = conn
with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
sanitized_identifier = vector_db.identifier.replace("-", "_")
self.table_name = f"vector_store_{sanitized_identifier}"
self.kvstore = kvstore
cur.execute(
f"""
CREATE TABLE IF NOT EXISTS {self.table_name} (
id TEXT PRIMARY KEY,
document JSONB,
embedding vector({dimension})
)
"""
)
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray):
assert len(chunks) == len(embeddings), (
f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
)
values = []
for i, chunk in enumerate(chunks):
values.append(
(
f"{chunk.chunk_id}",
Json(chunk.model_dump()),
embeddings[i].tolist(),
)
)
query = sql.SQL(
f"""
MERGE INTO {self.table_name} AS target
USING (VALUES %s) AS source (id, document, embedding)
ON (target.id = source.id)
WHEN MATCHED THEN
UPDATE SET document = source.document, embedding = source.embedding
WHEN NOT MATCHED THEN
INSERT (id, document, embedding) VALUES (source.id, source.document, source.embedding);
"""
)
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
execute_values(cur, query, values, template="(%s, %s::JSONB, %s::VECTOR)")
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
cur.execute(
f"""
SELECT document, embedding <=> %s::VECTOR AS distance
FROM {self.table_name}
ORDER BY distance
LIMIT %s
""",
(embedding.tolist(), k),
)
results = cur.fetchall()
chunks = []
scores = []
for doc, dist in results:
score = 1.0 / float(dist) if dist != 0 else float("inf")
if score < score_threshold:
continue
chunks.append(Chunk(**doc))
scores.append(score)
return QueryChunksResponse(chunks=chunks, scores=scores)
async def query_keyword(
self,
query_string: str,
k: int,
score_threshold: float,
) -> QueryChunksResponse:
raise NotImplementedError("Keyword search is not supported in OpenGauss")
async def query_hybrid(
self,
embedding: NDArray,
query_string: str,
k: int,
score_threshold: float,
reranker_type: str,
reranker_params: dict[str, Any] | None = None,
) -> QueryChunksResponse:
raise NotImplementedError("Hybrid search is not supported in OpenGauss")
async def delete(self):
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")
async def delete_chunks(self, chunks_for_deletion: list[ChunkForDeletion]) -> None:
"""Remove chunks from the OpenGauss table."""
chunk_ids = [c.chunk_id for c in chunks_for_deletion]
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
cur.execute(f"DELETE FROM {self.table_name} WHERE id = ANY(%s)", (chunk_ids,))
class OpenGaussVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
def __init__(
self,
config: OpenGaussVectorIOConfig,
inference_api: Any,
files_api: Files | None = None,
) -> None:
self.config = config
self.inference_api = inference_api
self.conn = None
self.cache: dict[str, VectorDBWithIndex] = {}
self.files_api = files_api
self.kvstore: KVStore | None = None
self.vector_db_store = None
self.openai_vector_store: dict[str, dict[str, Any]] = {}
self.metadatadata_collection_name = "openai_vector_stores_metadata"
async def initialize(self) -> None:
log.info(f"Initializing OpenGauss memory adapter with config: {self.config}")
if self.config.kvstore is not None:
self.kvstore = await kvstore_impl(self.config.kvstore)
await self.initialize_openai_vector_stores()
try:
self.conn = psycopg2.connect(
host=self.config.host,
port=self.config.port,
database=self.config.db,
user=self.config.user,
password=self.config.password,
)
if self.conn:
self.conn.autocommit = True
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
cur.execute("SELECT version();")
version = cur.fetchone()[0]
log.info(f"OpenGauss server version: {version}")
log.info("Assuming native vector support is enabled in this OpenGauss instance.")
cur.execute(
"""
CREATE TABLE IF NOT EXISTS metadata_store (
key TEXT PRIMARY KEY,
data JSONB
)
"""
)
except Exception as e:
log.exception("Could not connect to OpenGauss database server")
raise RuntimeError("Could not connect to OpenGauss database server") from e
async def shutdown(self) -> None:
if self.conn is not None:
self.conn.close()
log.info("Connection to OpenGauss database server closed")
async def register_vector_db(self, vector_db: VectorDB) -> None:
assert self.kvstore is not None
upsert_models(self.conn, [(vector_db.identifier, vector_db)])
index = VectorDBWithIndex(
vector_db,
index=OpenGaussIndex(vector_db, vector_db.embedding_dimension, self.conn, kvstore=self.kvstore),
inference_api=self.inference_api,
)
self.cache[vector_db.identifier] = index
async def unregister_vector_db(self, vector_db_id: str) -> None:
if vector_db_id in self.cache:
await self.cache[vector_db_id].index.delete()
del self.cache[vector_db_id]
assert self.kvstore is not None
await self.kvstore.delete(key=f"{VECTOR_DBS_PREFIX}{vector_db_id}")
async def insert_chunks(
self,
vector_db_id: str,
chunks: list[Chunk],
ttl_seconds: int | None = None,
) -> None:
index = await self._get_and_cache_vector_db_index(vector_db_id)
await index.insert_chunks(chunks)
async def query_chunks(
self,
vector_db_id: str,
query: InterleavedContent,
params: dict[str, Any] | None = None,
) -> QueryChunksResponse:
index = await self._get_and_cache_vector_db_index(vector_db_id)
return await index.query_chunks(query, params)
async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> VectorDBWithIndex:
if vector_db_id in self.cache:
return self.cache[vector_db_id]
if self.vector_db_store is None:
raise RuntimeError("Vector DB store not initialized")
vector_db = self.vector_db_store.get_vector_db(vector_db_id)
if vector_db is None:
raise VectorStoreNotFoundError(vector_db_id)
index = OpenGaussIndex(vector_db, vector_db.embedding_dimension, self.conn)
self.cache[vector_db_id] = VectorDBWithIndex(vector_db, index, self.inference_api)
return self.cache[vector_db_id]
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
"""Delete chunks from an OpenGauss vector store."""
index = await self._get_and_cache_vector_db_index(store_id)
if not index:
raise VectorStoreNotFoundError(store_id)
await index.index.delete_chunks(chunks_for_deletion)

View file

@ -27,6 +27,7 @@ def skip_if_provider_doesnt_support_openai_vector_stores(client_with_models):
"inline::milvus", "inline::milvus",
"inline::chromadb", "inline::chromadb",
"remote::pgvector", "remote::pgvector",
"remote::opengauss",
"remote::chromadb", "remote::chromadb",
"remote::qdrant", "remote::qdrant",
"inline::qdrant", "inline::qdrant",
@ -48,6 +49,7 @@ def skip_if_provider_doesnt_support_openai_vector_stores_search(client_with_mode
"inline::chromadb", "inline::chromadb",
"inline::qdrant", "inline::qdrant",
"remote::pgvector", "remote::pgvector",
"remote::opengauss",
"remote::chromadb", "remote::chromadb",
"remote::weaviate", "remote::weaviate",
"remote::qdrant", "remote::qdrant",

View file

@ -4,7 +4,9 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import os
import random import random
from unittest.mock import AsyncMock
import numpy as np import numpy as np
import pytest import pytest
@ -22,6 +24,8 @@ from llama_stack.providers.inline.vector_io.sqlite_vec import SQLiteVectorIOConf
from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import SQLiteVecIndex, SQLiteVecVectorIOAdapter from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import SQLiteVecIndex, SQLiteVecVectorIOAdapter
from llama_stack.providers.remote.vector_io.chroma.chroma import ChromaIndex, ChromaVectorIOAdapter, maybe_await from llama_stack.providers.remote.vector_io.chroma.chroma import ChromaIndex, ChromaVectorIOAdapter, maybe_await
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex, MilvusVectorIOAdapter from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex, MilvusVectorIOAdapter
from llama_stack.providers.remote.vector_io.opengauss.config import OpenGaussVectorIOConfig
from llama_stack.providers.remote.vector_io.opengauss.opengauss import OpenGaussIndex, OpenGaussVectorIOAdapter
from llama_stack.providers.remote.vector_io.qdrant.qdrant import QdrantVectorIOAdapter from llama_stack.providers.remote.vector_io.qdrant.qdrant import QdrantVectorIOAdapter
EMBEDDING_DIMENSION = 384 EMBEDDING_DIMENSION = 384
@ -29,7 +33,7 @@ COLLECTION_PREFIX = "test_collection"
MILVUS_ALIAS = "test_milvus" MILVUS_ALIAS = "test_milvus"
@pytest.fixture(params=["milvus", "sqlite_vec", "faiss", "chroma"]) @pytest.fixture(params=["milvus", "sqlite_vec", "faiss", "chroma", "opengauss"])
def vector_provider(request): def vector_provider(request):
return request.param return request.param
@ -333,6 +337,92 @@ async def qdrant_vec_index(qdrant_vec_db_path, embedding_dimension):
await index.delete() await index.delete()
@pytest.fixture
def opengauss_vec_db_path():
return {
"host": "localhost",
"port": 5432,
"db": "test_db",
"user": "test_user",
"password": "test_password",
}
@pytest.fixture
async def opengauss_vec_index(embedding_dimension, opengauss_vec_db_path):
mock_conn = AsyncMock()
mock_cursor = AsyncMock()
mock_conn.cursor.return_value.__enter__.return_value = mock_cursor
vector_db = VectorDB(
identifier=f"test_opengauss_db_{np.random.randint(1e6)}",
provider_id="opengauss",
embedding_model="test_model",
embedding_dimension=embedding_dimension,
)
if all(
os.getenv(var)
for var in ["OPENGAUSS_HOST", "OPENGAUSS_PORT", "OPENGAUSS_DB", "OPENGAUSS_USER", "OPENGAUSS_PASSWORD"]
):
import psycopg2
real_conn = psycopg2.connect(**opengauss_vec_db_path)
real_conn.autocommit = True
index = OpenGaussIndex(vector_db, embedding_dimension, real_conn)
yield index
await index.delete()
real_conn.close()
else:
index = OpenGaussIndex(vector_db, embedding_dimension, mock_conn)
yield index
@pytest.fixture
async def opengauss_vec_adapter(mock_inference_api, embedding_dimension, tmp_path_factory):
temp_dir = tmp_path_factory.getbasetemp()
kv_db_path = str(temp_dir / f"opengauss_kv_{np.random.randint(1e6)}.db")
config = OpenGaussVectorIOConfig(
host=os.getenv("OPENGAUSS_HOST", "localhost"),
port=int(os.getenv("OPENGAUSS_PORT", "5432")),
db=os.getenv("OPENGAUSS_DB", "test_db"),
user=os.getenv("OPENGAUSS_USER", "test_user"),
password=os.getenv("OPENGAUSS_PASSWORD", "test_password"),
kvstore=SqliteKVStoreConfig(db_path=kv_db_path),
)
if all(
os.getenv(var)
for var in ["OPENGAUSS_HOST", "OPENGAUSS_PORT", "OPENGAUSS_DB", "OPENGAUSS_USER", "OPENGAUSS_PASSWORD"]
):
adapter = OpenGaussVectorIOAdapter(config, mock_inference_api)
await adapter.initialize()
collection_id = f"opengauss_test_collection_{np.random.randint(1e6)}"
await adapter.register_vector_db(
VectorDB(
identifier=collection_id,
provider_id="opengauss",
embedding_model="test_model",
embedding_dimension=embedding_dimension,
)
)
adapter.test_collection_id = collection_id
yield adapter
try:
await adapter.unregister_vector_db(collection_id)
except Exception:
pass
await adapter.shutdown()
if os.path.exists(kv_db_path):
os.remove(kv_db_path)
else:
pytest.skip("OpenGauss connection not available for integration testing")
@pytest.fixture @pytest.fixture
def vector_io_adapter(vector_provider, request): def vector_io_adapter(vector_provider, request):
"""Returns the appropriate vector IO adapter based on the provider parameter.""" """Returns the appropriate vector IO adapter based on the provider parameter."""
@ -342,6 +432,7 @@ def vector_io_adapter(vector_provider, request):
"sqlite_vec": "sqlite_vec_adapter", "sqlite_vec": "sqlite_vec_adapter",
"chroma": "chroma_vec_adapter", "chroma": "chroma_vec_adapter",
"qdrant": "qdrant_vec_adapter", "qdrant": "qdrant_vec_adapter",
"opengauss": "opengauss_vec_adapter",
} }
return request.getfixturevalue(vector_provider_dict[vector_provider]) return request.getfixturevalue(vector_provider_dict[vector_provider])

View file

@ -0,0 +1,215 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
import random
from unittest.mock import AsyncMock
import numpy as np
import pytest
from llama_stack.apis.inference import EmbeddingsResponse, Inference
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse
from llama_stack.providers.remote.vector_io.opengauss.config import (
OpenGaussVectorIOConfig,
)
from llama_stack.providers.remote.vector_io.opengauss.opengauss import (
OpenGaussIndex,
OpenGaussVectorIOAdapter,
)
from llama_stack.providers.utils.kvstore.config import (
SqliteKVStoreConfig,
)
# Skip all tests in this file if the required environment variables are not set.
pytestmark = pytest.mark.skipif(
not all(
os.getenv(var)
for var in [
"OPENGAUSS_HOST",
"OPENGAUSS_PORT",
"OPENGAUSS_DB",
"OPENGAUSS_USER",
"OPENGAUSS_PASSWORD",
]
),
reason="OpenGauss connection environment variables not set",
)
@pytest.fixture(scope="session")
def embedding_dimension() -> int:
return 128
@pytest.fixture
def sample_chunks():
"""Provides a list of sample chunks for testing."""
return [
Chunk(
content="The sky is blue.",
metadata={"document_id": "doc1", "topic": "nature"},
),
Chunk(
content="An apple a day keeps the doctor away.",
metadata={"document_id": "doc2", "topic": "health"},
),
Chunk(
content="Quantum computing is a new frontier.",
metadata={"document_id": "doc3", "topic": "technology"},
),
]
@pytest.fixture
def sample_embeddings(embedding_dimension, sample_chunks):
"""Provides a deterministic set of embeddings for the sample chunks."""
# Use a fixed seed for reproducibility
rng = np.random.default_rng(42)
return rng.random((len(sample_chunks), embedding_dimension), dtype=np.float32)
@pytest.fixture
def mock_inference_api(sample_embeddings):
"""Mocks the inference API to return dummy embeddings."""
mock_api = AsyncMock(spec=Inference)
mock_api.embeddings = AsyncMock(return_value=EmbeddingsResponse(embeddings=sample_embeddings.tolist()))
return mock_api
@pytest.fixture
def vector_db(embedding_dimension):
"""Provides a sample VectorDB object for registration."""
return VectorDB(
identifier=f"test_db_{random.randint(1, 10000)}",
embedding_model="test_embedding_model",
embedding_dimension=embedding_dimension,
provider_id="opengauss",
)
@pytest.fixture
async def opengauss_connection():
"""Creates and manages a connection to the OpenGauss database."""
import psycopg2
conn = psycopg2.connect(
host=os.getenv("OPENGAUSS_HOST"),
port=int(os.getenv("OPENGAUSS_PORT")),
database=os.getenv("OPENGAUSS_DB"),
user=os.getenv("OPENGAUSS_USER"),
password=os.getenv("OPENGAUSS_PASSWORD"),
)
conn.autocommit = True
yield conn
conn.close()
@pytest.fixture
async def opengauss_index(opengauss_connection, vector_db):
"""Fixture to create and clean up an OpenGaussIndex instance."""
index = OpenGaussIndex(vector_db, vector_db.embedding_dimension, opengauss_connection)
yield index
await index.delete()
@pytest.fixture
async def opengauss_adapter(mock_inference_api):
"""Fixture to set up and tear down the OpenGaussVectorIOAdapter."""
config = OpenGaussVectorIOConfig(
host=os.getenv("OPENGAUSS_HOST"),
port=int(os.getenv("OPENGAUSS_PORT")),
db=os.getenv("OPENGAUSS_DB"),
user=os.getenv("OPENGAUSS_USER"),
password=os.getenv("OPENGAUSS_PASSWORD"),
kvstore=SqliteKVStoreConfig(db_name="opengauss_test.db"),
)
adapter = OpenGaussVectorIOAdapter(config, mock_inference_api)
await adapter.initialize()
yield adapter
if adapter.conn and not adapter.conn.closed:
for db_id in list(adapter.cache.keys()):
try:
await adapter.unregister_vector_db(db_id)
except Exception as e:
print(f"Error during cleanup of {db_id}: {e}")
await adapter.shutdown()
# Clean up the sqlite db file
if os.path.exists("opengauss_test.db"):
os.remove("opengauss_test.db")
class TestOpenGaussIndex:
async def test_add_and_query_vector(self, opengauss_index, sample_chunks, sample_embeddings):
"""Test adding chunks with embeddings and querying for the most similar one."""
await opengauss_index.add_chunks(sample_chunks, sample_embeddings)
# Query with the embedding of the first chunk
query_embedding = sample_embeddings[0]
response = await opengauss_index.query_vector(query_embedding, k=1, score_threshold=0.0)
assert isinstance(response, QueryChunksResponse)
assert len(response.chunks) == 1
assert response.chunks[0].content == sample_chunks[0].content
# The distance to itself should be 0, resulting in infinite score
assert response.scores[0] == float("inf")
class TestOpenGaussVectorIOAdapter:
async def test_initialization(self, opengauss_adapter):
"""Test that the adapter initializes and connects to the database."""
assert opengauss_adapter.conn is not None
assert not opengauss_adapter.conn.closed
async def test_register_and_unregister_vector_db(self, opengauss_adapter, vector_db):
"""Test the registration and unregistration of a vector database."""
await opengauss_adapter.register_vector_db(vector_db)
assert vector_db.identifier in opengauss_adapter.cache
table_name = opengauss_adapter.cache[vector_db.identifier].index.table_name
with opengauss_adapter.conn.cursor() as cur:
cur.execute(
"SELECT EXISTS (SELECT 1 FROM pg_tables WHERE schemaname = 'public' AND tablename = %s);",
(table_name,),
)
assert cur.fetchone()[0]
await opengauss_adapter.unregister_vector_db(vector_db.identifier)
assert vector_db.identifier not in opengauss_adapter.cache
with opengauss_adapter.conn.cursor() as cur:
cur.execute(
"SELECT EXISTS (SELECT 1 FROM pg_tables WHERE schemaname = 'public' AND tablename = %s);",
(table_name,),
)
assert not cur.fetchone()[0]
async def test_adapter_end_to_end_query(self, opengauss_adapter, vector_db, sample_chunks):
"""
Tests the full adapter flow: text query -> embedding generation -> vector search.
"""
# 1. Register the DB and insert chunks. The adapter will use the mocked
# inference_api to generate embeddings for these chunks.
await opengauss_adapter.register_vector_db(vector_db)
await opengauss_adapter.insert_chunks(vector_db.identifier, sample_chunks)
# 2. The user query is a text string.
query_text = "What is the color of the sky?"
# 3. The adapter will now internally call the (mocked) inference_api
# to get an embedding for the query_text.
response = await opengauss_adapter.query_chunks(vector_db.identifier, query_text)
# 4. Assertions
assert isinstance(response, QueryChunksResponse)
assert len(response.chunks) > 0
# Because the mocked inference_api returns random embeddings, we can't
# deterministically know which chunk is "closest". However, in a real
# integration test with a real model, this assertion would be more specific.
# For this unit test, we just confirm that the process completes and returns data.
assert response.chunks[0].content in [c.content for c in sample_chunks]