mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-21 07:18:41 +00:00
feat(vector-io): add OpenGauss vector database provider
Implement OpenGauss vector database integration for Llama Stack with the following features: - Add OpenGaussVectorIOAdapter for vector storage and retrieval - Support native vector similarity search operations - Provide configuration template for easy setup - Add comprehensive unit tests - Align with the latest Llama Stack provider architecture, including KVStore and OpenAI Vector Store Mixin. The implementation allows Llama Stack users to leverage OpenGauss as an enterprise-grade vector database for RAG applications.
This commit is contained in:
parent
eb07a0f86a
commit
35a0a6cb7b
14 changed files with 802 additions and 15 deletions
|
|
@ -4,7 +4,9 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import random
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
|
@ -22,6 +24,8 @@ from llama_stack.providers.inline.vector_io.sqlite_vec import SQLiteVectorIOConf
|
|||
from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import SQLiteVecIndex, SQLiteVecVectorIOAdapter
|
||||
from llama_stack.providers.remote.vector_io.chroma.chroma import ChromaIndex, ChromaVectorIOAdapter, maybe_await
|
||||
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex, MilvusVectorIOAdapter
|
||||
from llama_stack.providers.remote.vector_io.opengauss.config import OpenGaussVectorIOConfig
|
||||
from llama_stack.providers.remote.vector_io.opengauss.opengauss import OpenGaussIndex, OpenGaussVectorIOAdapter
|
||||
from llama_stack.providers.remote.vector_io.qdrant.qdrant import QdrantVectorIOAdapter
|
||||
|
||||
EMBEDDING_DIMENSION = 384
|
||||
|
|
@ -29,7 +33,7 @@ COLLECTION_PREFIX = "test_collection"
|
|||
MILVUS_ALIAS = "test_milvus"
|
||||
|
||||
|
||||
@pytest.fixture(params=["milvus", "sqlite_vec", "faiss", "chroma"])
|
||||
@pytest.fixture(params=["milvus", "sqlite_vec", "faiss", "chroma", "opengauss"])
|
||||
def vector_provider(request):
|
||||
return request.param
|
||||
|
||||
|
|
@ -333,6 +337,92 @@ async def qdrant_vec_index(qdrant_vec_db_path, embedding_dimension):
|
|||
await index.delete()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def opengauss_vec_db_path():
|
||||
return {
|
||||
"host": "localhost",
|
||||
"port": 5432,
|
||||
"db": "test_db",
|
||||
"user": "test_user",
|
||||
"password": "test_password",
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def opengauss_vec_index(embedding_dimension, opengauss_vec_db_path):
|
||||
mock_conn = AsyncMock()
|
||||
mock_cursor = AsyncMock()
|
||||
mock_conn.cursor.return_value.__enter__.return_value = mock_cursor
|
||||
|
||||
vector_db = VectorDB(
|
||||
identifier=f"test_opengauss_db_{np.random.randint(1e6)}",
|
||||
provider_id="opengauss",
|
||||
embedding_model="test_model",
|
||||
embedding_dimension=embedding_dimension,
|
||||
)
|
||||
|
||||
if all(
|
||||
os.getenv(var)
|
||||
for var in ["OPENGAUSS_HOST", "OPENGAUSS_PORT", "OPENGAUSS_DB", "OPENGAUSS_USER", "OPENGAUSS_PASSWORD"]
|
||||
):
|
||||
import psycopg2
|
||||
|
||||
real_conn = psycopg2.connect(**opengauss_vec_db_path)
|
||||
real_conn.autocommit = True
|
||||
index = OpenGaussIndex(vector_db, embedding_dimension, real_conn)
|
||||
yield index
|
||||
await index.delete()
|
||||
real_conn.close()
|
||||
else:
|
||||
index = OpenGaussIndex(vector_db, embedding_dimension, mock_conn)
|
||||
yield index
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def opengauss_vec_adapter(mock_inference_api, embedding_dimension, tmp_path_factory):
|
||||
temp_dir = tmp_path_factory.getbasetemp()
|
||||
kv_db_path = str(temp_dir / f"opengauss_kv_{np.random.randint(1e6)}.db")
|
||||
|
||||
config = OpenGaussVectorIOConfig(
|
||||
host=os.getenv("OPENGAUSS_HOST", "localhost"),
|
||||
port=int(os.getenv("OPENGAUSS_PORT", "5432")),
|
||||
db=os.getenv("OPENGAUSS_DB", "test_db"),
|
||||
user=os.getenv("OPENGAUSS_USER", "test_user"),
|
||||
password=os.getenv("OPENGAUSS_PASSWORD", "test_password"),
|
||||
kvstore=SqliteKVStoreConfig(db_path=kv_db_path),
|
||||
)
|
||||
|
||||
if all(
|
||||
os.getenv(var)
|
||||
for var in ["OPENGAUSS_HOST", "OPENGAUSS_PORT", "OPENGAUSS_DB", "OPENGAUSS_USER", "OPENGAUSS_PASSWORD"]
|
||||
):
|
||||
adapter = OpenGaussVectorIOAdapter(config, mock_inference_api)
|
||||
await adapter.initialize()
|
||||
|
||||
collection_id = f"opengauss_test_collection_{np.random.randint(1e6)}"
|
||||
await adapter.register_vector_db(
|
||||
VectorDB(
|
||||
identifier=collection_id,
|
||||
provider_id="opengauss",
|
||||
embedding_model="test_model",
|
||||
embedding_dimension=embedding_dimension,
|
||||
)
|
||||
)
|
||||
adapter.test_collection_id = collection_id
|
||||
yield adapter
|
||||
|
||||
try:
|
||||
await adapter.unregister_vector_db(collection_id)
|
||||
except Exception:
|
||||
pass
|
||||
await adapter.shutdown()
|
||||
|
||||
if os.path.exists(kv_db_path):
|
||||
os.remove(kv_db_path)
|
||||
else:
|
||||
pytest.skip("OpenGauss connection not available for integration testing")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vector_io_adapter(vector_provider, request):
|
||||
"""Returns the appropriate vector IO adapter based on the provider parameter."""
|
||||
|
|
@ -342,6 +432,7 @@ def vector_io_adapter(vector_provider, request):
|
|||
"sqlite_vec": "sqlite_vec_adapter",
|
||||
"chroma": "chroma_vec_adapter",
|
||||
"qdrant": "qdrant_vec_adapter",
|
||||
"opengauss": "opengauss_vec_adapter",
|
||||
}
|
||||
return request.getfixturevalue(vector_provider_dict[vector_provider])
|
||||
|
||||
|
|
|
|||
215
tests/unit/providers/vector_io/test_opengauss.py
Normal file
215
tests/unit/providers/vector_io/test_opengauss.py
Normal file
|
|
@ -0,0 +1,215 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import random
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.inference import EmbeddingsResponse, Inference
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse
|
||||
from llama_stack.providers.remote.vector_io.opengauss.config import (
|
||||
OpenGaussVectorIOConfig,
|
||||
)
|
||||
from llama_stack.providers.remote.vector_io.opengauss.opengauss import (
|
||||
OpenGaussIndex,
|
||||
OpenGaussVectorIOAdapter,
|
||||
)
|
||||
from llama_stack.providers.utils.kvstore.config import (
|
||||
SqliteKVStoreConfig,
|
||||
)
|
||||
|
||||
# Skip all tests in this file if the required environment variables are not set.
|
||||
pytestmark = pytest.mark.skipif(
|
||||
not all(
|
||||
os.getenv(var)
|
||||
for var in [
|
||||
"OPENGAUSS_HOST",
|
||||
"OPENGAUSS_PORT",
|
||||
"OPENGAUSS_DB",
|
||||
"OPENGAUSS_USER",
|
||||
"OPENGAUSS_PASSWORD",
|
||||
]
|
||||
),
|
||||
reason="OpenGauss connection environment variables not set",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def embedding_dimension() -> int:
|
||||
return 128
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_chunks():
|
||||
"""Provides a list of sample chunks for testing."""
|
||||
return [
|
||||
Chunk(
|
||||
content="The sky is blue.",
|
||||
metadata={"document_id": "doc1", "topic": "nature"},
|
||||
),
|
||||
Chunk(
|
||||
content="An apple a day keeps the doctor away.",
|
||||
metadata={"document_id": "doc2", "topic": "health"},
|
||||
),
|
||||
Chunk(
|
||||
content="Quantum computing is a new frontier.",
|
||||
metadata={"document_id": "doc3", "topic": "technology"},
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_embeddings(embedding_dimension, sample_chunks):
|
||||
"""Provides a deterministic set of embeddings for the sample chunks."""
|
||||
# Use a fixed seed for reproducibility
|
||||
rng = np.random.default_rng(42)
|
||||
return rng.random((len(sample_chunks), embedding_dimension), dtype=np.float32)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_inference_api(sample_embeddings):
|
||||
"""Mocks the inference API to return dummy embeddings."""
|
||||
mock_api = AsyncMock(spec=Inference)
|
||||
mock_api.embeddings = AsyncMock(return_value=EmbeddingsResponse(embeddings=sample_embeddings.tolist()))
|
||||
return mock_api
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vector_db(embedding_dimension):
|
||||
"""Provides a sample VectorDB object for registration."""
|
||||
return VectorDB(
|
||||
identifier=f"test_db_{random.randint(1, 10000)}",
|
||||
embedding_model="test_embedding_model",
|
||||
embedding_dimension=embedding_dimension,
|
||||
provider_id="opengauss",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def opengauss_connection():
|
||||
"""Creates and manages a connection to the OpenGauss database."""
|
||||
import psycopg2
|
||||
|
||||
conn = psycopg2.connect(
|
||||
host=os.getenv("OPENGAUSS_HOST"),
|
||||
port=int(os.getenv("OPENGAUSS_PORT")),
|
||||
database=os.getenv("OPENGAUSS_DB"),
|
||||
user=os.getenv("OPENGAUSS_USER"),
|
||||
password=os.getenv("OPENGAUSS_PASSWORD"),
|
||||
)
|
||||
conn.autocommit = True
|
||||
yield conn
|
||||
conn.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def opengauss_index(opengauss_connection, vector_db):
|
||||
"""Fixture to create and clean up an OpenGaussIndex instance."""
|
||||
index = OpenGaussIndex(vector_db, vector_db.embedding_dimension, opengauss_connection)
|
||||
yield index
|
||||
await index.delete()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def opengauss_adapter(mock_inference_api):
|
||||
"""Fixture to set up and tear down the OpenGaussVectorIOAdapter."""
|
||||
config = OpenGaussVectorIOConfig(
|
||||
host=os.getenv("OPENGAUSS_HOST"),
|
||||
port=int(os.getenv("OPENGAUSS_PORT")),
|
||||
db=os.getenv("OPENGAUSS_DB"),
|
||||
user=os.getenv("OPENGAUSS_USER"),
|
||||
password=os.getenv("OPENGAUSS_PASSWORD"),
|
||||
kvstore=SqliteKVStoreConfig(db_name="opengauss_test.db"),
|
||||
)
|
||||
adapter = OpenGaussVectorIOAdapter(config, mock_inference_api)
|
||||
await adapter.initialize()
|
||||
yield adapter
|
||||
if adapter.conn and not adapter.conn.closed:
|
||||
for db_id in list(adapter.cache.keys()):
|
||||
try:
|
||||
await adapter.unregister_vector_db(db_id)
|
||||
except Exception as e:
|
||||
print(f"Error during cleanup of {db_id}: {e}")
|
||||
await adapter.shutdown()
|
||||
# Clean up the sqlite db file
|
||||
if os.path.exists("opengauss_test.db"):
|
||||
os.remove("opengauss_test.db")
|
||||
|
||||
|
||||
class TestOpenGaussIndex:
|
||||
async def test_add_and_query_vector(self, opengauss_index, sample_chunks, sample_embeddings):
|
||||
"""Test adding chunks with embeddings and querying for the most similar one."""
|
||||
await opengauss_index.add_chunks(sample_chunks, sample_embeddings)
|
||||
|
||||
# Query with the embedding of the first chunk
|
||||
query_embedding = sample_embeddings[0]
|
||||
response = await opengauss_index.query_vector(query_embedding, k=1, score_threshold=0.0)
|
||||
|
||||
assert isinstance(response, QueryChunksResponse)
|
||||
assert len(response.chunks) == 1
|
||||
assert response.chunks[0].content == sample_chunks[0].content
|
||||
# The distance to itself should be 0, resulting in infinite score
|
||||
assert response.scores[0] == float("inf")
|
||||
|
||||
|
||||
class TestOpenGaussVectorIOAdapter:
|
||||
async def test_initialization(self, opengauss_adapter):
|
||||
"""Test that the adapter initializes and connects to the database."""
|
||||
assert opengauss_adapter.conn is not None
|
||||
assert not opengauss_adapter.conn.closed
|
||||
|
||||
async def test_register_and_unregister_vector_db(self, opengauss_adapter, vector_db):
|
||||
"""Test the registration and unregistration of a vector database."""
|
||||
await opengauss_adapter.register_vector_db(vector_db)
|
||||
assert vector_db.identifier in opengauss_adapter.cache
|
||||
|
||||
table_name = opengauss_adapter.cache[vector_db.identifier].index.table_name
|
||||
with opengauss_adapter.conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"SELECT EXISTS (SELECT 1 FROM pg_tables WHERE schemaname = 'public' AND tablename = %s);",
|
||||
(table_name,),
|
||||
)
|
||||
assert cur.fetchone()[0]
|
||||
|
||||
await opengauss_adapter.unregister_vector_db(vector_db.identifier)
|
||||
assert vector_db.identifier not in opengauss_adapter.cache
|
||||
|
||||
with opengauss_adapter.conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"SELECT EXISTS (SELECT 1 FROM pg_tables WHERE schemaname = 'public' AND tablename = %s);",
|
||||
(table_name,),
|
||||
)
|
||||
assert not cur.fetchone()[0]
|
||||
|
||||
async def test_adapter_end_to_end_query(self, opengauss_adapter, vector_db, sample_chunks):
|
||||
"""
|
||||
Tests the full adapter flow: text query -> embedding generation -> vector search.
|
||||
"""
|
||||
# 1. Register the DB and insert chunks. The adapter will use the mocked
|
||||
# inference_api to generate embeddings for these chunks.
|
||||
await opengauss_adapter.register_vector_db(vector_db)
|
||||
await opengauss_adapter.insert_chunks(vector_db.identifier, sample_chunks)
|
||||
|
||||
# 2. The user query is a text string.
|
||||
query_text = "What is the color of the sky?"
|
||||
|
||||
# 3. The adapter will now internally call the (mocked) inference_api
|
||||
# to get an embedding for the query_text.
|
||||
response = await opengauss_adapter.query_chunks(vector_db.identifier, query_text)
|
||||
|
||||
# 4. Assertions
|
||||
assert isinstance(response, QueryChunksResponse)
|
||||
assert len(response.chunks) > 0
|
||||
|
||||
# Because the mocked inference_api returns random embeddings, we can't
|
||||
# deterministically know which chunk is "closest". However, in a real
|
||||
# integration test with a real model, this assertion would be more specific.
|
||||
# For this unit test, we just confirm that the process completes and returns data.
|
||||
assert response.chunks[0].content in [c.content for c in sample_chunks]
|
||||
Loading…
Add table
Add a link
Reference in a new issue