mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-21 12:09:40 +00:00
chore: Adding unit tests for OpenAI vector stores and migrating SQLite-vec registry to kvstore (#2665)
# What does this PR do? This PR refactors and the VectorIO backend logic for `sqlite-vec` and adds unit tests and fixtures to make it easy to test both `sqlite-vec` and `milvus`. Key changes: - `sqlite-vec` migrated to `kvstore` registry - added in-memory cache for sqlite-vec to be consistent with `milvus` - default fixtures moved to `conftest.py` - removed redundant tests from sqlite`-vec` - made `test_vector_io_openai_vector_stores.py` more easily extensible ## Test Plan Unit tests added testing inline providers. --------- Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
parent
b18f4d1ccf
commit
6a6b66ae4f
12 changed files with 422 additions and 424 deletions
|
@ -8,10 +8,18 @@ import random
|
|||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from pymilvus import MilvusClient, connections
|
||||
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import Chunk, ChunkMetadata
|
||||
from llama_stack.providers.inline.vector_io.milvus.config import MilvusVectorIOConfig, SqliteKVStoreConfig
|
||||
from llama_stack.providers.inline.vector_io.sqlite_vec import SQLiteVectorIOConfig
|
||||
from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import SQLiteVecIndex, SQLiteVecVectorIOAdapter
|
||||
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex, MilvusVectorIOAdapter
|
||||
|
||||
EMBEDDING_DIMENSION = 384
|
||||
COLLECTION_PREFIX = "test_collection"
|
||||
MILVUS_ALIAS = "test_milvus"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
@ -50,7 +58,156 @@ def sample_chunks():
|
|||
return sample
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def sample_chunks_with_metadata():
|
||||
"""Generates chunks that force multiple batches for a single document to expose ID conflicts."""
|
||||
n, k = 10, 3
|
||||
sample = [
|
||||
Chunk(
|
||||
content=f"Sentence {i} from document {j}",
|
||||
metadata={"document_id": f"document-{j}"},
|
||||
chunk_metadata=ChunkMetadata(
|
||||
document_id=f"document-{j}",
|
||||
chunk_id=f"document-{j}-chunk-{i}",
|
||||
source=f"example source-{j}-{i}",
|
||||
),
|
||||
)
|
||||
for j in range(k)
|
||||
for i in range(n)
|
||||
]
|
||||
return sample
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def sample_embeddings(sample_chunks):
|
||||
np.random.seed(42)
|
||||
return np.array([np.random.rand(EMBEDDING_DIMENSION).astype(np.float32) for _ in sample_chunks])
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def sample_embeddings_with_metadata(sample_chunks_with_metadata):
|
||||
np.random.seed(42)
|
||||
return np.array([np.random.rand(EMBEDDING_DIMENSION).astype(np.float32) for _ in sample_chunks_with_metadata])
|
||||
|
||||
|
||||
@pytest.fixture(params=["milvus", "sqlite_vec"])
|
||||
def vector_provider(request):
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def mock_inference_api(embedding_dimension):
|
||||
class MockInferenceAPI:
|
||||
async def embed_batch(self, texts: list[str]) -> list[list[float]]:
|
||||
return [np.random.rand(embedding_dimension).astype(np.float32).tolist() for _ in texts]
|
||||
|
||||
return MockInferenceAPI()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def unique_kvstore_config(tmp_path_factory):
|
||||
# Generate a unique filename for this test
|
||||
unique_id = f"test_kv_{np.random.randint(1e6)}"
|
||||
temp_dir = tmp_path_factory.getbasetemp()
|
||||
db_path = str(temp_dir / f"{unique_id}.db")
|
||||
|
||||
return SqliteKVStoreConfig(db_path=db_path)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def sqlite_vec_db_path(tmp_path_factory):
|
||||
db_path = str(tmp_path_factory.getbasetemp() / "test.db")
|
||||
return db_path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def sqlite_vec_vec_index(embedding_dimension, tmp_path_factory):
|
||||
temp_dir = tmp_path_factory.getbasetemp()
|
||||
db_path = str(temp_dir / f"test_sqlite_vec_{np.random.randint(1e6)}.db")
|
||||
bank_id = f"sqlite_vec_bank_{np.random.randint(1e6)}"
|
||||
index = SQLiteVecIndex(embedding_dimension, db_path, bank_id)
|
||||
await index.initialize()
|
||||
index.db_path = db_path
|
||||
yield index
|
||||
index.delete()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def sqlite_vec_adapter(sqlite_vec_db_path, mock_inference_api, embedding_dimension):
|
||||
config = SQLiteVectorIOConfig(
|
||||
db_path=sqlite_vec_db_path,
|
||||
kvstore=SqliteKVStoreConfig(),
|
||||
)
|
||||
adapter = SQLiteVecVectorIOAdapter(
|
||||
config=config,
|
||||
inference_api=mock_inference_api,
|
||||
files_api=None,
|
||||
)
|
||||
collection_id = f"sqlite_test_collection_{np.random.randint(1e6)}"
|
||||
await adapter.initialize()
|
||||
await adapter.register_vector_db(
|
||||
VectorDB(
|
||||
identifier=collection_id,
|
||||
provider_id="test_provider",
|
||||
embedding_model="test_model",
|
||||
embedding_dimension=embedding_dimension,
|
||||
)
|
||||
)
|
||||
adapter.test_collection_id = collection_id
|
||||
yield adapter
|
||||
await adapter.shutdown()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def milvus_vec_db_path(tmp_path_factory):
|
||||
db_path = str(tmp_path_factory.getbasetemp() / "test_milvus.db")
|
||||
return db_path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def milvus_vec_index(milvus_vec_db_path, embedding_dimension):
|
||||
client = MilvusClient(milvus_vec_db_path)
|
||||
name = f"{COLLECTION_PREFIX}_{np.random.randint(1e6)}"
|
||||
connections.connect(alias=MILVUS_ALIAS, uri=milvus_vec_db_path)
|
||||
index = MilvusIndex(client, name, consistency_level="Strong")
|
||||
index.db_path = milvus_vec_db_path
|
||||
yield index
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def milvus_vec_adapter(milvus_vec_db_path, mock_inference_api):
|
||||
config = MilvusVectorIOConfig(
|
||||
db_path=milvus_vec_db_path,
|
||||
kvstore=SqliteKVStoreConfig(),
|
||||
)
|
||||
adapter = MilvusVectorIOAdapter(
|
||||
config=config,
|
||||
inference_api=mock_inference_api,
|
||||
files_api=None,
|
||||
)
|
||||
await adapter.initialize()
|
||||
await adapter.register_vector_db(
|
||||
VectorDB(
|
||||
identifier=adapter.metadata_collection_name,
|
||||
provider_id="test_provider",
|
||||
embedding_model="test_model",
|
||||
embedding_dimension=128,
|
||||
)
|
||||
)
|
||||
yield adapter
|
||||
await adapter.shutdown()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vector_io_adapter(vector_provider, request):
|
||||
"""Returns the appropriate vector IO adapter based on the provider parameter."""
|
||||
if vector_provider == "milvus":
|
||||
return request.getfixturevalue("milvus_vec_adapter")
|
||||
else:
|
||||
return request.getfixturevalue("sqlite_vec_adapter")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vector_index(vector_provider, request):
|
||||
"""Returns appropriate vector index based on provider parameter"""
|
||||
return request.getfixturevalue(f"{vector_provider}_vec_index")
|
||||
|
|
|
@ -34,7 +34,7 @@ def loop():
|
|||
return asyncio.new_event_loop()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session", autouse=True)
|
||||
@pytest_asyncio.fixture
|
||||
async def sqlite_vec_index(embedding_dimension, tmp_path_factory):
|
||||
temp_dir = tmp_path_factory.getbasetemp()
|
||||
db_path = str(temp_dir / "test_sqlite.db")
|
||||
|
@ -44,38 +44,15 @@ async def sqlite_vec_index(embedding_dimension, tmp_path_factory):
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_chunks(sqlite_vec_index, sample_chunks, sample_embeddings):
|
||||
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings, batch_size=2)
|
||||
connection = _create_sqlite_connection(sqlite_vec_index.db_path)
|
||||
cur = connection.cursor()
|
||||
cur.execute(f"SELECT COUNT(*) FROM {sqlite_vec_index.metadata_table}")
|
||||
count = cur.fetchone()[0]
|
||||
assert count == len(sample_chunks)
|
||||
cur.close()
|
||||
connection.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_chunks_vector(sqlite_vec_index, sample_chunks, sample_embeddings, embedding_dimension):
|
||||
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
||||
query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
|
||||
response = await sqlite_vec_index.query_vector(query_embedding, k=2, score_threshold=0.0)
|
||||
assert isinstance(response, QueryChunksResponse)
|
||||
assert len(response.chunks) == 2
|
||||
|
||||
|
||||
@pytest.mark.xfail(reason="Chunk Metadata not yet supported for SQLite-vec", strict=True)
|
||||
async def test_query_chunk_metadata(sqlite_vec_index, sample_chunks, sample_embeddings):
|
||||
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
||||
query_embedding = sample_embeddings[0]
|
||||
response = await sqlite_vec_index.query_vector(query_embedding, k=2, score_threshold=0.0)
|
||||
assert response.chunks[-1].chunk_metadata == sample_chunks[-1].chunk_metadata
|
||||
async def test_query_chunk_metadata(sqlite_vec_index, sample_chunks_with_metadata, sample_embeddings_with_metadata):
|
||||
await sqlite_vec_index.add_chunks(sample_chunks_with_metadata, sample_embeddings_with_metadata)
|
||||
response = await sqlite_vec_index.query_vector(sample_embeddings_with_metadata[-1], k=2, score_threshold=0.0)
|
||||
assert response.chunks[0].chunk_metadata == sample_chunks_with_metadata[-1].chunk_metadata
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_chunks_full_text_search(sqlite_vec_index, sample_chunks, sample_embeddings):
|
||||
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
||||
|
||||
query_string = "Sentence 5"
|
||||
response = await sqlite_vec_index.query_keyword(k=3, score_threshold=0.0, query_string=query_string)
|
||||
|
||||
|
@ -148,7 +125,7 @@ async def test_chunk_id_conflict(sqlite_vec_index, sample_chunks, embedding_dime
|
|||
assert len(chunk_ids) == len(set(chunk_ids)), "Duplicate chunk IDs detected across batches!"
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
@pytest.fixture(scope="session")
|
||||
async def sqlite_vec_adapter(sqlite_connection):
|
||||
config = type("Config", (object,), {"db_path": ":memory:"}) # Mock config with in-memory database
|
||||
adapter = SQLiteVecVectorIOAdapter(config=config, inference_api=None)
|
||||
|
|
|
@ -4,253 +4,142 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from pymilvus import Collection, MilvusClient, connections
|
||||
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse
|
||||
from llama_stack.providers.inline.vector_io.milvus.config import MilvusVectorIOConfig, SqliteKVStoreConfig
|
||||
from llama_stack.providers.remote.vector_io.milvus.milvus import VECTOR_DBS_PREFIX, MilvusIndex, MilvusVectorIOAdapter
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.remote.vector_io.milvus.milvus import VECTOR_DBS_PREFIX
|
||||
|
||||
# TODO: Refactor these to be for inline vector-io providers
|
||||
MILVUS_ALIAS = "test_milvus"
|
||||
COLLECTION_PREFIX = "test_collection"
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def loop():
|
||||
return asyncio.new_event_loop()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def mock_inference_api(embedding_dimension):
|
||||
class MockInferenceAPI:
|
||||
async def embed_batch(self, texts: list[str]) -> list[list[float]]:
|
||||
return [np.random.rand(embedding_dimension).astype(np.float32).tolist() for _ in texts]
|
||||
|
||||
return MockInferenceAPI()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def unique_kvstore_config(tmp_path_factory):
|
||||
# Generate a unique filename for this test
|
||||
unique_id = f"test_kv_{np.random.randint(1e6)}"
|
||||
temp_dir = tmp_path_factory.getbasetemp()
|
||||
db_path = str(temp_dir / f"{unique_id}.db")
|
||||
|
||||
return SqliteKVStoreConfig(db_path=db_path)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session", autouse=True)
|
||||
async def milvus_vec_index(embedding_dimension, tmp_path_factory):
|
||||
temp_dir = tmp_path_factory.getbasetemp()
|
||||
db_path = str(temp_dir / "test_milvus.db")
|
||||
client = MilvusClient(db_path)
|
||||
name = f"{COLLECTION_PREFIX}_{np.random.randint(1e6)}"
|
||||
connections.connect(alias=MILVUS_ALIAS, uri=db_path)
|
||||
index = MilvusIndex(client, name, consistency_level="Strong")
|
||||
index.db_path = db_path
|
||||
yield index
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
async def milvus_vec_adapter(milvus_vec_index, mock_inference_api):
|
||||
config = MilvusVectorIOConfig(
|
||||
db_path=milvus_vec_index.db_path,
|
||||
kvstore=SqliteKVStoreConfig(),
|
||||
)
|
||||
adapter = MilvusVectorIOAdapter(
|
||||
config=config,
|
||||
inference_api=mock_inference_api,
|
||||
files_api=None,
|
||||
)
|
||||
await adapter.initialize()
|
||||
await adapter.register_vector_db(
|
||||
VectorDB(
|
||||
identifier=adapter.metadata_collection_name,
|
||||
provider_id="test_provider",
|
||||
embedding_model="test_model",
|
||||
embedding_dimension=128,
|
||||
)
|
||||
)
|
||||
yield adapter
|
||||
await adapter.shutdown()
|
||||
# This test is a unit test for the inline VectoerIO providers. This should only contain
|
||||
# tests which are specific to this class. More general (API-level) tests should be placed in
|
||||
# tests/integration/vector_io/
|
||||
#
|
||||
# How to run this test:
|
||||
#
|
||||
# pytest tests/unit/providers/vector_io/test_vector_io_openai_vector_stores.py \
|
||||
# -v -s --tb=short --disable-warnings --asyncio-mode=auto
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_contains_initial_collection(milvus_vec_adapter):
|
||||
coll_name = milvus_vec_adapter.metadata_collection_name
|
||||
assert coll_name in milvus_vec_adapter.cache
|
||||
async def test_initialize_index(vector_index):
|
||||
await vector_index.initialize()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_chunks(milvus_vec_index, sample_chunks, sample_embeddings):
|
||||
await milvus_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
||||
resp = await milvus_vec_index.query_vector(sample_embeddings[0], k=1, score_threshold=-1)
|
||||
async def test_add_chunks_query_vector(vector_index, sample_chunks, sample_embeddings):
|
||||
vector_index.delete()
|
||||
vector_index.initialize()
|
||||
await vector_index.add_chunks(sample_chunks, sample_embeddings)
|
||||
resp = await vector_index.query_vector(sample_embeddings[0], k=1, score_threshold=-1)
|
||||
assert resp.chunks[0].content == sample_chunks[0].content
|
||||
vector_index.delete()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_chunks_vector(milvus_vec_index, sample_chunks, sample_embeddings, embedding_dimension):
|
||||
await milvus_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
||||
query_emb = np.random.rand(embedding_dimension).astype(np.float32)
|
||||
resp = await milvus_vec_index.query_vector(query_emb, k=2, score_threshold=0.0)
|
||||
assert isinstance(resp, QueryChunksResponse)
|
||||
assert len(resp.chunks) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chunk_id_conflict(milvus_vec_index, sample_chunks, embedding_dimension):
|
||||
async def test_chunk_id_conflict(vector_index, sample_chunks, embedding_dimension):
|
||||
embeddings = np.random.rand(len(sample_chunks), embedding_dimension).astype(np.float32)
|
||||
await milvus_vec_index.add_chunks(sample_chunks, embeddings)
|
||||
coll = Collection(milvus_vec_index.collection_name, using=MILVUS_ALIAS)
|
||||
ids = coll.query(expr="id >= 0", output_fields=["id"], timeout=30)
|
||||
flat_ids = [i["id"] for i in ids]
|
||||
assert len(flat_ids) == len(set(flat_ids))
|
||||
await vector_index.add_chunks(sample_chunks, embeddings)
|
||||
resp = await vector_index.query_vector(
|
||||
np.random.rand(embedding_dimension).astype(np.float32),
|
||||
k=len(sample_chunks),
|
||||
score_threshold=-1,
|
||||
)
|
||||
|
||||
contents = [chunk.content for chunk in resp.chunks]
|
||||
assert len(contents) == len(set(contents))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialize_with_milvus_client(milvus_vec_index, unique_kvstore_config):
|
||||
kvstore = await kvstore_impl(unique_kvstore_config)
|
||||
vector_db = VectorDB(
|
||||
identifier="test_db",
|
||||
provider_id="test_provider",
|
||||
embedding_model="test_model",
|
||||
embedding_dimension=128,
|
||||
metadata={"test_key": "test_value"},
|
||||
)
|
||||
test_vector_db_data = vector_db.model_dump_json()
|
||||
await kvstore.set(f"{VECTOR_DBS_PREFIX}test_db", test_vector_db_data)
|
||||
tmp_milvus_vec_adapter = MilvusVectorIOAdapter(
|
||||
config=MilvusVectorIOConfig(
|
||||
db_path=milvus_vec_index.db_path,
|
||||
kvstore=unique_kvstore_config,
|
||||
),
|
||||
inference_api=None,
|
||||
files_api=None,
|
||||
)
|
||||
await tmp_milvus_vec_adapter.initialize()
|
||||
|
||||
vector_db = VectorDB(
|
||||
identifier="test_db",
|
||||
provider_id="test_provider",
|
||||
embedding_model="test_model",
|
||||
embedding_dimension=128,
|
||||
)
|
||||
test_vector_db_data = vector_db.model_dump_json()
|
||||
await tmp_milvus_vec_adapter.kvstore.set(f"{VECTOR_DBS_PREFIX}/test_db", test_vector_db_data)
|
||||
|
||||
assert milvus_vec_index.client is not None
|
||||
assert isinstance(milvus_vec_index.client, MilvusClient)
|
||||
assert tmp_milvus_vec_adapter.cache is not None
|
||||
# registering a vector won't update the cache or openai_vector_store collection name
|
||||
assert (
|
||||
tmp_milvus_vec_adapter.metadata_collection_name not in tmp_milvus_vec_adapter.cache
|
||||
or tmp_milvus_vec_adapter.openai_vector_stores
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_persistence_across_adapter_restarts(
|
||||
tmp_path, milvus_vec_index, mock_inference_api, unique_kvstore_config
|
||||
):
|
||||
adapter1 = MilvusVectorIOAdapter(
|
||||
config=MilvusVectorIOConfig(db_path=milvus_vec_index.db_path, kvstore=unique_kvstore_config),
|
||||
inference_api=mock_inference_api,
|
||||
files_api=None,
|
||||
)
|
||||
await adapter1.initialize()
|
||||
async def test_initialize_adapter_with_existing_kvstore(vector_io_adapter):
|
||||
key = f"{VECTOR_DBS_PREFIX}db1"
|
||||
dummy = VectorDB(
|
||||
identifier="foo_db", provider_id="test_provider", embedding_model="test_model", embedding_dimension=128
|
||||
)
|
||||
await adapter1.register_vector_db(dummy)
|
||||
await adapter1.shutdown()
|
||||
await vector_io_adapter.kvstore.set(key=key, value=json.dumps(dummy.model_dump()))
|
||||
|
||||
await adapter1.initialize()
|
||||
assert "foo_db" in adapter1.cache
|
||||
await adapter1.shutdown()
|
||||
await vector_io_adapter.initialize()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_and_unregister_vector_db(milvus_vec_adapter):
|
||||
try:
|
||||
connections.disconnect(MILVUS_ALIAS)
|
||||
except Exception as _:
|
||||
pass
|
||||
async def test_persistence_across_adapter_restarts(vector_io_adapter):
|
||||
await vector_io_adapter.initialize()
|
||||
dummy = VectorDB(
|
||||
identifier="foo_db", provider_id="test_provider", embedding_model="test_model", embedding_dimension=128
|
||||
)
|
||||
await vector_io_adapter.register_vector_db(dummy)
|
||||
await vector_io_adapter.shutdown()
|
||||
|
||||
connections.connect(alias=MILVUS_ALIAS, uri=milvus_vec_adapter.config.db_path)
|
||||
await vector_io_adapter.initialize()
|
||||
assert "foo_db" in vector_io_adapter.cache
|
||||
await vector_io_adapter.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_and_unregister_vector_db(vector_io_adapter):
|
||||
unique_id = f"foo_db_{np.random.randint(1e6)}"
|
||||
dummy = VectorDB(
|
||||
identifier=unique_id, provider_id="test_provider", embedding_model="test_model", embedding_dimension=128
|
||||
)
|
||||
|
||||
await milvus_vec_adapter.register_vector_db(dummy)
|
||||
assert dummy.identifier in milvus_vec_adapter.cache
|
||||
|
||||
if dummy.identifier in milvus_vec_adapter.cache:
|
||||
index = milvus_vec_adapter.cache[dummy.identifier].index
|
||||
if hasattr(index, "client") and hasattr(index.client, "_using"):
|
||||
index.client._using = MILVUS_ALIAS
|
||||
|
||||
await milvus_vec_adapter.unregister_vector_db(dummy.identifier)
|
||||
assert dummy.identifier not in milvus_vec_adapter.cache
|
||||
await vector_io_adapter.register_vector_db(dummy)
|
||||
assert dummy.identifier in vector_io_adapter.cache
|
||||
await vector_io_adapter.unregister_vector_db(dummy.identifier)
|
||||
assert dummy.identifier not in vector_io_adapter.cache
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_unregistered_raises(milvus_vec_adapter):
|
||||
async def test_query_unregistered_raises(vector_io_adapter):
|
||||
fake_emb = np.zeros(8, dtype=np.float32)
|
||||
with pytest.raises(AttributeError):
|
||||
await milvus_vec_adapter.query_chunks("no_such_db", fake_emb)
|
||||
with pytest.raises(ValueError):
|
||||
await vector_io_adapter.query_chunks("no_such_db", fake_emb)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_insert_chunks_calls_underlying_index(milvus_vec_adapter):
|
||||
async def test_insert_chunks_calls_underlying_index(vector_io_adapter):
|
||||
fake_index = AsyncMock()
|
||||
milvus_vec_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=fake_index)
|
||||
vector_io_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=fake_index)
|
||||
|
||||
chunks = ["chunk1", "chunk2"]
|
||||
await milvus_vec_adapter.insert_chunks("db1", chunks)
|
||||
await vector_io_adapter.insert_chunks("db1", chunks)
|
||||
|
||||
fake_index.insert_chunks.assert_awaited_once_with(chunks)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_insert_chunks_missing_db_raises(milvus_vec_adapter):
|
||||
milvus_vec_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=None)
|
||||
async def test_insert_chunks_missing_db_raises(vector_io_adapter):
|
||||
vector_io_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=None)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await milvus_vec_adapter.insert_chunks("db_not_exist", [])
|
||||
await vector_io_adapter.insert_chunks("db_not_exist", [])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_chunks_calls_underlying_index_and_returns(milvus_vec_adapter):
|
||||
async def test_query_chunks_calls_underlying_index_and_returns(vector_io_adapter):
|
||||
expected = QueryChunksResponse(chunks=[Chunk(content="c1")], scores=[0.1])
|
||||
fake_index = AsyncMock(query_chunks=AsyncMock(return_value=expected))
|
||||
milvus_vec_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=fake_index)
|
||||
vector_io_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=fake_index)
|
||||
|
||||
response = await milvus_vec_adapter.query_chunks("db1", "my_query", {"param": 1})
|
||||
response = await vector_io_adapter.query_chunks("db1", "my_query", {"param": 1})
|
||||
|
||||
fake_index.query_chunks.assert_awaited_once_with("my_query", {"param": 1})
|
||||
assert response is expected
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_chunks_missing_db_raises(milvus_vec_adapter):
|
||||
milvus_vec_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=None)
|
||||
async def test_query_chunks_missing_db_raises(vector_io_adapter):
|
||||
vector_io_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=None)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await milvus_vec_adapter.query_chunks("db_missing", "q", None)
|
||||
await vector_io_adapter.query_chunks("db_missing", "q", None)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_save_openai_vector_store(milvus_vec_adapter):
|
||||
async def test_save_openai_vector_store(vector_io_adapter):
|
||||
store_id = "vs_1234"
|
||||
openai_vector_store = {
|
||||
"id": store_id,
|
||||
|
@ -260,14 +149,14 @@ async def test_save_openai_vector_store(milvus_vec_adapter):
|
|||
"embedding_model": "test_model",
|
||||
}
|
||||
|
||||
await milvus_vec_adapter._save_openai_vector_store(store_id, openai_vector_store)
|
||||
await vector_io_adapter._save_openai_vector_store(store_id, openai_vector_store)
|
||||
|
||||
assert openai_vector_store["id"] in milvus_vec_adapter.openai_vector_stores
|
||||
assert milvus_vec_adapter.openai_vector_stores[openai_vector_store["id"]] == openai_vector_store
|
||||
assert openai_vector_store["id"] in vector_io_adapter.openai_vector_stores
|
||||
assert vector_io_adapter.openai_vector_stores[openai_vector_store["id"]] == openai_vector_store
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_openai_vector_store(milvus_vec_adapter):
|
||||
async def test_update_openai_vector_store(vector_io_adapter):
|
||||
store_id = "vs_1234"
|
||||
openai_vector_store = {
|
||||
"id": store_id,
|
||||
|
@ -277,14 +166,14 @@ async def test_update_openai_vector_store(milvus_vec_adapter):
|
|||
"embedding_model": "test_model",
|
||||
}
|
||||
|
||||
await milvus_vec_adapter._save_openai_vector_store(store_id, openai_vector_store)
|
||||
await vector_io_adapter._save_openai_vector_store(store_id, openai_vector_store)
|
||||
openai_vector_store["description"] = "Updated description"
|
||||
await milvus_vec_adapter._update_openai_vector_store(store_id, openai_vector_store)
|
||||
assert milvus_vec_adapter.openai_vector_stores[openai_vector_store["id"]] == openai_vector_store
|
||||
await vector_io_adapter._update_openai_vector_store(store_id, openai_vector_store)
|
||||
assert vector_io_adapter.openai_vector_stores[openai_vector_store["id"]] == openai_vector_store
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_openai_vector_store(milvus_vec_adapter):
|
||||
async def test_delete_openai_vector_store(vector_io_adapter):
|
||||
store_id = "vs_1234"
|
||||
openai_vector_store = {
|
||||
"id": store_id,
|
||||
|
@ -294,13 +183,13 @@ async def test_delete_openai_vector_store(milvus_vec_adapter):
|
|||
"embedding_model": "test_model",
|
||||
}
|
||||
|
||||
await milvus_vec_adapter._save_openai_vector_store(store_id, openai_vector_store)
|
||||
await milvus_vec_adapter._delete_openai_vector_store_from_storage(store_id)
|
||||
assert openai_vector_store["id"] not in milvus_vec_adapter.openai_vector_stores
|
||||
await vector_io_adapter._save_openai_vector_store(store_id, openai_vector_store)
|
||||
await vector_io_adapter._delete_openai_vector_store_from_storage(store_id)
|
||||
assert openai_vector_store["id"] not in vector_io_adapter.openai_vector_stores
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_openai_vector_stores(milvus_vec_adapter):
|
||||
async def test_load_openai_vector_stores(vector_io_adapter):
|
||||
store_id = "vs_1234"
|
||||
openai_vector_store = {
|
||||
"id": store_id,
|
||||
|
@ -310,13 +199,13 @@ async def test_load_openai_vector_stores(milvus_vec_adapter):
|
|||
"embedding_model": "test_model",
|
||||
}
|
||||
|
||||
await milvus_vec_adapter._save_openai_vector_store(store_id, openai_vector_store)
|
||||
loaded_stores = await milvus_vec_adapter._load_openai_vector_stores()
|
||||
await vector_io_adapter._save_openai_vector_store(store_id, openai_vector_store)
|
||||
loaded_stores = await vector_io_adapter._load_openai_vector_stores()
|
||||
assert loaded_stores[store_id] == openai_vector_store
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_save_openai_vector_store_file(milvus_vec_adapter, tmp_path_factory):
|
||||
async def test_save_openai_vector_store_file(vector_io_adapter, tmp_path_factory):
|
||||
store_id = "vs_1234"
|
||||
file_id = "file_1234"
|
||||
|
||||
|
@ -334,11 +223,11 @@ async def test_save_openai_vector_store_file(milvus_vec_adapter, tmp_path_factor
|
|||
]
|
||||
|
||||
# validating we don't raise an exception
|
||||
await milvus_vec_adapter._save_openai_vector_store_file(store_id, file_id, file_info, file_contents)
|
||||
await vector_io_adapter._save_openai_vector_store_file(store_id, file_id, file_info, file_contents)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_openai_vector_store_file(milvus_vec_adapter, tmp_path_factory):
|
||||
async def test_update_openai_vector_store_file(vector_io_adapter, tmp_path_factory):
|
||||
store_id = "vs_1234"
|
||||
file_id = "file_1234"
|
||||
|
||||
|
@ -355,24 +244,24 @@ async def test_update_openai_vector_store_file(milvus_vec_adapter, tmp_path_fact
|
|||
{"content": "Test content", "chunk_metadata": {"chunk_id": "chunk_001"}, "metadata": {"file_id": file_id}}
|
||||
]
|
||||
|
||||
await milvus_vec_adapter._save_openai_vector_store_file(store_id, file_id, file_info, file_contents)
|
||||
await vector_io_adapter._save_openai_vector_store_file(store_id, file_id, file_info, file_contents)
|
||||
|
||||
updated_file_info = file_info.copy()
|
||||
updated_file_info["filename"] = "updated_test_file.txt"
|
||||
|
||||
await milvus_vec_adapter._update_openai_vector_store_file(
|
||||
await vector_io_adapter._update_openai_vector_store_file(
|
||||
store_id,
|
||||
file_id,
|
||||
updated_file_info,
|
||||
)
|
||||
|
||||
loaded_contents = await milvus_vec_adapter._load_openai_vector_store_file(store_id, file_id)
|
||||
loaded_contents = await vector_io_adapter._load_openai_vector_store_file(store_id, file_id)
|
||||
assert loaded_contents == updated_file_info
|
||||
assert loaded_contents != file_info
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_openai_vector_store_file_contents(milvus_vec_adapter, tmp_path_factory):
|
||||
async def test_load_openai_vector_store_file_contents(vector_io_adapter, tmp_path_factory):
|
||||
store_id = "vs_1234"
|
||||
file_id = "file_1234"
|
||||
|
||||
|
@ -389,14 +278,14 @@ async def test_load_openai_vector_store_file_contents(milvus_vec_adapter, tmp_pa
|
|||
{"content": "Test content", "chunk_metadata": {"chunk_id": "chunk_001"}, "metadata": {"file_id": file_id}}
|
||||
]
|
||||
|
||||
await milvus_vec_adapter._save_openai_vector_store_file(store_id, file_id, file_info, file_contents)
|
||||
await vector_io_adapter._save_openai_vector_store_file(store_id, file_id, file_info, file_contents)
|
||||
|
||||
loaded_contents = await milvus_vec_adapter._load_openai_vector_store_file_contents(store_id, file_id)
|
||||
loaded_contents = await vector_io_adapter._load_openai_vector_store_file_contents(store_id, file_id)
|
||||
assert loaded_contents == file_contents
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_openai_vector_store_file_from_storage(milvus_vec_adapter, tmp_path_factory):
|
||||
async def test_delete_openai_vector_store_file_from_storage(vector_io_adapter, tmp_path_factory):
|
||||
store_id = "vs_1234"
|
||||
file_id = "file_1234"
|
||||
|
||||
|
@ -413,8 +302,8 @@ async def test_delete_openai_vector_store_file_from_storage(milvus_vec_adapter,
|
|||
{"content": "Test content", "chunk_metadata": {"chunk_id": "chunk_001"}, "metadata": {"file_id": file_id}}
|
||||
]
|
||||
|
||||
await milvus_vec_adapter._save_openai_vector_store_file(store_id, file_id, file_info, file_contents)
|
||||
await milvus_vec_adapter._delete_openai_vector_store_file_from_storage(store_id, file_id)
|
||||
await vector_io_adapter._save_openai_vector_store_file(store_id, file_id, file_info, file_contents)
|
||||
await vector_io_adapter._delete_openai_vector_store_file_from_storage(store_id, file_id)
|
||||
|
||||
loaded_contents = await milvus_vec_adapter._load_openai_vector_store_file_contents(store_id, file_id)
|
||||
loaded_contents = await vector_io_adapter._load_openai_vector_store_file_contents(store_id, file_id)
|
||||
assert loaded_contents == []
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue