chore: Adding unit tests for OpenAI vector stores and migrating SQLite-vec registry to kvstore (#2665)

# What does this PR do?

This PR refactors and the VectorIO backend logic for `sqlite-vec` and
adds unit tests and fixtures to make it easy to test both `sqlite-vec`
and `milvus`.

Key changes:
- `sqlite-vec` migrated to `kvstore` registry
- added in-memory cache for sqlite-vec to be consistent with `milvus`
- default fixtures moved to `conftest.py` 
- removed redundant tests from sqlite`-vec`
- made `test_vector_io_openai_vector_stores.py` more easily extensible 


## Test Plan
Unit tests added testing inline providers.

---------

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
Francisco Arceo 2025-07-10 14:22:13 -04:00 committed by GitHub
parent b18f4d1ccf
commit 6a6b66ae4f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 422 additions and 424 deletions

View file

@ -11,7 +11,7 @@ Please refer to the remote provider documentation.
| Field | Type | Required | Default | Description | | Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------| |-------|------|----------|---------|-------------|
| `db_path` | `<class 'str'>` | No | PydanticUndefined | | | `db_path` | `<class 'str'>` | No | PydanticUndefined | |
| `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig` | No | sqlite | | | `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig` | No | sqlite | Config for KV store backend (SQLite only for now) |
| `consistency_level` | `<class 'str'>` | No | Strong | The consistency level of the Milvus server | | `consistency_level` | `<class 'str'>` | No | Strong | The consistency level of the Milvus server |
## Sample Configuration ## Sample Configuration

View file

@ -205,12 +205,16 @@ See [sqlite-vec's GitHub repo](https://github.com/asg017/sqlite-vec/tree/main) f
| Field | Type | Required | Default | Description | | Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------| |-------|------|----------|---------|-------------|
| `db_path` | `<class 'str'>` | No | PydanticUndefined | | | `db_path` | `<class 'str'>` | No | PydanticUndefined | Path to the SQLite database file |
| `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig` | No | sqlite | Config for KV store backend (SQLite only for now) |
## Sample Configuration ## Sample Configuration
```yaml ```yaml
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/sqlite_vec.db db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/sqlite_vec.db
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/sqlite_vec_registry.db
``` ```

View file

@ -10,12 +10,16 @@ Please refer to the sqlite-vec provider documentation.
| Field | Type | Required | Default | Description | | Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------| |-------|------|----------|---------|-------------|
| `db_path` | `<class 'str'>` | No | PydanticUndefined | | | `db_path` | `<class 'str'>` | No | PydanticUndefined | Path to the SQLite database file |
| `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig` | No | sqlite | Config for KV store backend (SQLite only for now) |
## Sample Configuration ## Sample Configuration
```yaml ```yaml
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/sqlite_vec.db db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/sqlite_vec.db
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/sqlite_vec_registry.db
``` ```

View file

@ -18,7 +18,7 @@ from llama_stack.schema_utils import json_schema_type
@json_schema_type @json_schema_type
class MilvusVectorIOConfig(BaseModel): class MilvusVectorIOConfig(BaseModel):
db_path: str db_path: str
kvstore: KVStoreConfig kvstore: KVStoreConfig = Field(description="Config for KV store backend (SQLite only for now)")
consistency_level: str = Field(description="The consistency level of the Milvus server", default="Strong") consistency_level: str = Field(description="The consistency level of the Milvus server", default="Strong")
@classmethod @classmethod

View file

@ -6,14 +6,24 @@
from typing import Any from typing import Any
from pydantic import BaseModel from pydantic import BaseModel, Field
from llama_stack.providers.utils.kvstore.config import (
KVStoreConfig,
SqliteKVStoreConfig,
)
class SQLiteVectorIOConfig(BaseModel): class SQLiteVectorIOConfig(BaseModel):
db_path: str db_path: str = Field(description="Path to the SQLite database file")
kvstore: KVStoreConfig = Field(description="Config for KV store backend (SQLite only for now)")
@classmethod @classmethod
def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]: def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]:
return { return {
"db_path": "${env.SQLITE_STORE_DIR:=" + __distro_dir__ + "}/" + "sqlite_vec.db", "db_path": "${env.SQLITE_STORE_DIR:=" + __distro_dir__ + "}/" + "sqlite_vec.db",
"kvstore": SqliteKVStoreConfig.sample_run_config(
__distro_dir__=__distro_dir__,
db_name="sqlite_vec_registry.db",
),
} }

View file

@ -24,6 +24,8 @@ from llama_stack.apis.vector_io import (
VectorIO, VectorIO,
) )
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.kvstore.api import KVStore
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
from llama_stack.providers.utils.memory.vector_store import ( from llama_stack.providers.utils.memory.vector_store import (
RERANKER_TYPE_RRF, RERANKER_TYPE_RRF,
@ -40,6 +42,13 @@ KEYWORD_SEARCH = "keyword"
HYBRID_SEARCH = "hybrid" HYBRID_SEARCH = "hybrid"
SEARCH_MODES = {VECTOR_SEARCH, KEYWORD_SEARCH, HYBRID_SEARCH} SEARCH_MODES = {VECTOR_SEARCH, KEYWORD_SEARCH, HYBRID_SEARCH}
VERSION = "v3"
VECTOR_DBS_PREFIX = f"vector_dbs:sqlite_vec:{VERSION}::"
VECTOR_INDEX_PREFIX = f"vector_index:sqlite_vec:{VERSION}::"
OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:sqlite_vec:{VERSION}::"
OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:sqlite_vec:{VERSION}::"
OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_contents:sqlite_vec:{VERSION}::"
def serialize_vector(vector: list[float]) -> bytes: def serialize_vector(vector: list[float]) -> bytes:
"""Serialize a list of floats into a compact binary representation.""" """Serialize a list of floats into a compact binary representation."""
@ -117,13 +126,14 @@ class SQLiteVecIndex(EmbeddingIndex):
- An FTS5 table (fts_chunks_{bank_id}) for full-text keyword search. - An FTS5 table (fts_chunks_{bank_id}) for full-text keyword search.
""" """
def __init__(self, dimension: int, db_path: str, bank_id: str): def __init__(self, dimension: int, db_path: str, bank_id: str, kvstore: KVStore | None = None):
self.dimension = dimension self.dimension = dimension
self.db_path = db_path self.db_path = db_path
self.bank_id = bank_id self.bank_id = bank_id
self.metadata_table = f"chunks_{bank_id}".replace("-", "_") self.metadata_table = f"chunks_{bank_id}".replace("-", "_")
self.vector_table = f"vec_chunks_{bank_id}".replace("-", "_") self.vector_table = f"vec_chunks_{bank_id}".replace("-", "_")
self.fts_table = f"fts_chunks_{bank_id}".replace("-", "_") self.fts_table = f"fts_chunks_{bank_id}".replace("-", "_")
self.kvstore = kvstore
@classmethod @classmethod
async def create(cls, dimension: int, db_path: str, bank_id: str): async def create(cls, dimension: int, db_path: str, bank_id: str):
@ -425,27 +435,116 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
self.files_api = files_api self.files_api = files_api
self.cache: dict[str, VectorDBWithIndex] = {} self.cache: dict[str, VectorDBWithIndex] = {}
self.openai_vector_stores: dict[str, dict[str, Any]] = {} self.openai_vector_stores: dict[str, dict[str, Any]] = {}
self.kvstore: KVStore | None = None
async def initialize(self) -> None: async def initialize(self) -> None:
def _setup_connection(): self.kvstore = await kvstore_impl(self.config.kvstore)
# Open a connection to the SQLite database (the file is specified in the config).
start_key = VECTOR_DBS_PREFIX
end_key = f"{VECTOR_DBS_PREFIX}\xff"
stored_vector_dbs = await self.kvstore.values_in_range(start_key, end_key)
for db_json in stored_vector_dbs:
vector_db = VectorDB.model_validate_json(db_json)
index = await SQLiteVecIndex.create(
vector_db.embedding_dimension,
self.config.db_path,
vector_db.identifier,
)
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
# load any existing OpenAI vector stores
self.openai_vector_stores = await self._load_openai_vector_stores()
async def shutdown(self) -> None:
# nothing to do since we don't maintain a persistent connection
pass
async def list_vector_dbs(self) -> list[VectorDB]:
return [v.vector_db for v in self.cache.values()]
async def register_vector_db(self, vector_db: VectorDB) -> None:
index = await SQLiteVecIndex.create(
vector_db.embedding_dimension,
self.config.db_path,
vector_db.identifier,
)
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> VectorDBWithIndex | None:
if vector_db_id in self.cache:
return self.cache[vector_db_id]
if self.vector_db_store is None:
raise ValueError(f"Vector DB {vector_db_id} not found")
vector_db = self.vector_db_store.get_vector_db(vector_db_id)
if not vector_db:
raise ValueError(f"Vector DB {vector_db_id} not found")
index = VectorDBWithIndex(
vector_db=vector_db,
index=SQLiteVecIndex(
dimension=vector_db.embedding_dimension,
db_path=self.config.db_path,
bank_id=vector_db.identifier,
kvstore=self.kvstore,
),
inference_api=self.inference_api,
)
self.cache[vector_db_id] = index
return index
async def unregister_vector_db(self, vector_db_id: str) -> None:
if vector_db_id not in self.cache:
logger.warning(f"Vector DB {vector_db_id} not found")
return
await self.cache[vector_db_id].index.delete()
del self.cache[vector_db_id]
# OpenAI Vector Store Mixin abstract method implementations
async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
"""Save vector store metadata to SQLite database."""
assert self.kvstore is not None
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
await self.kvstore.set(key=key, value=json.dumps(store_info))
self.openai_vector_stores[store_id] = store_info
async def _load_openai_vector_stores(self) -> dict[str, dict[str, Any]]:
"""Load all vector store metadata from SQLite database."""
assert self.kvstore is not None
start_key = OPENAI_VECTOR_STORES_PREFIX
end_key = f"{OPENAI_VECTOR_STORES_PREFIX}\xff"
stored_openai_stores = await self.kvstore.values_in_range(start_key, end_key)
stores = {}
for store_data in stored_openai_stores:
store_info = json.loads(store_data)
stores[store_info["id"]] = store_info
return stores
async def _update_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
"""Update vector store metadata in SQLite database."""
assert self.kvstore is not None
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
await self.kvstore.set(key=key, value=json.dumps(store_info))
self.openai_vector_stores[store_id] = store_info
async def _delete_openai_vector_store_from_storage(self, store_id: str) -> None:
"""Delete vector store metadata from SQLite database."""
assert self.kvstore is not None
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
await self.kvstore.delete(key)
if store_id in self.openai_vector_stores:
del self.openai_vector_stores[store_id]
async def _save_openai_vector_store_file(
self, store_id: str, file_id: str, file_info: dict[str, Any], file_contents: list[dict[str, Any]]
) -> None:
"""Save vector store file metadata to SQLite database."""
def _create_or_store():
connection = _create_sqlite_connection(self.config.db_path) connection = _create_sqlite_connection(self.config.db_path)
cur = connection.cursor() cur = connection.cursor()
try: try:
# Create a table to persist vector DB registrations.
cur.execute("""
CREATE TABLE IF NOT EXISTS vector_dbs (
id TEXT PRIMARY KEY,
metadata TEXT
);
""")
# Create a table to persist OpenAI vector stores.
cur.execute("""
CREATE TABLE IF NOT EXISTS openai_vector_stores (
id TEXT PRIMARY KEY,
metadata TEXT
);
""")
# Create a table to persist OpenAI vector store files. # Create a table to persist OpenAI vector store files.
cur.execute(""" cur.execute("""
CREATE TABLE IF NOT EXISTS openai_vector_store_files ( CREATE TABLE IF NOT EXISTS openai_vector_store_files (
@ -464,168 +563,6 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
); );
""") """)
connection.commit() connection.commit()
# Load any existing vector DB registrations.
cur.execute("SELECT metadata FROM vector_dbs")
vector_db_rows = cur.fetchall()
return vector_db_rows
finally:
cur.close()
connection.close()
vector_db_rows = await asyncio.to_thread(_setup_connection)
# Load existing vector DBs
for row in vector_db_rows:
vector_db_data = row[0]
vector_db = VectorDB.model_validate_json(vector_db_data)
index = await SQLiteVecIndex.create(
vector_db.embedding_dimension,
self.config.db_path,
vector_db.identifier,
)
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
# Load existing OpenAI vector stores using the mixin method
self.openai_vector_stores = await self._load_openai_vector_stores()
async def shutdown(self) -> None:
# nothing to do since we don't maintain a persistent connection
pass
async def register_vector_db(self, vector_db: VectorDB) -> None:
def _register_db():
connection = _create_sqlite_connection(self.config.db_path)
cur = connection.cursor()
try:
cur.execute(
"INSERT OR REPLACE INTO vector_dbs (id, metadata) VALUES (?, ?)",
(vector_db.identifier, vector_db.model_dump_json()),
)
connection.commit()
finally:
cur.close()
connection.close()
await asyncio.to_thread(_register_db)
index = await SQLiteVecIndex.create(
vector_db.embedding_dimension,
self.config.db_path,
vector_db.identifier,
)
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
async def list_vector_dbs(self) -> list[VectorDB]:
return [v.vector_db for v in self.cache.values()]
async def unregister_vector_db(self, vector_db_id: str) -> None:
if vector_db_id not in self.cache:
logger.warning(f"Vector DB {vector_db_id} not found")
return
await self.cache[vector_db_id].index.delete()
del self.cache[vector_db_id]
def _delete_vector_db_from_registry():
connection = _create_sqlite_connection(self.config.db_path)
cur = connection.cursor()
try:
cur.execute("DELETE FROM vector_dbs WHERE id = ?", (vector_db_id,))
connection.commit()
finally:
cur.close()
connection.close()
await asyncio.to_thread(_delete_vector_db_from_registry)
# OpenAI Vector Store Mixin abstract method implementations
async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
"""Save vector store metadata to SQLite database."""
def _store():
connection = _create_sqlite_connection(self.config.db_path)
cur = connection.cursor()
try:
cur.execute(
"INSERT OR REPLACE INTO openai_vector_stores (id, metadata) VALUES (?, ?)",
(store_id, json.dumps(store_info)),
)
connection.commit()
except Exception as e:
logger.error(f"Error saving openai vector store {store_id}: {e}")
raise
finally:
cur.close()
connection.close()
try:
await asyncio.to_thread(_store)
except Exception as e:
logger.error(f"Error saving openai vector store {store_id}: {e}")
raise
async def _load_openai_vector_stores(self) -> dict[str, dict[str, Any]]:
"""Load all vector store metadata from SQLite database."""
def _load():
connection = _create_sqlite_connection(self.config.db_path)
cur = connection.cursor()
try:
cur.execute("SELECT metadata FROM openai_vector_stores")
rows = cur.fetchall()
return rows
finally:
cur.close()
connection.close()
rows = await asyncio.to_thread(_load)
stores = {}
for row in rows:
store_data = row[0]
store_info = json.loads(store_data)
stores[store_info["id"]] = store_info
return stores
async def _update_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
"""Update vector store metadata in SQLite database."""
def _update():
connection = _create_sqlite_connection(self.config.db_path)
cur = connection.cursor()
try:
cur.execute(
"UPDATE openai_vector_stores SET metadata = ? WHERE id = ?",
(json.dumps(store_info), store_id),
)
connection.commit()
finally:
cur.close()
connection.close()
await asyncio.to_thread(_update)
async def _delete_openai_vector_store_from_storage(self, store_id: str) -> None:
"""Delete vector store metadata from SQLite database."""
def _delete():
connection = _create_sqlite_connection(self.config.db_path)
cur = connection.cursor()
try:
cur.execute("DELETE FROM openai_vector_stores WHERE id = ?", (store_id,))
connection.commit()
finally:
cur.close()
connection.close()
await asyncio.to_thread(_delete)
async def _save_openai_vector_store_file(
self, store_id: str, file_id: str, file_info: dict[str, Any], file_contents: list[dict[str, Any]]
) -> None:
"""Save vector store file metadata to SQLite database."""
def _store():
connection = _create_sqlite_connection(self.config.db_path)
cur = connection.cursor()
try:
cur.execute( cur.execute(
"INSERT OR REPLACE INTO openai_vector_store_files (store_id, file_id, metadata) VALUES (?, ?, ?)", "INSERT OR REPLACE INTO openai_vector_store_files (store_id, file_id, metadata) VALUES (?, ?, ?)",
(store_id, file_id, json.dumps(file_info)), (store_id, file_id, json.dumps(file_info)),
@ -643,7 +580,7 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
connection.close() connection.close()
try: try:
await asyncio.to_thread(_store) await asyncio.to_thread(_create_or_store)
except Exception as e: except Exception as e:
logger.error(f"Error saving openai vector store file {store_id} {file_id}: {e}") logger.error(f"Error saving openai vector store file {store_id} {file_id}: {e}")
raise raise
@ -722,6 +659,10 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
cur.execute( cur.execute(
"DELETE FROM openai_vector_store_files WHERE store_id = ? AND file_id = ?", (store_id, file_id) "DELETE FROM openai_vector_store_files WHERE store_id = ? AND file_id = ?", (store_id, file_id)
) )
cur.execute(
"DELETE FROM openai_vector_store_files_contents WHERE store_id = ? AND file_id = ?",
(store_id, file_id),
)
connection.commit() connection.commit()
finally: finally:
cur.close() cur.close()
@ -730,15 +671,17 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
await asyncio.to_thread(_delete) await asyncio.to_thread(_delete)
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None: async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
if vector_db_id not in self.cache: index = await self._get_and_cache_vector_db_index(vector_db_id)
raise ValueError(f"Vector DB {vector_db_id} not found. Found: {list(self.cache.keys())}") if not index:
raise ValueError(f"Vector DB {vector_db_id} not found")
# The VectorDBWithIndex helper is expected to compute embeddings via the inference_api # The VectorDBWithIndex helper is expected to compute embeddings via the inference_api
# and then call our index's add_chunks. # and then call our index's add_chunks.
await self.cache[vector_db_id].insert_chunks(chunks) await index.insert_chunks(chunks)
async def query_chunks( async def query_chunks(
self, vector_db_id: str, query: Any, params: dict[str, Any] | None = None self, vector_db_id: str, query: Any, params: dict[str, Any] | None = None
) -> QueryChunksResponse: ) -> QueryChunksResponse:
if vector_db_id not in self.cache: index = await self._get_and_cache_vector_db_index(vector_db_id)
if not index:
raise ValueError(f"Vector DB {vector_db_id} not found") raise ValueError(f"Vector DB {vector_db_id} not found")
return await self.cache[vector_db_id].query_chunks(query, params) return await index.query_chunks(query, params)

View file

@ -61,6 +61,11 @@ class MilvusIndex(EmbeddingIndex):
self.consistency_level = consistency_level self.consistency_level = consistency_level
self.kvstore = kvstore self.kvstore = kvstore
async def initialize(self):
# MilvusIndex does not require explicit initialization
# TODO: could move collection creation into initialization but it is not really necessary
pass
async def delete(self): async def delete(self):
if await asyncio.to_thread(self.client.has_collection, self.collection_name): if await asyncio.to_thread(self.client.has_collection, self.collection_name):
await asyncio.to_thread(self.client.drop_collection, collection_name=self.collection_name) await asyncio.to_thread(self.client.drop_collection, collection_name=self.collection_name)
@ -199,6 +204,9 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
if vector_db_id in self.cache: if vector_db_id in self.cache:
return self.cache[vector_db_id] return self.cache[vector_db_id]
if self.vector_db_store is None:
raise ValueError(f"Vector DB {vector_db_id} not found")
vector_db = await self.vector_db_store.get_vector_db(vector_db_id) vector_db = await self.vector_db_store.get_vector_db(vector_db_id)
if not vector_db: if not vector_db:
raise ValueError(f"Vector DB {vector_db_id} not found") raise ValueError(f"Vector DB {vector_db_id} not found")

View file

@ -39,6 +39,9 @@ providers:
provider_type: inline::sqlite-vec provider_type: inline::sqlite-vec
config: config:
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/open-benchmark}/sqlite_vec.db db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/open-benchmark}/sqlite_vec.db
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/open-benchmark}/sqlite_vec_registry.db
- provider_id: ${env.ENABLE_CHROMADB:+chromadb} - provider_id: ${env.ENABLE_CHROMADB:+chromadb}
provider_type: remote::chromadb provider_type: remote::chromadb
config: config:

View file

@ -144,6 +144,9 @@ providers:
provider_type: inline::sqlite-vec provider_type: inline::sqlite-vec
config: config:
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/sqlite_vec.db db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/sqlite_vec.db
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/sqlite_vec_registry.db
- provider_id: ${env.ENABLE_MILVUS:=__disabled__} - provider_id: ${env.ENABLE_MILVUS:=__disabled__}
provider_type: inline::milvus provider_type: inline::milvus
config: config:

View file

@ -8,10 +8,18 @@ import random
import numpy as np import numpy as np
import pytest import pytest
from pymilvus import MilvusClient, connections
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, ChunkMetadata from llama_stack.apis.vector_io import Chunk, ChunkMetadata
from llama_stack.providers.inline.vector_io.milvus.config import MilvusVectorIOConfig, SqliteKVStoreConfig
from llama_stack.providers.inline.vector_io.sqlite_vec import SQLiteVectorIOConfig
from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import SQLiteVecIndex, SQLiteVecVectorIOAdapter
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex, MilvusVectorIOAdapter
EMBEDDING_DIMENSION = 384 EMBEDDING_DIMENSION = 384
COLLECTION_PREFIX = "test_collection"
MILVUS_ALIAS = "test_milvus"
@pytest.fixture @pytest.fixture
@ -50,7 +58,156 @@ def sample_chunks():
return sample return sample
@pytest.fixture(scope="session")
def sample_chunks_with_metadata():
"""Generates chunks that force multiple batches for a single document to expose ID conflicts."""
n, k = 10, 3
sample = [
Chunk(
content=f"Sentence {i} from document {j}",
metadata={"document_id": f"document-{j}"},
chunk_metadata=ChunkMetadata(
document_id=f"document-{j}",
chunk_id=f"document-{j}-chunk-{i}",
source=f"example source-{j}-{i}",
),
)
for j in range(k)
for i in range(n)
]
return sample
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def sample_embeddings(sample_chunks): def sample_embeddings(sample_chunks):
np.random.seed(42) np.random.seed(42)
return np.array([np.random.rand(EMBEDDING_DIMENSION).astype(np.float32) for _ in sample_chunks]) return np.array([np.random.rand(EMBEDDING_DIMENSION).astype(np.float32) for _ in sample_chunks])
@pytest.fixture(scope="session")
def sample_embeddings_with_metadata(sample_chunks_with_metadata):
np.random.seed(42)
return np.array([np.random.rand(EMBEDDING_DIMENSION).astype(np.float32) for _ in sample_chunks_with_metadata])
@pytest.fixture(params=["milvus", "sqlite_vec"])
def vector_provider(request):
return request.param
@pytest.fixture(scope="session")
def mock_inference_api(embedding_dimension):
class MockInferenceAPI:
async def embed_batch(self, texts: list[str]) -> list[list[float]]:
return [np.random.rand(embedding_dimension).astype(np.float32).tolist() for _ in texts]
return MockInferenceAPI()
@pytest.fixture
async def unique_kvstore_config(tmp_path_factory):
# Generate a unique filename for this test
unique_id = f"test_kv_{np.random.randint(1e6)}"
temp_dir = tmp_path_factory.getbasetemp()
db_path = str(temp_dir / f"{unique_id}.db")
return SqliteKVStoreConfig(db_path=db_path)
@pytest.fixture(scope="session")
def sqlite_vec_db_path(tmp_path_factory):
db_path = str(tmp_path_factory.getbasetemp() / "test.db")
return db_path
@pytest.fixture
async def sqlite_vec_vec_index(embedding_dimension, tmp_path_factory):
temp_dir = tmp_path_factory.getbasetemp()
db_path = str(temp_dir / f"test_sqlite_vec_{np.random.randint(1e6)}.db")
bank_id = f"sqlite_vec_bank_{np.random.randint(1e6)}"
index = SQLiteVecIndex(embedding_dimension, db_path, bank_id)
await index.initialize()
index.db_path = db_path
yield index
index.delete()
@pytest.fixture
async def sqlite_vec_adapter(sqlite_vec_db_path, mock_inference_api, embedding_dimension):
config = SQLiteVectorIOConfig(
db_path=sqlite_vec_db_path,
kvstore=SqliteKVStoreConfig(),
)
adapter = SQLiteVecVectorIOAdapter(
config=config,
inference_api=mock_inference_api,
files_api=None,
)
collection_id = f"sqlite_test_collection_{np.random.randint(1e6)}"
await adapter.initialize()
await adapter.register_vector_db(
VectorDB(
identifier=collection_id,
provider_id="test_provider",
embedding_model="test_model",
embedding_dimension=embedding_dimension,
)
)
adapter.test_collection_id = collection_id
yield adapter
await adapter.shutdown()
@pytest.fixture(scope="session")
def milvus_vec_db_path(tmp_path_factory):
db_path = str(tmp_path_factory.getbasetemp() / "test_milvus.db")
return db_path
@pytest.fixture
async def milvus_vec_index(milvus_vec_db_path, embedding_dimension):
client = MilvusClient(milvus_vec_db_path)
name = f"{COLLECTION_PREFIX}_{np.random.randint(1e6)}"
connections.connect(alias=MILVUS_ALIAS, uri=milvus_vec_db_path)
index = MilvusIndex(client, name, consistency_level="Strong")
index.db_path = milvus_vec_db_path
yield index
@pytest.fixture
async def milvus_vec_adapter(milvus_vec_db_path, mock_inference_api):
config = MilvusVectorIOConfig(
db_path=milvus_vec_db_path,
kvstore=SqliteKVStoreConfig(),
)
adapter = MilvusVectorIOAdapter(
config=config,
inference_api=mock_inference_api,
files_api=None,
)
await adapter.initialize()
await adapter.register_vector_db(
VectorDB(
identifier=adapter.metadata_collection_name,
provider_id="test_provider",
embedding_model="test_model",
embedding_dimension=128,
)
)
yield adapter
await adapter.shutdown()
@pytest.fixture
def vector_io_adapter(vector_provider, request):
"""Returns the appropriate vector IO adapter based on the provider parameter."""
if vector_provider == "milvus":
return request.getfixturevalue("milvus_vec_adapter")
else:
return request.getfixturevalue("sqlite_vec_adapter")
@pytest.fixture
def vector_index(vector_provider, request):
"""Returns appropriate vector index based on provider parameter"""
return request.getfixturevalue(f"{vector_provider}_vec_index")

View file

@ -34,7 +34,7 @@ def loop():
return asyncio.new_event_loop() return asyncio.new_event_loop()
@pytest_asyncio.fixture(scope="session", autouse=True) @pytest_asyncio.fixture
async def sqlite_vec_index(embedding_dimension, tmp_path_factory): async def sqlite_vec_index(embedding_dimension, tmp_path_factory):
temp_dir = tmp_path_factory.getbasetemp() temp_dir = tmp_path_factory.getbasetemp()
db_path = str(temp_dir / "test_sqlite.db") db_path = str(temp_dir / "test_sqlite.db")
@ -44,38 +44,15 @@ async def sqlite_vec_index(embedding_dimension, tmp_path_factory):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_add_chunks(sqlite_vec_index, sample_chunks, sample_embeddings): async def test_query_chunk_metadata(sqlite_vec_index, sample_chunks_with_metadata, sample_embeddings_with_metadata):
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings, batch_size=2) await sqlite_vec_index.add_chunks(sample_chunks_with_metadata, sample_embeddings_with_metadata)
connection = _create_sqlite_connection(sqlite_vec_index.db_path) response = await sqlite_vec_index.query_vector(sample_embeddings_with_metadata[-1], k=2, score_threshold=0.0)
cur = connection.cursor() assert response.chunks[0].chunk_metadata == sample_chunks_with_metadata[-1].chunk_metadata
cur.execute(f"SELECT COUNT(*) FROM {sqlite_vec_index.metadata_table}")
count = cur.fetchone()[0]
assert count == len(sample_chunks)
cur.close()
connection.close()
@pytest.mark.asyncio
async def test_query_chunks_vector(sqlite_vec_index, sample_chunks, sample_embeddings, embedding_dimension):
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
response = await sqlite_vec_index.query_vector(query_embedding, k=2, score_threshold=0.0)
assert isinstance(response, QueryChunksResponse)
assert len(response.chunks) == 2
@pytest.mark.xfail(reason="Chunk Metadata not yet supported for SQLite-vec", strict=True)
async def test_query_chunk_metadata(sqlite_vec_index, sample_chunks, sample_embeddings):
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
query_embedding = sample_embeddings[0]
response = await sqlite_vec_index.query_vector(query_embedding, k=2, score_threshold=0.0)
assert response.chunks[-1].chunk_metadata == sample_chunks[-1].chunk_metadata
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_query_chunks_full_text_search(sqlite_vec_index, sample_chunks, sample_embeddings): async def test_query_chunks_full_text_search(sqlite_vec_index, sample_chunks, sample_embeddings):
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings) await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
query_string = "Sentence 5" query_string = "Sentence 5"
response = await sqlite_vec_index.query_keyword(k=3, score_threshold=0.0, query_string=query_string) response = await sqlite_vec_index.query_keyword(k=3, score_threshold=0.0, query_string=query_string)
@ -148,7 +125,7 @@ async def test_chunk_id_conflict(sqlite_vec_index, sample_chunks, embedding_dime
assert len(chunk_ids) == len(set(chunk_ids)), "Duplicate chunk IDs detected across batches!" assert len(chunk_ids) == len(set(chunk_ids)), "Duplicate chunk IDs detected across batches!"
@pytest_asyncio.fixture(scope="session") @pytest.fixture(scope="session")
async def sqlite_vec_adapter(sqlite_connection): async def sqlite_vec_adapter(sqlite_connection):
config = type("Config", (object,), {"db_path": ":memory:"}) # Mock config with in-memory database config = type("Config", (object,), {"db_path": ":memory:"}) # Mock config with in-memory database
adapter = SQLiteVecVectorIOAdapter(config=config, inference_api=None) adapter = SQLiteVecVectorIOAdapter(config=config, inference_api=None)

View file

@ -4,253 +4,142 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import asyncio import json
import time import time
from unittest.mock import AsyncMock from unittest.mock import AsyncMock
import numpy as np import numpy as np
import pytest import pytest
import pytest_asyncio
from pymilvus import Collection, MilvusClient, connections
from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse from llama_stack.apis.vector_io import Chunk, QueryChunksResponse
from llama_stack.providers.inline.vector_io.milvus.config import MilvusVectorIOConfig, SqliteKVStoreConfig from llama_stack.providers.remote.vector_io.milvus.milvus import VECTOR_DBS_PREFIX
from llama_stack.providers.remote.vector_io.milvus.milvus import VECTOR_DBS_PREFIX, MilvusIndex, MilvusVectorIOAdapter
from llama_stack.providers.utils.kvstore import kvstore_impl
# TODO: Refactor these to be for inline vector-io providers # This test is a unit test for the inline VectoerIO providers. This should only contain
MILVUS_ALIAS = "test_milvus" # tests which are specific to this class. More general (API-level) tests should be placed in
COLLECTION_PREFIX = "test_collection" # tests/integration/vector_io/
#
# How to run this test:
@pytest.fixture(scope="session") #
def loop(): # pytest tests/unit/providers/vector_io/test_vector_io_openai_vector_stores.py \
return asyncio.new_event_loop() # -v -s --tb=short --disable-warnings --asyncio-mode=auto
@pytest.fixture(scope="session")
def mock_inference_api(embedding_dimension):
class MockInferenceAPI:
async def embed_batch(self, texts: list[str]) -> list[list[float]]:
return [np.random.rand(embedding_dimension).astype(np.float32).tolist() for _ in texts]
return MockInferenceAPI()
@pytest_asyncio.fixture
async def unique_kvstore_config(tmp_path_factory):
# Generate a unique filename for this test
unique_id = f"test_kv_{np.random.randint(1e6)}"
temp_dir = tmp_path_factory.getbasetemp()
db_path = str(temp_dir / f"{unique_id}.db")
return SqliteKVStoreConfig(db_path=db_path)
@pytest_asyncio.fixture(scope="session", autouse=True)
async def milvus_vec_index(embedding_dimension, tmp_path_factory):
temp_dir = tmp_path_factory.getbasetemp()
db_path = str(temp_dir / "test_milvus.db")
client = MilvusClient(db_path)
name = f"{COLLECTION_PREFIX}_{np.random.randint(1e6)}"
connections.connect(alias=MILVUS_ALIAS, uri=db_path)
index = MilvusIndex(client, name, consistency_level="Strong")
index.db_path = db_path
yield index
@pytest_asyncio.fixture(scope="session")
async def milvus_vec_adapter(milvus_vec_index, mock_inference_api):
config = MilvusVectorIOConfig(
db_path=milvus_vec_index.db_path,
kvstore=SqliteKVStoreConfig(),
)
adapter = MilvusVectorIOAdapter(
config=config,
inference_api=mock_inference_api,
files_api=None,
)
await adapter.initialize()
await adapter.register_vector_db(
VectorDB(
identifier=adapter.metadata_collection_name,
provider_id="test_provider",
embedding_model="test_model",
embedding_dimension=128,
)
)
yield adapter
await adapter.shutdown()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_cache_contains_initial_collection(milvus_vec_adapter): async def test_initialize_index(vector_index):
coll_name = milvus_vec_adapter.metadata_collection_name await vector_index.initialize()
assert coll_name in milvus_vec_adapter.cache
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_add_chunks(milvus_vec_index, sample_chunks, sample_embeddings): async def test_add_chunks_query_vector(vector_index, sample_chunks, sample_embeddings):
await milvus_vec_index.add_chunks(sample_chunks, sample_embeddings) vector_index.delete()
resp = await milvus_vec_index.query_vector(sample_embeddings[0], k=1, score_threshold=-1) vector_index.initialize()
await vector_index.add_chunks(sample_chunks, sample_embeddings)
resp = await vector_index.query_vector(sample_embeddings[0], k=1, score_threshold=-1)
assert resp.chunks[0].content == sample_chunks[0].content assert resp.chunks[0].content == sample_chunks[0].content
vector_index.delete()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_query_chunks_vector(milvus_vec_index, sample_chunks, sample_embeddings, embedding_dimension): async def test_chunk_id_conflict(vector_index, sample_chunks, embedding_dimension):
await milvus_vec_index.add_chunks(sample_chunks, sample_embeddings)
query_emb = np.random.rand(embedding_dimension).astype(np.float32)
resp = await milvus_vec_index.query_vector(query_emb, k=2, score_threshold=0.0)
assert isinstance(resp, QueryChunksResponse)
assert len(resp.chunks) == 2
@pytest.mark.asyncio
async def test_chunk_id_conflict(milvus_vec_index, sample_chunks, embedding_dimension):
embeddings = np.random.rand(len(sample_chunks), embedding_dimension).astype(np.float32) embeddings = np.random.rand(len(sample_chunks), embedding_dimension).astype(np.float32)
await milvus_vec_index.add_chunks(sample_chunks, embeddings) await vector_index.add_chunks(sample_chunks, embeddings)
coll = Collection(milvus_vec_index.collection_name, using=MILVUS_ALIAS) resp = await vector_index.query_vector(
ids = coll.query(expr="id >= 0", output_fields=["id"], timeout=30) np.random.rand(embedding_dimension).astype(np.float32),
flat_ids = [i["id"] for i in ids] k=len(sample_chunks),
assert len(flat_ids) == len(set(flat_ids)) score_threshold=-1,
)
contents = [chunk.content for chunk in resp.chunks]
assert len(contents) == len(set(contents))
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_initialize_with_milvus_client(milvus_vec_index, unique_kvstore_config): async def test_initialize_adapter_with_existing_kvstore(vector_io_adapter):
kvstore = await kvstore_impl(unique_kvstore_config) key = f"{VECTOR_DBS_PREFIX}db1"
vector_db = VectorDB(
identifier="test_db",
provider_id="test_provider",
embedding_model="test_model",
embedding_dimension=128,
metadata={"test_key": "test_value"},
)
test_vector_db_data = vector_db.model_dump_json()
await kvstore.set(f"{VECTOR_DBS_PREFIX}test_db", test_vector_db_data)
tmp_milvus_vec_adapter = MilvusVectorIOAdapter(
config=MilvusVectorIOConfig(
db_path=milvus_vec_index.db_path,
kvstore=unique_kvstore_config,
),
inference_api=None,
files_api=None,
)
await tmp_milvus_vec_adapter.initialize()
vector_db = VectorDB(
identifier="test_db",
provider_id="test_provider",
embedding_model="test_model",
embedding_dimension=128,
)
test_vector_db_data = vector_db.model_dump_json()
await tmp_milvus_vec_adapter.kvstore.set(f"{VECTOR_DBS_PREFIX}/test_db", test_vector_db_data)
assert milvus_vec_index.client is not None
assert isinstance(milvus_vec_index.client, MilvusClient)
assert tmp_milvus_vec_adapter.cache is not None
# registering a vector won't update the cache or openai_vector_store collection name
assert (
tmp_milvus_vec_adapter.metadata_collection_name not in tmp_milvus_vec_adapter.cache
or tmp_milvus_vec_adapter.openai_vector_stores
)
@pytest.mark.asyncio
async def test_persistence_across_adapter_restarts(
tmp_path, milvus_vec_index, mock_inference_api, unique_kvstore_config
):
adapter1 = MilvusVectorIOAdapter(
config=MilvusVectorIOConfig(db_path=milvus_vec_index.db_path, kvstore=unique_kvstore_config),
inference_api=mock_inference_api,
files_api=None,
)
await adapter1.initialize()
dummy = VectorDB( dummy = VectorDB(
identifier="foo_db", provider_id="test_provider", embedding_model="test_model", embedding_dimension=128 identifier="foo_db", provider_id="test_provider", embedding_model="test_model", embedding_dimension=128
) )
await adapter1.register_vector_db(dummy) await vector_io_adapter.kvstore.set(key=key, value=json.dumps(dummy.model_dump()))
await adapter1.shutdown()
await adapter1.initialize() await vector_io_adapter.initialize()
assert "foo_db" in adapter1.cache
await adapter1.shutdown()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_register_and_unregister_vector_db(milvus_vec_adapter): async def test_persistence_across_adapter_restarts(vector_io_adapter):
try: await vector_io_adapter.initialize()
connections.disconnect(MILVUS_ALIAS) dummy = VectorDB(
except Exception as _: identifier="foo_db", provider_id="test_provider", embedding_model="test_model", embedding_dimension=128
pass )
await vector_io_adapter.register_vector_db(dummy)
await vector_io_adapter.shutdown()
connections.connect(alias=MILVUS_ALIAS, uri=milvus_vec_adapter.config.db_path) await vector_io_adapter.initialize()
assert "foo_db" in vector_io_adapter.cache
await vector_io_adapter.shutdown()
@pytest.mark.asyncio
async def test_register_and_unregister_vector_db(vector_io_adapter):
unique_id = f"foo_db_{np.random.randint(1e6)}" unique_id = f"foo_db_{np.random.randint(1e6)}"
dummy = VectorDB( dummy = VectorDB(
identifier=unique_id, provider_id="test_provider", embedding_model="test_model", embedding_dimension=128 identifier=unique_id, provider_id="test_provider", embedding_model="test_model", embedding_dimension=128
) )
await milvus_vec_adapter.register_vector_db(dummy) await vector_io_adapter.register_vector_db(dummy)
assert dummy.identifier in milvus_vec_adapter.cache assert dummy.identifier in vector_io_adapter.cache
await vector_io_adapter.unregister_vector_db(dummy.identifier)
if dummy.identifier in milvus_vec_adapter.cache: assert dummy.identifier not in vector_io_adapter.cache
index = milvus_vec_adapter.cache[dummy.identifier].index
if hasattr(index, "client") and hasattr(index.client, "_using"):
index.client._using = MILVUS_ALIAS
await milvus_vec_adapter.unregister_vector_db(dummy.identifier)
assert dummy.identifier not in milvus_vec_adapter.cache
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_query_unregistered_raises(milvus_vec_adapter): async def test_query_unregistered_raises(vector_io_adapter):
fake_emb = np.zeros(8, dtype=np.float32) fake_emb = np.zeros(8, dtype=np.float32)
with pytest.raises(AttributeError): with pytest.raises(ValueError):
await milvus_vec_adapter.query_chunks("no_such_db", fake_emb) await vector_io_adapter.query_chunks("no_such_db", fake_emb)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_insert_chunks_calls_underlying_index(milvus_vec_adapter): async def test_insert_chunks_calls_underlying_index(vector_io_adapter):
fake_index = AsyncMock() fake_index = AsyncMock()
milvus_vec_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=fake_index) vector_io_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=fake_index)
chunks = ["chunk1", "chunk2"] chunks = ["chunk1", "chunk2"]
await milvus_vec_adapter.insert_chunks("db1", chunks) await vector_io_adapter.insert_chunks("db1", chunks)
fake_index.insert_chunks.assert_awaited_once_with(chunks) fake_index.insert_chunks.assert_awaited_once_with(chunks)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_insert_chunks_missing_db_raises(milvus_vec_adapter): async def test_insert_chunks_missing_db_raises(vector_io_adapter):
milvus_vec_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=None) vector_io_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=None)
with pytest.raises(ValueError): with pytest.raises(ValueError):
await milvus_vec_adapter.insert_chunks("db_not_exist", []) await vector_io_adapter.insert_chunks("db_not_exist", [])
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_query_chunks_calls_underlying_index_and_returns(milvus_vec_adapter): async def test_query_chunks_calls_underlying_index_and_returns(vector_io_adapter):
expected = QueryChunksResponse(chunks=[Chunk(content="c1")], scores=[0.1]) expected = QueryChunksResponse(chunks=[Chunk(content="c1")], scores=[0.1])
fake_index = AsyncMock(query_chunks=AsyncMock(return_value=expected)) fake_index = AsyncMock(query_chunks=AsyncMock(return_value=expected))
milvus_vec_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=fake_index) vector_io_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=fake_index)
response = await milvus_vec_adapter.query_chunks("db1", "my_query", {"param": 1}) response = await vector_io_adapter.query_chunks("db1", "my_query", {"param": 1})
fake_index.query_chunks.assert_awaited_once_with("my_query", {"param": 1}) fake_index.query_chunks.assert_awaited_once_with("my_query", {"param": 1})
assert response is expected assert response is expected
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_query_chunks_missing_db_raises(milvus_vec_adapter): async def test_query_chunks_missing_db_raises(vector_io_adapter):
milvus_vec_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=None) vector_io_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=None)
with pytest.raises(ValueError): with pytest.raises(ValueError):
await milvus_vec_adapter.query_chunks("db_missing", "q", None) await vector_io_adapter.query_chunks("db_missing", "q", None)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_save_openai_vector_store(milvus_vec_adapter): async def test_save_openai_vector_store(vector_io_adapter):
store_id = "vs_1234" store_id = "vs_1234"
openai_vector_store = { openai_vector_store = {
"id": store_id, "id": store_id,
@ -260,14 +149,14 @@ async def test_save_openai_vector_store(milvus_vec_adapter):
"embedding_model": "test_model", "embedding_model": "test_model",
} }
await milvus_vec_adapter._save_openai_vector_store(store_id, openai_vector_store) await vector_io_adapter._save_openai_vector_store(store_id, openai_vector_store)
assert openai_vector_store["id"] in milvus_vec_adapter.openai_vector_stores assert openai_vector_store["id"] in vector_io_adapter.openai_vector_stores
assert milvus_vec_adapter.openai_vector_stores[openai_vector_store["id"]] == openai_vector_store assert vector_io_adapter.openai_vector_stores[openai_vector_store["id"]] == openai_vector_store
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_update_openai_vector_store(milvus_vec_adapter): async def test_update_openai_vector_store(vector_io_adapter):
store_id = "vs_1234" store_id = "vs_1234"
openai_vector_store = { openai_vector_store = {
"id": store_id, "id": store_id,
@ -277,14 +166,14 @@ async def test_update_openai_vector_store(milvus_vec_adapter):
"embedding_model": "test_model", "embedding_model": "test_model",
} }
await milvus_vec_adapter._save_openai_vector_store(store_id, openai_vector_store) await vector_io_adapter._save_openai_vector_store(store_id, openai_vector_store)
openai_vector_store["description"] = "Updated description" openai_vector_store["description"] = "Updated description"
await milvus_vec_adapter._update_openai_vector_store(store_id, openai_vector_store) await vector_io_adapter._update_openai_vector_store(store_id, openai_vector_store)
assert milvus_vec_adapter.openai_vector_stores[openai_vector_store["id"]] == openai_vector_store assert vector_io_adapter.openai_vector_stores[openai_vector_store["id"]] == openai_vector_store
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_delete_openai_vector_store(milvus_vec_adapter): async def test_delete_openai_vector_store(vector_io_adapter):
store_id = "vs_1234" store_id = "vs_1234"
openai_vector_store = { openai_vector_store = {
"id": store_id, "id": store_id,
@ -294,13 +183,13 @@ async def test_delete_openai_vector_store(milvus_vec_adapter):
"embedding_model": "test_model", "embedding_model": "test_model",
} }
await milvus_vec_adapter._save_openai_vector_store(store_id, openai_vector_store) await vector_io_adapter._save_openai_vector_store(store_id, openai_vector_store)
await milvus_vec_adapter._delete_openai_vector_store_from_storage(store_id) await vector_io_adapter._delete_openai_vector_store_from_storage(store_id)
assert openai_vector_store["id"] not in milvus_vec_adapter.openai_vector_stores assert openai_vector_store["id"] not in vector_io_adapter.openai_vector_stores
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_load_openai_vector_stores(milvus_vec_adapter): async def test_load_openai_vector_stores(vector_io_adapter):
store_id = "vs_1234" store_id = "vs_1234"
openai_vector_store = { openai_vector_store = {
"id": store_id, "id": store_id,
@ -310,13 +199,13 @@ async def test_load_openai_vector_stores(milvus_vec_adapter):
"embedding_model": "test_model", "embedding_model": "test_model",
} }
await milvus_vec_adapter._save_openai_vector_store(store_id, openai_vector_store) await vector_io_adapter._save_openai_vector_store(store_id, openai_vector_store)
loaded_stores = await milvus_vec_adapter._load_openai_vector_stores() loaded_stores = await vector_io_adapter._load_openai_vector_stores()
assert loaded_stores[store_id] == openai_vector_store assert loaded_stores[store_id] == openai_vector_store
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_save_openai_vector_store_file(milvus_vec_adapter, tmp_path_factory): async def test_save_openai_vector_store_file(vector_io_adapter, tmp_path_factory):
store_id = "vs_1234" store_id = "vs_1234"
file_id = "file_1234" file_id = "file_1234"
@ -334,11 +223,11 @@ async def test_save_openai_vector_store_file(milvus_vec_adapter, tmp_path_factor
] ]
# validating we don't raise an exception # validating we don't raise an exception
await milvus_vec_adapter._save_openai_vector_store_file(store_id, file_id, file_info, file_contents) await vector_io_adapter._save_openai_vector_store_file(store_id, file_id, file_info, file_contents)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_update_openai_vector_store_file(milvus_vec_adapter, tmp_path_factory): async def test_update_openai_vector_store_file(vector_io_adapter, tmp_path_factory):
store_id = "vs_1234" store_id = "vs_1234"
file_id = "file_1234" file_id = "file_1234"
@ -355,24 +244,24 @@ async def test_update_openai_vector_store_file(milvus_vec_adapter, tmp_path_fact
{"content": "Test content", "chunk_metadata": {"chunk_id": "chunk_001"}, "metadata": {"file_id": file_id}} {"content": "Test content", "chunk_metadata": {"chunk_id": "chunk_001"}, "metadata": {"file_id": file_id}}
] ]
await milvus_vec_adapter._save_openai_vector_store_file(store_id, file_id, file_info, file_contents) await vector_io_adapter._save_openai_vector_store_file(store_id, file_id, file_info, file_contents)
updated_file_info = file_info.copy() updated_file_info = file_info.copy()
updated_file_info["filename"] = "updated_test_file.txt" updated_file_info["filename"] = "updated_test_file.txt"
await milvus_vec_adapter._update_openai_vector_store_file( await vector_io_adapter._update_openai_vector_store_file(
store_id, store_id,
file_id, file_id,
updated_file_info, updated_file_info,
) )
loaded_contents = await milvus_vec_adapter._load_openai_vector_store_file(store_id, file_id) loaded_contents = await vector_io_adapter._load_openai_vector_store_file(store_id, file_id)
assert loaded_contents == updated_file_info assert loaded_contents == updated_file_info
assert loaded_contents != file_info assert loaded_contents != file_info
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_load_openai_vector_store_file_contents(milvus_vec_adapter, tmp_path_factory): async def test_load_openai_vector_store_file_contents(vector_io_adapter, tmp_path_factory):
store_id = "vs_1234" store_id = "vs_1234"
file_id = "file_1234" file_id = "file_1234"
@ -389,14 +278,14 @@ async def test_load_openai_vector_store_file_contents(milvus_vec_adapter, tmp_pa
{"content": "Test content", "chunk_metadata": {"chunk_id": "chunk_001"}, "metadata": {"file_id": file_id}} {"content": "Test content", "chunk_metadata": {"chunk_id": "chunk_001"}, "metadata": {"file_id": file_id}}
] ]
await milvus_vec_adapter._save_openai_vector_store_file(store_id, file_id, file_info, file_contents) await vector_io_adapter._save_openai_vector_store_file(store_id, file_id, file_info, file_contents)
loaded_contents = await milvus_vec_adapter._load_openai_vector_store_file_contents(store_id, file_id) loaded_contents = await vector_io_adapter._load_openai_vector_store_file_contents(store_id, file_id)
assert loaded_contents == file_contents assert loaded_contents == file_contents
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_delete_openai_vector_store_file_from_storage(milvus_vec_adapter, tmp_path_factory): async def test_delete_openai_vector_store_file_from_storage(vector_io_adapter, tmp_path_factory):
store_id = "vs_1234" store_id = "vs_1234"
file_id = "file_1234" file_id = "file_1234"
@ -413,8 +302,8 @@ async def test_delete_openai_vector_store_file_from_storage(milvus_vec_adapter,
{"content": "Test content", "chunk_metadata": {"chunk_id": "chunk_001"}, "metadata": {"file_id": file_id}} {"content": "Test content", "chunk_metadata": {"chunk_id": "chunk_001"}, "metadata": {"file_id": file_id}}
] ]
await milvus_vec_adapter._save_openai_vector_store_file(store_id, file_id, file_info, file_contents) await vector_io_adapter._save_openai_vector_store_file(store_id, file_id, file_info, file_contents)
await milvus_vec_adapter._delete_openai_vector_store_file_from_storage(store_id, file_id) await vector_io_adapter._delete_openai_vector_store_file_from_storage(store_id, file_id)
loaded_contents = await milvus_vec_adapter._load_openai_vector_store_file_contents(store_id, file_id) loaded_contents = await vector_io_adapter._load_openai_vector_store_file_contents(store_id, file_id)
assert loaded_contents == [] assert loaded_contents == []