mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-15 06:00:48 +00:00
feat(vector-io): add OpenGauss vector database provider
Implement OpenGauss vector database integration for Llama Stack with the following features: - Add OpenGaussVectorIOAdapter for vector storage and retrieval - Support native vector similarity search operations - Provide configuration template for easy setup - Add comprehensive unit tests - Align with the latest Llama Stack provider architecture, including KVStore and OpenAI Vector Store Mixin. The implementation allows Llama Stack users to leverage OpenGauss as an enterprise-grade vector database for RAG applications.
This commit is contained in:
parent
803114180b
commit
3e65c70b2d
10 changed files with 790 additions and 2 deletions
|
@ -24,7 +24,7 @@ jobs:
|
|||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
vector-io-provider: ["inline::faiss", "inline::sqlite-vec", "inline::milvus", "remote::chromadb", "remote::pgvector", "remote::weaviate", "remote::qdrant"]
|
||||
vector-io-provider: ["inline::faiss", "inline::sqlite-vec", "inline::milvus", "remote::chromadb", "remote::pgvector", "remote::weaviate", "remote::qdrant", "remote::opengauss"]
|
||||
python-version: ["3.12", "3.13"]
|
||||
fail-fast: false # we want to run all tests regardless of failure
|
||||
|
||||
|
@ -86,6 +86,30 @@ jobs:
|
|||
PGPASSWORD=llamastack psql -h localhost -U llamastack -d llamastack \
|
||||
-c "CREATE EXTENSION IF NOT EXISTS vector;"
|
||||
|
||||
- name: Start OpenGauss DB
|
||||
if: matrix.vector-io-provider == 'remote::opengauss'
|
||||
run: |
|
||||
docker run -d \
|
||||
--name opengauss \
|
||||
-e GS_PASSWORD=Enmo@123 \
|
||||
-e GS_DB=llamastack \
|
||||
-e GS_USER=llamastack \
|
||||
-p 5432:5432 \
|
||||
enmotech/opengauss:latest
|
||||
|
||||
- name: Wait for OpenGauss to be ready
|
||||
if: matrix.vector-io-provider == 'remote::opengauss'
|
||||
run: |
|
||||
echo "Waiting for OpenGauss to be ready..."
|
||||
for i in {1..30}; do
|
||||
if docker exec opengauss gsql -d llamastack -U llamastack -W Enmo@123 -c "SELECT version();" > /dev/null 2>&1; then
|
||||
echo "OpenGauss is ready!"
|
||||
break
|
||||
fi
|
||||
echo "Not ready yet... ($i)"
|
||||
sleep 2
|
||||
done
|
||||
|
||||
- name: Setup Qdrant
|
||||
if: matrix.vector-io-provider == 'remote::qdrant'
|
||||
run: |
|
||||
|
@ -163,6 +187,12 @@ jobs:
|
|||
QDRANT_URL: ${{ matrix.vector-io-provider == 'remote::qdrant' && 'http://localhost:6333' || '' }}
|
||||
ENABLE_WEAVIATE: ${{ matrix.vector-io-provider == 'remote::weaviate' && 'true' || '' }}
|
||||
WEAVIATE_CLUSTER_URL: ${{ matrix.vector-io-provider == 'remote::weaviate' && 'localhost:8080' || '' }}
|
||||
ENABLE_OPENGAUSS: ${{ matrix.vector-io-provider == 'remote::opengauss' && 'true' || '' }}
|
||||
OPENGAUSS_HOST: ${{ matrix.vector-io-provider == 'remote::opengauss' && 'localhost' || '' }}
|
||||
OPENGAUSS_PORT: ${{ matrix.vector-io-provider == 'remote::opengauss' && '5432' || '' }}
|
||||
OPENGAUSS_DB: ${{ matrix.vector-io-provider == 'remote::opengauss' && 'llamastack' || '' }}
|
||||
OPENGAUSS_USER: ${{ matrix.vector-io-provider == 'remote::opengauss' && 'llamastack' || '' }}
|
||||
OPENGAUSS_PASSWORD: ${{ matrix.vector-io-provider == 'remote::opengauss' && 'Enmo@123' || '' }}
|
||||
run: |
|
||||
uv run pytest -sv --stack-config="inference=inline::sentence-transformers,vector_io=${{ matrix.vector-io-provider }}" \
|
||||
tests/integration/vector_io \
|
||||
|
|
|
@ -18,6 +18,7 @@ inline_sqlite-vec
|
|||
inline_sqlite_vec
|
||||
remote_chromadb
|
||||
remote_milvus
|
||||
remote_opengauss
|
||||
remote_pgvector
|
||||
remote_qdrant
|
||||
remote_weaviate
|
||||
|
|
58
docs/source/providers/vector_io/remote_opengauss.md
Normal file
58
docs/source/providers/vector_io/remote_opengauss.md
Normal file
|
@ -0,0 +1,58 @@
|
|||
# remote::opengauss
|
||||
|
||||
## Description
|
||||
|
||||
|
||||
[OpenGauss](https://opengauss.org/en/) is a remote vector database provider for Llama Stack. It
|
||||
allows you to store and query vectors directly in memory.
|
||||
That means you'll get fast and efficient vector retrieval.
|
||||
|
||||
## Features
|
||||
|
||||
- Easy to use
|
||||
- Fully integrated with Llama Stack
|
||||
|
||||
## Usage
|
||||
|
||||
To use OpenGauss in your Llama Stack project, follow these steps:
|
||||
|
||||
1. Install the necessary dependencies.
|
||||
2. Configure your Llama Stack project to use OpenGauss.
|
||||
3. Start storing and querying vectors.
|
||||
|
||||
## Installation
|
||||
|
||||
You can install OpenGauss using docker:
|
||||
|
||||
```bash
|
||||
docker pull opengauss/opengauss:latest
|
||||
```
|
||||
## Documentation
|
||||
See [OpenGauss' documentation](https://docs.opengauss.org/en/docs/5.0.0/docs/GettingStarted/understanding-opengauss.html) for more details about OpenGauss in general.
|
||||
|
||||
|
||||
## Configuration
|
||||
|
||||
| Field | Type | Required | Default | Description |
|
||||
|-------|------|----------|---------|-------------|
|
||||
| `host` | `str \| None` | No | localhost | |
|
||||
| `port` | `int \| None` | No | 5432 | |
|
||||
| `db` | `str \| None` | No | postgres | |
|
||||
| `user` | `str \| None` | No | postgres | |
|
||||
| `password` | `str \| None` | No | mysecretpassword | |
|
||||
| `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig, annotation=NoneType, required=False, default='sqlite', discriminator='type'` | No | | Config for KV store backend (SQLite only for now) |
|
||||
|
||||
## Sample Configuration
|
||||
|
||||
```yaml
|
||||
host: ${env.OPENGAUSS_HOST:=localhost}
|
||||
port: ${env.OPENGAUSS_PORT:=5432}
|
||||
db: ${env.OPENGAUSS_DB}
|
||||
user: ${env.OPENGAUSS_USER}
|
||||
password: ${env.OPENGAUSS_PASSWORD}
|
||||
kvstore:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/opengauss_registry.db
|
||||
|
||||
```
|
||||
|
|
@ -424,6 +424,44 @@ See [PGVector's documentation](https://github.com/pgvector/pgvector) for more de
|
|||
api_dependencies=[Api.inference],
|
||||
optional_api_dependencies=[Api.files],
|
||||
),
|
||||
remote_provider_spec(
|
||||
Api.vector_io,
|
||||
AdapterSpec(
|
||||
adapter_type="opengauss",
|
||||
pip_packages=["psycopg2-binary"],
|
||||
module="llama_stack.providers.remote.vector_io.opengauss",
|
||||
config_class="llama_stack.providers.remote.vector_io.opengauss.OpenGaussVectorIOConfig",
|
||||
description="""
|
||||
[OpenGauss](https://opengauss.org/en/) is a remote vector database provider for Llama Stack. It
|
||||
allows you to store and query vectors directly in memory.
|
||||
That means you'll get fast and efficient vector retrieval.
|
||||
|
||||
## Features
|
||||
|
||||
- Easy to use
|
||||
- Fully integrated with Llama Stack
|
||||
|
||||
## Usage
|
||||
|
||||
To use OpenGauss in your Llama Stack project, follow these steps:
|
||||
|
||||
1. Install the necessary dependencies.
|
||||
2. Configure your Llama Stack project to use OpenGauss.
|
||||
3. Start storing and querying vectors.
|
||||
|
||||
## Installation
|
||||
|
||||
You can install OpenGauss using docker:
|
||||
|
||||
```bash
|
||||
docker pull opengauss/opengauss:latest
|
||||
```
|
||||
## Documentation
|
||||
See [OpenGauss' documentation](https://docs.opengauss.org/en/docs/5.0.0/docs/GettingStarted/understanding-opengauss.html) for more details about OpenGauss in general.
|
||||
""",
|
||||
),
|
||||
api_dependencies=[Api.inference],
|
||||
),
|
||||
remote_provider_spec(
|
||||
Api.vector_io,
|
||||
AdapterSpec(
|
||||
|
|
18
llama_stack/providers/remote/vector_io/opengauss/__init__.py
Normal file
18
llama_stack/providers/remote/vector_io/opengauss/__init__.py
Normal file
|
@ -0,0 +1,18 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
from .config import OpenGaussVectorIOConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: OpenGaussVectorIOConfig, deps):
|
||||
from .opengauss import OpenGaussVectorIOAdapter
|
||||
|
||||
files_api = deps.get(Api.files)
|
||||
impl = OpenGaussVectorIOAdapter(config, deps[Api.inference], files_api)
|
||||
await impl.initialize()
|
||||
return impl
|
48
llama_stack/providers/remote/vector_io/opengauss/config.py
Normal file
48
llama_stack/providers/remote/vector_io/opengauss/config.py
Normal file
|
@ -0,0 +1,48 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.providers.utils.kvstore.config import (
|
||||
KVStoreConfig,
|
||||
SqliteKVStoreConfig,
|
||||
)
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenGaussVectorIOConfig(BaseModel):
|
||||
host: str | None = Field(default="localhost")
|
||||
port: int | None = Field(default=5432)
|
||||
db: str | None = Field(default="postgres")
|
||||
user: str | None = Field(default="postgres")
|
||||
password: str | None = Field(default="mysecretpassword")
|
||||
kvstore: KVStoreConfig | None = Field(description="Config for KV store backend (SQLite only for now)", default=None)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(
|
||||
cls,
|
||||
__distro_dir__: str,
|
||||
host: str = "${env.OPENGAUSS_HOST:=localhost}",
|
||||
port: str = "${env.OPENGAUSS_PORT:=5432}",
|
||||
db: str = "${env.OPENGAUSS_DB}",
|
||||
user: str = "${env.OPENGAUSS_USER}",
|
||||
password: str = "${env.OPENGAUSS_PASSWORD}",
|
||||
**kwargs: Any,
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"host": host,
|
||||
"port": port,
|
||||
"db": db,
|
||||
"user": user,
|
||||
"password": password,
|
||||
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__,
|
||||
db_name="opengauss_registry.db",
|
||||
),
|
||||
}
|
283
llama_stack/providers/remote/vector_io/opengauss/opengauss.py
Normal file
283
llama_stack/providers/remote/vector_io/opengauss/opengauss.py
Normal file
|
@ -0,0 +1,283 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import psycopg2
|
||||
from numpy.typing import NDArray
|
||||
from psycopg2 import sql
|
||||
from psycopg2.extras import Json, execute_values
|
||||
from pydantic import BaseModel, TypeAdapter
|
||||
|
||||
from llama_stack.apis.common.errors import VectorStoreNotFoundError
|
||||
from llama_stack.apis.files.files import Files
|
||||
from llama_stack.apis.inference import InterleavedContent
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import (
|
||||
Chunk,
|
||||
QueryChunksResponse,
|
||||
VectorIO,
|
||||
)
|
||||
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
EmbeddingIndex,
|
||||
VectorDBWithIndex,
|
||||
)
|
||||
|
||||
from .config import OpenGaussVectorIOConfig
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
VERSION = "v3"
|
||||
VECTOR_DBS_PREFIX = f"vector_dbs:opengauss:{VERSION}::"
|
||||
VECTOR_INDEX_PREFIX = f"vector_index:opengauss:{VERSION}::"
|
||||
OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:opengauss:{VERSION}::"
|
||||
OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:opengauss:{VERSION}::"
|
||||
OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_contents:opengauss:{VERSION}::"
|
||||
|
||||
|
||||
def upsert_models(conn, keys_models: list[tuple[str, BaseModel]]):
|
||||
with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
||||
query = sql.SQL(
|
||||
"""
|
||||
MERGE INTO metadata_store AS target
|
||||
USING (VALUES %s) AS source (key, data)
|
||||
ON (target.key = source.key)
|
||||
WHEN MATCHED THEN
|
||||
UPDATE SET data = source.data
|
||||
WHEN NOT MATCHED THEN
|
||||
INSERT (key, data) VALUES (source.key, source.data);
|
||||
"""
|
||||
)
|
||||
|
||||
values = [(key, Json(model.model_dump())) for key, model in keys_models]
|
||||
execute_values(cur, query, values, template="(%s, %s::JSONB)")
|
||||
|
||||
|
||||
def load_models(cur, cls):
|
||||
cur.execute("SELECT key, data FROM metadata_store")
|
||||
rows = cur.fetchall()
|
||||
return [TypeAdapter(cls).validate_python(row["data"]) for row in rows]
|
||||
|
||||
|
||||
class OpenGaussIndex(EmbeddingIndex):
|
||||
def __init__(self, vector_db: VectorDB, dimension: int, conn, kvstore: KVStore | None = None):
|
||||
self.conn = conn
|
||||
with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
||||
sanitized_identifier = vector_db.identifier.replace("-", "_")
|
||||
self.table_name = f"vector_store_{sanitized_identifier}"
|
||||
self.kvstore = kvstore
|
||||
|
||||
cur.execute(
|
||||
f"""
|
||||
CREATE TABLE IF NOT EXISTS {self.table_name} (
|
||||
id TEXT PRIMARY KEY,
|
||||
document JSONB,
|
||||
embedding vector({dimension})
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray):
|
||||
assert len(chunks) == len(embeddings), (
|
||||
f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
|
||||
)
|
||||
|
||||
values = []
|
||||
for i, chunk in enumerate(chunks):
|
||||
values.append(
|
||||
(
|
||||
f"{chunk.chunk_id}",
|
||||
Json(chunk.model_dump()),
|
||||
embeddings[i].tolist(),
|
||||
)
|
||||
)
|
||||
|
||||
query = sql.SQL(
|
||||
f"""
|
||||
MERGE INTO {self.table_name} AS target
|
||||
USING (VALUES %s) AS source (id, document, embedding)
|
||||
ON (target.id = source.id)
|
||||
WHEN MATCHED THEN
|
||||
UPDATE SET document = source.document, embedding = source.embedding
|
||||
WHEN NOT MATCHED THEN
|
||||
INSERT (id, document, embedding) VALUES (source.id, source.document, source.embedding);
|
||||
"""
|
||||
)
|
||||
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
||||
execute_values(cur, query, values, template="(%s, %s::JSONB, %s::VECTOR)")
|
||||
|
||||
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
||||
cur.execute(
|
||||
f"""
|
||||
SELECT document, embedding <=> %s::VECTOR AS distance
|
||||
FROM {self.table_name}
|
||||
ORDER BY distance
|
||||
LIMIT %s
|
||||
""",
|
||||
(embedding.tolist(), k),
|
||||
)
|
||||
results = cur.fetchall()
|
||||
|
||||
chunks = []
|
||||
scores = []
|
||||
for doc, dist in results:
|
||||
score = 1.0 / float(dist) if dist != 0 else float("inf")
|
||||
if score < score_threshold:
|
||||
continue
|
||||
chunks.append(Chunk(**doc))
|
||||
scores.append(score)
|
||||
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
async def query_keyword(
|
||||
self,
|
||||
query_string: str,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
) -> QueryChunksResponse:
|
||||
raise NotImplementedError("Keyword search is not supported in OpenGauss")
|
||||
|
||||
async def query_hybrid(
|
||||
self,
|
||||
embedding: NDArray,
|
||||
query_string: str,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
reranker_type: str,
|
||||
reranker_params: dict[str, Any] | None = None,
|
||||
) -> QueryChunksResponse:
|
||||
raise NotImplementedError("Hybrid search is not supported in OpenGauss")
|
||||
|
||||
async def delete(self):
|
||||
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
||||
cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")
|
||||
|
||||
async def delete_chunk(self, chunk_id: str) -> None:
|
||||
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
||||
cur.execute(f"DELETE FROM {self.table_name} WHERE id = %s", (chunk_id,))
|
||||
|
||||
|
||||
class OpenGaussVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
|
||||
def __init__(
|
||||
self,
|
||||
config: OpenGaussVectorIOConfig,
|
||||
inference_api: Any,
|
||||
files_api: Files | None = None,
|
||||
) -> None:
|
||||
self.config = config
|
||||
self.inference_api = inference_api
|
||||
self.conn = None
|
||||
self.cache: dict[str, VectorDBWithIndex] = {}
|
||||
self.files_api = files_api
|
||||
self.kvstore: KVStore | None = None
|
||||
self.vector_db_store = None
|
||||
self.openai_vector_store: dict[str, dict[str, Any]] = {}
|
||||
self.metadatadata_collection_name = "openai_vector_stores_metadata"
|
||||
|
||||
async def initialize(self) -> None:
|
||||
log.info(f"Initializing OpenGauss memory adapter with config: {self.config}")
|
||||
if self.config.kvstore is not None:
|
||||
self.kvstore = await kvstore_impl(self.config.kvstore)
|
||||
await self.initialize_openai_vector_stores()
|
||||
|
||||
try:
|
||||
self.conn = psycopg2.connect(
|
||||
host=self.config.host,
|
||||
port=self.config.port,
|
||||
database=self.config.db,
|
||||
user=self.config.user,
|
||||
password=self.config.password,
|
||||
)
|
||||
if self.conn:
|
||||
self.conn.autocommit = True
|
||||
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
||||
cur.execute("SELECT version();")
|
||||
version = cur.fetchone()[0]
|
||||
log.info(f"OpenGauss server version: {version}")
|
||||
log.info("Assuming native vector support is enabled in this OpenGauss instance.")
|
||||
|
||||
cur.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS metadata_store (
|
||||
key TEXT PRIMARY KEY,
|
||||
data JSONB
|
||||
)
|
||||
"""
|
||||
)
|
||||
except Exception as e:
|
||||
log.exception("Could not connect to OpenGauss database server")
|
||||
raise RuntimeError("Could not connect to OpenGauss database server") from e
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
if self.conn is not None:
|
||||
self.conn.close()
|
||||
log.info("Connection to OpenGauss database server closed")
|
||||
|
||||
async def register_vector_db(self, vector_db: VectorDB) -> None:
|
||||
assert self.kvstore is not None
|
||||
upsert_models(self.conn, [(vector_db.identifier, vector_db)])
|
||||
|
||||
index = VectorDBWithIndex(
|
||||
vector_db,
|
||||
index=OpenGaussIndex(vector_db, vector_db.embedding_dimension, self.conn, kvstore=self.kvstore),
|
||||
inference_api=self.inference_api,
|
||||
)
|
||||
self.cache[vector_db.identifier] = index
|
||||
|
||||
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
||||
if vector_db_id in self.cache:
|
||||
await self.cache[vector_db_id].index.delete()
|
||||
del self.cache[vector_db_id]
|
||||
|
||||
assert self.kvstore is not None
|
||||
await self.kvstore.delete(key=f"{VECTOR_DBS_PREFIX}{vector_db_id}")
|
||||
|
||||
async def insert_chunks(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
chunks: list[Chunk],
|
||||
ttl_seconds: int | None = None,
|
||||
) -> None:
|
||||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||
await index.insert_chunks(chunks)
|
||||
|
||||
async def query_chunks(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
query: InterleavedContent,
|
||||
params: dict[str, Any] | None = None,
|
||||
) -> QueryChunksResponse:
|
||||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||
return await index.query_chunks(query, params)
|
||||
|
||||
async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> VectorDBWithIndex:
|
||||
if vector_db_id in self.cache:
|
||||
return self.cache[vector_db_id]
|
||||
|
||||
if self.vector_db_store is None:
|
||||
raise RuntimeError("Vector DB store not initialized")
|
||||
|
||||
vector_db = self.vector_db_store.get_vector_db(vector_db_id)
|
||||
if vector_db is None:
|
||||
raise VectorStoreNotFoundError(vector_db_id)
|
||||
|
||||
index = OpenGaussIndex(vector_db, vector_db.embedding_dimension, self.conn)
|
||||
self.cache[vector_db_id] = VectorDBWithIndex(vector_db, index, self.inference_api)
|
||||
return self.cache[vector_db_id]
|
||||
|
||||
async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None:
|
||||
index = await self._get_and_cache_vector_db_index(store_id)
|
||||
if not index:
|
||||
raise VectorStoreNotFoundError(store_id)
|
||||
|
||||
for chunk_id in chunk_ids:
|
||||
await index.index.delete_chunk(chunk_id)
|
|
@ -26,6 +26,7 @@ def skip_if_provider_doesnt_support_openai_vector_stores(client_with_models):
|
|||
"inline::milvus",
|
||||
"inline::chromadb",
|
||||
"remote::pgvector",
|
||||
"remote::opengauss",
|
||||
"remote::chromadb",
|
||||
"remote::qdrant",
|
||||
"inline::qdrant",
|
||||
|
@ -47,6 +48,7 @@ def skip_if_provider_doesnt_support_openai_vector_stores_search(client_with_mode
|
|||
"inline::chromadb",
|
||||
"inline::qdrant",
|
||||
"remote::pgvector",
|
||||
"remote::opengauss",
|
||||
"remote::chromadb",
|
||||
"remote::weaviate",
|
||||
"remote::qdrant",
|
||||
|
|
|
@ -4,7 +4,9 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import random
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
@ -22,6 +24,8 @@ from llama_stack.providers.inline.vector_io.sqlite_vec import SQLiteVectorIOConf
|
|||
from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import SQLiteVecIndex, SQLiteVecVectorIOAdapter
|
||||
from llama_stack.providers.remote.vector_io.chroma.chroma import ChromaIndex, ChromaVectorIOAdapter, maybe_await
|
||||
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex, MilvusVectorIOAdapter
|
||||
from llama_stack.providers.remote.vector_io.opengauss.config import OpenGaussVectorIOConfig
|
||||
from llama_stack.providers.remote.vector_io.opengauss.opengauss import OpenGaussIndex, OpenGaussVectorIOAdapter
|
||||
from llama_stack.providers.remote.vector_io.qdrant.qdrant import QdrantVectorIOAdapter
|
||||
|
||||
EMBEDDING_DIMENSION = 384
|
||||
|
@ -29,7 +33,7 @@ COLLECTION_PREFIX = "test_collection"
|
|||
MILVUS_ALIAS = "test_milvus"
|
||||
|
||||
|
||||
@pytest.fixture(params=["milvus", "sqlite_vec", "faiss", "chroma"])
|
||||
@pytest.fixture(params=["milvus", "sqlite_vec", "faiss", "chroma", "opengauss"])
|
||||
def vector_provider(request):
|
||||
return request.param
|
||||
|
||||
|
@ -333,6 +337,92 @@ async def qdrant_vec_index(qdrant_vec_db_path, embedding_dimension):
|
|||
await index.delete()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def opengauss_vec_db_path():
|
||||
return {
|
||||
"host": "localhost",
|
||||
"port": 5432,
|
||||
"db": "test_db",
|
||||
"user": "test_user",
|
||||
"password": "test_password",
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def opengauss_vec_index(embedding_dimension, opengauss_vec_db_path):
|
||||
mock_conn = AsyncMock()
|
||||
mock_cursor = AsyncMock()
|
||||
mock_conn.cursor.return_value.__enter__.return_value = mock_cursor
|
||||
|
||||
vector_db = VectorDB(
|
||||
identifier=f"test_opengauss_db_{np.random.randint(1e6)}",
|
||||
provider_id="opengauss",
|
||||
embedding_model="test_model",
|
||||
embedding_dimension=embedding_dimension,
|
||||
)
|
||||
|
||||
if all(
|
||||
os.getenv(var)
|
||||
for var in ["OPENGAUSS_HOST", "OPENGAUSS_PORT", "OPENGAUSS_DB", "OPENGAUSS_USER", "OPENGAUSS_PASSWORD"]
|
||||
):
|
||||
import psycopg2
|
||||
|
||||
real_conn = psycopg2.connect(**opengauss_vec_db_path)
|
||||
real_conn.autocommit = True
|
||||
index = OpenGaussIndex(vector_db, embedding_dimension, real_conn)
|
||||
yield index
|
||||
await index.delete()
|
||||
real_conn.close()
|
||||
else:
|
||||
index = OpenGaussIndex(vector_db, embedding_dimension, mock_conn)
|
||||
yield index
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def opengauss_vec_adapter(mock_inference_api, embedding_dimension, tmp_path_factory):
|
||||
temp_dir = tmp_path_factory.getbasetemp()
|
||||
kv_db_path = str(temp_dir / f"opengauss_kv_{np.random.randint(1e6)}.db")
|
||||
|
||||
config = OpenGaussVectorIOConfig(
|
||||
host=os.getenv("OPENGAUSS_HOST", "localhost"),
|
||||
port=int(os.getenv("OPENGAUSS_PORT", "5432")),
|
||||
db=os.getenv("OPENGAUSS_DB", "test_db"),
|
||||
user=os.getenv("OPENGAUSS_USER", "test_user"),
|
||||
password=os.getenv("OPENGAUSS_PASSWORD", "test_password"),
|
||||
kvstore=SqliteKVStoreConfig(db_path=kv_db_path),
|
||||
)
|
||||
|
||||
if all(
|
||||
os.getenv(var)
|
||||
for var in ["OPENGAUSS_HOST", "OPENGAUSS_PORT", "OPENGAUSS_DB", "OPENGAUSS_USER", "OPENGAUSS_PASSWORD"]
|
||||
):
|
||||
adapter = OpenGaussVectorIOAdapter(config, mock_inference_api)
|
||||
await adapter.initialize()
|
||||
|
||||
collection_id = f"opengauss_test_collection_{np.random.randint(1e6)}"
|
||||
await adapter.register_vector_db(
|
||||
VectorDB(
|
||||
identifier=collection_id,
|
||||
provider_id="opengauss",
|
||||
embedding_model="test_model",
|
||||
embedding_dimension=embedding_dimension,
|
||||
)
|
||||
)
|
||||
adapter.test_collection_id = collection_id
|
||||
yield adapter
|
||||
|
||||
try:
|
||||
await adapter.unregister_vector_db(collection_id)
|
||||
except Exception:
|
||||
pass
|
||||
await adapter.shutdown()
|
||||
|
||||
if os.path.exists(kv_db_path):
|
||||
os.remove(kv_db_path)
|
||||
else:
|
||||
pytest.skip("OpenGauss connection not available for integration testing")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vector_io_adapter(vector_provider, request):
|
||||
"""Returns the appropriate vector IO adapter based on the provider parameter."""
|
||||
|
@ -342,6 +432,7 @@ def vector_io_adapter(vector_provider, request):
|
|||
"sqlite_vec": "sqlite_vec_adapter",
|
||||
"chroma": "chroma_vec_adapter",
|
||||
"qdrant": "qdrant_vec_adapter",
|
||||
"opengauss": "opengauss_vec_adapter",
|
||||
}
|
||||
return request.getfixturevalue(vector_provider_dict[vector_provider])
|
||||
|
||||
|
|
219
tests/unit/providers/vector_io/test_opengauss.py
Normal file
219
tests/unit/providers/vector_io/test_opengauss.py
Normal file
|
@ -0,0 +1,219 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import random
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from llama_stack.apis.inference import EmbeddingsResponse, Inference
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse
|
||||
from llama_stack.providers.remote.vector_io.opengauss.config import (
|
||||
OpenGaussVectorIOConfig,
|
||||
)
|
||||
from llama_stack.providers.remote.vector_io.opengauss.opengauss import (
|
||||
OpenGaussIndex,
|
||||
OpenGaussVectorIOAdapter,
|
||||
)
|
||||
from llama_stack.providers.utils.kvstore.config import (
|
||||
SqliteKVStoreConfig,
|
||||
)
|
||||
|
||||
# Skip all tests in this file if the required environment variables are not set.
|
||||
pytestmark = pytest.mark.skipif(
|
||||
not all(
|
||||
os.getenv(var)
|
||||
for var in [
|
||||
"OPENGAUSS_HOST",
|
||||
"OPENGAUSS_PORT",
|
||||
"OPENGAUSS_DB",
|
||||
"OPENGAUSS_USER",
|
||||
"OPENGAUSS_PASSWORD",
|
||||
]
|
||||
),
|
||||
reason="OpenGauss connection environment variables not set",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def embedding_dimension() -> int:
|
||||
return 128
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_chunks():
|
||||
"""Provides a list of sample chunks for testing."""
|
||||
return [
|
||||
Chunk(
|
||||
content="The sky is blue.",
|
||||
metadata={"document_id": "doc1", "topic": "nature"},
|
||||
),
|
||||
Chunk(
|
||||
content="An apple a day keeps the doctor away.",
|
||||
metadata={"document_id": "doc2", "topic": "health"},
|
||||
),
|
||||
Chunk(
|
||||
content="Quantum computing is a new frontier.",
|
||||
metadata={"document_id": "doc3", "topic": "technology"},
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_embeddings(embedding_dimension, sample_chunks):
|
||||
"""Provides a deterministic set of embeddings for the sample chunks."""
|
||||
# Use a fixed seed for reproducibility
|
||||
rng = np.random.default_rng(42)
|
||||
return rng.random((len(sample_chunks), embedding_dimension), dtype=np.float32)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_inference_api(sample_embeddings):
|
||||
"""Mocks the inference API to return dummy embeddings."""
|
||||
mock_api = AsyncMock(spec=Inference)
|
||||
mock_api.embeddings = AsyncMock(return_value=EmbeddingsResponse(embeddings=sample_embeddings.tolist()))
|
||||
return mock_api
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vector_db(embedding_dimension):
|
||||
"""Provides a sample VectorDB object for registration."""
|
||||
return VectorDB(
|
||||
identifier=f"test_db_{random.randint(1, 10000)}",
|
||||
embedding_model="test_embedding_model",
|
||||
embedding_dimension=embedding_dimension,
|
||||
provider_id="opengauss",
|
||||
)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def opengauss_connection():
|
||||
"""Creates and manages a connection to the OpenGauss database."""
|
||||
import psycopg2
|
||||
|
||||
conn = psycopg2.connect(
|
||||
host=os.getenv("OPENGAUSS_HOST"),
|
||||
port=int(os.getenv("OPENGAUSS_PORT")),
|
||||
database=os.getenv("OPENGAUSS_DB"),
|
||||
user=os.getenv("OPENGAUSS_USER"),
|
||||
password=os.getenv("OPENGAUSS_PASSWORD"),
|
||||
)
|
||||
conn.autocommit = True
|
||||
yield conn
|
||||
conn.close()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def opengauss_index(opengauss_connection, vector_db):
|
||||
"""Fixture to create and clean up an OpenGaussIndex instance."""
|
||||
index = OpenGaussIndex(vector_db, vector_db.embedding_dimension, opengauss_connection)
|
||||
yield index
|
||||
await index.delete()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def opengauss_adapter(mock_inference_api):
|
||||
"""Fixture to set up and tear down the OpenGaussVectorIOAdapter."""
|
||||
config = OpenGaussVectorIOConfig(
|
||||
host=os.getenv("OPENGAUSS_HOST"),
|
||||
port=int(os.getenv("OPENGAUSS_PORT")),
|
||||
db=os.getenv("OPENGAUSS_DB"),
|
||||
user=os.getenv("OPENGAUSS_USER"),
|
||||
password=os.getenv("OPENGAUSS_PASSWORD"),
|
||||
kvstore=SqliteKVStoreConfig(db_name="opengauss_test.db"),
|
||||
)
|
||||
adapter = OpenGaussVectorIOAdapter(config, mock_inference_api)
|
||||
await adapter.initialize()
|
||||
yield adapter
|
||||
if adapter.conn and not adapter.conn.closed:
|
||||
for db_id in list(adapter.cache.keys()):
|
||||
try:
|
||||
await adapter.unregister_vector_db(db_id)
|
||||
except Exception as e:
|
||||
print(f"Error during cleanup of {db_id}: {e}")
|
||||
await adapter.shutdown()
|
||||
# Clean up the sqlite db file
|
||||
if os.path.exists("opengauss_test.db"):
|
||||
os.remove("opengauss_test.db")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestOpenGaussIndex:
|
||||
async def test_add_and_query_vector(self, opengauss_index, sample_chunks, sample_embeddings):
|
||||
"""Test adding chunks with embeddings and querying for the most similar one."""
|
||||
await opengauss_index.add_chunks(sample_chunks, sample_embeddings)
|
||||
|
||||
# Query with the embedding of the first chunk
|
||||
query_embedding = sample_embeddings[0]
|
||||
response = await opengauss_index.query_vector(query_embedding, k=1, score_threshold=0.0)
|
||||
|
||||
assert isinstance(response, QueryChunksResponse)
|
||||
assert len(response.chunks) == 1
|
||||
assert response.chunks[0].content == sample_chunks[0].content
|
||||
# The distance to itself should be 0, resulting in infinite score
|
||||
assert response.scores[0] == float("inf")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestOpenGaussVectorIOAdapter:
|
||||
async def test_initialization(self, opengauss_adapter):
|
||||
"""Test that the adapter initializes and connects to the database."""
|
||||
assert opengauss_adapter.conn is not None
|
||||
assert not opengauss_adapter.conn.closed
|
||||
|
||||
async def test_register_and_unregister_vector_db(self, opengauss_adapter, vector_db):
|
||||
"""Test the registration and unregistration of a vector database."""
|
||||
await opengauss_adapter.register_vector_db(vector_db)
|
||||
assert vector_db.identifier in opengauss_adapter.cache
|
||||
|
||||
table_name = opengauss_adapter.cache[vector_db.identifier].index.table_name
|
||||
with opengauss_adapter.conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"SELECT EXISTS (SELECT 1 FROM pg_tables WHERE schemaname = 'public' AND tablename = %s);",
|
||||
(table_name,),
|
||||
)
|
||||
assert cur.fetchone()[0]
|
||||
|
||||
await opengauss_adapter.unregister_vector_db(vector_db.identifier)
|
||||
assert vector_db.identifier not in opengauss_adapter.cache
|
||||
|
||||
with opengauss_adapter.conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"SELECT EXISTS (SELECT 1 FROM pg_tables WHERE schemaname = 'public' AND tablename = %s);",
|
||||
(table_name,),
|
||||
)
|
||||
assert not cur.fetchone()[0]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_adapter_end_to_end_query(self, opengauss_adapter, vector_db, sample_chunks):
|
||||
"""
|
||||
Tests the full adapter flow: text query -> embedding generation -> vector search.
|
||||
"""
|
||||
# 1. Register the DB and insert chunks. The adapter will use the mocked
|
||||
# inference_api to generate embeddings for these chunks.
|
||||
await opengauss_adapter.register_vector_db(vector_db)
|
||||
await opengauss_adapter.insert_chunks(vector_db.identifier, sample_chunks)
|
||||
|
||||
# 2. The user query is a text string.
|
||||
query_text = "What is the color of the sky?"
|
||||
|
||||
# 3. The adapter will now internally call the (mocked) inference_api
|
||||
# to get an embedding for the query_text.
|
||||
response = await opengauss_adapter.query_chunks(vector_db.identifier, query_text)
|
||||
|
||||
# 4. Assertions
|
||||
assert isinstance(response, QueryChunksResponse)
|
||||
assert len(response.chunks) > 0
|
||||
|
||||
# Because the mocked inference_api returns random embeddings, we can't
|
||||
# deterministically know which chunk is "closest". However, in a real
|
||||
# integration test with a real model, this assertion would be more specific.
|
||||
# For this unit test, we just confirm that the process completes and returns data.
|
||||
assert response.chunks[0].content in [c.content for c in sample_chunks]
|
Loading…
Add table
Add a link
Reference in a new issue