mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-13 16:46:09 +00:00
# What does this PR do? This PR refactors and the VectorIO backend logic for `sqlite-vec` and adds unit tests and fixtures to make it easy to test both `sqlite-vec` and `milvus`. Key changes: - `sqlite-vec` migrated to `kvstore` registry - added in-memory cache for sqlite-vec to be consistent with `milvus` - default fixtures moved to `conftest.py` - removed redundant tests from sqlite`-vec` - made `test_vector_io_openai_vector_stores.py` more easily extensible ## Test Plan Unit tests added testing inline providers. --------- Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
213 lines
6.6 KiB
Python
213 lines
6.6 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import random
|
|
|
|
import numpy as np
|
|
import pytest
|
|
from pymilvus import MilvusClient, connections
|
|
|
|
from llama_stack.apis.vector_dbs import VectorDB
|
|
from llama_stack.apis.vector_io import Chunk, ChunkMetadata
|
|
from llama_stack.providers.inline.vector_io.milvus.config import MilvusVectorIOConfig, SqliteKVStoreConfig
|
|
from llama_stack.providers.inline.vector_io.sqlite_vec import SQLiteVectorIOConfig
|
|
from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import SQLiteVecIndex, SQLiteVecVectorIOAdapter
|
|
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex, MilvusVectorIOAdapter
|
|
|
|
EMBEDDING_DIMENSION = 384
|
|
COLLECTION_PREFIX = "test_collection"
|
|
MILVUS_ALIAS = "test_milvus"
|
|
|
|
|
|
@pytest.fixture
|
|
def vector_db_id() -> str:
|
|
return f"test-vector-db-{random.randint(1, 100)}"
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def embedding_dimension() -> int:
|
|
return EMBEDDING_DIMENSION
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def sample_chunks():
|
|
"""Generates chunks that force multiple batches for a single document to expose ID conflicts."""
|
|
n, k = 10, 3
|
|
sample = [
|
|
Chunk(content=f"Sentence {i} from document {j}", metadata={"document_id": f"document-{j}"})
|
|
for j in range(k)
|
|
for i in range(n)
|
|
]
|
|
sample.extend(
|
|
[
|
|
Chunk(
|
|
content=f"Sentence {i} from document {j + k}",
|
|
chunk_metadata=ChunkMetadata(
|
|
document_id=f"document-{j + k}",
|
|
chunk_id=f"document-{j}-chunk-{i}",
|
|
source=f"example source-{j + k}-{i}",
|
|
),
|
|
)
|
|
for j in range(k)
|
|
for i in range(n)
|
|
]
|
|
)
|
|
return sample
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def sample_chunks_with_metadata():
|
|
"""Generates chunks that force multiple batches for a single document to expose ID conflicts."""
|
|
n, k = 10, 3
|
|
sample = [
|
|
Chunk(
|
|
content=f"Sentence {i} from document {j}",
|
|
metadata={"document_id": f"document-{j}"},
|
|
chunk_metadata=ChunkMetadata(
|
|
document_id=f"document-{j}",
|
|
chunk_id=f"document-{j}-chunk-{i}",
|
|
source=f"example source-{j}-{i}",
|
|
),
|
|
)
|
|
for j in range(k)
|
|
for i in range(n)
|
|
]
|
|
return sample
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def sample_embeddings(sample_chunks):
|
|
np.random.seed(42)
|
|
return np.array([np.random.rand(EMBEDDING_DIMENSION).astype(np.float32) for _ in sample_chunks])
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def sample_embeddings_with_metadata(sample_chunks_with_metadata):
|
|
np.random.seed(42)
|
|
return np.array([np.random.rand(EMBEDDING_DIMENSION).astype(np.float32) for _ in sample_chunks_with_metadata])
|
|
|
|
|
|
@pytest.fixture(params=["milvus", "sqlite_vec"])
|
|
def vector_provider(request):
|
|
return request.param
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def mock_inference_api(embedding_dimension):
|
|
class MockInferenceAPI:
|
|
async def embed_batch(self, texts: list[str]) -> list[list[float]]:
|
|
return [np.random.rand(embedding_dimension).astype(np.float32).tolist() for _ in texts]
|
|
|
|
return MockInferenceAPI()
|
|
|
|
|
|
@pytest.fixture
|
|
async def unique_kvstore_config(tmp_path_factory):
|
|
# Generate a unique filename for this test
|
|
unique_id = f"test_kv_{np.random.randint(1e6)}"
|
|
temp_dir = tmp_path_factory.getbasetemp()
|
|
db_path = str(temp_dir / f"{unique_id}.db")
|
|
|
|
return SqliteKVStoreConfig(db_path=db_path)
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def sqlite_vec_db_path(tmp_path_factory):
|
|
db_path = str(tmp_path_factory.getbasetemp() / "test.db")
|
|
return db_path
|
|
|
|
|
|
@pytest.fixture
|
|
async def sqlite_vec_vec_index(embedding_dimension, tmp_path_factory):
|
|
temp_dir = tmp_path_factory.getbasetemp()
|
|
db_path = str(temp_dir / f"test_sqlite_vec_{np.random.randint(1e6)}.db")
|
|
bank_id = f"sqlite_vec_bank_{np.random.randint(1e6)}"
|
|
index = SQLiteVecIndex(embedding_dimension, db_path, bank_id)
|
|
await index.initialize()
|
|
index.db_path = db_path
|
|
yield index
|
|
index.delete()
|
|
|
|
|
|
@pytest.fixture
|
|
async def sqlite_vec_adapter(sqlite_vec_db_path, mock_inference_api, embedding_dimension):
|
|
config = SQLiteVectorIOConfig(
|
|
db_path=sqlite_vec_db_path,
|
|
kvstore=SqliteKVStoreConfig(),
|
|
)
|
|
adapter = SQLiteVecVectorIOAdapter(
|
|
config=config,
|
|
inference_api=mock_inference_api,
|
|
files_api=None,
|
|
)
|
|
collection_id = f"sqlite_test_collection_{np.random.randint(1e6)}"
|
|
await adapter.initialize()
|
|
await adapter.register_vector_db(
|
|
VectorDB(
|
|
identifier=collection_id,
|
|
provider_id="test_provider",
|
|
embedding_model="test_model",
|
|
embedding_dimension=embedding_dimension,
|
|
)
|
|
)
|
|
adapter.test_collection_id = collection_id
|
|
yield adapter
|
|
await adapter.shutdown()
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def milvus_vec_db_path(tmp_path_factory):
|
|
db_path = str(tmp_path_factory.getbasetemp() / "test_milvus.db")
|
|
return db_path
|
|
|
|
|
|
@pytest.fixture
|
|
async def milvus_vec_index(milvus_vec_db_path, embedding_dimension):
|
|
client = MilvusClient(milvus_vec_db_path)
|
|
name = f"{COLLECTION_PREFIX}_{np.random.randint(1e6)}"
|
|
connections.connect(alias=MILVUS_ALIAS, uri=milvus_vec_db_path)
|
|
index = MilvusIndex(client, name, consistency_level="Strong")
|
|
index.db_path = milvus_vec_db_path
|
|
yield index
|
|
|
|
|
|
@pytest.fixture
|
|
async def milvus_vec_adapter(milvus_vec_db_path, mock_inference_api):
|
|
config = MilvusVectorIOConfig(
|
|
db_path=milvus_vec_db_path,
|
|
kvstore=SqliteKVStoreConfig(),
|
|
)
|
|
adapter = MilvusVectorIOAdapter(
|
|
config=config,
|
|
inference_api=mock_inference_api,
|
|
files_api=None,
|
|
)
|
|
await adapter.initialize()
|
|
await adapter.register_vector_db(
|
|
VectorDB(
|
|
identifier=adapter.metadata_collection_name,
|
|
provider_id="test_provider",
|
|
embedding_model="test_model",
|
|
embedding_dimension=128,
|
|
)
|
|
)
|
|
yield adapter
|
|
await adapter.shutdown()
|
|
|
|
|
|
@pytest.fixture
|
|
def vector_io_adapter(vector_provider, request):
|
|
"""Returns the appropriate vector IO adapter based on the provider parameter."""
|
|
if vector_provider == "milvus":
|
|
return request.getfixturevalue("milvus_vec_adapter")
|
|
else:
|
|
return request.getfixturevalue("sqlite_vec_adapter")
|
|
|
|
|
|
@pytest.fixture
|
|
def vector_index(vector_provider, request):
|
|
"""Returns appropriate vector index based on provider parameter"""
|
|
return request.getfixturevalue(f"{vector_provider}_vec_index")
|