From 6a6b66ae4f965de4cd3cd71a4320e868fa777b95 Mon Sep 17 00:00:00 2001 From: Francisco Arceo Date: Thu, 10 Jul 2025 14:22:13 -0400 Subject: [PATCH] chore: Adding unit tests for OpenAI vector stores and migrating SQLite-vec registry to kvstore (#2665) # What does this PR do? This PR refactors and the VectorIO backend logic for `sqlite-vec` and adds unit tests and fixtures to make it easy to test both `sqlite-vec` and `milvus`. Key changes: - `sqlite-vec` migrated to `kvstore` registry - added in-memory cache for sqlite-vec to be consistent with `milvus` - default fixtures moved to `conftest.py` - removed redundant tests from sqlite`-vec` - made `test_vector_io_openai_vector_stores.py` more easily extensible ## Test Plan Unit tests added testing inline providers. --------- Signed-off-by: Francisco Javier Arceo --- .../providers/vector_io/inline_milvus.md | 2 +- .../providers/vector_io/inline_sqlite-vec.md | 6 +- .../providers/vector_io/inline_sqlite_vec.md | 6 +- .../inline/vector_io/milvus/config.py | 2 +- .../inline/vector_io/sqlite_vec/config.py | 14 +- .../inline/vector_io/sqlite_vec/sqlite_vec.py | 313 +++++++----------- .../remote/vector_io/milvus/milvus.py | 8 + llama_stack/templates/open-benchmark/run.yaml | 3 + llama_stack/templates/starter/run.yaml | 3 + tests/unit/providers/vector_io/conftest.py | 157 +++++++++ .../providers/vector_io/test_sqlite_vec.py | 35 +- .../test_vector_io_openai_vector_stores.py | 297 ++++++----------- 12 files changed, 422 insertions(+), 424 deletions(-) diff --git a/docs/source/providers/vector_io/inline_milvus.md b/docs/source/providers/vector_io/inline_milvus.md index be7340c9d..3b3aad3fc 100644 --- a/docs/source/providers/vector_io/inline_milvus.md +++ b/docs/source/providers/vector_io/inline_milvus.md @@ -11,7 +11,7 @@ Please refer to the remote provider documentation. | Field | Type | Required | Default | Description | |-------|------|----------|---------|-------------| | `db_path` | `` | No | PydanticUndefined | | -| `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig` | No | sqlite | | +| `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig` | No | sqlite | Config for KV store backend (SQLite only for now) | | `consistency_level` | `` | No | Strong | The consistency level of the Milvus server | ## Sample Configuration diff --git a/docs/source/providers/vector_io/inline_sqlite-vec.md b/docs/source/providers/vector_io/inline_sqlite-vec.md index fd3ec1dc4..ae7c45b21 100644 --- a/docs/source/providers/vector_io/inline_sqlite-vec.md +++ b/docs/source/providers/vector_io/inline_sqlite-vec.md @@ -205,12 +205,16 @@ See [sqlite-vec's GitHub repo](https://github.com/asg017/sqlite-vec/tree/main) f | Field | Type | Required | Default | Description | |-------|------|----------|---------|-------------| -| `db_path` | `` | No | PydanticUndefined | | +| `db_path` | `` | No | PydanticUndefined | Path to the SQLite database file | +| `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig` | No | sqlite | Config for KV store backend (SQLite only for now) | ## Sample Configuration ```yaml db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/sqlite_vec.db +kvstore: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/sqlite_vec_registry.db ``` diff --git a/docs/source/providers/vector_io/inline_sqlite_vec.md b/docs/source/providers/vector_io/inline_sqlite_vec.md index e4b69c9ab..7e14bb8bd 100644 --- a/docs/source/providers/vector_io/inline_sqlite_vec.md +++ b/docs/source/providers/vector_io/inline_sqlite_vec.md @@ -10,12 +10,16 @@ Please refer to the sqlite-vec provider documentation. | Field | Type | Required | Default | Description | |-------|------|----------|---------|-------------| -| `db_path` | `` | No | PydanticUndefined | | +| `db_path` | `` | No | PydanticUndefined | Path to the SQLite database file | +| `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig` | No | sqlite | Config for KV store backend (SQLite only for now) | ## Sample Configuration ```yaml db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/sqlite_vec.db +kvstore: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/sqlite_vec_registry.db ``` diff --git a/llama_stack/providers/inline/vector_io/milvus/config.py b/llama_stack/providers/inline/vector_io/milvus/config.py index a05ca1670..8cbd056be 100644 --- a/llama_stack/providers/inline/vector_io/milvus/config.py +++ b/llama_stack/providers/inline/vector_io/milvus/config.py @@ -18,7 +18,7 @@ from llama_stack.schema_utils import json_schema_type @json_schema_type class MilvusVectorIOConfig(BaseModel): db_path: str - kvstore: KVStoreConfig + kvstore: KVStoreConfig = Field(description="Config for KV store backend (SQLite only for now)") consistency_level: str = Field(description="The consistency level of the Milvus server", default="Strong") @classmethod diff --git a/llama_stack/providers/inline/vector_io/sqlite_vec/config.py b/llama_stack/providers/inline/vector_io/sqlite_vec/config.py index 4c57f4aba..525ed4b1f 100644 --- a/llama_stack/providers/inline/vector_io/sqlite_vec/config.py +++ b/llama_stack/providers/inline/vector_io/sqlite_vec/config.py @@ -6,14 +6,24 @@ from typing import Any -from pydantic import BaseModel +from pydantic import BaseModel, Field + +from llama_stack.providers.utils.kvstore.config import ( + KVStoreConfig, + SqliteKVStoreConfig, +) class SQLiteVectorIOConfig(BaseModel): - db_path: str + db_path: str = Field(description="Path to the SQLite database file") + kvstore: KVStoreConfig = Field(description="Config for KV store backend (SQLite only for now)") @classmethod def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]: return { "db_path": "${env.SQLITE_STORE_DIR:=" + __distro_dir__ + "}/" + "sqlite_vec.db", + "kvstore": SqliteKVStoreConfig.sample_run_config( + __distro_dir__=__distro_dir__, + db_name="sqlite_vec_registry.db", + ), } diff --git a/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py b/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py index 7e977635a..6acd85c56 100644 --- a/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +++ b/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py @@ -24,6 +24,8 @@ from llama_stack.apis.vector_io import ( VectorIO, ) from llama_stack.providers.datatypes import VectorDBsProtocolPrivate +from llama_stack.providers.utils.kvstore import kvstore_impl +from llama_stack.providers.utils.kvstore.api import KVStore from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin from llama_stack.providers.utils.memory.vector_store import ( RERANKER_TYPE_RRF, @@ -40,6 +42,13 @@ KEYWORD_SEARCH = "keyword" HYBRID_SEARCH = "hybrid" SEARCH_MODES = {VECTOR_SEARCH, KEYWORD_SEARCH, HYBRID_SEARCH} +VERSION = "v3" +VECTOR_DBS_PREFIX = f"vector_dbs:sqlite_vec:{VERSION}::" +VECTOR_INDEX_PREFIX = f"vector_index:sqlite_vec:{VERSION}::" +OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:sqlite_vec:{VERSION}::" +OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:sqlite_vec:{VERSION}::" +OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_contents:sqlite_vec:{VERSION}::" + def serialize_vector(vector: list[float]) -> bytes: """Serialize a list of floats into a compact binary representation.""" @@ -117,13 +126,14 @@ class SQLiteVecIndex(EmbeddingIndex): - An FTS5 table (fts_chunks_{bank_id}) for full-text keyword search. """ - def __init__(self, dimension: int, db_path: str, bank_id: str): + def __init__(self, dimension: int, db_path: str, bank_id: str, kvstore: KVStore | None = None): self.dimension = dimension self.db_path = db_path self.bank_id = bank_id self.metadata_table = f"chunks_{bank_id}".replace("-", "_") self.vector_table = f"vec_chunks_{bank_id}".replace("-", "_") self.fts_table = f"fts_chunks_{bank_id}".replace("-", "_") + self.kvstore = kvstore @classmethod async def create(cls, dimension: int, db_path: str, bank_id: str): @@ -425,27 +435,116 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc self.files_api = files_api self.cache: dict[str, VectorDBWithIndex] = {} self.openai_vector_stores: dict[str, dict[str, Any]] = {} + self.kvstore: KVStore | None = None async def initialize(self) -> None: - def _setup_connection(): - # Open a connection to the SQLite database (the file is specified in the config). + self.kvstore = await kvstore_impl(self.config.kvstore) + + start_key = VECTOR_DBS_PREFIX + end_key = f"{VECTOR_DBS_PREFIX}\xff" + stored_vector_dbs = await self.kvstore.values_in_range(start_key, end_key) + for db_json in stored_vector_dbs: + vector_db = VectorDB.model_validate_json(db_json) + index = await SQLiteVecIndex.create( + vector_db.embedding_dimension, + self.config.db_path, + vector_db.identifier, + ) + self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api) + + # load any existing OpenAI vector stores + self.openai_vector_stores = await self._load_openai_vector_stores() + + async def shutdown(self) -> None: + # nothing to do since we don't maintain a persistent connection + pass + + async def list_vector_dbs(self) -> list[VectorDB]: + return [v.vector_db for v in self.cache.values()] + + async def register_vector_db(self, vector_db: VectorDB) -> None: + index = await SQLiteVecIndex.create( + vector_db.embedding_dimension, + self.config.db_path, + vector_db.identifier, + ) + self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api) + + async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> VectorDBWithIndex | None: + if vector_db_id in self.cache: + return self.cache[vector_db_id] + + if self.vector_db_store is None: + raise ValueError(f"Vector DB {vector_db_id} not found") + + vector_db = self.vector_db_store.get_vector_db(vector_db_id) + if not vector_db: + raise ValueError(f"Vector DB {vector_db_id} not found") + + index = VectorDBWithIndex( + vector_db=vector_db, + index=SQLiteVecIndex( + dimension=vector_db.embedding_dimension, + db_path=self.config.db_path, + bank_id=vector_db.identifier, + kvstore=self.kvstore, + ), + inference_api=self.inference_api, + ) + self.cache[vector_db_id] = index + return index + + async def unregister_vector_db(self, vector_db_id: str) -> None: + if vector_db_id not in self.cache: + logger.warning(f"Vector DB {vector_db_id} not found") + return + await self.cache[vector_db_id].index.delete() + del self.cache[vector_db_id] + + # OpenAI Vector Store Mixin abstract method implementations + async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None: + """Save vector store metadata to SQLite database.""" + assert self.kvstore is not None + key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}" + await self.kvstore.set(key=key, value=json.dumps(store_info)) + self.openai_vector_stores[store_id] = store_info + + async def _load_openai_vector_stores(self) -> dict[str, dict[str, Any]]: + """Load all vector store metadata from SQLite database.""" + assert self.kvstore is not None + start_key = OPENAI_VECTOR_STORES_PREFIX + end_key = f"{OPENAI_VECTOR_STORES_PREFIX}\xff" + stored_openai_stores = await self.kvstore.values_in_range(start_key, end_key) + stores = {} + for store_data in stored_openai_stores: + store_info = json.loads(store_data) + stores[store_info["id"]] = store_info + return stores + + async def _update_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None: + """Update vector store metadata in SQLite database.""" + assert self.kvstore is not None + key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}" + await self.kvstore.set(key=key, value=json.dumps(store_info)) + self.openai_vector_stores[store_id] = store_info + + async def _delete_openai_vector_store_from_storage(self, store_id: str) -> None: + """Delete vector store metadata from SQLite database.""" + assert self.kvstore is not None + key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}" + await self.kvstore.delete(key) + if store_id in self.openai_vector_stores: + del self.openai_vector_stores[store_id] + + async def _save_openai_vector_store_file( + self, store_id: str, file_id: str, file_info: dict[str, Any], file_contents: list[dict[str, Any]] + ) -> None: + """Save vector store file metadata to SQLite database.""" + + def _create_or_store(): connection = _create_sqlite_connection(self.config.db_path) cur = connection.cursor() try: - # Create a table to persist vector DB registrations. - cur.execute(""" - CREATE TABLE IF NOT EXISTS vector_dbs ( - id TEXT PRIMARY KEY, - metadata TEXT - ); - """) - # Create a table to persist OpenAI vector stores. - cur.execute(""" - CREATE TABLE IF NOT EXISTS openai_vector_stores ( - id TEXT PRIMARY KEY, - metadata TEXT - ); - """) # Create a table to persist OpenAI vector store files. cur.execute(""" CREATE TABLE IF NOT EXISTS openai_vector_store_files ( @@ -464,168 +563,6 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc ); """) connection.commit() - # Load any existing vector DB registrations. - cur.execute("SELECT metadata FROM vector_dbs") - vector_db_rows = cur.fetchall() - return vector_db_rows - finally: - cur.close() - connection.close() - - vector_db_rows = await asyncio.to_thread(_setup_connection) - - # Load existing vector DBs - for row in vector_db_rows: - vector_db_data = row[0] - vector_db = VectorDB.model_validate_json(vector_db_data) - index = await SQLiteVecIndex.create( - vector_db.embedding_dimension, - self.config.db_path, - vector_db.identifier, - ) - self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api) - - # Load existing OpenAI vector stores using the mixin method - self.openai_vector_stores = await self._load_openai_vector_stores() - - async def shutdown(self) -> None: - # nothing to do since we don't maintain a persistent connection - pass - - async def register_vector_db(self, vector_db: VectorDB) -> None: - def _register_db(): - connection = _create_sqlite_connection(self.config.db_path) - cur = connection.cursor() - try: - cur.execute( - "INSERT OR REPLACE INTO vector_dbs (id, metadata) VALUES (?, ?)", - (vector_db.identifier, vector_db.model_dump_json()), - ) - connection.commit() - finally: - cur.close() - connection.close() - - await asyncio.to_thread(_register_db) - index = await SQLiteVecIndex.create( - vector_db.embedding_dimension, - self.config.db_path, - vector_db.identifier, - ) - self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api) - - async def list_vector_dbs(self) -> list[VectorDB]: - return [v.vector_db for v in self.cache.values()] - - async def unregister_vector_db(self, vector_db_id: str) -> None: - if vector_db_id not in self.cache: - logger.warning(f"Vector DB {vector_db_id} not found") - return - await self.cache[vector_db_id].index.delete() - del self.cache[vector_db_id] - - def _delete_vector_db_from_registry(): - connection = _create_sqlite_connection(self.config.db_path) - cur = connection.cursor() - try: - cur.execute("DELETE FROM vector_dbs WHERE id = ?", (vector_db_id,)) - connection.commit() - finally: - cur.close() - connection.close() - - await asyncio.to_thread(_delete_vector_db_from_registry) - - # OpenAI Vector Store Mixin abstract method implementations - async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None: - """Save vector store metadata to SQLite database.""" - - def _store(): - connection = _create_sqlite_connection(self.config.db_path) - cur = connection.cursor() - try: - cur.execute( - "INSERT OR REPLACE INTO openai_vector_stores (id, metadata) VALUES (?, ?)", - (store_id, json.dumps(store_info)), - ) - connection.commit() - except Exception as e: - logger.error(f"Error saving openai vector store {store_id}: {e}") - raise - finally: - cur.close() - connection.close() - - try: - await asyncio.to_thread(_store) - except Exception as e: - logger.error(f"Error saving openai vector store {store_id}: {e}") - raise - - async def _load_openai_vector_stores(self) -> dict[str, dict[str, Any]]: - """Load all vector store metadata from SQLite database.""" - - def _load(): - connection = _create_sqlite_connection(self.config.db_path) - cur = connection.cursor() - try: - cur.execute("SELECT metadata FROM openai_vector_stores") - rows = cur.fetchall() - return rows - finally: - cur.close() - connection.close() - - rows = await asyncio.to_thread(_load) - stores = {} - for row in rows: - store_data = row[0] - store_info = json.loads(store_data) - stores[store_info["id"]] = store_info - return stores - - async def _update_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None: - """Update vector store metadata in SQLite database.""" - - def _update(): - connection = _create_sqlite_connection(self.config.db_path) - cur = connection.cursor() - try: - cur.execute( - "UPDATE openai_vector_stores SET metadata = ? WHERE id = ?", - (json.dumps(store_info), store_id), - ) - connection.commit() - finally: - cur.close() - connection.close() - - await asyncio.to_thread(_update) - - async def _delete_openai_vector_store_from_storage(self, store_id: str) -> None: - """Delete vector store metadata from SQLite database.""" - - def _delete(): - connection = _create_sqlite_connection(self.config.db_path) - cur = connection.cursor() - try: - cur.execute("DELETE FROM openai_vector_stores WHERE id = ?", (store_id,)) - connection.commit() - finally: - cur.close() - connection.close() - - await asyncio.to_thread(_delete) - - async def _save_openai_vector_store_file( - self, store_id: str, file_id: str, file_info: dict[str, Any], file_contents: list[dict[str, Any]] - ) -> None: - """Save vector store file metadata to SQLite database.""" - - def _store(): - connection = _create_sqlite_connection(self.config.db_path) - cur = connection.cursor() - try: cur.execute( "INSERT OR REPLACE INTO openai_vector_store_files (store_id, file_id, metadata) VALUES (?, ?, ?)", (store_id, file_id, json.dumps(file_info)), @@ -643,7 +580,7 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc connection.close() try: - await asyncio.to_thread(_store) + await asyncio.to_thread(_create_or_store) except Exception as e: logger.error(f"Error saving openai vector store file {store_id} {file_id}: {e}") raise @@ -722,6 +659,10 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc cur.execute( "DELETE FROM openai_vector_store_files WHERE store_id = ? AND file_id = ?", (store_id, file_id) ) + cur.execute( + "DELETE FROM openai_vector_store_files_contents WHERE store_id = ? AND file_id = ?", + (store_id, file_id), + ) connection.commit() finally: cur.close() @@ -730,15 +671,17 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc await asyncio.to_thread(_delete) async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None: - if vector_db_id not in self.cache: - raise ValueError(f"Vector DB {vector_db_id} not found. Found: {list(self.cache.keys())}") + index = await self._get_and_cache_vector_db_index(vector_db_id) + if not index: + raise ValueError(f"Vector DB {vector_db_id} not found") # The VectorDBWithIndex helper is expected to compute embeddings via the inference_api # and then call our index's add_chunks. - await self.cache[vector_db_id].insert_chunks(chunks) + await index.insert_chunks(chunks) async def query_chunks( self, vector_db_id: str, query: Any, params: dict[str, Any] | None = None ) -> QueryChunksResponse: - if vector_db_id not in self.cache: + index = await self._get_and_cache_vector_db_index(vector_db_id) + if not index: raise ValueError(f"Vector DB {vector_db_id} not found") - return await self.cache[vector_db_id].query_chunks(query, params) + return await index.query_chunks(query, params) diff --git a/llama_stack/providers/remote/vector_io/milvus/milvus.py b/llama_stack/providers/remote/vector_io/milvus/milvus.py index 1f65e580e..a06130fd0 100644 --- a/llama_stack/providers/remote/vector_io/milvus/milvus.py +++ b/llama_stack/providers/remote/vector_io/milvus/milvus.py @@ -61,6 +61,11 @@ class MilvusIndex(EmbeddingIndex): self.consistency_level = consistency_level self.kvstore = kvstore + async def initialize(self): + # MilvusIndex does not require explicit initialization + # TODO: could move collection creation into initialization but it is not really necessary + pass + async def delete(self): if await asyncio.to_thread(self.client.has_collection, self.collection_name): await asyncio.to_thread(self.client.drop_collection, collection_name=self.collection_name) @@ -199,6 +204,9 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP if vector_db_id in self.cache: return self.cache[vector_db_id] + if self.vector_db_store is None: + raise ValueError(f"Vector DB {vector_db_id} not found") + vector_db = await self.vector_db_store.get_vector_db(vector_db_id) if not vector_db: raise ValueError(f"Vector DB {vector_db_id} not found") diff --git a/llama_stack/templates/open-benchmark/run.yaml b/llama_stack/templates/open-benchmark/run.yaml index 76c029864..0b368ebc9 100644 --- a/llama_stack/templates/open-benchmark/run.yaml +++ b/llama_stack/templates/open-benchmark/run.yaml @@ -39,6 +39,9 @@ providers: provider_type: inline::sqlite-vec config: db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/open-benchmark}/sqlite_vec.db + kvstore: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/open-benchmark}/sqlite_vec_registry.db - provider_id: ${env.ENABLE_CHROMADB:+chromadb} provider_type: remote::chromadb config: diff --git a/llama_stack/templates/starter/run.yaml b/llama_stack/templates/starter/run.yaml index b3dfe32d5..888a2c3bf 100644 --- a/llama_stack/templates/starter/run.yaml +++ b/llama_stack/templates/starter/run.yaml @@ -144,6 +144,9 @@ providers: provider_type: inline::sqlite-vec config: db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/sqlite_vec.db + kvstore: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/sqlite_vec_registry.db - provider_id: ${env.ENABLE_MILVUS:=__disabled__} provider_type: inline::milvus config: diff --git a/tests/unit/providers/vector_io/conftest.py b/tests/unit/providers/vector_io/conftest.py index 5eaca8a25..4a9639326 100644 --- a/tests/unit/providers/vector_io/conftest.py +++ b/tests/unit/providers/vector_io/conftest.py @@ -8,10 +8,18 @@ import random import numpy as np import pytest +from pymilvus import MilvusClient, connections +from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_io import Chunk, ChunkMetadata +from llama_stack.providers.inline.vector_io.milvus.config import MilvusVectorIOConfig, SqliteKVStoreConfig +from llama_stack.providers.inline.vector_io.sqlite_vec import SQLiteVectorIOConfig +from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import SQLiteVecIndex, SQLiteVecVectorIOAdapter +from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex, MilvusVectorIOAdapter EMBEDDING_DIMENSION = 384 +COLLECTION_PREFIX = "test_collection" +MILVUS_ALIAS = "test_milvus" @pytest.fixture @@ -50,7 +58,156 @@ def sample_chunks(): return sample +@pytest.fixture(scope="session") +def sample_chunks_with_metadata(): + """Generates chunks that force multiple batches for a single document to expose ID conflicts.""" + n, k = 10, 3 + sample = [ + Chunk( + content=f"Sentence {i} from document {j}", + metadata={"document_id": f"document-{j}"}, + chunk_metadata=ChunkMetadata( + document_id=f"document-{j}", + chunk_id=f"document-{j}-chunk-{i}", + source=f"example source-{j}-{i}", + ), + ) + for j in range(k) + for i in range(n) + ] + return sample + + @pytest.fixture(scope="session") def sample_embeddings(sample_chunks): np.random.seed(42) return np.array([np.random.rand(EMBEDDING_DIMENSION).astype(np.float32) for _ in sample_chunks]) + + +@pytest.fixture(scope="session") +def sample_embeddings_with_metadata(sample_chunks_with_metadata): + np.random.seed(42) + return np.array([np.random.rand(EMBEDDING_DIMENSION).astype(np.float32) for _ in sample_chunks_with_metadata]) + + +@pytest.fixture(params=["milvus", "sqlite_vec"]) +def vector_provider(request): + return request.param + + +@pytest.fixture(scope="session") +def mock_inference_api(embedding_dimension): + class MockInferenceAPI: + async def embed_batch(self, texts: list[str]) -> list[list[float]]: + return [np.random.rand(embedding_dimension).astype(np.float32).tolist() for _ in texts] + + return MockInferenceAPI() + + +@pytest.fixture +async def unique_kvstore_config(tmp_path_factory): + # Generate a unique filename for this test + unique_id = f"test_kv_{np.random.randint(1e6)}" + temp_dir = tmp_path_factory.getbasetemp() + db_path = str(temp_dir / f"{unique_id}.db") + + return SqliteKVStoreConfig(db_path=db_path) + + +@pytest.fixture(scope="session") +def sqlite_vec_db_path(tmp_path_factory): + db_path = str(tmp_path_factory.getbasetemp() / "test.db") + return db_path + + +@pytest.fixture +async def sqlite_vec_vec_index(embedding_dimension, tmp_path_factory): + temp_dir = tmp_path_factory.getbasetemp() + db_path = str(temp_dir / f"test_sqlite_vec_{np.random.randint(1e6)}.db") + bank_id = f"sqlite_vec_bank_{np.random.randint(1e6)}" + index = SQLiteVecIndex(embedding_dimension, db_path, bank_id) + await index.initialize() + index.db_path = db_path + yield index + index.delete() + + +@pytest.fixture +async def sqlite_vec_adapter(sqlite_vec_db_path, mock_inference_api, embedding_dimension): + config = SQLiteVectorIOConfig( + db_path=sqlite_vec_db_path, + kvstore=SqliteKVStoreConfig(), + ) + adapter = SQLiteVecVectorIOAdapter( + config=config, + inference_api=mock_inference_api, + files_api=None, + ) + collection_id = f"sqlite_test_collection_{np.random.randint(1e6)}" + await adapter.initialize() + await adapter.register_vector_db( + VectorDB( + identifier=collection_id, + provider_id="test_provider", + embedding_model="test_model", + embedding_dimension=embedding_dimension, + ) + ) + adapter.test_collection_id = collection_id + yield adapter + await adapter.shutdown() + + +@pytest.fixture(scope="session") +def milvus_vec_db_path(tmp_path_factory): + db_path = str(tmp_path_factory.getbasetemp() / "test_milvus.db") + return db_path + + +@pytest.fixture +async def milvus_vec_index(milvus_vec_db_path, embedding_dimension): + client = MilvusClient(milvus_vec_db_path) + name = f"{COLLECTION_PREFIX}_{np.random.randint(1e6)}" + connections.connect(alias=MILVUS_ALIAS, uri=milvus_vec_db_path) + index = MilvusIndex(client, name, consistency_level="Strong") + index.db_path = milvus_vec_db_path + yield index + + +@pytest.fixture +async def milvus_vec_adapter(milvus_vec_db_path, mock_inference_api): + config = MilvusVectorIOConfig( + db_path=milvus_vec_db_path, + kvstore=SqliteKVStoreConfig(), + ) + adapter = MilvusVectorIOAdapter( + config=config, + inference_api=mock_inference_api, + files_api=None, + ) + await adapter.initialize() + await adapter.register_vector_db( + VectorDB( + identifier=adapter.metadata_collection_name, + provider_id="test_provider", + embedding_model="test_model", + embedding_dimension=128, + ) + ) + yield adapter + await adapter.shutdown() + + +@pytest.fixture +def vector_io_adapter(vector_provider, request): + """Returns the appropriate vector IO adapter based on the provider parameter.""" + if vector_provider == "milvus": + return request.getfixturevalue("milvus_vec_adapter") + else: + return request.getfixturevalue("sqlite_vec_adapter") + + +@pytest.fixture +def vector_index(vector_provider, request): + """Returns appropriate vector index based on provider parameter""" + return request.getfixturevalue(f"{vector_provider}_vec_index") diff --git a/tests/unit/providers/vector_io/test_sqlite_vec.py b/tests/unit/providers/vector_io/test_sqlite_vec.py index 5d9d92cf3..8579c31bb 100644 --- a/tests/unit/providers/vector_io/test_sqlite_vec.py +++ b/tests/unit/providers/vector_io/test_sqlite_vec.py @@ -34,7 +34,7 @@ def loop(): return asyncio.new_event_loop() -@pytest_asyncio.fixture(scope="session", autouse=True) +@pytest_asyncio.fixture async def sqlite_vec_index(embedding_dimension, tmp_path_factory): temp_dir = tmp_path_factory.getbasetemp() db_path = str(temp_dir / "test_sqlite.db") @@ -44,38 +44,15 @@ async def sqlite_vec_index(embedding_dimension, tmp_path_factory): @pytest.mark.asyncio -async def test_add_chunks(sqlite_vec_index, sample_chunks, sample_embeddings): - await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings, batch_size=2) - connection = _create_sqlite_connection(sqlite_vec_index.db_path) - cur = connection.cursor() - cur.execute(f"SELECT COUNT(*) FROM {sqlite_vec_index.metadata_table}") - count = cur.fetchone()[0] - assert count == len(sample_chunks) - cur.close() - connection.close() - - -@pytest.mark.asyncio -async def test_query_chunks_vector(sqlite_vec_index, sample_chunks, sample_embeddings, embedding_dimension): - await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings) - query_embedding = np.random.rand(embedding_dimension).astype(np.float32) - response = await sqlite_vec_index.query_vector(query_embedding, k=2, score_threshold=0.0) - assert isinstance(response, QueryChunksResponse) - assert len(response.chunks) == 2 - - -@pytest.mark.xfail(reason="Chunk Metadata not yet supported for SQLite-vec", strict=True) -async def test_query_chunk_metadata(sqlite_vec_index, sample_chunks, sample_embeddings): - await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings) - query_embedding = sample_embeddings[0] - response = await sqlite_vec_index.query_vector(query_embedding, k=2, score_threshold=0.0) - assert response.chunks[-1].chunk_metadata == sample_chunks[-1].chunk_metadata +async def test_query_chunk_metadata(sqlite_vec_index, sample_chunks_with_metadata, sample_embeddings_with_metadata): + await sqlite_vec_index.add_chunks(sample_chunks_with_metadata, sample_embeddings_with_metadata) + response = await sqlite_vec_index.query_vector(sample_embeddings_with_metadata[-1], k=2, score_threshold=0.0) + assert response.chunks[0].chunk_metadata == sample_chunks_with_metadata[-1].chunk_metadata @pytest.mark.asyncio async def test_query_chunks_full_text_search(sqlite_vec_index, sample_chunks, sample_embeddings): await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings) - query_string = "Sentence 5" response = await sqlite_vec_index.query_keyword(k=3, score_threshold=0.0, query_string=query_string) @@ -148,7 +125,7 @@ async def test_chunk_id_conflict(sqlite_vec_index, sample_chunks, embedding_dime assert len(chunk_ids) == len(set(chunk_ids)), "Duplicate chunk IDs detected across batches!" -@pytest_asyncio.fixture(scope="session") +@pytest.fixture(scope="session") async def sqlite_vec_adapter(sqlite_connection): config = type("Config", (object,), {"db_path": ":memory:"}) # Mock config with in-memory database adapter = SQLiteVecVectorIOAdapter(config=config, inference_api=None) diff --git a/tests/unit/providers/vector_io/test_vector_io_openai_vector_stores.py b/tests/unit/providers/vector_io/test_vector_io_openai_vector_stores.py index 0a109e833..5f7926ce6 100644 --- a/tests/unit/providers/vector_io/test_vector_io_openai_vector_stores.py +++ b/tests/unit/providers/vector_io/test_vector_io_openai_vector_stores.py @@ -4,253 +4,142 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import asyncio +import json import time from unittest.mock import AsyncMock import numpy as np import pytest -import pytest_asyncio -from pymilvus import Collection, MilvusClient, connections from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_io import Chunk, QueryChunksResponse -from llama_stack.providers.inline.vector_io.milvus.config import MilvusVectorIOConfig, SqliteKVStoreConfig -from llama_stack.providers.remote.vector_io.milvus.milvus import VECTOR_DBS_PREFIX, MilvusIndex, MilvusVectorIOAdapter -from llama_stack.providers.utils.kvstore import kvstore_impl +from llama_stack.providers.remote.vector_io.milvus.milvus import VECTOR_DBS_PREFIX -# TODO: Refactor these to be for inline vector-io providers -MILVUS_ALIAS = "test_milvus" -COLLECTION_PREFIX = "test_collection" - - -@pytest.fixture(scope="session") -def loop(): - return asyncio.new_event_loop() - - -@pytest.fixture(scope="session") -def mock_inference_api(embedding_dimension): - class MockInferenceAPI: - async def embed_batch(self, texts: list[str]) -> list[list[float]]: - return [np.random.rand(embedding_dimension).astype(np.float32).tolist() for _ in texts] - - return MockInferenceAPI() - - -@pytest_asyncio.fixture -async def unique_kvstore_config(tmp_path_factory): - # Generate a unique filename for this test - unique_id = f"test_kv_{np.random.randint(1e6)}" - temp_dir = tmp_path_factory.getbasetemp() - db_path = str(temp_dir / f"{unique_id}.db") - - return SqliteKVStoreConfig(db_path=db_path) - - -@pytest_asyncio.fixture(scope="session", autouse=True) -async def milvus_vec_index(embedding_dimension, tmp_path_factory): - temp_dir = tmp_path_factory.getbasetemp() - db_path = str(temp_dir / "test_milvus.db") - client = MilvusClient(db_path) - name = f"{COLLECTION_PREFIX}_{np.random.randint(1e6)}" - connections.connect(alias=MILVUS_ALIAS, uri=db_path) - index = MilvusIndex(client, name, consistency_level="Strong") - index.db_path = db_path - yield index - - -@pytest_asyncio.fixture(scope="session") -async def milvus_vec_adapter(milvus_vec_index, mock_inference_api): - config = MilvusVectorIOConfig( - db_path=milvus_vec_index.db_path, - kvstore=SqliteKVStoreConfig(), - ) - adapter = MilvusVectorIOAdapter( - config=config, - inference_api=mock_inference_api, - files_api=None, - ) - await adapter.initialize() - await adapter.register_vector_db( - VectorDB( - identifier=adapter.metadata_collection_name, - provider_id="test_provider", - embedding_model="test_model", - embedding_dimension=128, - ) - ) - yield adapter - await adapter.shutdown() +# This test is a unit test for the inline VectoerIO providers. This should only contain +# tests which are specific to this class. More general (API-level) tests should be placed in +# tests/integration/vector_io/ +# +# How to run this test: +# +# pytest tests/unit/providers/vector_io/test_vector_io_openai_vector_stores.py \ +# -v -s --tb=short --disable-warnings --asyncio-mode=auto @pytest.mark.asyncio -async def test_cache_contains_initial_collection(milvus_vec_adapter): - coll_name = milvus_vec_adapter.metadata_collection_name - assert coll_name in milvus_vec_adapter.cache +async def test_initialize_index(vector_index): + await vector_index.initialize() @pytest.mark.asyncio -async def test_add_chunks(milvus_vec_index, sample_chunks, sample_embeddings): - await milvus_vec_index.add_chunks(sample_chunks, sample_embeddings) - resp = await milvus_vec_index.query_vector(sample_embeddings[0], k=1, score_threshold=-1) +async def test_add_chunks_query_vector(vector_index, sample_chunks, sample_embeddings): + vector_index.delete() + vector_index.initialize() + await vector_index.add_chunks(sample_chunks, sample_embeddings) + resp = await vector_index.query_vector(sample_embeddings[0], k=1, score_threshold=-1) assert resp.chunks[0].content == sample_chunks[0].content + vector_index.delete() @pytest.mark.asyncio -async def test_query_chunks_vector(milvus_vec_index, sample_chunks, sample_embeddings, embedding_dimension): - await milvus_vec_index.add_chunks(sample_chunks, sample_embeddings) - query_emb = np.random.rand(embedding_dimension).astype(np.float32) - resp = await milvus_vec_index.query_vector(query_emb, k=2, score_threshold=0.0) - assert isinstance(resp, QueryChunksResponse) - assert len(resp.chunks) == 2 - - -@pytest.mark.asyncio -async def test_chunk_id_conflict(milvus_vec_index, sample_chunks, embedding_dimension): +async def test_chunk_id_conflict(vector_index, sample_chunks, embedding_dimension): embeddings = np.random.rand(len(sample_chunks), embedding_dimension).astype(np.float32) - await milvus_vec_index.add_chunks(sample_chunks, embeddings) - coll = Collection(milvus_vec_index.collection_name, using=MILVUS_ALIAS) - ids = coll.query(expr="id >= 0", output_fields=["id"], timeout=30) - flat_ids = [i["id"] for i in ids] - assert len(flat_ids) == len(set(flat_ids)) + await vector_index.add_chunks(sample_chunks, embeddings) + resp = await vector_index.query_vector( + np.random.rand(embedding_dimension).astype(np.float32), + k=len(sample_chunks), + score_threshold=-1, + ) + + contents = [chunk.content for chunk in resp.chunks] + assert len(contents) == len(set(contents)) @pytest.mark.asyncio -async def test_initialize_with_milvus_client(milvus_vec_index, unique_kvstore_config): - kvstore = await kvstore_impl(unique_kvstore_config) - vector_db = VectorDB( - identifier="test_db", - provider_id="test_provider", - embedding_model="test_model", - embedding_dimension=128, - metadata={"test_key": "test_value"}, - ) - test_vector_db_data = vector_db.model_dump_json() - await kvstore.set(f"{VECTOR_DBS_PREFIX}test_db", test_vector_db_data) - tmp_milvus_vec_adapter = MilvusVectorIOAdapter( - config=MilvusVectorIOConfig( - db_path=milvus_vec_index.db_path, - kvstore=unique_kvstore_config, - ), - inference_api=None, - files_api=None, - ) - await tmp_milvus_vec_adapter.initialize() - - vector_db = VectorDB( - identifier="test_db", - provider_id="test_provider", - embedding_model="test_model", - embedding_dimension=128, - ) - test_vector_db_data = vector_db.model_dump_json() - await tmp_milvus_vec_adapter.kvstore.set(f"{VECTOR_DBS_PREFIX}/test_db", test_vector_db_data) - - assert milvus_vec_index.client is not None - assert isinstance(milvus_vec_index.client, MilvusClient) - assert tmp_milvus_vec_adapter.cache is not None - # registering a vector won't update the cache or openai_vector_store collection name - assert ( - tmp_milvus_vec_adapter.metadata_collection_name not in tmp_milvus_vec_adapter.cache - or tmp_milvus_vec_adapter.openai_vector_stores - ) - - -@pytest.mark.asyncio -async def test_persistence_across_adapter_restarts( - tmp_path, milvus_vec_index, mock_inference_api, unique_kvstore_config -): - adapter1 = MilvusVectorIOAdapter( - config=MilvusVectorIOConfig(db_path=milvus_vec_index.db_path, kvstore=unique_kvstore_config), - inference_api=mock_inference_api, - files_api=None, - ) - await adapter1.initialize() +async def test_initialize_adapter_with_existing_kvstore(vector_io_adapter): + key = f"{VECTOR_DBS_PREFIX}db1" dummy = VectorDB( identifier="foo_db", provider_id="test_provider", embedding_model="test_model", embedding_dimension=128 ) - await adapter1.register_vector_db(dummy) - await adapter1.shutdown() + await vector_io_adapter.kvstore.set(key=key, value=json.dumps(dummy.model_dump())) - await adapter1.initialize() - assert "foo_db" in adapter1.cache - await adapter1.shutdown() + await vector_io_adapter.initialize() @pytest.mark.asyncio -async def test_register_and_unregister_vector_db(milvus_vec_adapter): - try: - connections.disconnect(MILVUS_ALIAS) - except Exception as _: - pass +async def test_persistence_across_adapter_restarts(vector_io_adapter): + await vector_io_adapter.initialize() + dummy = VectorDB( + identifier="foo_db", provider_id="test_provider", embedding_model="test_model", embedding_dimension=128 + ) + await vector_io_adapter.register_vector_db(dummy) + await vector_io_adapter.shutdown() - connections.connect(alias=MILVUS_ALIAS, uri=milvus_vec_adapter.config.db_path) + await vector_io_adapter.initialize() + assert "foo_db" in vector_io_adapter.cache + await vector_io_adapter.shutdown() + + +@pytest.mark.asyncio +async def test_register_and_unregister_vector_db(vector_io_adapter): unique_id = f"foo_db_{np.random.randint(1e6)}" dummy = VectorDB( identifier=unique_id, provider_id="test_provider", embedding_model="test_model", embedding_dimension=128 ) - await milvus_vec_adapter.register_vector_db(dummy) - assert dummy.identifier in milvus_vec_adapter.cache - - if dummy.identifier in milvus_vec_adapter.cache: - index = milvus_vec_adapter.cache[dummy.identifier].index - if hasattr(index, "client") and hasattr(index.client, "_using"): - index.client._using = MILVUS_ALIAS - - await milvus_vec_adapter.unregister_vector_db(dummy.identifier) - assert dummy.identifier not in milvus_vec_adapter.cache + await vector_io_adapter.register_vector_db(dummy) + assert dummy.identifier in vector_io_adapter.cache + await vector_io_adapter.unregister_vector_db(dummy.identifier) + assert dummy.identifier not in vector_io_adapter.cache @pytest.mark.asyncio -async def test_query_unregistered_raises(milvus_vec_adapter): +async def test_query_unregistered_raises(vector_io_adapter): fake_emb = np.zeros(8, dtype=np.float32) - with pytest.raises(AttributeError): - await milvus_vec_adapter.query_chunks("no_such_db", fake_emb) + with pytest.raises(ValueError): + await vector_io_adapter.query_chunks("no_such_db", fake_emb) @pytest.mark.asyncio -async def test_insert_chunks_calls_underlying_index(milvus_vec_adapter): +async def test_insert_chunks_calls_underlying_index(vector_io_adapter): fake_index = AsyncMock() - milvus_vec_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=fake_index) + vector_io_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=fake_index) chunks = ["chunk1", "chunk2"] - await milvus_vec_adapter.insert_chunks("db1", chunks) + await vector_io_adapter.insert_chunks("db1", chunks) fake_index.insert_chunks.assert_awaited_once_with(chunks) @pytest.mark.asyncio -async def test_insert_chunks_missing_db_raises(milvus_vec_adapter): - milvus_vec_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=None) +async def test_insert_chunks_missing_db_raises(vector_io_adapter): + vector_io_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=None) with pytest.raises(ValueError): - await milvus_vec_adapter.insert_chunks("db_not_exist", []) + await vector_io_adapter.insert_chunks("db_not_exist", []) @pytest.mark.asyncio -async def test_query_chunks_calls_underlying_index_and_returns(milvus_vec_adapter): +async def test_query_chunks_calls_underlying_index_and_returns(vector_io_adapter): expected = QueryChunksResponse(chunks=[Chunk(content="c1")], scores=[0.1]) fake_index = AsyncMock(query_chunks=AsyncMock(return_value=expected)) - milvus_vec_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=fake_index) + vector_io_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=fake_index) - response = await milvus_vec_adapter.query_chunks("db1", "my_query", {"param": 1}) + response = await vector_io_adapter.query_chunks("db1", "my_query", {"param": 1}) fake_index.query_chunks.assert_awaited_once_with("my_query", {"param": 1}) assert response is expected @pytest.mark.asyncio -async def test_query_chunks_missing_db_raises(milvus_vec_adapter): - milvus_vec_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=None) +async def test_query_chunks_missing_db_raises(vector_io_adapter): + vector_io_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=None) with pytest.raises(ValueError): - await milvus_vec_adapter.query_chunks("db_missing", "q", None) + await vector_io_adapter.query_chunks("db_missing", "q", None) @pytest.mark.asyncio -async def test_save_openai_vector_store(milvus_vec_adapter): +async def test_save_openai_vector_store(vector_io_adapter): store_id = "vs_1234" openai_vector_store = { "id": store_id, @@ -260,14 +149,14 @@ async def test_save_openai_vector_store(milvus_vec_adapter): "embedding_model": "test_model", } - await milvus_vec_adapter._save_openai_vector_store(store_id, openai_vector_store) + await vector_io_adapter._save_openai_vector_store(store_id, openai_vector_store) - assert openai_vector_store["id"] in milvus_vec_adapter.openai_vector_stores - assert milvus_vec_adapter.openai_vector_stores[openai_vector_store["id"]] == openai_vector_store + assert openai_vector_store["id"] in vector_io_adapter.openai_vector_stores + assert vector_io_adapter.openai_vector_stores[openai_vector_store["id"]] == openai_vector_store @pytest.mark.asyncio -async def test_update_openai_vector_store(milvus_vec_adapter): +async def test_update_openai_vector_store(vector_io_adapter): store_id = "vs_1234" openai_vector_store = { "id": store_id, @@ -277,14 +166,14 @@ async def test_update_openai_vector_store(milvus_vec_adapter): "embedding_model": "test_model", } - await milvus_vec_adapter._save_openai_vector_store(store_id, openai_vector_store) + await vector_io_adapter._save_openai_vector_store(store_id, openai_vector_store) openai_vector_store["description"] = "Updated description" - await milvus_vec_adapter._update_openai_vector_store(store_id, openai_vector_store) - assert milvus_vec_adapter.openai_vector_stores[openai_vector_store["id"]] == openai_vector_store + await vector_io_adapter._update_openai_vector_store(store_id, openai_vector_store) + assert vector_io_adapter.openai_vector_stores[openai_vector_store["id"]] == openai_vector_store @pytest.mark.asyncio -async def test_delete_openai_vector_store(milvus_vec_adapter): +async def test_delete_openai_vector_store(vector_io_adapter): store_id = "vs_1234" openai_vector_store = { "id": store_id, @@ -294,13 +183,13 @@ async def test_delete_openai_vector_store(milvus_vec_adapter): "embedding_model": "test_model", } - await milvus_vec_adapter._save_openai_vector_store(store_id, openai_vector_store) - await milvus_vec_adapter._delete_openai_vector_store_from_storage(store_id) - assert openai_vector_store["id"] not in milvus_vec_adapter.openai_vector_stores + await vector_io_adapter._save_openai_vector_store(store_id, openai_vector_store) + await vector_io_adapter._delete_openai_vector_store_from_storage(store_id) + assert openai_vector_store["id"] not in vector_io_adapter.openai_vector_stores @pytest.mark.asyncio -async def test_load_openai_vector_stores(milvus_vec_adapter): +async def test_load_openai_vector_stores(vector_io_adapter): store_id = "vs_1234" openai_vector_store = { "id": store_id, @@ -310,13 +199,13 @@ async def test_load_openai_vector_stores(milvus_vec_adapter): "embedding_model": "test_model", } - await milvus_vec_adapter._save_openai_vector_store(store_id, openai_vector_store) - loaded_stores = await milvus_vec_adapter._load_openai_vector_stores() + await vector_io_adapter._save_openai_vector_store(store_id, openai_vector_store) + loaded_stores = await vector_io_adapter._load_openai_vector_stores() assert loaded_stores[store_id] == openai_vector_store @pytest.mark.asyncio -async def test_save_openai_vector_store_file(milvus_vec_adapter, tmp_path_factory): +async def test_save_openai_vector_store_file(vector_io_adapter, tmp_path_factory): store_id = "vs_1234" file_id = "file_1234" @@ -334,11 +223,11 @@ async def test_save_openai_vector_store_file(milvus_vec_adapter, tmp_path_factor ] # validating we don't raise an exception - await milvus_vec_adapter._save_openai_vector_store_file(store_id, file_id, file_info, file_contents) + await vector_io_adapter._save_openai_vector_store_file(store_id, file_id, file_info, file_contents) @pytest.mark.asyncio -async def test_update_openai_vector_store_file(milvus_vec_adapter, tmp_path_factory): +async def test_update_openai_vector_store_file(vector_io_adapter, tmp_path_factory): store_id = "vs_1234" file_id = "file_1234" @@ -355,24 +244,24 @@ async def test_update_openai_vector_store_file(milvus_vec_adapter, tmp_path_fact {"content": "Test content", "chunk_metadata": {"chunk_id": "chunk_001"}, "metadata": {"file_id": file_id}} ] - await milvus_vec_adapter._save_openai_vector_store_file(store_id, file_id, file_info, file_contents) + await vector_io_adapter._save_openai_vector_store_file(store_id, file_id, file_info, file_contents) updated_file_info = file_info.copy() updated_file_info["filename"] = "updated_test_file.txt" - await milvus_vec_adapter._update_openai_vector_store_file( + await vector_io_adapter._update_openai_vector_store_file( store_id, file_id, updated_file_info, ) - loaded_contents = await milvus_vec_adapter._load_openai_vector_store_file(store_id, file_id) + loaded_contents = await vector_io_adapter._load_openai_vector_store_file(store_id, file_id) assert loaded_contents == updated_file_info assert loaded_contents != file_info @pytest.mark.asyncio -async def test_load_openai_vector_store_file_contents(milvus_vec_adapter, tmp_path_factory): +async def test_load_openai_vector_store_file_contents(vector_io_adapter, tmp_path_factory): store_id = "vs_1234" file_id = "file_1234" @@ -389,14 +278,14 @@ async def test_load_openai_vector_store_file_contents(milvus_vec_adapter, tmp_pa {"content": "Test content", "chunk_metadata": {"chunk_id": "chunk_001"}, "metadata": {"file_id": file_id}} ] - await milvus_vec_adapter._save_openai_vector_store_file(store_id, file_id, file_info, file_contents) + await vector_io_adapter._save_openai_vector_store_file(store_id, file_id, file_info, file_contents) - loaded_contents = await milvus_vec_adapter._load_openai_vector_store_file_contents(store_id, file_id) + loaded_contents = await vector_io_adapter._load_openai_vector_store_file_contents(store_id, file_id) assert loaded_contents == file_contents @pytest.mark.asyncio -async def test_delete_openai_vector_store_file_from_storage(milvus_vec_adapter, tmp_path_factory): +async def test_delete_openai_vector_store_file_from_storage(vector_io_adapter, tmp_path_factory): store_id = "vs_1234" file_id = "file_1234" @@ -413,8 +302,8 @@ async def test_delete_openai_vector_store_file_from_storage(milvus_vec_adapter, {"content": "Test content", "chunk_metadata": {"chunk_id": "chunk_001"}, "metadata": {"file_id": file_id}} ] - await milvus_vec_adapter._save_openai_vector_store_file(store_id, file_id, file_info, file_contents) - await milvus_vec_adapter._delete_openai_vector_store_file_from_storage(store_id, file_id) + await vector_io_adapter._save_openai_vector_store_file(store_id, file_id, file_info, file_contents) + await vector_io_adapter._delete_openai_vector_store_file_from_storage(store_id, file_id) - loaded_contents = await milvus_vec_adapter._load_openai_vector_store_file_contents(store_id, file_id) + loaded_contents = await vector_io_adapter._load_openai_vector_store_file_contents(store_id, file_id) assert loaded_contents == []