mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-09 19:58:29 +00:00
Adding some additional unit tests
Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
parent
9ee59443aa
commit
fe152620fb
2 changed files with 128 additions and 5 deletions
|
@ -59,7 +59,7 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
# Create the virtual table for embeddings.
|
||||
cur.execute(f"""
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS {self.vector_table}
|
||||
USING vec0(embedding FLOAT[{self.dimension}]);
|
||||
USING vec0(embedding FLOAT[{self.dimension}], id TEXT);
|
||||
""")
|
||||
self.connection.commit()
|
||||
|
||||
|
@ -77,6 +77,7 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
If any insert fails, the transaction is rolled back to maintain consistency.
|
||||
"""
|
||||
cur = self.connection.cursor()
|
||||
print(f"inserting {len(chunks)} chunks: {chunks}")
|
||||
try:
|
||||
# Start transaction
|
||||
cur.execute("BEGIN TRANSACTION")
|
||||
|
@ -91,9 +92,9 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
# Insert metadata (ON CONFLICT to avoid duplicates)
|
||||
cur.executemany(
|
||||
f"""
|
||||
INSERT INTO {self.metadata_table} (id, document)
|
||||
INSERT INTO {self.metadata_table} (id, chunk)
|
||||
VALUES (?, ?)
|
||||
ON CONFLICT(id) DO UPDATE SET document = excluded.document;
|
||||
ON CONFLICT(id) DO UPDATE SET chunk = excluded.chunk;
|
||||
""",
|
||||
metadata_data,
|
||||
)
|
||||
|
@ -103,7 +104,7 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
for j, (chunk, emb) in enumerate(zip(batch_chunks, batch_embeddings, strict=True))
|
||||
]
|
||||
# Insert embeddings in batch
|
||||
cur.executemany(f"INSERT INTO {self.vector_table} (rowid, embedding) VALUES (?, ?);", embedding_data)
|
||||
cur.executemany(f"INSERT INTO {self.vector_table} (id, embedding) VALUES (?, ?);", embedding_data)
|
||||
self.connection.commit()
|
||||
|
||||
except sqlite3.Error as e:
|
||||
|
@ -124,7 +125,7 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
query_sql = f"""
|
||||
SELECT m.id, m.chunk, v.distance
|
||||
FROM {self.vector_table} AS v
|
||||
JOIN {self.metadata_table} AS m ON m.id = v.rowid
|
||||
JOIN {self.metadata_table} AS m ON m.id = v.id
|
||||
WHERE v.embedding MATCH ? AND k = ?
|
||||
ORDER BY v.distance;
|
||||
"""
|
||||
|
|
122
llama_stack/providers/tests/vector_io/test_sqlite_vec.py
Normal file
122
llama_stack/providers/tests/vector_io/test_sqlite_vec.py
Normal file
|
@ -0,0 +1,122 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import sqlite3
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import sqlite_vec
|
||||
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse
|
||||
from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import SQLiteVecIndex, SQLiteVecVectorIOAdapter
|
||||
|
||||
# How to run this test:
|
||||
#
|
||||
# pytest llama_stack/providers/tests/vector_io/test_sqlite_vec.py \
|
||||
# -v -s --tb=short --disable-warnings --asyncio-mode=auto
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def loop():
|
||||
return asyncio.get_event_loop()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def sqlite_connection(loop):
|
||||
conn = sqlite3.connect(":memory:")
|
||||
conn.enable_load_extension(True)
|
||||
sqlite_vec.load(conn)
|
||||
|
||||
yield conn
|
||||
conn.close()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
async def sqlite_vec_index(sqlite_connection):
|
||||
return await SQLiteVecIndex.create(dimension=384, connection=sqlite_connection, bank_id="test_bank")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_chunks():
|
||||
return [
|
||||
Chunk(
|
||||
content="Python is a high-level programming language.",
|
||||
metadata={"category": "programming", "document_id": "doc 1"},
|
||||
),
|
||||
Chunk(
|
||||
content="Machine learning is a subset of artificial intelligence.",
|
||||
metadata={"category": "AI", "document_id": "doc 1"},
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_embeddings():
|
||||
np.random.seed(42)
|
||||
return np.array(
|
||||
[
|
||||
np.random.rand(384).astype(np.float32),
|
||||
np.random.rand(384).astype(np.float32),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_chunks(sqlite_vec_index, sample_chunks, sample_embeddings):
|
||||
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
||||
cur = sqlite_vec_index.connection.cursor()
|
||||
cur.execute(f"SELECT COUNT(*) FROM {sqlite_vec_index.metadata_table}")
|
||||
count = cur.fetchone()[0]
|
||||
assert count == len(sample_chunks)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_chunks(sqlite_vec_index, sample_chunks, sample_embeddings):
|
||||
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
||||
query_embedding = np.random.rand(384).astype(np.float32)
|
||||
response = await sqlite_vec_index.query(query_embedding, k=1, score_threshold=0.0)
|
||||
assert isinstance(response, QueryChunksResponse)
|
||||
assert len(response.chunks) > 0
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def sqlite_vec_adapter(sqlite_connection):
|
||||
config = type("Config", (object,), {"db_path": ":memory:"}) # Mock config with in-memory database
|
||||
adapter = SQLiteVecVectorIOAdapter(config=config, inference_api=None)
|
||||
await adapter.initialize()
|
||||
yield adapter
|
||||
await adapter.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_vector_db(sqlite_vec_adapter):
|
||||
vector_db = VectorDB(
|
||||
identifier="test_db",
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
embedding_dimension=384,
|
||||
metadata={},
|
||||
provider_id="sqlite_vec",
|
||||
)
|
||||
await sqlite_vec_adapter.register_vector_db(vector_db)
|
||||
vector_dbs = await sqlite_vec_adapter.list_vector_dbs()
|
||||
assert any(db.identifier == "test_db" for db in vector_dbs)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unregister_vector_db(sqlite_vec_adapter):
|
||||
vector_db = VectorDB(
|
||||
identifier="test_db",
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
embedding_dimension=384,
|
||||
metadata={},
|
||||
provider_id="sqlite_vec",
|
||||
)
|
||||
await sqlite_vec_adapter.register_vector_db(vector_db)
|
||||
await sqlite_vec_adapter.unregister_vector_db("test_db")
|
||||
vector_dbs = await sqlite_vec_adapter.list_vector_dbs()
|
||||
assert not any(db.identifier == "test_db" for db in vector_dbs)
|
Loading…
Add table
Add a link
Reference in a new issue