mirror of
https://github.com/meta-llama/llama-stack.git
synced 2026-01-08 00:11:30 +00:00
refactoring test a little and removing print statement
Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
parent
fe152620fb
commit
5a6c95ecf9
2 changed files with 22 additions and 18 deletions
|
|
@ -77,7 +77,6 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
If any insert fails, the transaction is rolled back to maintain consistency.
|
||||
"""
|
||||
cur = self.connection.cursor()
|
||||
print(f"inserting {len(chunks)} chunks: {chunks}")
|
||||
try:
|
||||
# Start transaction
|
||||
cur.execute("BEGIN TRANSACTION")
|
||||
|
|
|
|||
|
|
@ -20,25 +20,30 @@ from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import SQLiteV
|
|||
# pytest llama_stack/providers/tests/vector_io/test_sqlite_vec.py \
|
||||
# -v -s --tb=short --disable-warnings --asyncio-mode=auto
|
||||
|
||||
SQLITE_VEC_PROVIDER = "sqlite_vec"
|
||||
EMBEDDING_DIMENSION = 384
|
||||
EMBEDDING_MODEL = "all-MiniLM-L6-v2"
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def loop():
|
||||
return asyncio.get_event_loop()
|
||||
return asyncio.new_event_loop()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def sqlite_connection(loop):
|
||||
conn = sqlite3.connect(":memory:")
|
||||
conn.enable_load_extension(True)
|
||||
sqlite_vec.load(conn)
|
||||
|
||||
yield conn
|
||||
conn.close()
|
||||
try:
|
||||
conn.enable_load_extension(True)
|
||||
sqlite_vec.load(conn)
|
||||
yield conn
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
async def sqlite_vec_index(sqlite_connection):
|
||||
return await SQLiteVecIndex.create(dimension=384, connection=sqlite_connection, bank_id="test_bank")
|
||||
return await SQLiteVecIndex.create(dimension=EMBEDDING_DIMENSION, connection=sqlite_connection, bank_id="test_bank")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -60,8 +65,8 @@ def sample_embeddings():
|
|||
np.random.seed(42)
|
||||
return np.array(
|
||||
[
|
||||
np.random.rand(384).astype(np.float32),
|
||||
np.random.rand(384).astype(np.float32),
|
||||
np.random.rand(EMBEDDING_DIMENSION).astype(np.float32),
|
||||
np.random.rand(EMBEDDING_DIMENSION).astype(np.float32),
|
||||
]
|
||||
)
|
||||
|
||||
|
|
@ -78,13 +83,13 @@ async def test_add_chunks(sqlite_vec_index, sample_chunks, sample_embeddings):
|
|||
@pytest.mark.asyncio
|
||||
async def test_query_chunks(sqlite_vec_index, sample_chunks, sample_embeddings):
|
||||
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
||||
query_embedding = np.random.rand(384).astype(np.float32)
|
||||
query_embedding = np.random.rand(EMBEDDING_DIMENSION).astype(np.float32)
|
||||
response = await sqlite_vec_index.query(query_embedding, k=1, score_threshold=0.0)
|
||||
assert isinstance(response, QueryChunksResponse)
|
||||
assert len(response.chunks) > 0
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@pytest.fixture(scope="session")
|
||||
async def sqlite_vec_adapter(sqlite_connection):
|
||||
config = type("Config", (object,), {"db_path": ":memory:"}) # Mock config with in-memory database
|
||||
adapter = SQLiteVecVectorIOAdapter(config=config, inference_api=None)
|
||||
|
|
@ -97,10 +102,10 @@ async def sqlite_vec_adapter(sqlite_connection):
|
|||
async def test_register_vector_db(sqlite_vec_adapter):
|
||||
vector_db = VectorDB(
|
||||
identifier="test_db",
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
embedding_dimension=384,
|
||||
embedding_model=EMBEDDING_MODEL,
|
||||
embedding_dimension=EMBEDDING_DIMENSION,
|
||||
metadata={},
|
||||
provider_id="sqlite_vec",
|
||||
provider_id=SQLITE_VEC_PROVIDER,
|
||||
)
|
||||
await sqlite_vec_adapter.register_vector_db(vector_db)
|
||||
vector_dbs = await sqlite_vec_adapter.list_vector_dbs()
|
||||
|
|
@ -111,10 +116,10 @@ async def test_register_vector_db(sqlite_vec_adapter):
|
|||
async def test_unregister_vector_db(sqlite_vec_adapter):
|
||||
vector_db = VectorDB(
|
||||
identifier="test_db",
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
embedding_dimension=384,
|
||||
embedding_model=EMBEDDING_MODEL,
|
||||
embedding_dimension=EMBEDDING_DIMENSION,
|
||||
metadata={},
|
||||
provider_id="sqlite_vec",
|
||||
provider_id=SQLITE_VEC_PROVIDER,
|
||||
)
|
||||
await sqlite_vec_adapter.register_vector_db(vector_db)
|
||||
await sqlite_vec_adapter.unregister_vector_db("test_db")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue