refactoring test a little and removing print statement

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
Francisco Javier Arceo 2025-02-18 12:11:10 -05:00
parent fe152620fb
commit 5a6c95ecf9
2 changed files with 22 additions and 18 deletions

View file

@ -77,7 +77,6 @@ class SQLiteVecIndex(EmbeddingIndex):
If any insert fails, the transaction is rolled back to maintain consistency.
"""
cur = self.connection.cursor()
print(f"inserting {len(chunks)} chunks: {chunks}")
try:
# Start transaction
cur.execute("BEGIN TRANSACTION")

View file

@ -20,25 +20,30 @@ from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import SQLiteV
# pytest llama_stack/providers/tests/vector_io/test_sqlite_vec.py \
# -v -s --tb=short --disable-warnings --asyncio-mode=auto
SQLITE_VEC_PROVIDER = "sqlite_vec"
EMBEDDING_DIMENSION = 384
EMBEDDING_MODEL = "all-MiniLM-L6-v2"
@pytest.fixture(scope="session")
def loop():
return asyncio.get_event_loop()
return asyncio.new_event_loop()
@pytest.fixture(scope="session", autouse=True)
def sqlite_connection(loop):
conn = sqlite3.connect(":memory:")
conn.enable_load_extension(True)
sqlite_vec.load(conn)
yield conn
conn.close()
try:
conn.enable_load_extension(True)
sqlite_vec.load(conn)
yield conn
finally:
conn.close()
@pytest.fixture(scope="session", autouse=True)
async def sqlite_vec_index(sqlite_connection):
return await SQLiteVecIndex.create(dimension=384, connection=sqlite_connection, bank_id="test_bank")
return await SQLiteVecIndex.create(dimension=EMBEDDING_DIMENSION, connection=sqlite_connection, bank_id="test_bank")
@pytest.fixture
@ -60,8 +65,8 @@ def sample_embeddings():
np.random.seed(42)
return np.array(
[
np.random.rand(384).astype(np.float32),
np.random.rand(384).astype(np.float32),
np.random.rand(EMBEDDING_DIMENSION).astype(np.float32),
np.random.rand(EMBEDDING_DIMENSION).astype(np.float32),
]
)
@ -78,13 +83,13 @@ async def test_add_chunks(sqlite_vec_index, sample_chunks, sample_embeddings):
@pytest.mark.asyncio
async def test_query_chunks(sqlite_vec_index, sample_chunks, sample_embeddings):
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
query_embedding = np.random.rand(384).astype(np.float32)
query_embedding = np.random.rand(EMBEDDING_DIMENSION).astype(np.float32)
response = await sqlite_vec_index.query(query_embedding, k=1, score_threshold=0.0)
assert isinstance(response, QueryChunksResponse)
assert len(response.chunks) > 0
@pytest.fixture
@pytest.fixture(scope="session")
async def sqlite_vec_adapter(sqlite_connection):
config = type("Config", (object,), {"db_path": ":memory:"}) # Mock config with in-memory database
adapter = SQLiteVecVectorIOAdapter(config=config, inference_api=None)
@ -97,10 +102,10 @@ async def sqlite_vec_adapter(sqlite_connection):
async def test_register_vector_db(sqlite_vec_adapter):
vector_db = VectorDB(
identifier="test_db",
embedding_model="all-MiniLM-L6-v2",
embedding_dimension=384,
embedding_model=EMBEDDING_MODEL,
embedding_dimension=EMBEDDING_DIMENSION,
metadata={},
provider_id="sqlite_vec",
provider_id=SQLITE_VEC_PROVIDER,
)
await sqlite_vec_adapter.register_vector_db(vector_db)
vector_dbs = await sqlite_vec_adapter.list_vector_dbs()
@ -111,10 +116,10 @@ async def test_register_vector_db(sqlite_vec_adapter):
async def test_unregister_vector_db(sqlite_vec_adapter):
vector_db = VectorDB(
identifier="test_db",
embedding_model="all-MiniLM-L6-v2",
embedding_dimension=384,
embedding_model=EMBEDDING_MODEL,
embedding_dimension=EMBEDDING_DIMENSION,
metadata={},
provider_id="sqlite_vec",
provider_id=SQLITE_VEC_PROVIDER,
)
await sqlite_vec_adapter.register_vector_db(vector_db)
await sqlite_vec_adapter.unregister_vector_db("test_db")