mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 20:14:13 +00:00
# What does this PR do? <!-- Provide a short summary of what this PR does and why. Link to relevant issues if applicable. --> The purpose of this task is to implement `openai/v1/vector_stores/{vector_store_id}/search` for PGVector provider. It involves implementing vector similarity search, keyword search and hybrid search for `PGVectorIndex`. <!-- If resolving an issue, uncomment and update the line below --> <!-- Closes #[issue-number] --> Closes #3006 ## Test Plan <!-- Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.* --> Run unit tests: ` ./scripts/unit-tests.sh ` Run integration tests for openai vector stores: 1. Export env vars: ``` export ENABLE_PGVECTOR=true export PGVECTOR_HOST=localhost export PGVECTOR_PORT=5432 export PGVECTOR_DB=llamastack export PGVECTOR_USER=llamastack export PGVECTOR_PASSWORD=llamastack ``` 2. Create DB: ``` psql -h localhost -U postgres -c "CREATE ROLE llamastack LOGIN PASSWORD 'llamastack';" psql -h localhost -U postgres -c "CREATE DATABASE llamastack OWNER llamastack;" psql -h localhost -U llamastack -d llamastack -c "CREATE EXTENSION IF NOT EXISTS vector;" ``` 3. Install sentence-transformers: ` uv pip install sentence-transformers ` 4. Run: ``` uv run --group test pytest -s -v --stack-config="inference=inline::sentence-transformers,vector_io=remote::pgvector" --embedding-model sentence-transformers/all-MiniLM-L6-v2 tests/integration/vector_io/test_openai_vector_stores.py ``` Inspect PGVector vector stores (optional): ``` psql llamastack psql (14.18 (Homebrew)) Type "help" for help. llamastack=# \z Access privileges Schema | Name | Type | Access privileges | Column privileges | Policies --------+------------------------------------------------------+-------+-------------------+-------------------+---------- public | llamastack_kvstore | table | | | public | metadata_store | table | | | public | vector_store_pgvector_main | table | | | public | vector_store_vs_1dfbc061_1f4d_4497_9165_ecba2622ba3a | table | | | public | vector_store_vs_2085a9fb_1822_4e42_a277_c6a685843fa7 | table | | | public | vector_store_vs_2b3dae46_38be_462a_afd6_37ee5fe661b1 | table | | | public | vector_store_vs_2f438de6_f606_4561_9d50_ef9160eb9060 | table | | | public | vector_store_vs_3eeca564_2580_4c68_bfea_83dc57e31214 | table | | | public | vector_store_vs_53942163_05f3_40e0_83c0_0997c64613da | table | | | public | vector_store_vs_545bac75_8950_4ff1_b084_e221192d4709 | table | | | public | vector_store_vs_688a37d8_35b2_4298_a035_bfedf5b21f86 | table | | | public | vector_store_vs_70624d9a_f6ac_4c42_b8ab_0649473c6600 | table | | | public | vector_store_vs_73fc1dd2_e942_4972_afb1_1e177b591ac2 | table | | | public | vector_store_vs_9d464949_d51f_49db_9f87_e033b8b84ac9 | table | | | public | vector_store_vs_a1e4d724_5162_4d6d_a6c0_bdafaf6b76ec | table | | | public | vector_store_vs_a328fb1b_1a21_480f_9624_ffaa60fb6672 | table | | | public | vector_store_vs_a8981bf0_2e66_4445_a267_a8fff442db53 | table | | | public | vector_store_vs_ccd4b6a4_1efd_4984_ad03_e7ff8eadb296 | table | | | public | vector_store_vs_cd6420a4_a1fc_4cec_948c_1413a26281c9 | table | | | public | vector_store_vs_cd709284_e5cf_4a88_aba5_dc76a35364bd | table | | | public | vector_store_vs_d7a4548e_fbc1_44d7_b2ec_b664417f2a46 | table | | | public | vector_store_vs_e7f73231_414c_4523_886c_d1174eee836e | table | | | public | vector_store_vs_ffd53588_819f_47e8_bb9d_954af6f7833d | table | | | (23 rows) llamastack=# ``` Co-authored-by: Francisco Arceo <arceofrancisco@gmail.com>
467 lines
16 KiB
Python
467 lines
16 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import random
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import numpy as np
|
|
import pytest
|
|
from chromadb import PersistentClient
|
|
from pymilvus import MilvusClient, connections
|
|
|
|
from llama_stack.apis.vector_dbs import VectorDB
|
|
from llama_stack.apis.vector_io import Chunk, ChunkMetadata, QueryChunksResponse
|
|
from llama_stack.providers.inline.vector_io.chroma.config import ChromaVectorIOConfig
|
|
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
|
|
from llama_stack.providers.inline.vector_io.faiss.faiss import FaissIndex, FaissVectorIOAdapter
|
|
from llama_stack.providers.inline.vector_io.milvus.config import MilvusVectorIOConfig, SqliteKVStoreConfig
|
|
from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig
|
|
from llama_stack.providers.inline.vector_io.sqlite_vec import SQLiteVectorIOConfig
|
|
from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import SQLiteVecIndex, SQLiteVecVectorIOAdapter
|
|
from llama_stack.providers.remote.vector_io.chroma.chroma import ChromaIndex, ChromaVectorIOAdapter, maybe_await
|
|
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex, MilvusVectorIOAdapter
|
|
from llama_stack.providers.remote.vector_io.pgvector.config import PGVectorVectorIOConfig
|
|
from llama_stack.providers.remote.vector_io.pgvector.pgvector import PGVectorIndex, PGVectorVectorIOAdapter
|
|
from llama_stack.providers.remote.vector_io.qdrant.qdrant import QdrantVectorIOAdapter
|
|
|
|
EMBEDDING_DIMENSION = 384
|
|
COLLECTION_PREFIX = "test_collection"
|
|
MILVUS_ALIAS = "test_milvus"
|
|
|
|
|
|
@pytest.fixture(params=["milvus", "sqlite_vec", "faiss", "chroma", "pgvector"])
|
|
def vector_provider(request):
|
|
return request.param
|
|
|
|
|
|
@pytest.fixture
|
|
def vector_db_id() -> str:
|
|
return f"test-vector-db-{random.randint(1, 100)}"
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def embedding_dimension() -> int:
|
|
return EMBEDDING_DIMENSION
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def sample_chunks():
|
|
"""Generates chunks that force multiple batches for a single document to expose ID conflicts."""
|
|
n, k = 10, 3
|
|
sample = [
|
|
Chunk(content=f"Sentence {i} from document {j}", metadata={"document_id": f"document-{j}"})
|
|
for j in range(k)
|
|
for i in range(n)
|
|
]
|
|
sample.extend(
|
|
[
|
|
Chunk(
|
|
content=f"Sentence {i} from document {j + k}",
|
|
chunk_metadata=ChunkMetadata(
|
|
document_id=f"document-{j + k}",
|
|
chunk_id=f"document-{j}-chunk-{i}",
|
|
source=f"example source-{j + k}-{i}",
|
|
),
|
|
)
|
|
for j in range(k)
|
|
for i in range(n)
|
|
]
|
|
)
|
|
return sample
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def sample_chunks_with_metadata():
|
|
"""Generates chunks that force multiple batches for a single document to expose ID conflicts."""
|
|
n, k = 10, 3
|
|
sample = [
|
|
Chunk(
|
|
content=f"Sentence {i} from document {j}",
|
|
metadata={"document_id": f"document-{j}"},
|
|
chunk_metadata=ChunkMetadata(
|
|
document_id=f"document-{j}",
|
|
chunk_id=f"document-{j}-chunk-{i}",
|
|
source=f"example source-{j}-{i}",
|
|
),
|
|
)
|
|
for j in range(k)
|
|
for i in range(n)
|
|
]
|
|
return sample
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def sample_embeddings(sample_chunks):
|
|
np.random.seed(42)
|
|
return np.array([np.random.rand(EMBEDDING_DIMENSION).astype(np.float32) for _ in sample_chunks])
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def sample_embeddings_with_metadata(sample_chunks_with_metadata):
|
|
np.random.seed(42)
|
|
return np.array([np.random.rand(EMBEDDING_DIMENSION).astype(np.float32) for _ in sample_chunks_with_metadata])
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def mock_inference_api(embedding_dimension):
|
|
class MockInferenceAPI:
|
|
async def embed_batch(self, texts: list[str]) -> list[list[float]]:
|
|
return [np.random.rand(embedding_dimension).astype(np.float32).tolist() for _ in texts]
|
|
|
|
return MockInferenceAPI()
|
|
|
|
|
|
@pytest.fixture
|
|
async def unique_kvstore_config(tmp_path_factory):
|
|
# Generate a unique filename for this test
|
|
unique_id = f"test_kv_{np.random.randint(1e6)}"
|
|
temp_dir = tmp_path_factory.getbasetemp()
|
|
db_path = str(temp_dir / f"{unique_id}.db")
|
|
|
|
return SqliteKVStoreConfig(db_path=db_path)
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def sqlite_vec_db_path(tmp_path_factory):
|
|
db_path = str(tmp_path_factory.getbasetemp() / "test_sqlite_vec.db")
|
|
return db_path
|
|
|
|
|
|
@pytest.fixture
|
|
async def sqlite_vec_vec_index(embedding_dimension, tmp_path_factory):
|
|
temp_dir = tmp_path_factory.getbasetemp()
|
|
db_path = str(temp_dir / f"test_sqlite_vec_{np.random.randint(1e6)}.db")
|
|
bank_id = f"sqlite_vec_bank_{np.random.randint(1e6)}"
|
|
index = SQLiteVecIndex(embedding_dimension, db_path, bank_id)
|
|
await index.initialize()
|
|
index.db_path = db_path
|
|
yield index
|
|
index.delete()
|
|
|
|
|
|
@pytest.fixture
|
|
async def sqlite_vec_adapter(sqlite_vec_db_path, mock_inference_api, embedding_dimension):
|
|
config = SQLiteVectorIOConfig(
|
|
db_path=sqlite_vec_db_path,
|
|
kvstore=SqliteKVStoreConfig(),
|
|
)
|
|
adapter = SQLiteVecVectorIOAdapter(
|
|
config=config,
|
|
inference_api=mock_inference_api,
|
|
files_api=None,
|
|
)
|
|
collection_id = f"sqlite_test_collection_{np.random.randint(1e6)}"
|
|
await adapter.initialize()
|
|
await adapter.register_vector_db(
|
|
VectorDB(
|
|
identifier=collection_id,
|
|
provider_id="test_provider",
|
|
embedding_model="test_model",
|
|
embedding_dimension=embedding_dimension,
|
|
)
|
|
)
|
|
adapter.test_collection_id = collection_id
|
|
yield adapter
|
|
await adapter.shutdown()
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def milvus_vec_db_path(tmp_path_factory):
|
|
db_path = str(tmp_path_factory.getbasetemp() / "test_milvus.db")
|
|
return db_path
|
|
|
|
|
|
@pytest.fixture
|
|
async def milvus_vec_index(milvus_vec_db_path, embedding_dimension):
|
|
client = MilvusClient(milvus_vec_db_path)
|
|
name = f"{COLLECTION_PREFIX}_{np.random.randint(1e6)}"
|
|
connections.connect(alias=MILVUS_ALIAS, uri=milvus_vec_db_path)
|
|
index = MilvusIndex(client, name, consistency_level="Strong")
|
|
index.db_path = milvus_vec_db_path
|
|
yield index
|
|
|
|
|
|
@pytest.fixture
|
|
async def milvus_vec_adapter(milvus_vec_db_path, mock_inference_api):
|
|
config = MilvusVectorIOConfig(
|
|
db_path=milvus_vec_db_path,
|
|
kvstore=SqliteKVStoreConfig(),
|
|
)
|
|
adapter = MilvusVectorIOAdapter(
|
|
config=config,
|
|
inference_api=mock_inference_api,
|
|
files_api=None,
|
|
)
|
|
await adapter.initialize()
|
|
await adapter.register_vector_db(
|
|
VectorDB(
|
|
identifier=adapter.metadata_collection_name,
|
|
provider_id="test_provider",
|
|
embedding_model="test_model",
|
|
embedding_dimension=128,
|
|
)
|
|
)
|
|
yield adapter
|
|
await adapter.shutdown()
|
|
|
|
|
|
@pytest.fixture
|
|
def faiss_vec_db_path(tmp_path_factory):
|
|
db_path = str(tmp_path_factory.getbasetemp() / "test_faiss.db")
|
|
return db_path
|
|
|
|
|
|
@pytest.fixture
|
|
async def faiss_vec_index(embedding_dimension):
|
|
index = FaissIndex(embedding_dimension)
|
|
yield index
|
|
await index.delete()
|
|
|
|
|
|
@pytest.fixture
|
|
async def faiss_vec_adapter(unique_kvstore_config, mock_inference_api, embedding_dimension):
|
|
config = FaissVectorIOConfig(
|
|
kvstore=unique_kvstore_config,
|
|
)
|
|
adapter = FaissVectorIOAdapter(
|
|
config=config,
|
|
inference_api=mock_inference_api,
|
|
files_api=None,
|
|
)
|
|
await adapter.initialize()
|
|
await adapter.register_vector_db(
|
|
VectorDB(
|
|
identifier=f"faiss_test_collection_{np.random.randint(1e6)}",
|
|
provider_id="test_provider",
|
|
embedding_model="test_model",
|
|
embedding_dimension=embedding_dimension,
|
|
)
|
|
)
|
|
yield adapter
|
|
await adapter.shutdown()
|
|
|
|
|
|
@pytest.fixture
|
|
def chroma_vec_db_path(tmp_path_factory):
|
|
persist_dir = tmp_path_factory.mktemp(f"chroma_{np.random.randint(1e6)}")
|
|
return str(persist_dir)
|
|
|
|
|
|
@pytest.fixture
|
|
async def chroma_vec_index(chroma_vec_db_path, embedding_dimension):
|
|
client = PersistentClient(path=chroma_vec_db_path)
|
|
name = f"{COLLECTION_PREFIX}_{np.random.randint(1e6)}"
|
|
collection = await maybe_await(client.get_or_create_collection(name))
|
|
index = ChromaIndex(client=client, collection=collection)
|
|
await index.initialize()
|
|
yield index
|
|
await index.delete()
|
|
|
|
|
|
@pytest.fixture
|
|
async def chroma_vec_adapter(chroma_vec_db_path, mock_inference_api, embedding_dimension):
|
|
config = ChromaVectorIOConfig(
|
|
db_path=chroma_vec_db_path,
|
|
kvstore=SqliteKVStoreConfig(),
|
|
)
|
|
adapter = ChromaVectorIOAdapter(
|
|
config=config,
|
|
inference_api=mock_inference_api,
|
|
files_api=None,
|
|
)
|
|
await adapter.initialize()
|
|
await adapter.register_vector_db(
|
|
VectorDB(
|
|
identifier=f"chroma_test_collection_{random.randint(1, 1_000_000)}",
|
|
provider_id="test_provider",
|
|
embedding_model="test_model",
|
|
embedding_dimension=embedding_dimension,
|
|
)
|
|
)
|
|
yield adapter
|
|
await adapter.shutdown()
|
|
|
|
|
|
@pytest.fixture
|
|
def qdrant_vec_db_path(tmp_path_factory):
|
|
import uuid
|
|
|
|
db_path = str(tmp_path_factory.getbasetemp() / f"test_qdrant_{uuid.uuid4()}.db")
|
|
return db_path
|
|
|
|
|
|
@pytest.fixture
|
|
async def qdrant_vec_adapter(qdrant_vec_db_path, mock_inference_api, embedding_dimension):
|
|
import uuid
|
|
|
|
config = QdrantVectorIOConfig(
|
|
db_path=qdrant_vec_db_path,
|
|
kvstore=SqliteKVStoreConfig(),
|
|
)
|
|
adapter = QdrantVectorIOAdapter(
|
|
config=config,
|
|
inference_api=mock_inference_api,
|
|
files_api=None,
|
|
)
|
|
collection_id = f"qdrant_test_collection_{uuid.uuid4()}"
|
|
await adapter.initialize()
|
|
await adapter.register_vector_db(
|
|
VectorDB(
|
|
identifier=collection_id,
|
|
provider_id="test_provider",
|
|
embedding_model="test_model",
|
|
embedding_dimension=embedding_dimension,
|
|
)
|
|
)
|
|
adapter.test_collection_id = collection_id
|
|
yield adapter
|
|
await adapter.shutdown()
|
|
|
|
|
|
@pytest.fixture
|
|
async def qdrant_vec_index(qdrant_vec_db_path, embedding_dimension):
|
|
import uuid
|
|
|
|
from qdrant_client import AsyncQdrantClient
|
|
|
|
from llama_stack.providers.remote.vector_io.qdrant.qdrant import QdrantIndex
|
|
|
|
client = AsyncQdrantClient(path=qdrant_vec_db_path)
|
|
collection_name = f"qdrant_test_collection_{uuid.uuid4()}"
|
|
index = QdrantIndex(client, collection_name)
|
|
yield index
|
|
await index.delete()
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_psycopg2_connection():
|
|
connection = MagicMock()
|
|
cursor = MagicMock()
|
|
|
|
cursor.__enter__ = MagicMock(return_value=cursor)
|
|
cursor.__exit__ = MagicMock()
|
|
|
|
connection.cursor.return_value = cursor
|
|
|
|
return connection, cursor
|
|
|
|
|
|
@pytest.fixture
|
|
async def pgvector_vec_index(embedding_dimension, mock_psycopg2_connection):
|
|
connection, cursor = mock_psycopg2_connection
|
|
|
|
vector_db = VectorDB(
|
|
identifier="test-vector-db",
|
|
embedding_model="test-model",
|
|
embedding_dimension=embedding_dimension,
|
|
provider_id="pgvector",
|
|
provider_resource_id="pgvector:test-vector-db",
|
|
)
|
|
|
|
with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.psycopg2"):
|
|
with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.execute_values"):
|
|
index = PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric="COSINE")
|
|
index._test_chunks = []
|
|
original_add_chunks = index.add_chunks
|
|
|
|
async def mock_add_chunks(chunks, embeddings):
|
|
index._test_chunks = list(chunks)
|
|
await original_add_chunks(chunks, embeddings)
|
|
|
|
index.add_chunks = mock_add_chunks
|
|
|
|
async def mock_query_vector(embedding, k, score_threshold):
|
|
chunks = index._test_chunks[:k] if hasattr(index, "_test_chunks") else []
|
|
scores = [1.0] * len(chunks)
|
|
return QueryChunksResponse(chunks=chunks, scores=scores)
|
|
|
|
index.query_vector = mock_query_vector
|
|
|
|
yield index
|
|
|
|
|
|
@pytest.fixture
|
|
async def pgvector_vec_adapter(mock_inference_api, embedding_dimension):
|
|
config = PGVectorVectorIOConfig(
|
|
host="localhost",
|
|
port=5432,
|
|
db="test_db",
|
|
user="test_user",
|
|
password="test_password",
|
|
kvstore=SqliteKVStoreConfig(),
|
|
)
|
|
|
|
adapter = PGVectorVectorIOAdapter(config, mock_inference_api, None)
|
|
|
|
with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.psycopg2.connect") as mock_connect:
|
|
mock_conn = MagicMock()
|
|
mock_cursor = MagicMock()
|
|
mock_cursor.__enter__ = MagicMock(return_value=mock_cursor)
|
|
mock_cursor.__exit__ = MagicMock()
|
|
mock_conn.cursor.return_value = mock_cursor
|
|
mock_conn.autocommit = True
|
|
mock_connect.return_value = mock_conn
|
|
|
|
with patch(
|
|
"llama_stack.providers.remote.vector_io.pgvector.pgvector.check_extension_version"
|
|
) as mock_check_version:
|
|
mock_check_version.return_value = "0.5.1"
|
|
|
|
with patch("llama_stack.providers.utils.kvstore.kvstore_impl") as mock_kvstore_impl:
|
|
mock_kvstore = AsyncMock()
|
|
mock_kvstore_impl.return_value = mock_kvstore
|
|
|
|
with patch.object(adapter, "initialize_openai_vector_stores", new_callable=AsyncMock):
|
|
with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.upsert_models"):
|
|
await adapter.initialize()
|
|
adapter.conn = mock_conn
|
|
|
|
async def mock_insert_chunks(vector_db_id, chunks, ttl_seconds=None):
|
|
index = await adapter._get_and_cache_vector_db_index(vector_db_id)
|
|
if not index:
|
|
raise ValueError(f"Vector DB {vector_db_id} not found")
|
|
await index.insert_chunks(chunks)
|
|
|
|
adapter.insert_chunks = mock_insert_chunks
|
|
|
|
async def mock_query_chunks(vector_db_id, query, params=None):
|
|
index = await adapter._get_and_cache_vector_db_index(vector_db_id)
|
|
if not index:
|
|
raise ValueError(f"Vector DB {vector_db_id} not found")
|
|
return await index.query_chunks(query, params)
|
|
|
|
adapter.query_chunks = mock_query_chunks
|
|
|
|
test_vector_db = VectorDB(
|
|
identifier=f"pgvector_test_collection_{random.randint(1, 1_000_000)}",
|
|
provider_id="test_provider",
|
|
embedding_model="test_model",
|
|
embedding_dimension=embedding_dimension,
|
|
)
|
|
await adapter.register_vector_db(test_vector_db)
|
|
adapter.test_collection_id = test_vector_db.identifier
|
|
|
|
yield adapter
|
|
await adapter.shutdown()
|
|
|
|
|
|
@pytest.fixture
|
|
def vector_io_adapter(vector_provider, request):
|
|
vector_provider_dict = {
|
|
"milvus": "milvus_vec_adapter",
|
|
"faiss": "faiss_vec_adapter",
|
|
"sqlite_vec": "sqlite_vec_adapter",
|
|
"chroma": "chroma_vec_adapter",
|
|
"qdrant": "qdrant_vec_adapter",
|
|
"pgvector": "pgvector_vec_adapter",
|
|
}
|
|
return request.getfixturevalue(vector_provider_dict[vector_provider])
|
|
|
|
|
|
@pytest.fixture
|
|
def vector_index(vector_provider, request):
|
|
"""Returns appropriate vector index based on provider parameter"""
|
|
return request.getfixturevalue(f"{vector_provider}_vec_index")
|