feat: Adding sqlite-vec as a vectordb (#1040)

# What does this PR do?
This PR adds `sqlite_vec` as an additional inline vectordb.

Tested with `ollama` by adding the `vector_io` object in
`./llama_stack/templates/ollama/run.yaml` :

```yaml
  vector_io:
  - provider_id: sqlite_vec
    provider_type: inline::sqlite_vec
    config:
      kvstore:
        type: sqlite
        namespace: null
        db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/sqlite_vec.db
      db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/sqlite_vec.db
```
I also updated the `./tests/client-sdk/vector_io/test_vector_io.py` test
file with:
```python
INLINE_VECTOR_DB_PROVIDERS = ["faiss", "sqlite_vec"]
```
And parameterized the relevant tests. 

[//]: # (If resolving an issue, uncomment and update the line below)
# Closes 
https://github.com/meta-llama/llama-stack/issues/1005

## Test Plan
I ran the tests with:
```bash
INFERENCE_MODEL=llama3.2:3b-instruct-fp16 LLAMA_STACK_CONFIG=ollama pytest -s -v tests/client-sdk/vector_io/test_vector_io.py
```
Which outputs:
```python
...
PASSED
tests/client-sdk/vector_io/test_vector_io.py::test_vector_db_retrieve[all-MiniLM-L6-v2-sqlite_vec] PASSED
tests/client-sdk/vector_io/test_vector_io.py::test_vector_db_list PASSED
tests/client-sdk/vector_io/test_vector_io.py::test_vector_db_register[all-MiniLM-L6-v2-faiss] PASSED
tests/client-sdk/vector_io/test_vector_io.py::test_vector_db_register[all-MiniLM-L6-v2-sqlite_vec] PASSED
tests/client-sdk/vector_io/test_vector_io.py::test_vector_db_unregister[faiss] PASSED
tests/client-sdk/vector_io/test_vector_io.py::test_vector_db_unregister[sqlite_vec] PASSED
```

In addition, I ran the `rag_with_vector_db.py`
[example](https://github.com/meta-llama/llama-stack-apps/blob/main/examples/agents/rag_with_vector_db.py)
using the script below with `uv run rag_example.py`.
<details>
<summary>CLICK TO SHOW SCRIPT 👋  </summary>

```python
#!/usr/bin/env python3
import os
import uuid
from termcolor import cprint

# Set environment variables
os.environ['INFERENCE_MODEL'] = 'llama3.2:3b-instruct-fp16'
os.environ['LLAMA_STACK_CONFIG'] = 'ollama'

# Import libraries after setting environment variables
from llama_stack.distribution.library_client import LlamaStackAsLibraryClient
from llama_stack_client.lib.agents.agent import Agent
from llama_stack_client.lib.agents.event_logger import EventLogger
from llama_stack_client.types.agent_create_params import AgentConfig
from llama_stack_client.types import Document


def main():
    # Initialize the client
    client = LlamaStackAsLibraryClient("ollama")
    vector_db_id = f"test-vector-db-{uuid.uuid4().hex}"

    _ = client.initialize()

    model_id = 'llama3.2:3b-instruct-fp16'

    # Define the list of document URLs and create Document objects
    urls = [
        "chat.rst",
        "llama3.rst",
        "memory_optimizations.rst",
        "lora_finetune.rst",
    ]
    documents = [
        Document(
            document_id=f"num-{i}",
            content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
            mime_type="text/plain",
            metadata={},
        )
        for i, url in enumerate(urls)
    ]
    # (Optional) Use the documents as needed with your client here

    client.vector_dbs.register(
        provider_id='sqlite_vec',
        vector_db_id=vector_db_id,
        embedding_model="all-MiniLM-L6-v2",
        embedding_dimension=384,
    )

    client.tool_runtime.rag_tool.insert(
        documents=documents,
        vector_db_id=vector_db_id,
        chunk_size_in_tokens=512,
    )
    # Create agent configuration
    agent_config = AgentConfig(
        model=model_id,
        instructions="You are a helpful assistant",
        enable_session_persistence=False,
        toolgroups=[
            {
                "name": "builtin::rag",
                "args": {
                    "vector_db_ids": [vector_db_id],
                }
            }
        ],
    )

    # Instantiate the Agent
    agent = Agent(client, agent_config)

    # List of user prompts
    user_prompts = [
        "What are the top 5 topics that were explained in the documentation? Only list succinct bullet points.",
        "Was anything related to 'Llama3' discussed, if so what?",
        "Tell me how to use LoRA",
        "What about Quantization?",
    ]

    # Create a session for the agent
    session_id = agent.create_session("test-session")

    # Process each prompt and display the output
    for prompt in user_prompts:
        cprint(f"User> {prompt}", "green")
        response = agent.create_turn(
            messages=[
                {
                    "role": "user",
                    "content": prompt,
                }
            ],
            session_id=session_id,
        )
        # Log and print events from the response
        for log in EventLogger().log(response):
            log.print()


if __name__ == "__main__":
    main()
```
</details>

Which outputs a large summary of RAG generation.

# Documentation

Will handle documentation updates in follow-up PR.

# (- [ ] Added a Changelog entry if the change is significant)

---------

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
Francisco Arceo 2025-02-12 13:50:03 -05:00 committed by GitHub
parent 025f615868
commit 119fe8742a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 331 additions and 12 deletions

View file

@ -0,0 +1,18 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict
from llama_stack.providers.datatypes import Api, ProviderSpec
from .config import SQLiteVectorIOConfig
async def get_provider_impl(config: SQLiteVectorIOConfig, deps: Dict[Api, ProviderSpec]):
from .sqlite_vec import SQLiteVecVectorIOAdapter
assert isinstance(config, SQLiteVectorIOConfig), f"Unexpected config type: {type(config)}"
impl = SQLiteVecVectorIOAdapter(config, deps[Api.inference])
await impl.initialize()
return impl

View file

@ -0,0 +1,28 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# config.py
from pydantic import BaseModel
from typing import Any, Dict
from llama_stack.providers.utils.kvstore.config import (
KVStoreConfig,
SqliteKVStoreConfig,
)
class SQLiteVectorIOConfig(BaseModel):
db_path: str
kvstore: KVStoreConfig
@classmethod
def sample_run_config(cls, __distro_dir__: str) -> Dict[str, Any]:
return {
"kvstore": SqliteKVStoreConfig.sample_run_config(
__distro_dir__=__distro_dir__,
db_name="sqlite_vec.db",
)
}

View file

@ -0,0 +1,214 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import sqlite3
import sqlite_vec
import struct
import logging
import numpy as np
from numpy.typing import NDArray
from typing import List, Optional, Dict, Any
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
from llama_stack.providers.utils.memory.vector_store import EmbeddingIndex, VectorDBWithIndex
logger = logging.getLogger(__name__)
def serialize_vector(vector: List[float]) -> bytes:
"""Serialize a list of floats into a compact binary representation."""
return struct.pack(f"{len(vector)}f", *vector)
class SQLiteVecIndex(EmbeddingIndex):
"""
An index implementation that stores embeddings in a SQLite virtual table using sqlite-vec.
Two tables are used:
- A metadata table (chunks_{bank_id}) that holds the chunk JSON.
- A virtual table (vec_chunks_{bank_id}) that holds the serialized vector.
"""
def __init__(self, dimension: int, connection: sqlite3.Connection, bank_id: str):
self.dimension = dimension
self.connection = connection
self.bank_id = bank_id
self.metadata_table = f"chunks_{bank_id}".replace("-", "_")
self.vector_table = f"vec_chunks_{bank_id}".replace("-", "_")
@classmethod
async def create(cls, dimension: int, connection: sqlite3.Connection, bank_id: str):
instance = cls(dimension, connection, bank_id)
await instance.initialize()
return instance
async def initialize(self) -> None:
cur = self.connection.cursor()
# Create the table to store chunk metadata.
cur.execute(f"""
CREATE TABLE IF NOT EXISTS {self.metadata_table} (
id INTEGER PRIMARY KEY,
chunk TEXT
);
""")
# Create the virtual table for embeddings.
cur.execute(f"""
CREATE VIRTUAL TABLE IF NOT EXISTS {self.vector_table}
USING vec0(embedding FLOAT[{self.dimension}]);
""")
self.connection.commit()
async def delete(self):
cur = self.connection.cursor()
cur.execute(f"DROP TABLE IF EXISTS {self.metadata_table};")
cur.execute(f"DROP TABLE IF EXISTS {self.vector_table};")
self.connection.commit()
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
"""
Add new chunks along with their embeddings.
For each chunk, we insert its JSON into the metadata table and then insert its
embedding (serialized to raw bytes) into the virtual table using the assigned rowid.
If any insert fails, the transaction is rolled back to maintain consistency.
"""
cur = self.connection.cursor()
try:
# Start transaction
cur.execute("BEGIN TRANSACTION")
for chunk, emb in zip(chunks, embeddings):
# Serialize and insert the chunk metadata.
chunk_json = chunk.model_dump_json()
cur.execute(f"INSERT INTO {self.metadata_table} (chunk) VALUES (?)", (chunk_json,))
row_id = cur.lastrowid
# Ensure the embedding is a list of floats.
emb_list = emb.tolist() if isinstance(emb, np.ndarray) else list(emb)
emb_blob = serialize_vector(emb_list)
cur.execute(f"INSERT INTO {self.vector_table} (rowid, embedding) VALUES (?, ?)", (row_id, emb_blob))
# Commit transaction if all inserts succeed
self.connection.commit()
except sqlite3.Error as e:
self.connection.rollback() # Rollback on failure
print(f"Error inserting into {self.vector_table} - error: {e}") # Log error (Consider using logging module)
finally:
cur.close() # Ensure cursor is closed
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
"""
Query for the k most similar chunks. We convert the query embedding to a blob and run a SQL query
against the virtual table. The SQL joins the metadata table to recover the chunk JSON.
"""
emb_list = embedding.tolist() if isinstance(embedding, np.ndarray) else list(embedding)
emb_blob = serialize_vector(emb_list)
cur = self.connection.cursor()
query_sql = f"""
SELECT m.id, m.chunk, v.distance
FROM {self.vector_table} AS v
JOIN {self.metadata_table} AS m ON m.id = v.rowid
WHERE v.embedding MATCH ? AND k = ?
ORDER BY v.distance;
"""
cur.execute(query_sql, (emb_blob, k))
rows = cur.fetchall()
chunks = []
scores = []
for _id, chunk_json, distance in rows:
try:
chunk = Chunk.model_validate_json(chunk_json)
except Exception as e:
logger.error(f"Error parsing chunk JSON for id {_id}: {e}")
continue
chunks.append(chunk)
# Mimic the Faiss scoring: score = 1/distance (avoid division by zero)
score = 1.0 / distance if distance != 0 else float("inf")
scores.append(score)
return QueryChunksResponse(chunks=chunks, scores=scores)
class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
"""
A VectorIO implementation using SQLite + sqlite_vec.
This class handles vector database registration (with metadata stored in a table named `vector_dbs`)
and creates a cache of VectorDBWithIndex instances (each wrapping a SQLiteVecIndex).
"""
def __init__(self, config, inference_api: Api.inference) -> None:
self.config = config
self.inference_api = inference_api
self.cache: Dict[str, VectorDBWithIndex] = {}
self.connection: Optional[sqlite3.Connection] = None
async def initialize(self) -> None:
# Open a connection to the SQLite database (the file is specified in the config).
self.connection = sqlite3.connect(self.config.db_path)
self.connection.enable_load_extension(True)
sqlite_vec.load(self.connection)
self.connection.enable_load_extension(False)
cur = self.connection.cursor()
# Create a table to persist vector DB registrations.
cur.execute("""
CREATE TABLE IF NOT EXISTS vector_dbs (
id TEXT PRIMARY KEY,
metadata TEXT
);
""")
self.connection.commit()
# Load any existing vector DB registrations.
cur.execute("SELECT metadata FROM vector_dbs")
rows = cur.fetchall()
for row in rows:
vector_db_data = row[0]
vector_db = VectorDB.model_validate_json(vector_db_data)
index = await SQLiteVecIndex.create(vector_db.embedding_dimension, self.connection, vector_db.identifier)
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
async def shutdown(self) -> None:
if self.connection:
self.connection.close()
self.connection = None
async def register_vector_db(self, vector_db: VectorDB) -> None:
if self.connection is None:
raise RuntimeError("SQLite connection not initialized")
cur = self.connection.cursor()
cur.execute(
"INSERT OR REPLACE INTO vector_dbs (id, metadata) VALUES (?, ?)",
(vector_db.identifier, vector_db.model_dump_json()),
)
self.connection.commit()
index = await SQLiteVecIndex.create(vector_db.embedding_dimension, self.connection, vector_db.identifier)
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
async def list_vector_dbs(self) -> List[VectorDB]:
return [v.vector_db for v in self.cache.values()]
async def unregister_vector_db(self, vector_db_id: str) -> None:
if self.connection is None:
raise RuntimeError("SQLite connection not initialized")
if vector_db_id not in self.cache:
logger.warning(f"Vector DB {vector_db_id} not found")
return
await self.cache[vector_db_id].index.delete()
del self.cache[vector_db_id]
cur = self.connection.cursor()
cur.execute("DELETE FROM vector_dbs WHERE id = ?", (vector_db_id,))
self.connection.commit()
async def insert_chunks(self, vector_db_id: str, chunks: List[Chunk], ttl_seconds: Optional[int] = None) -> None:
if vector_db_id not in self.cache:
raise ValueError(f"Vector DB {vector_db_id} not found. Found: {list(self.cache.keys())}")
# The VectorDBWithIndex helper is expected to compute embeddings via the inference_api
# and then call our indexs add_chunks.
await self.cache[vector_db_id].insert_chunks(chunks)
async def query_chunks(
self, vector_db_id: str, query: Any, params: Optional[Dict[str, Any]] = None
) -> QueryChunksResponse:
if vector_db_id not in self.cache:
raise ValueError(f"Vector DB {vector_db_id} not found")
return await self.cache[vector_db_id].query_chunks(query, params)

View file

@ -54,6 +54,14 @@ def available_providers() -> List[ProviderSpec]:
config_class="llama_stack.providers.inline.vector_io.faiss.FaissImplConfig",
api_dependencies=[Api.inference],
),
InlineProviderSpec(
api=Api.vector_io,
provider_type="inline::sqlite_vec",
pip_packages=EMBEDDING_DEPS + ["sqlite-vec"],
module="llama_stack.providers.inline.vector_io.sqlite_vec",
config_class="llama_stack.providers.inline.vector_io.sqlite_vec.SQLiteVectorIOConfig",
api_dependencies=[Api.inference],
),
remote_provider_spec(
Api.vector_io,
AdapterSpec(

View file

@ -41,6 +41,14 @@ DEFAULT_PROVIDER_COMBINATIONS = [
id="ollama",
marks=pytest.mark.ollama,
),
pytest.param(
{
"inference": "ollama",
"vector_io": "sqlite_vec",
},
id="sqlite_vec",
marks=pytest.mark.ollama,
),
pytest.param(
{
"inference": "sentence_transformers",

View file

@ -15,6 +15,7 @@ from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.inline.vector_io.chroma import ChromaInlineImplConfig
from llama_stack.providers.inline.vector_io.faiss import FaissImplConfig
from llama_stack.providers.inline.vector_io.sqlite_vec import SQLiteVectorIOConfig
from llama_stack.providers.remote.vector_io.chroma import ChromaRemoteImplConfig
from llama_stack.providers.remote.vector_io.pgvector import PGVectorConfig
from llama_stack.providers.remote.vector_io.weaviate import WeaviateConfig
@ -53,6 +54,22 @@ def vector_io_faiss() -> ProviderFixture:
)
@pytest.fixture(scope="session")
def vector_io_sqlite_vec() -> ProviderFixture:
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
return ProviderFixture(
providers=[
Provider(
provider_id="sqlite_vec",
provider_type="inline::sqlite_vec",
config=SQLiteVectorIOConfig(
kvstore=SqliteKVStoreConfig(db_path=temp_file.name).model_dump(),
).model_dump(),
)
],
)
@pytest.fixture(scope="session")
def vector_io_pgvector() -> ProviderFixture:
return ProviderFixture(
@ -111,7 +128,13 @@ def vector_io_chroma() -> ProviderFixture:
)
VECTOR_IO_FIXTURES = ["faiss", "pgvector", "weaviate", "chroma"]
VECTOR_IO_FIXTURES = [
"faiss",
"pgvector",
"weaviate",
"chroma",
"sqlite_vec",
]
@pytest_asyncio.fixture(scope="session")

View file

@ -6,6 +6,7 @@ distribution_spec:
- remote::ollama
vector_io:
- inline::faiss
- inline::sqlite_vec
- remote::chromadb
- remote::pgvector
safety:

View file

@ -17,6 +17,7 @@ from llama_stack.providers.inline.inference.sentence_transformers import (
SentenceTransformersInferenceConfig,
)
from llama_stack.providers.inline.vector_io.faiss.config import FaissImplConfig
from llama_stack.providers.inline.vector_io.sqlite_vec.config import SQLiteVectorIOConfig
from llama_stack.providers.remote.inference.ollama import OllamaImplConfig
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
@ -49,11 +50,16 @@ def get_distribution_template() -> DistributionTemplate:
provider_type="inline::sentence-transformers",
config=SentenceTransformersInferenceConfig.sample_run_config(),
)
vector_io_provider = Provider(
vector_io_provider_faiss = Provider(
provider_id="faiss",
provider_type="inline::faiss",
config=FaissImplConfig.sample_run_config(f"distributions/{name}"),
)
vector_io_provider_sqlite = Provider(
provider_id="sqlite_vec",
provider_type="inline::sqlite_vec",
config=SQLiteVectorIOConfig.sample_run_config(f"distributions/{name}"),
)
inference_model = ModelInput(
model_id="${env.INFERENCE_MODEL}",
@ -98,7 +104,7 @@ def get_distribution_template() -> DistributionTemplate:
"run.yaml": RunConfigSettings(
provider_overrides={
"inference": [inference_provider, embedding_provider],
"vector_io": [vector_io_provider],
"vector_io": [vector_io_provider_faiss, vector_io_provider_sqlite],
},
default_models=[inference_model, embedding_model],
default_tool_groups=default_tool_groups,
@ -109,7 +115,7 @@ def get_distribution_template() -> DistributionTemplate:
inference_provider,
embedding_provider,
],
"vector_io": [vector_io_provider],
"vector_io": [vector_io_provider_faiss, vector_io_provider_faiss],
"safety": [
Provider(
provider_id="llama-guard",

View file

@ -27,6 +27,14 @@ providers:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/faiss_store.db
- provider_id: sqlite_vec
provider_type: inline::sqlite_vec
config:
kvstore:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/sqlite_vec.db
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/sqlite_vec.db
safety:
- provider_id: llama-guard
provider_type: inline::llama-guard

View file

@ -8,6 +8,8 @@ import random
import pytest
INLINE_VECTOR_DB_PROVIDERS = ["faiss", "sqlite_vec"]
@pytest.fixture(scope="function")
def empty_vector_db_registry(llama_stack_client):
@ -17,26 +19,27 @@ def empty_vector_db_registry(llama_stack_client):
@pytest.fixture(scope="function")
def single_entry_vector_db_registry(llama_stack_client, empty_vector_db_registry):
def single_entry_vector_db_registry(llama_stack_client, empty_vector_db_registry, provider_id):
vector_db_id = f"test_vector_db_{random.randint(1000, 9999)}"
llama_stack_client.vector_dbs.register(
vector_db_id=vector_db_id,
embedding_model="all-MiniLM-L6-v2",
embedding_dimension=384,
provider_id="faiss",
provider_id=provider_id,
)
vector_dbs = [vector_db.identifier for vector_db in llama_stack_client.vector_dbs.list()]
return vector_dbs
def test_vector_db_retrieve(llama_stack_client, embedding_model, empty_vector_db_registry):
@pytest.mark.parametrize("provider_id", INLINE_VECTOR_DB_PROVIDERS)
def test_vector_db_retrieve(llama_stack_client, embedding_model, empty_vector_db_registry, provider_id):
# Register a memory bank first
vector_db_id = f"test_vector_db_{random.randint(1000, 9999)}"
llama_stack_client.vector_dbs.register(
vector_db_id=vector_db_id,
embedding_model=embedding_model,
embedding_dimension=384,
provider_id="faiss",
provider_id=provider_id,
)
# Retrieve the memory bank and validate its properties
@ -44,7 +47,7 @@ def test_vector_db_retrieve(llama_stack_client, embedding_model, empty_vector_db
assert response is not None
assert response.identifier == vector_db_id
assert response.embedding_model == embedding_model
assert response.provider_id == "faiss"
assert response.provider_id == provider_id
assert response.provider_resource_id == vector_db_id
@ -53,20 +56,22 @@ def test_vector_db_list(llama_stack_client, empty_vector_db_registry):
assert len(vector_dbs_after_register) == 0
def test_vector_db_register(llama_stack_client, embedding_model, empty_vector_db_registry):
@pytest.mark.parametrize("provider_id", INLINE_VECTOR_DB_PROVIDERS)
def test_vector_db_register(llama_stack_client, embedding_model, empty_vector_db_registry, provider_id):
vector_db_id = f"test_vector_db_{random.randint(1000, 9999)}"
llama_stack_client.vector_dbs.register(
vector_db_id=vector_db_id,
embedding_model=embedding_model,
embedding_dimension=384,
provider_id="faiss",
provider_id=provider_id,
)
vector_dbs_after_register = [vector_db.identifier for vector_db in llama_stack_client.vector_dbs.list()]
assert vector_dbs_after_register == [vector_db_id]
def test_vector_db_unregister(llama_stack_client, single_entry_vector_db_registry):
@pytest.mark.parametrize("provider_id", INLINE_VECTOR_DB_PROVIDERS)
def test_vector_db_unregister(llama_stack_client, single_entry_vector_db_registry, provider_id):
vector_dbs = [vector_db.identifier for vector_db in llama_stack_client.vector_dbs.list()]
assert len(vector_dbs) == 1