From af7748a4d55fac47eef592d1f197c647b5a0865e Mon Sep 17 00:00:00 2001 From: Francisco Javier Arceo Date: Mon, 10 Feb 2025 16:16:55 -0500 Subject: [PATCH] feat: Adding sqlite-vec as vectordb Signed-off-by: Francisco Javier Arceo --- .../distribution/routers/routing_tables.py | 5 +- .../inline/vector_io/sqlite_vec/__init__.py | 18 ++ .../inline/vector_io/sqlite_vec/config.py | 28 +++ .../inline/vector_io/sqlite_vec/sqlite_vec.py | 205 ++++++++++++++++++ llama_stack/providers/registry/vector_io.py | 8 + .../remote/inference/ollama/ollama.py | 26 ++- .../providers/tests/vector_io/conftest.py | 8 + .../providers/tests/vector_io/fixtures.py | 25 ++- llama_stack/providers/utils/kvstore/config.py | 17 ++ llama_stack/templates/ollama/build.yaml | 1 + llama_stack/templates/ollama/ollama.py | 12 +- llama_stack/templates/ollama/run.yaml | 7 +- tests/client-sdk/conftest.py | 2 +- tests/client-sdk/vector_io/test_vector_io.py | 22 +- 14 files changed, 356 insertions(+), 28 deletions(-) create mode 100644 llama_stack/providers/inline/vector_io/sqlite_vec/__init__.py create mode 100644 llama_stack/providers/inline/vector_io/sqlite_vec/config.py create mode 100644 llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 68fafd8ee..d7f6f6c4c 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -185,7 +185,9 @@ class CommonRoutingTableImpl(RoutingTable): obj.provider_id = list(self.impls_by_provider_id.keys())[0] if obj.provider_id not in self.impls_by_provider_id: - raise ValueError(f"Provider `{obj.provider_id}` not found") + raise ValueError( + f"Provider `{obj.provider_id}` not found \navailable providers: {self.impls_by_provider_id.keys()}" + ) p = self.impls_by_provider_id[obj.provider_id] @@ -335,6 +337,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): "embedding_model": embedding_model, "embedding_dimension": model.metadata["embedding_dimension"], } + print(f"Registering vector db {vector_db_data} with embedding model {embedding_model}") vector_db = TypeAdapter(VectorDB).validate_python(vector_db_data) await self.register_object(vector_db) return vector_db diff --git a/llama_stack/providers/inline/vector_io/sqlite_vec/__init__.py b/llama_stack/providers/inline/vector_io/sqlite_vec/__init__.py new file mode 100644 index 000000000..441e64fe8 --- /dev/null +++ b/llama_stack/providers/inline/vector_io/sqlite_vec/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Dict +from llama_stack.providers.datatypes import Api, ProviderSpec +from .config import SQLiteVecImplConfig + + +async def get_provider_impl(config: SQLiteVecImplConfig, deps: Dict[Api, ProviderSpec]): + from .sqlite_vec import SQLiteVecVectorIOImpl + + assert isinstance(config, SQLiteVecImplConfig), f"Unexpected config type: {type(config)}" + impl = SQLiteVecVectorIOImpl(config, deps[Api.inference]) + await impl.initialize() + return impl diff --git a/llama_stack/providers/inline/vector_io/sqlite_vec/config.py b/llama_stack/providers/inline/vector_io/sqlite_vec/config.py new file mode 100644 index 000000000..204ecab2a --- /dev/null +++ b/llama_stack/providers/inline/vector_io/sqlite_vec/config.py @@ -0,0 +1,28 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# config.py +from pydantic import BaseModel +from typing import Any, Dict + +from llama_stack.providers.utils.kvstore.config import ( + KVStoreConfig, + SqliteKVStoreConfig, +) + + +class SQLiteVecImplConfig(BaseModel): + db_path: str + kvstore: KVStoreConfig + + @classmethod + def sample_run_config(cls, __distro_dir__: str) -> Dict[str, Any]: + return { + "kvstore": SqliteKVStoreConfig.sample_run_config( + __distro_dir__=__distro_dir__, + db_name="sqlite_vec.db", + ) + } diff --git a/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py b/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py new file mode 100644 index 000000000..a543430da --- /dev/null +++ b/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py @@ -0,0 +1,205 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# sqlite_vec_impl.py +import sqlite3 +import sqlite_vec +import struct +import logging +import numpy as np +from numpy.typing import NDArray +from typing import List, Optional, Dict, Any + +from llama_stack.apis.vector_dbs import VectorDB +from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO +from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate +from llama_stack.providers.utils.memory.vector_store import EmbeddingIndex, VectorDBWithIndex + +logger = logging.getLogger(__name__) + + +def serialize_vector(vector: List[float]) -> bytes: + """Serialize a list of floats into a compact binary representation.""" + return struct.pack(f"{len(vector)}f", *vector) + + +class SQLiteVecIndex(EmbeddingIndex): + """ + An index implementation that stores embeddings in a SQLite virtual table using sqlite-vec. + Two tables are used: + - A metadata table (chunks_{bank_id}) that holds the chunk JSON. + - A virtual table (vec_chunks_{bank_id}) that holds the serialized vector. + """ + + def __init__(self, dimension: int, connection: sqlite3.Connection, bank_id: str): + self.dimension = dimension + self.connection = connection + self.bank_id = bank_id + self.metadata_table = f"chunks_{bank_id}".replace("-", "_") + self.vector_table = f"vec_chunks_{bank_id}".replace("-", "_") + + @classmethod + async def create(cls, dimension: int, connection: sqlite3.Connection, bank_id: str): + instance = cls(dimension, connection, bank_id) + await instance.initialize() + return instance + + async def initialize(self) -> None: + cur = self.connection.cursor() + print(f"Creating tables {self.metadata_table} and {self.vector_table}") + # Create the table to store chunk metadata. + cur.execute(f""" + CREATE TABLE IF NOT EXISTS {self.metadata_table} ( + id INTEGER PRIMARY KEY, + chunk TEXT + ); + """) + # Create the virtual table for embeddings. + cur.execute(f""" + CREATE VIRTUAL TABLE IF NOT EXISTS {self.vector_table} + USING vec0(embedding FLOAT[{self.dimension}]); + """) + self.connection.commit() + + async def delete(self): + cur = self.connection.cursor() + cur.execute(f"DROP TABLE IF EXISTS {self.metadata_table};") + cur.execute(f"DROP TABLE IF EXISTS {self.vector_table};") + self.connection.commit() + + async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray): + """ + Add new chunks along with their embeddings. + For each chunk, we insert its JSON into the metadata table and then insert its + embedding (serialized to raw bytes) into the virtual table using the assigned rowid. + """ + cur = self.connection.cursor() + for chunk, emb in zip(chunks, embeddings): + # Serialize and insert the chunk metadata. + chunk_json = chunk.model_dump_json() + cur.execute(f"INSERT INTO {self.metadata_table} (chunk) VALUES (?)", (chunk_json,)) + row_id = cur.lastrowid + # Ensure the embedding is a list of floats. + emb_list = emb.tolist() if isinstance(emb, np.ndarray) else list(emb) + emb_blob = serialize_vector(emb_list) + cur.execute(f"INSERT INTO {self.vector_table} (rowid, embedding) VALUES (?, ?)", (row_id, emb_blob)) + self.connection.commit() + + async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: + """ + Query for the k most similar chunks. We convert the query embedding to a blob and run a SQL query + against the virtual table. The SQL joins the metadata table to recover the chunk JSON. + """ + emb_list = embedding.tolist() if isinstance(embedding, np.ndarray) else list(embedding) + emb_blob = serialize_vector(emb_list) + cur = self.connection.cursor() + query_sql = f""" + SELECT m.id, m.chunk, v.distance + FROM {self.vector_table} AS v + JOIN {self.metadata_table} AS m ON m.id = v.rowid + WHERE v.embedding MATCH ? AND k = ? + ORDER BY v.distance; + """ + cur.execute(query_sql, (emb_blob, k)) + rows = cur.fetchall() + chunks = [] + scores = [] + for _id, chunk_json, distance in rows: + try: + chunk = Chunk.model_validate_json(chunk_json) + except Exception as e: + logger.error(f"Error parsing chunk JSON for id {_id}: {e}") + continue + chunks.append(chunk) + # Mimic the Faiss scoring: score = 1/distance (avoid division by zero) + score = 1.0 / distance if distance != 0 else float("inf") + scores.append(score) + return QueryChunksResponse(chunks=chunks, scores=scores) + + +class SQLiteVecVectorIOImpl(VectorIO, VectorDBsProtocolPrivate): + """ + A VectorIO implementation using SQLite + sqlite_vec. + This class handles vector database registration (with metadata stored in a table named `vector_dbs`) + and creates a cache of VectorDBWithIndex instances (each wrapping a SQLiteVecIndex). + """ + + def __init__(self, config, inference_api: Api.inference) -> None: + self.config = config + self.inference_api = inference_api + self.cache: Dict[str, VectorDBWithIndex] = {} + self.connection: Optional[sqlite3.Connection] = None + + async def initialize(self) -> None: + # Open a connection to the SQLite database (the file is specified in the config). + print(f"Connecting to SQLite database at {self.config.db_path}") + self.connection = sqlite3.connect(self.config.db_path) + self.connection.enable_load_extension(True) + sqlite_vec.load(self.connection) + self.connection.enable_load_extension(False) + cur = self.connection.cursor() + # Create a table to persist vector DB registrations. + cur.execute(""" + CREATE TABLE IF NOT EXISTS vector_dbs ( + id TEXT PRIMARY KEY, + metadata TEXT + ); + """) + self.connection.commit() + # Load any existing vector DB registrations. + cur.execute("SELECT metadata FROM vector_dbs") + rows = cur.fetchall() + for row in rows: + vector_db_data = row[0] + vector_db = VectorDB.model_validate_json(vector_db_data) + index = await SQLiteVecIndex.create(vector_db.embedding_dimension, self.connection, vector_db.identifier) + self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api) + + async def shutdown(self) -> None: + if self.connection: + self.connection.close() + self.connection = None + + async def register_vector_db(self, vector_db: VectorDB) -> None: + if self.connection is None: + raise RuntimeError("SQLite connection not initialized") + cur = self.connection.cursor() + cur.execute( + "INSERT OR REPLACE INTO vector_dbs (id, metadata) VALUES (?, ?)", + (vector_db.identifier, vector_db.model_dump_json()), + ) + self.connection.commit() + index = await SQLiteVecIndex.create(vector_db.embedding_dimension, self.connection, vector_db.identifier) + self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api) + + async def list_vector_dbs(self) -> List[VectorDB]: + return [v.vector_db for v in self.cache.values()] + + async def unregister_vector_db(self, vector_db_id: str) -> None: + if self.connection is None: + raise RuntimeError("SQLite connection not initialized") + if vector_db_id not in self.cache: + logger.warning(f"Vector DB {vector_db_id} not found") + return + await self.cache[vector_db_id].index.delete() + del self.cache[vector_db_id] + cur = self.connection.cursor() + cur.execute("DELETE FROM vector_dbs WHERE id = ?", (vector_db_id,)) + self.connection.commit() + + async def insert_chunks(self, vector_db_id: str, chunks: List[Chunk], ttl_seconds: Optional[int] = None) -> None: + if vector_db_id not in self.cache: + raise ValueError(f"Vector DB {vector_db_id} not found. Found: {list(self.cache.keys())}") + # The VectorDBWithIndex helper is expected to compute embeddings via the inference_api + # and then call our index’s add_chunks. + await self.cache[vector_db_id].insert_chunks(chunks) + + async def query_chunks( + self, vector_db_id: str, query: Any, params: Optional[Dict[str, Any]] = None + ) -> QueryChunksResponse: + if vector_db_id not in self.cache: + raise ValueError(f"Vector DB {vector_db_id} not found") + return await self.cache[vector_db_id].query_chunks(query, params) diff --git a/llama_stack/providers/registry/vector_io.py b/llama_stack/providers/registry/vector_io.py index 2d7c02d86..55158dc1d 100644 --- a/llama_stack/providers/registry/vector_io.py +++ b/llama_stack/providers/registry/vector_io.py @@ -54,6 +54,14 @@ def available_providers() -> List[ProviderSpec]: config_class="llama_stack.providers.inline.vector_io.faiss.FaissImplConfig", api_dependencies=[Api.inference], ), + InlineProviderSpec( + api=Api.vector_io, + provider_type="inline::sqlite_vec", + pip_packages=EMBEDDING_DEPS + ["sqlite_vec"], + module="llama_stack.providers.inline.vector_io.sqlite_vec", + config_class="llama_stack.providers.inline.vector_io.sqlite_vec.SQLiteVecImplConfig", + api_dependencies=[Api.inference], + ), remote_provider_spec( Api.vector_io, AdapterSpec( diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index ecd195854..5a73d3b14 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -352,20 +352,24 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): return EmbeddingsResponse(embeddings=embeddings) async def register_model(self, model: Model) -> Model: - async def check_model_availability(model_id: str): - response = await self.client.ps() - available_models = [m["model"] for m in response["models"]] - if model_id not in available_models: - raise ValueError( - f"Model '{model_id}' is not available in Ollama. Available models: {', '.join(available_models)}" - ) - + # ollama does not have embedding models running. Check if the model is in list of available models. if model.model_type == ModelType.embedding: - await check_model_availability(model.provider_resource_id) + response = await self.client.list() + available_models = [m["model"] for m in response["models"]] + if model.provider_resource_id not in available_models: + raise ValueError( + f"Model '{model.provider_resource_id}' is not available in Ollama. " + f"Available models: {', '.join(available_models)}" + ) return model - model = await self.register_helper.register_model(model) - await check_model_availability(model.provider_resource_id) + models = await self.client.ps() + available_models = [m["model"] for m in models["models"]] + if model.provider_resource_id not in available_models: + raise ValueError( + f"Model '{model.provider_resource_id}' is not available in Ollama. " + f"Available models: [{', '.join(available_models)}]" + ) return model diff --git a/llama_stack/providers/tests/vector_io/conftest.py b/llama_stack/providers/tests/vector_io/conftest.py index 1feb5af92..3a02ac712 100644 --- a/llama_stack/providers/tests/vector_io/conftest.py +++ b/llama_stack/providers/tests/vector_io/conftest.py @@ -41,6 +41,14 @@ DEFAULT_PROVIDER_COMBINATIONS = [ id="ollama", marks=pytest.mark.ollama, ), + pytest.param( + { + "inference": "ollama", + "vector_io": "sqlite_vec", + }, + id="sqlite_vec", + marks=pytest.mark.ollama, + ), pytest.param( { "inference": "sentence_transformers", diff --git a/llama_stack/providers/tests/vector_io/fixtures.py b/llama_stack/providers/tests/vector_io/fixtures.py index c8d5fa8cf..80ca05c32 100644 --- a/llama_stack/providers/tests/vector_io/fixtures.py +++ b/llama_stack/providers/tests/vector_io/fixtures.py @@ -15,6 +15,7 @@ from llama_stack.distribution.datatypes import Api, Provider from llama_stack.providers.inline.vector_io.chroma import ChromaInlineImplConfig from llama_stack.providers.inline.vector_io.faiss import FaissImplConfig +from llama_stack.providers.inline.vector_io.sqlite_vec import SQLiteVecImplConfig from llama_stack.providers.remote.vector_io.chroma import ChromaRemoteImplConfig from llama_stack.providers.remote.vector_io.pgvector import PGVectorConfig from llama_stack.providers.remote.vector_io.weaviate import WeaviateConfig @@ -53,6 +54,22 @@ def vector_io_faiss() -> ProviderFixture: ) +@pytest.fixture(scope="session") +def vector_io_sqlite_vec() -> ProviderFixture: + temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db") + return ProviderFixture( + providers=[ + Provider( + provider_id="sqlite_vec", + provider_type="inline::sqlite_vec", + config=SQLiteVecImplConfig( + kvstore=SqliteKVStoreConfig(db_path=temp_file.name).model_dump(), + ).model_dump(), + ) + ], + ) + + @pytest.fixture(scope="session") def vector_io_pgvector() -> ProviderFixture: return ProviderFixture( @@ -111,7 +128,13 @@ def vector_io_chroma() -> ProviderFixture: ) -VECTOR_IO_FIXTURES = ["faiss", "pgvector", "weaviate", "chroma"] +VECTOR_IO_FIXTURES = [ + "faiss", + "pgvector", + "weaviate", + "chroma", + "sqlite_vec", +] @pytest_asyncio.fixture(scope="session") diff --git a/llama_stack/providers/utils/kvstore/config.py b/llama_stack/providers/utils/kvstore/config.py index 85327c131..aea0fd14a 100644 --- a/llama_stack/providers/utils/kvstore/config.py +++ b/llama_stack/providers/utils/kvstore/config.py @@ -17,6 +17,7 @@ from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR class KVStoreType(Enum): redis = "redis" sqlite = "sqlite" + milvus_lite = "milvus_lite" postgres = "postgres" @@ -62,6 +63,22 @@ class SqliteKVStoreConfig(CommonConfig): } +class MilvusLiteKVStoreConfig(CommonConfig): + type: Literal[KVStoreType.milvus_lite.value] = KVStoreType.milvus_lite.value + db_path: str = Field( + default=(RUNTIME_BASE_DIR / "kvstore.db").as_posix(), + description="File path for the sqlite database", + ) + + @classmethod + def sample_run_config(cls, __distro_dir__: str = "runtime", db_name: str = "kvstore.db"): + return { + "type": "milvuslite", + "namespace": None, + "db_path": "${env.SQLITE_STORE_DIR:~/.llama/" + __distro_dir__ + "}/" + db_name, + } + + class PostgresKVStoreConfig(CommonConfig): type: Literal[KVStoreType.postgres.value] = KVStoreType.postgres.value host: str = "localhost" diff --git a/llama_stack/templates/ollama/build.yaml b/llama_stack/templates/ollama/build.yaml index 0fee6808c..48960c5ba 100644 --- a/llama_stack/templates/ollama/build.yaml +++ b/llama_stack/templates/ollama/build.yaml @@ -6,6 +6,7 @@ distribution_spec: - remote::ollama vector_io: - inline::faiss + - inline::sqlite_vec - remote::chromadb - remote::pgvector safety: diff --git a/llama_stack/templates/ollama/ollama.py b/llama_stack/templates/ollama/ollama.py index d14cb3aad..2a072bd89 100644 --- a/llama_stack/templates/ollama/ollama.py +++ b/llama_stack/templates/ollama/ollama.py @@ -17,6 +17,7 @@ from llama_stack.providers.inline.inference.sentence_transformers import ( SentenceTransformersInferenceConfig, ) from llama_stack.providers.inline.vector_io.faiss.config import FaissImplConfig +from llama_stack.providers.inline.vector_io.sqlite_vec.config import SQLiteVecImplConfig from llama_stack.providers.remote.inference.ollama import OllamaImplConfig from llama_stack.templates.template import DistributionTemplate, RunConfigSettings @@ -49,11 +50,16 @@ def get_distribution_template() -> DistributionTemplate: provider_type="inline::sentence-transformers", config=SentenceTransformersInferenceConfig.sample_run_config(), ) - vector_io_provider = Provider( + vector_io_provider_faiss = Provider( provider_id="faiss", provider_type="inline::faiss", config=FaissImplConfig.sample_run_config(f"distributions/{name}"), ) + vector_io_provider_sqlite = Provider( + provider_id="sqlite_vec", + provider_type="inline::sqlite_vec", + config=SQLiteVecImplConfig.sample_run_config(f"distributions/{name}"), + ) inference_model = ModelInput( model_id="${env.INFERENCE_MODEL}", @@ -98,7 +104,7 @@ def get_distribution_template() -> DistributionTemplate: "run.yaml": RunConfigSettings( provider_overrides={ "inference": [inference_provider, embedding_provider], - "vector_io": [vector_io_provider], + "vector_io": [vector_io_provider_faiss, vector_io_provider_sqlite], }, default_models=[inference_model, embedding_model], default_tool_groups=default_tool_groups, @@ -109,7 +115,7 @@ def get_distribution_template() -> DistributionTemplate: inference_provider, embedding_provider, ], - "vector_io": [vector_io_provider], + "vector_io": [vector_io_provider_faiss, vector_io_provider_faiss], "safety": [ Provider( provider_id="llama-guard", diff --git a/llama_stack/templates/ollama/run.yaml b/llama_stack/templates/ollama/run.yaml index 3cc1cb2ac..578e3f3b6 100644 --- a/llama_stack/templates/ollama/run.yaml +++ b/llama_stack/templates/ollama/run.yaml @@ -20,13 +20,14 @@ providers: provider_type: inline::sentence-transformers config: {} vector_io: - - provider_id: faiss - provider_type: inline::faiss + - provider_id: sqlite_vec + provider_type: inline::sqlite_vec config: kvstore: type: sqlite namespace: null - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/faiss_store.db + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/sqlite_vec.db + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/sqlite_vec.db safety: - provider_id: llama-guard provider_type: inline::llama-guard diff --git a/tests/client-sdk/conftest.py b/tests/client-sdk/conftest.py index 8c44242fe..ef4a25b15 100644 --- a/tests/client-sdk/conftest.py +++ b/tests/client-sdk/conftest.py @@ -80,7 +80,7 @@ def llama_stack_client(provider_data): skip_logger_removal=True, ) if not client.initialize(): - raise RuntimeError("Initialization failed") + raise RuntimeError(f"Initialization failed {os.environ.get('LLAMA_STACK_CONFIG')} not found") elif os.environ.get("LLAMA_STACK_BASE_URL"): client = LlamaStackClient( diff --git a/tests/client-sdk/vector_io/test_vector_io.py b/tests/client-sdk/vector_io/test_vector_io.py index 36d3fe2c1..3b2b755be 100644 --- a/tests/client-sdk/vector_io/test_vector_io.py +++ b/tests/client-sdk/vector_io/test_vector_io.py @@ -8,6 +8,9 @@ import random import pytest +INLINE_VECTOR_DB_PROVIDERS = ["faiss"] +# "sqlite_vec" + @pytest.fixture(scope="function") def empty_vector_db_registry(llama_stack_client): @@ -17,26 +20,27 @@ def empty_vector_db_registry(llama_stack_client): @pytest.fixture(scope="function") -def single_entry_vector_db_registry(llama_stack_client, empty_vector_db_registry): +def single_entry_vector_db_registry(llama_stack_client, empty_vector_db_registry, provider_id): vector_db_id = f"test_vector_db_{random.randint(1000, 9999)}" llama_stack_client.vector_dbs.register( vector_db_id=vector_db_id, embedding_model="all-MiniLM-L6-v2", embedding_dimension=384, - provider_id="faiss", + provider_id=provider_id, ) vector_dbs = [vector_db.identifier for vector_db in llama_stack_client.vector_dbs.list()] return vector_dbs -def test_vector_db_retrieve(llama_stack_client, embedding_model, empty_vector_db_registry): +@pytest.mark.parametrize("provider_id", INLINE_VECTOR_DB_PROVIDERS) +def test_vector_db_retrieve(llama_stack_client, embedding_model, empty_vector_db_registry, provider_id): # Register a memory bank first vector_db_id = f"test_vector_db_{random.randint(1000, 9999)}" llama_stack_client.vector_dbs.register( vector_db_id=vector_db_id, embedding_model=embedding_model, embedding_dimension=384, - provider_id="faiss", + provider_id=provider_id, ) # Retrieve the memory bank and validate its properties @@ -44,7 +48,7 @@ def test_vector_db_retrieve(llama_stack_client, embedding_model, empty_vector_db assert response is not None assert response.identifier == vector_db_id assert response.embedding_model == embedding_model - assert response.provider_id == "faiss" + assert response.provider_id == provider_id assert response.provider_resource_id == vector_db_id @@ -53,20 +57,22 @@ def test_vector_db_list(llama_stack_client, empty_vector_db_registry): assert len(vector_dbs_after_register) == 0 -def test_vector_db_register(llama_stack_client, embedding_model, empty_vector_db_registry): +@pytest.mark.parametrize("provider_id", INLINE_VECTOR_DB_PROVIDERS) +def test_vector_db_register(llama_stack_client, embedding_model, empty_vector_db_registry, provider_id): vector_db_id = f"test_vector_db_{random.randint(1000, 9999)}" llama_stack_client.vector_dbs.register( vector_db_id=vector_db_id, embedding_model=embedding_model, embedding_dimension=384, - provider_id="faiss", + provider_id=provider_id, ) vector_dbs_after_register = [vector_db.identifier for vector_db in llama_stack_client.vector_dbs.list()] assert vector_dbs_after_register == [vector_db_id] -def test_vector_db_unregister(llama_stack_client, single_entry_vector_db_registry): +@pytest.mark.parametrize("provider_id", INLINE_VECTOR_DB_PROVIDERS) +def test_vector_db_unregister(llama_stack_client, single_entry_vector_db_registry, provider_id): vector_dbs = [vector_db.identifier for vector_db in llama_stack_client.vector_dbs.list()] assert len(vector_dbs) == 1