feat: Adding sqlite-vec as vectordb

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
Francisco Javier Arceo 2025-02-10 16:16:55 -05:00
parent 371f11a569
commit af7748a4d5
14 changed files with 356 additions and 28 deletions

View file

@ -185,7 +185,9 @@ class CommonRoutingTableImpl(RoutingTable):
obj.provider_id = list(self.impls_by_provider_id.keys())[0] obj.provider_id = list(self.impls_by_provider_id.keys())[0]
if obj.provider_id not in self.impls_by_provider_id: if obj.provider_id not in self.impls_by_provider_id:
raise ValueError(f"Provider `{obj.provider_id}` not found") raise ValueError(
f"Provider `{obj.provider_id}` not found \navailable providers: {self.impls_by_provider_id.keys()}"
)
p = self.impls_by_provider_id[obj.provider_id] p = self.impls_by_provider_id[obj.provider_id]
@ -335,6 +337,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
"embedding_model": embedding_model, "embedding_model": embedding_model,
"embedding_dimension": model.metadata["embedding_dimension"], "embedding_dimension": model.metadata["embedding_dimension"],
} }
print(f"Registering vector db {vector_db_data} with embedding model {embedding_model}")
vector_db = TypeAdapter(VectorDB).validate_python(vector_db_data) vector_db = TypeAdapter(VectorDB).validate_python(vector_db_data)
await self.register_object(vector_db) await self.register_object(vector_db)
return vector_db return vector_db

View file

@ -0,0 +1,18 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict
from llama_stack.providers.datatypes import Api, ProviderSpec
from .config import SQLiteVecImplConfig
async def get_provider_impl(config: SQLiteVecImplConfig, deps: Dict[Api, ProviderSpec]):
from .sqlite_vec import SQLiteVecVectorIOImpl
assert isinstance(config, SQLiteVecImplConfig), f"Unexpected config type: {type(config)}"
impl = SQLiteVecVectorIOImpl(config, deps[Api.inference])
await impl.initialize()
return impl

View file

@ -0,0 +1,28 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# config.py
from pydantic import BaseModel
from typing import Any, Dict
from llama_stack.providers.utils.kvstore.config import (
KVStoreConfig,
SqliteKVStoreConfig,
)
class SQLiteVecImplConfig(BaseModel):
db_path: str
kvstore: KVStoreConfig
@classmethod
def sample_run_config(cls, __distro_dir__: str) -> Dict[str, Any]:
return {
"kvstore": SqliteKVStoreConfig.sample_run_config(
__distro_dir__=__distro_dir__,
db_name="sqlite_vec.db",
)
}

View file

@ -0,0 +1,205 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# sqlite_vec_impl.py
import sqlite3
import sqlite_vec
import struct
import logging
import numpy as np
from numpy.typing import NDArray
from typing import List, Optional, Dict, Any
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
from llama_stack.providers.utils.memory.vector_store import EmbeddingIndex, VectorDBWithIndex
logger = logging.getLogger(__name__)
def serialize_vector(vector: List[float]) -> bytes:
"""Serialize a list of floats into a compact binary representation."""
return struct.pack(f"{len(vector)}f", *vector)
class SQLiteVecIndex(EmbeddingIndex):
"""
An index implementation that stores embeddings in a SQLite virtual table using sqlite-vec.
Two tables are used:
- A metadata table (chunks_{bank_id}) that holds the chunk JSON.
- A virtual table (vec_chunks_{bank_id}) that holds the serialized vector.
"""
def __init__(self, dimension: int, connection: sqlite3.Connection, bank_id: str):
self.dimension = dimension
self.connection = connection
self.bank_id = bank_id
self.metadata_table = f"chunks_{bank_id}".replace("-", "_")
self.vector_table = f"vec_chunks_{bank_id}".replace("-", "_")
@classmethod
async def create(cls, dimension: int, connection: sqlite3.Connection, bank_id: str):
instance = cls(dimension, connection, bank_id)
await instance.initialize()
return instance
async def initialize(self) -> None:
cur = self.connection.cursor()
print(f"Creating tables {self.metadata_table} and {self.vector_table}")
# Create the table to store chunk metadata.
cur.execute(f"""
CREATE TABLE IF NOT EXISTS {self.metadata_table} (
id INTEGER PRIMARY KEY,
chunk TEXT
);
""")
# Create the virtual table for embeddings.
cur.execute(f"""
CREATE VIRTUAL TABLE IF NOT EXISTS {self.vector_table}
USING vec0(embedding FLOAT[{self.dimension}]);
""")
self.connection.commit()
async def delete(self):
cur = self.connection.cursor()
cur.execute(f"DROP TABLE IF EXISTS {self.metadata_table};")
cur.execute(f"DROP TABLE IF EXISTS {self.vector_table};")
self.connection.commit()
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
"""
Add new chunks along with their embeddings.
For each chunk, we insert its JSON into the metadata table and then insert its
embedding (serialized to raw bytes) into the virtual table using the assigned rowid.
"""
cur = self.connection.cursor()
for chunk, emb in zip(chunks, embeddings):
# Serialize and insert the chunk metadata.
chunk_json = chunk.model_dump_json()
cur.execute(f"INSERT INTO {self.metadata_table} (chunk) VALUES (?)", (chunk_json,))
row_id = cur.lastrowid
# Ensure the embedding is a list of floats.
emb_list = emb.tolist() if isinstance(emb, np.ndarray) else list(emb)
emb_blob = serialize_vector(emb_list)
cur.execute(f"INSERT INTO {self.vector_table} (rowid, embedding) VALUES (?, ?)", (row_id, emb_blob))
self.connection.commit()
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
"""
Query for the k most similar chunks. We convert the query embedding to a blob and run a SQL query
against the virtual table. The SQL joins the metadata table to recover the chunk JSON.
"""
emb_list = embedding.tolist() if isinstance(embedding, np.ndarray) else list(embedding)
emb_blob = serialize_vector(emb_list)
cur = self.connection.cursor()
query_sql = f"""
SELECT m.id, m.chunk, v.distance
FROM {self.vector_table} AS v
JOIN {self.metadata_table} AS m ON m.id = v.rowid
WHERE v.embedding MATCH ? AND k = ?
ORDER BY v.distance;
"""
cur.execute(query_sql, (emb_blob, k))
rows = cur.fetchall()
chunks = []
scores = []
for _id, chunk_json, distance in rows:
try:
chunk = Chunk.model_validate_json(chunk_json)
except Exception as e:
logger.error(f"Error parsing chunk JSON for id {_id}: {e}")
continue
chunks.append(chunk)
# Mimic the Faiss scoring: score = 1/distance (avoid division by zero)
score = 1.0 / distance if distance != 0 else float("inf")
scores.append(score)
return QueryChunksResponse(chunks=chunks, scores=scores)
class SQLiteVecVectorIOImpl(VectorIO, VectorDBsProtocolPrivate):
"""
A VectorIO implementation using SQLite + sqlite_vec.
This class handles vector database registration (with metadata stored in a table named `vector_dbs`)
and creates a cache of VectorDBWithIndex instances (each wrapping a SQLiteVecIndex).
"""
def __init__(self, config, inference_api: Api.inference) -> None:
self.config = config
self.inference_api = inference_api
self.cache: Dict[str, VectorDBWithIndex] = {}
self.connection: Optional[sqlite3.Connection] = None
async def initialize(self) -> None:
# Open a connection to the SQLite database (the file is specified in the config).
print(f"Connecting to SQLite database at {self.config.db_path}")
self.connection = sqlite3.connect(self.config.db_path)
self.connection.enable_load_extension(True)
sqlite_vec.load(self.connection)
self.connection.enable_load_extension(False)
cur = self.connection.cursor()
# Create a table to persist vector DB registrations.
cur.execute("""
CREATE TABLE IF NOT EXISTS vector_dbs (
id TEXT PRIMARY KEY,
metadata TEXT
);
""")
self.connection.commit()
# Load any existing vector DB registrations.
cur.execute("SELECT metadata FROM vector_dbs")
rows = cur.fetchall()
for row in rows:
vector_db_data = row[0]
vector_db = VectorDB.model_validate_json(vector_db_data)
index = await SQLiteVecIndex.create(vector_db.embedding_dimension, self.connection, vector_db.identifier)
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
async def shutdown(self) -> None:
if self.connection:
self.connection.close()
self.connection = None
async def register_vector_db(self, vector_db: VectorDB) -> None:
if self.connection is None:
raise RuntimeError("SQLite connection not initialized")
cur = self.connection.cursor()
cur.execute(
"INSERT OR REPLACE INTO vector_dbs (id, metadata) VALUES (?, ?)",
(vector_db.identifier, vector_db.model_dump_json()),
)
self.connection.commit()
index = await SQLiteVecIndex.create(vector_db.embedding_dimension, self.connection, vector_db.identifier)
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
async def list_vector_dbs(self) -> List[VectorDB]:
return [v.vector_db for v in self.cache.values()]
async def unregister_vector_db(self, vector_db_id: str) -> None:
if self.connection is None:
raise RuntimeError("SQLite connection not initialized")
if vector_db_id not in self.cache:
logger.warning(f"Vector DB {vector_db_id} not found")
return
await self.cache[vector_db_id].index.delete()
del self.cache[vector_db_id]
cur = self.connection.cursor()
cur.execute("DELETE FROM vector_dbs WHERE id = ?", (vector_db_id,))
self.connection.commit()
async def insert_chunks(self, vector_db_id: str, chunks: List[Chunk], ttl_seconds: Optional[int] = None) -> None:
if vector_db_id not in self.cache:
raise ValueError(f"Vector DB {vector_db_id} not found. Found: {list(self.cache.keys())}")
# The VectorDBWithIndex helper is expected to compute embeddings via the inference_api
# and then call our indexs add_chunks.
await self.cache[vector_db_id].insert_chunks(chunks)
async def query_chunks(
self, vector_db_id: str, query: Any, params: Optional[Dict[str, Any]] = None
) -> QueryChunksResponse:
if vector_db_id not in self.cache:
raise ValueError(f"Vector DB {vector_db_id} not found")
return await self.cache[vector_db_id].query_chunks(query, params)

View file

@ -54,6 +54,14 @@ def available_providers() -> List[ProviderSpec]:
config_class="llama_stack.providers.inline.vector_io.faiss.FaissImplConfig", config_class="llama_stack.providers.inline.vector_io.faiss.FaissImplConfig",
api_dependencies=[Api.inference], api_dependencies=[Api.inference],
), ),
InlineProviderSpec(
api=Api.vector_io,
provider_type="inline::sqlite_vec",
pip_packages=EMBEDDING_DEPS + ["sqlite_vec"],
module="llama_stack.providers.inline.vector_io.sqlite_vec",
config_class="llama_stack.providers.inline.vector_io.sqlite_vec.SQLiteVecImplConfig",
api_dependencies=[Api.inference],
),
remote_provider_spec( remote_provider_spec(
Api.vector_io, Api.vector_io,
AdapterSpec( AdapterSpec(

View file

@ -352,20 +352,24 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
return EmbeddingsResponse(embeddings=embeddings) return EmbeddingsResponse(embeddings=embeddings)
async def register_model(self, model: Model) -> Model: async def register_model(self, model: Model) -> Model:
async def check_model_availability(model_id: str): # ollama does not have embedding models running. Check if the model is in list of available models.
response = await self.client.ps()
available_models = [m["model"] for m in response["models"]]
if model_id not in available_models:
raise ValueError(
f"Model '{model_id}' is not available in Ollama. Available models: {', '.join(available_models)}"
)
if model.model_type == ModelType.embedding: if model.model_type == ModelType.embedding:
await check_model_availability(model.provider_resource_id) response = await self.client.list()
available_models = [m["model"] for m in response["models"]]
if model.provider_resource_id not in available_models:
raise ValueError(
f"Model '{model.provider_resource_id}' is not available in Ollama. "
f"Available models: {', '.join(available_models)}"
)
return model return model
model = await self.register_helper.register_model(model) model = await self.register_helper.register_model(model)
await check_model_availability(model.provider_resource_id) models = await self.client.ps()
available_models = [m["model"] for m in models["models"]]
if model.provider_resource_id not in available_models:
raise ValueError(
f"Model '{model.provider_resource_id}' is not available in Ollama. "
f"Available models: [{', '.join(available_models)}]"
)
return model return model

View file

@ -41,6 +41,14 @@ DEFAULT_PROVIDER_COMBINATIONS = [
id="ollama", id="ollama",
marks=pytest.mark.ollama, marks=pytest.mark.ollama,
), ),
pytest.param(
{
"inference": "ollama",
"vector_io": "sqlite_vec",
},
id="sqlite_vec",
marks=pytest.mark.ollama,
),
pytest.param( pytest.param(
{ {
"inference": "sentence_transformers", "inference": "sentence_transformers",

View file

@ -15,6 +15,7 @@ from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.inline.vector_io.chroma import ChromaInlineImplConfig from llama_stack.providers.inline.vector_io.chroma import ChromaInlineImplConfig
from llama_stack.providers.inline.vector_io.faiss import FaissImplConfig from llama_stack.providers.inline.vector_io.faiss import FaissImplConfig
from llama_stack.providers.inline.vector_io.sqlite_vec import SQLiteVecImplConfig
from llama_stack.providers.remote.vector_io.chroma import ChromaRemoteImplConfig from llama_stack.providers.remote.vector_io.chroma import ChromaRemoteImplConfig
from llama_stack.providers.remote.vector_io.pgvector import PGVectorConfig from llama_stack.providers.remote.vector_io.pgvector import PGVectorConfig
from llama_stack.providers.remote.vector_io.weaviate import WeaviateConfig from llama_stack.providers.remote.vector_io.weaviate import WeaviateConfig
@ -53,6 +54,22 @@ def vector_io_faiss() -> ProviderFixture:
) )
@pytest.fixture(scope="session")
def vector_io_sqlite_vec() -> ProviderFixture:
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
return ProviderFixture(
providers=[
Provider(
provider_id="sqlite_vec",
provider_type="inline::sqlite_vec",
config=SQLiteVecImplConfig(
kvstore=SqliteKVStoreConfig(db_path=temp_file.name).model_dump(),
).model_dump(),
)
],
)
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def vector_io_pgvector() -> ProviderFixture: def vector_io_pgvector() -> ProviderFixture:
return ProviderFixture( return ProviderFixture(
@ -111,7 +128,13 @@ def vector_io_chroma() -> ProviderFixture:
) )
VECTOR_IO_FIXTURES = ["faiss", "pgvector", "weaviate", "chroma"] VECTOR_IO_FIXTURES = [
"faiss",
"pgvector",
"weaviate",
"chroma",
"sqlite_vec",
]
@pytest_asyncio.fixture(scope="session") @pytest_asyncio.fixture(scope="session")

View file

@ -17,6 +17,7 @@ from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
class KVStoreType(Enum): class KVStoreType(Enum):
redis = "redis" redis = "redis"
sqlite = "sqlite" sqlite = "sqlite"
milvus_lite = "milvus_lite"
postgres = "postgres" postgres = "postgres"
@ -62,6 +63,22 @@ class SqliteKVStoreConfig(CommonConfig):
} }
class MilvusLiteKVStoreConfig(CommonConfig):
type: Literal[KVStoreType.milvus_lite.value] = KVStoreType.milvus_lite.value
db_path: str = Field(
default=(RUNTIME_BASE_DIR / "kvstore.db").as_posix(),
description="File path for the sqlite database",
)
@classmethod
def sample_run_config(cls, __distro_dir__: str = "runtime", db_name: str = "kvstore.db"):
return {
"type": "milvuslite",
"namespace": None,
"db_path": "${env.SQLITE_STORE_DIR:~/.llama/" + __distro_dir__ + "}/" + db_name,
}
class PostgresKVStoreConfig(CommonConfig): class PostgresKVStoreConfig(CommonConfig):
type: Literal[KVStoreType.postgres.value] = KVStoreType.postgres.value type: Literal[KVStoreType.postgres.value] = KVStoreType.postgres.value
host: str = "localhost" host: str = "localhost"

View file

@ -6,6 +6,7 @@ distribution_spec:
- remote::ollama - remote::ollama
vector_io: vector_io:
- inline::faiss - inline::faiss
- inline::sqlite_vec
- remote::chromadb - remote::chromadb
- remote::pgvector - remote::pgvector
safety: safety:

View file

@ -17,6 +17,7 @@ from llama_stack.providers.inline.inference.sentence_transformers import (
SentenceTransformersInferenceConfig, SentenceTransformersInferenceConfig,
) )
from llama_stack.providers.inline.vector_io.faiss.config import FaissImplConfig from llama_stack.providers.inline.vector_io.faiss.config import FaissImplConfig
from llama_stack.providers.inline.vector_io.sqlite_vec.config import SQLiteVecImplConfig
from llama_stack.providers.remote.inference.ollama import OllamaImplConfig from llama_stack.providers.remote.inference.ollama import OllamaImplConfig
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
@ -49,11 +50,16 @@ def get_distribution_template() -> DistributionTemplate:
provider_type="inline::sentence-transformers", provider_type="inline::sentence-transformers",
config=SentenceTransformersInferenceConfig.sample_run_config(), config=SentenceTransformersInferenceConfig.sample_run_config(),
) )
vector_io_provider = Provider( vector_io_provider_faiss = Provider(
provider_id="faiss", provider_id="faiss",
provider_type="inline::faiss", provider_type="inline::faiss",
config=FaissImplConfig.sample_run_config(f"distributions/{name}"), config=FaissImplConfig.sample_run_config(f"distributions/{name}"),
) )
vector_io_provider_sqlite = Provider(
provider_id="sqlite_vec",
provider_type="inline::sqlite_vec",
config=SQLiteVecImplConfig.sample_run_config(f"distributions/{name}"),
)
inference_model = ModelInput( inference_model = ModelInput(
model_id="${env.INFERENCE_MODEL}", model_id="${env.INFERENCE_MODEL}",
@ -98,7 +104,7 @@ def get_distribution_template() -> DistributionTemplate:
"run.yaml": RunConfigSettings( "run.yaml": RunConfigSettings(
provider_overrides={ provider_overrides={
"inference": [inference_provider, embedding_provider], "inference": [inference_provider, embedding_provider],
"vector_io": [vector_io_provider], "vector_io": [vector_io_provider_faiss, vector_io_provider_sqlite],
}, },
default_models=[inference_model, embedding_model], default_models=[inference_model, embedding_model],
default_tool_groups=default_tool_groups, default_tool_groups=default_tool_groups,
@ -109,7 +115,7 @@ def get_distribution_template() -> DistributionTemplate:
inference_provider, inference_provider,
embedding_provider, embedding_provider,
], ],
"vector_io": [vector_io_provider], "vector_io": [vector_io_provider_faiss, vector_io_provider_faiss],
"safety": [ "safety": [
Provider( Provider(
provider_id="llama-guard", provider_id="llama-guard",

View file

@ -20,13 +20,14 @@ providers:
provider_type: inline::sentence-transformers provider_type: inline::sentence-transformers
config: {} config: {}
vector_io: vector_io:
- provider_id: faiss - provider_id: sqlite_vec
provider_type: inline::faiss provider_type: inline::sqlite_vec
config: config:
kvstore: kvstore:
type: sqlite type: sqlite
namespace: null namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/faiss_store.db db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/sqlite_vec.db
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/sqlite_vec.db
safety: safety:
- provider_id: llama-guard - provider_id: llama-guard
provider_type: inline::llama-guard provider_type: inline::llama-guard

View file

@ -80,7 +80,7 @@ def llama_stack_client(provider_data):
skip_logger_removal=True, skip_logger_removal=True,
) )
if not client.initialize(): if not client.initialize():
raise RuntimeError("Initialization failed") raise RuntimeError(f"Initialization failed {os.environ.get('LLAMA_STACK_CONFIG')} not found")
elif os.environ.get("LLAMA_STACK_BASE_URL"): elif os.environ.get("LLAMA_STACK_BASE_URL"):
client = LlamaStackClient( client = LlamaStackClient(

View file

@ -8,6 +8,9 @@ import random
import pytest import pytest
INLINE_VECTOR_DB_PROVIDERS = ["faiss"]
# "sqlite_vec"
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
def empty_vector_db_registry(llama_stack_client): def empty_vector_db_registry(llama_stack_client):
@ -17,26 +20,27 @@ def empty_vector_db_registry(llama_stack_client):
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
def single_entry_vector_db_registry(llama_stack_client, empty_vector_db_registry): def single_entry_vector_db_registry(llama_stack_client, empty_vector_db_registry, provider_id):
vector_db_id = f"test_vector_db_{random.randint(1000, 9999)}" vector_db_id = f"test_vector_db_{random.randint(1000, 9999)}"
llama_stack_client.vector_dbs.register( llama_stack_client.vector_dbs.register(
vector_db_id=vector_db_id, vector_db_id=vector_db_id,
embedding_model="all-MiniLM-L6-v2", embedding_model="all-MiniLM-L6-v2",
embedding_dimension=384, embedding_dimension=384,
provider_id="faiss", provider_id=provider_id,
) )
vector_dbs = [vector_db.identifier for vector_db in llama_stack_client.vector_dbs.list()] vector_dbs = [vector_db.identifier for vector_db in llama_stack_client.vector_dbs.list()]
return vector_dbs return vector_dbs
def test_vector_db_retrieve(llama_stack_client, embedding_model, empty_vector_db_registry): @pytest.mark.parametrize("provider_id", INLINE_VECTOR_DB_PROVIDERS)
def test_vector_db_retrieve(llama_stack_client, embedding_model, empty_vector_db_registry, provider_id):
# Register a memory bank first # Register a memory bank first
vector_db_id = f"test_vector_db_{random.randint(1000, 9999)}" vector_db_id = f"test_vector_db_{random.randint(1000, 9999)}"
llama_stack_client.vector_dbs.register( llama_stack_client.vector_dbs.register(
vector_db_id=vector_db_id, vector_db_id=vector_db_id,
embedding_model=embedding_model, embedding_model=embedding_model,
embedding_dimension=384, embedding_dimension=384,
provider_id="faiss", provider_id=provider_id,
) )
# Retrieve the memory bank and validate its properties # Retrieve the memory bank and validate its properties
@ -44,7 +48,7 @@ def test_vector_db_retrieve(llama_stack_client, embedding_model, empty_vector_db
assert response is not None assert response is not None
assert response.identifier == vector_db_id assert response.identifier == vector_db_id
assert response.embedding_model == embedding_model assert response.embedding_model == embedding_model
assert response.provider_id == "faiss" assert response.provider_id == provider_id
assert response.provider_resource_id == vector_db_id assert response.provider_resource_id == vector_db_id
@ -53,20 +57,22 @@ def test_vector_db_list(llama_stack_client, empty_vector_db_registry):
assert len(vector_dbs_after_register) == 0 assert len(vector_dbs_after_register) == 0
def test_vector_db_register(llama_stack_client, embedding_model, empty_vector_db_registry): @pytest.mark.parametrize("provider_id", INLINE_VECTOR_DB_PROVIDERS)
def test_vector_db_register(llama_stack_client, embedding_model, empty_vector_db_registry, provider_id):
vector_db_id = f"test_vector_db_{random.randint(1000, 9999)}" vector_db_id = f"test_vector_db_{random.randint(1000, 9999)}"
llama_stack_client.vector_dbs.register( llama_stack_client.vector_dbs.register(
vector_db_id=vector_db_id, vector_db_id=vector_db_id,
embedding_model=embedding_model, embedding_model=embedding_model,
embedding_dimension=384, embedding_dimension=384,
provider_id="faiss", provider_id=provider_id,
) )
vector_dbs_after_register = [vector_db.identifier for vector_db in llama_stack_client.vector_dbs.list()] vector_dbs_after_register = [vector_db.identifier for vector_db in llama_stack_client.vector_dbs.list()]
assert vector_dbs_after_register == [vector_db_id] assert vector_dbs_after_register == [vector_db_id]
def test_vector_db_unregister(llama_stack_client, single_entry_vector_db_registry): @pytest.mark.parametrize("provider_id", INLINE_VECTOR_DB_PROVIDERS)
def test_vector_db_unregister(llama_stack_client, single_entry_vector_db_registry, provider_id):
vector_dbs = [vector_db.identifier for vector_db in llama_stack_client.vector_dbs.list()] vector_dbs = [vector_db.identifier for vector_db in llama_stack_client.vector_dbs.list()]
assert len(vector_dbs) == 1 assert len(vector_dbs) == 1