mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-07 02:58:21 +00:00
feat: Adding sqlite-vec as vectordb
Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
parent
371f11a569
commit
af7748a4d5
14 changed files with 356 additions and 28 deletions
|
@ -185,7 +185,9 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
obj.provider_id = list(self.impls_by_provider_id.keys())[0]
|
obj.provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||||
|
|
||||||
if obj.provider_id not in self.impls_by_provider_id:
|
if obj.provider_id not in self.impls_by_provider_id:
|
||||||
raise ValueError(f"Provider `{obj.provider_id}` not found")
|
raise ValueError(
|
||||||
|
f"Provider `{obj.provider_id}` not found \navailable providers: {self.impls_by_provider_id.keys()}"
|
||||||
|
)
|
||||||
|
|
||||||
p = self.impls_by_provider_id[obj.provider_id]
|
p = self.impls_by_provider_id[obj.provider_id]
|
||||||
|
|
||||||
|
@ -335,6 +337,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
||||||
"embedding_model": embedding_model,
|
"embedding_model": embedding_model,
|
||||||
"embedding_dimension": model.metadata["embedding_dimension"],
|
"embedding_dimension": model.metadata["embedding_dimension"],
|
||||||
}
|
}
|
||||||
|
print(f"Registering vector db {vector_db_data} with embedding model {embedding_model}")
|
||||||
vector_db = TypeAdapter(VectorDB).validate_python(vector_db_data)
|
vector_db = TypeAdapter(VectorDB).validate_python(vector_db_data)
|
||||||
await self.register_object(vector_db)
|
await self.register_object(vector_db)
|
||||||
return vector_db
|
return vector_db
|
||||||
|
|
|
@ -0,0 +1,18 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Dict
|
||||||
|
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||||
|
from .config import SQLiteVecImplConfig
|
||||||
|
|
||||||
|
|
||||||
|
async def get_provider_impl(config: SQLiteVecImplConfig, deps: Dict[Api, ProviderSpec]):
|
||||||
|
from .sqlite_vec import SQLiteVecVectorIOImpl
|
||||||
|
|
||||||
|
assert isinstance(config, SQLiteVecImplConfig), f"Unexpected config type: {type(config)}"
|
||||||
|
impl = SQLiteVecVectorIOImpl(config, deps[Api.inference])
|
||||||
|
await impl.initialize()
|
||||||
|
return impl
|
28
llama_stack/providers/inline/vector_io/sqlite_vec/config.py
Normal file
28
llama_stack/providers/inline/vector_io/sqlite_vec/config.py
Normal file
|
@ -0,0 +1,28 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
# config.py
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
from llama_stack.providers.utils.kvstore.config import (
|
||||||
|
KVStoreConfig,
|
||||||
|
SqliteKVStoreConfig,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SQLiteVecImplConfig(BaseModel):
|
||||||
|
db_path: str
|
||||||
|
kvstore: KVStoreConfig
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def sample_run_config(cls, __distro_dir__: str) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||||
|
__distro_dir__=__distro_dir__,
|
||||||
|
db_name="sqlite_vec.db",
|
||||||
|
)
|
||||||
|
}
|
205
llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py
Normal file
205
llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py
Normal file
|
@ -0,0 +1,205 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
# sqlite_vec_impl.py
|
||||||
|
import sqlite3
|
||||||
|
import sqlite_vec
|
||||||
|
import struct
|
||||||
|
import logging
|
||||||
|
import numpy as np
|
||||||
|
from numpy.typing import NDArray
|
||||||
|
from typing import List, Optional, Dict, Any
|
||||||
|
|
||||||
|
from llama_stack.apis.vector_dbs import VectorDB
|
||||||
|
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||||
|
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
|
||||||
|
from llama_stack.providers.utils.memory.vector_store import EmbeddingIndex, VectorDBWithIndex
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def serialize_vector(vector: List[float]) -> bytes:
|
||||||
|
"""Serialize a list of floats into a compact binary representation."""
|
||||||
|
return struct.pack(f"{len(vector)}f", *vector)
|
||||||
|
|
||||||
|
|
||||||
|
class SQLiteVecIndex(EmbeddingIndex):
|
||||||
|
"""
|
||||||
|
An index implementation that stores embeddings in a SQLite virtual table using sqlite-vec.
|
||||||
|
Two tables are used:
|
||||||
|
- A metadata table (chunks_{bank_id}) that holds the chunk JSON.
|
||||||
|
- A virtual table (vec_chunks_{bank_id}) that holds the serialized vector.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, dimension: int, connection: sqlite3.Connection, bank_id: str):
|
||||||
|
self.dimension = dimension
|
||||||
|
self.connection = connection
|
||||||
|
self.bank_id = bank_id
|
||||||
|
self.metadata_table = f"chunks_{bank_id}".replace("-", "_")
|
||||||
|
self.vector_table = f"vec_chunks_{bank_id}".replace("-", "_")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def create(cls, dimension: int, connection: sqlite3.Connection, bank_id: str):
|
||||||
|
instance = cls(dimension, connection, bank_id)
|
||||||
|
await instance.initialize()
|
||||||
|
return instance
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
cur = self.connection.cursor()
|
||||||
|
print(f"Creating tables {self.metadata_table} and {self.vector_table}")
|
||||||
|
# Create the table to store chunk metadata.
|
||||||
|
cur.execute(f"""
|
||||||
|
CREATE TABLE IF NOT EXISTS {self.metadata_table} (
|
||||||
|
id INTEGER PRIMARY KEY,
|
||||||
|
chunk TEXT
|
||||||
|
);
|
||||||
|
""")
|
||||||
|
# Create the virtual table for embeddings.
|
||||||
|
cur.execute(f"""
|
||||||
|
CREATE VIRTUAL TABLE IF NOT EXISTS {self.vector_table}
|
||||||
|
USING vec0(embedding FLOAT[{self.dimension}]);
|
||||||
|
""")
|
||||||
|
self.connection.commit()
|
||||||
|
|
||||||
|
async def delete(self):
|
||||||
|
cur = self.connection.cursor()
|
||||||
|
cur.execute(f"DROP TABLE IF EXISTS {self.metadata_table};")
|
||||||
|
cur.execute(f"DROP TABLE IF EXISTS {self.vector_table};")
|
||||||
|
self.connection.commit()
|
||||||
|
|
||||||
|
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
|
||||||
|
"""
|
||||||
|
Add new chunks along with their embeddings.
|
||||||
|
For each chunk, we insert its JSON into the metadata table and then insert its
|
||||||
|
embedding (serialized to raw bytes) into the virtual table using the assigned rowid.
|
||||||
|
"""
|
||||||
|
cur = self.connection.cursor()
|
||||||
|
for chunk, emb in zip(chunks, embeddings):
|
||||||
|
# Serialize and insert the chunk metadata.
|
||||||
|
chunk_json = chunk.model_dump_json()
|
||||||
|
cur.execute(f"INSERT INTO {self.metadata_table} (chunk) VALUES (?)", (chunk_json,))
|
||||||
|
row_id = cur.lastrowid
|
||||||
|
# Ensure the embedding is a list of floats.
|
||||||
|
emb_list = emb.tolist() if isinstance(emb, np.ndarray) else list(emb)
|
||||||
|
emb_blob = serialize_vector(emb_list)
|
||||||
|
cur.execute(f"INSERT INTO {self.vector_table} (rowid, embedding) VALUES (?, ?)", (row_id, emb_blob))
|
||||||
|
self.connection.commit()
|
||||||
|
|
||||||
|
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||||
|
"""
|
||||||
|
Query for the k most similar chunks. We convert the query embedding to a blob and run a SQL query
|
||||||
|
against the virtual table. The SQL joins the metadata table to recover the chunk JSON.
|
||||||
|
"""
|
||||||
|
emb_list = embedding.tolist() if isinstance(embedding, np.ndarray) else list(embedding)
|
||||||
|
emb_blob = serialize_vector(emb_list)
|
||||||
|
cur = self.connection.cursor()
|
||||||
|
query_sql = f"""
|
||||||
|
SELECT m.id, m.chunk, v.distance
|
||||||
|
FROM {self.vector_table} AS v
|
||||||
|
JOIN {self.metadata_table} AS m ON m.id = v.rowid
|
||||||
|
WHERE v.embedding MATCH ? AND k = ?
|
||||||
|
ORDER BY v.distance;
|
||||||
|
"""
|
||||||
|
cur.execute(query_sql, (emb_blob, k))
|
||||||
|
rows = cur.fetchall()
|
||||||
|
chunks = []
|
||||||
|
scores = []
|
||||||
|
for _id, chunk_json, distance in rows:
|
||||||
|
try:
|
||||||
|
chunk = Chunk.model_validate_json(chunk_json)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error parsing chunk JSON for id {_id}: {e}")
|
||||||
|
continue
|
||||||
|
chunks.append(chunk)
|
||||||
|
# Mimic the Faiss scoring: score = 1/distance (avoid division by zero)
|
||||||
|
score = 1.0 / distance if distance != 0 else float("inf")
|
||||||
|
scores.append(score)
|
||||||
|
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||||
|
|
||||||
|
|
||||||
|
class SQLiteVecVectorIOImpl(VectorIO, VectorDBsProtocolPrivate):
|
||||||
|
"""
|
||||||
|
A VectorIO implementation using SQLite + sqlite_vec.
|
||||||
|
This class handles vector database registration (with metadata stored in a table named `vector_dbs`)
|
||||||
|
and creates a cache of VectorDBWithIndex instances (each wrapping a SQLiteVecIndex).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config, inference_api: Api.inference) -> None:
|
||||||
|
self.config = config
|
||||||
|
self.inference_api = inference_api
|
||||||
|
self.cache: Dict[str, VectorDBWithIndex] = {}
|
||||||
|
self.connection: Optional[sqlite3.Connection] = None
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
# Open a connection to the SQLite database (the file is specified in the config).
|
||||||
|
print(f"Connecting to SQLite database at {self.config.db_path}")
|
||||||
|
self.connection = sqlite3.connect(self.config.db_path)
|
||||||
|
self.connection.enable_load_extension(True)
|
||||||
|
sqlite_vec.load(self.connection)
|
||||||
|
self.connection.enable_load_extension(False)
|
||||||
|
cur = self.connection.cursor()
|
||||||
|
# Create a table to persist vector DB registrations.
|
||||||
|
cur.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS vector_dbs (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
metadata TEXT
|
||||||
|
);
|
||||||
|
""")
|
||||||
|
self.connection.commit()
|
||||||
|
# Load any existing vector DB registrations.
|
||||||
|
cur.execute("SELECT metadata FROM vector_dbs")
|
||||||
|
rows = cur.fetchall()
|
||||||
|
for row in rows:
|
||||||
|
vector_db_data = row[0]
|
||||||
|
vector_db = VectorDB.model_validate_json(vector_db_data)
|
||||||
|
index = await SQLiteVecIndex.create(vector_db.embedding_dimension, self.connection, vector_db.identifier)
|
||||||
|
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
if self.connection:
|
||||||
|
self.connection.close()
|
||||||
|
self.connection = None
|
||||||
|
|
||||||
|
async def register_vector_db(self, vector_db: VectorDB) -> None:
|
||||||
|
if self.connection is None:
|
||||||
|
raise RuntimeError("SQLite connection not initialized")
|
||||||
|
cur = self.connection.cursor()
|
||||||
|
cur.execute(
|
||||||
|
"INSERT OR REPLACE INTO vector_dbs (id, metadata) VALUES (?, ?)",
|
||||||
|
(vector_db.identifier, vector_db.model_dump_json()),
|
||||||
|
)
|
||||||
|
self.connection.commit()
|
||||||
|
index = await SQLiteVecIndex.create(vector_db.embedding_dimension, self.connection, vector_db.identifier)
|
||||||
|
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
|
||||||
|
|
||||||
|
async def list_vector_dbs(self) -> List[VectorDB]:
|
||||||
|
return [v.vector_db for v in self.cache.values()]
|
||||||
|
|
||||||
|
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
||||||
|
if self.connection is None:
|
||||||
|
raise RuntimeError("SQLite connection not initialized")
|
||||||
|
if vector_db_id not in self.cache:
|
||||||
|
logger.warning(f"Vector DB {vector_db_id} not found")
|
||||||
|
return
|
||||||
|
await self.cache[vector_db_id].index.delete()
|
||||||
|
del self.cache[vector_db_id]
|
||||||
|
cur = self.connection.cursor()
|
||||||
|
cur.execute("DELETE FROM vector_dbs WHERE id = ?", (vector_db_id,))
|
||||||
|
self.connection.commit()
|
||||||
|
|
||||||
|
async def insert_chunks(self, vector_db_id: str, chunks: List[Chunk], ttl_seconds: Optional[int] = None) -> None:
|
||||||
|
if vector_db_id not in self.cache:
|
||||||
|
raise ValueError(f"Vector DB {vector_db_id} not found. Found: {list(self.cache.keys())}")
|
||||||
|
# The VectorDBWithIndex helper is expected to compute embeddings via the inference_api
|
||||||
|
# and then call our index’s add_chunks.
|
||||||
|
await self.cache[vector_db_id].insert_chunks(chunks)
|
||||||
|
|
||||||
|
async def query_chunks(
|
||||||
|
self, vector_db_id: str, query: Any, params: Optional[Dict[str, Any]] = None
|
||||||
|
) -> QueryChunksResponse:
|
||||||
|
if vector_db_id not in self.cache:
|
||||||
|
raise ValueError(f"Vector DB {vector_db_id} not found")
|
||||||
|
return await self.cache[vector_db_id].query_chunks(query, params)
|
|
@ -54,6 +54,14 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
config_class="llama_stack.providers.inline.vector_io.faiss.FaissImplConfig",
|
config_class="llama_stack.providers.inline.vector_io.faiss.FaissImplConfig",
|
||||||
api_dependencies=[Api.inference],
|
api_dependencies=[Api.inference],
|
||||||
),
|
),
|
||||||
|
InlineProviderSpec(
|
||||||
|
api=Api.vector_io,
|
||||||
|
provider_type="inline::sqlite_vec",
|
||||||
|
pip_packages=EMBEDDING_DEPS + ["sqlite_vec"],
|
||||||
|
module="llama_stack.providers.inline.vector_io.sqlite_vec",
|
||||||
|
config_class="llama_stack.providers.inline.vector_io.sqlite_vec.SQLiteVecImplConfig",
|
||||||
|
api_dependencies=[Api.inference],
|
||||||
|
),
|
||||||
remote_provider_spec(
|
remote_provider_spec(
|
||||||
Api.vector_io,
|
Api.vector_io,
|
||||||
AdapterSpec(
|
AdapterSpec(
|
||||||
|
|
|
@ -352,20 +352,24 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
return EmbeddingsResponse(embeddings=embeddings)
|
return EmbeddingsResponse(embeddings=embeddings)
|
||||||
|
|
||||||
async def register_model(self, model: Model) -> Model:
|
async def register_model(self, model: Model) -> Model:
|
||||||
async def check_model_availability(model_id: str):
|
# ollama does not have embedding models running. Check if the model is in list of available models.
|
||||||
response = await self.client.ps()
|
|
||||||
available_models = [m["model"] for m in response["models"]]
|
|
||||||
if model_id not in available_models:
|
|
||||||
raise ValueError(
|
|
||||||
f"Model '{model_id}' is not available in Ollama. Available models: {', '.join(available_models)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if model.model_type == ModelType.embedding:
|
if model.model_type == ModelType.embedding:
|
||||||
await check_model_availability(model.provider_resource_id)
|
response = await self.client.list()
|
||||||
|
available_models = [m["model"] for m in response["models"]]
|
||||||
|
if model.provider_resource_id not in available_models:
|
||||||
|
raise ValueError(
|
||||||
|
f"Model '{model.provider_resource_id}' is not available in Ollama. "
|
||||||
|
f"Available models: {', '.join(available_models)}"
|
||||||
|
)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
model = await self.register_helper.register_model(model)
|
model = await self.register_helper.register_model(model)
|
||||||
await check_model_availability(model.provider_resource_id)
|
models = await self.client.ps()
|
||||||
|
available_models = [m["model"] for m in models["models"]]
|
||||||
|
if model.provider_resource_id not in available_models:
|
||||||
|
raise ValueError(
|
||||||
|
f"Model '{model.provider_resource_id}' is not available in Ollama. "
|
||||||
|
f"Available models: [{', '.join(available_models)}]"
|
||||||
|
)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
|
@ -41,6 +41,14 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
||||||
id="ollama",
|
id="ollama",
|
||||||
marks=pytest.mark.ollama,
|
marks=pytest.mark.ollama,
|
||||||
),
|
),
|
||||||
|
pytest.param(
|
||||||
|
{
|
||||||
|
"inference": "ollama",
|
||||||
|
"vector_io": "sqlite_vec",
|
||||||
|
},
|
||||||
|
id="sqlite_vec",
|
||||||
|
marks=pytest.mark.ollama,
|
||||||
|
),
|
||||||
pytest.param(
|
pytest.param(
|
||||||
{
|
{
|
||||||
"inference": "sentence_transformers",
|
"inference": "sentence_transformers",
|
||||||
|
|
|
@ -15,6 +15,7 @@ from llama_stack.distribution.datatypes import Api, Provider
|
||||||
|
|
||||||
from llama_stack.providers.inline.vector_io.chroma import ChromaInlineImplConfig
|
from llama_stack.providers.inline.vector_io.chroma import ChromaInlineImplConfig
|
||||||
from llama_stack.providers.inline.vector_io.faiss import FaissImplConfig
|
from llama_stack.providers.inline.vector_io.faiss import FaissImplConfig
|
||||||
|
from llama_stack.providers.inline.vector_io.sqlite_vec import SQLiteVecImplConfig
|
||||||
from llama_stack.providers.remote.vector_io.chroma import ChromaRemoteImplConfig
|
from llama_stack.providers.remote.vector_io.chroma import ChromaRemoteImplConfig
|
||||||
from llama_stack.providers.remote.vector_io.pgvector import PGVectorConfig
|
from llama_stack.providers.remote.vector_io.pgvector import PGVectorConfig
|
||||||
from llama_stack.providers.remote.vector_io.weaviate import WeaviateConfig
|
from llama_stack.providers.remote.vector_io.weaviate import WeaviateConfig
|
||||||
|
@ -53,6 +54,22 @@ def vector_io_faiss() -> ProviderFixture:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def vector_io_sqlite_vec() -> ProviderFixture:
|
||||||
|
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
|
||||||
|
return ProviderFixture(
|
||||||
|
providers=[
|
||||||
|
Provider(
|
||||||
|
provider_id="sqlite_vec",
|
||||||
|
provider_type="inline::sqlite_vec",
|
||||||
|
config=SQLiteVecImplConfig(
|
||||||
|
kvstore=SqliteKVStoreConfig(db_path=temp_file.name).model_dump(),
|
||||||
|
).model_dump(),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def vector_io_pgvector() -> ProviderFixture:
|
def vector_io_pgvector() -> ProviderFixture:
|
||||||
return ProviderFixture(
|
return ProviderFixture(
|
||||||
|
@ -111,7 +128,13 @@ def vector_io_chroma() -> ProviderFixture:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
VECTOR_IO_FIXTURES = ["faiss", "pgvector", "weaviate", "chroma"]
|
VECTOR_IO_FIXTURES = [
|
||||||
|
"faiss",
|
||||||
|
"pgvector",
|
||||||
|
"weaviate",
|
||||||
|
"chroma",
|
||||||
|
"sqlite_vec",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture(scope="session")
|
@pytest_asyncio.fixture(scope="session")
|
||||||
|
|
|
@ -17,6 +17,7 @@ from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
|
||||||
class KVStoreType(Enum):
|
class KVStoreType(Enum):
|
||||||
redis = "redis"
|
redis = "redis"
|
||||||
sqlite = "sqlite"
|
sqlite = "sqlite"
|
||||||
|
milvus_lite = "milvus_lite"
|
||||||
postgres = "postgres"
|
postgres = "postgres"
|
||||||
|
|
||||||
|
|
||||||
|
@ -62,6 +63,22 @@ class SqliteKVStoreConfig(CommonConfig):
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class MilvusLiteKVStoreConfig(CommonConfig):
|
||||||
|
type: Literal[KVStoreType.milvus_lite.value] = KVStoreType.milvus_lite.value
|
||||||
|
db_path: str = Field(
|
||||||
|
default=(RUNTIME_BASE_DIR / "kvstore.db").as_posix(),
|
||||||
|
description="File path for the sqlite database",
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def sample_run_config(cls, __distro_dir__: str = "runtime", db_name: str = "kvstore.db"):
|
||||||
|
return {
|
||||||
|
"type": "milvuslite",
|
||||||
|
"namespace": None,
|
||||||
|
"db_path": "${env.SQLITE_STORE_DIR:~/.llama/" + __distro_dir__ + "}/" + db_name,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class PostgresKVStoreConfig(CommonConfig):
|
class PostgresKVStoreConfig(CommonConfig):
|
||||||
type: Literal[KVStoreType.postgres.value] = KVStoreType.postgres.value
|
type: Literal[KVStoreType.postgres.value] = KVStoreType.postgres.value
|
||||||
host: str = "localhost"
|
host: str = "localhost"
|
||||||
|
|
|
@ -6,6 +6,7 @@ distribution_spec:
|
||||||
- remote::ollama
|
- remote::ollama
|
||||||
vector_io:
|
vector_io:
|
||||||
- inline::faiss
|
- inline::faiss
|
||||||
|
- inline::sqlite_vec
|
||||||
- remote::chromadb
|
- remote::chromadb
|
||||||
- remote::pgvector
|
- remote::pgvector
|
||||||
safety:
|
safety:
|
||||||
|
|
|
@ -17,6 +17,7 @@ from llama_stack.providers.inline.inference.sentence_transformers import (
|
||||||
SentenceTransformersInferenceConfig,
|
SentenceTransformersInferenceConfig,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.inline.vector_io.faiss.config import FaissImplConfig
|
from llama_stack.providers.inline.vector_io.faiss.config import FaissImplConfig
|
||||||
|
from llama_stack.providers.inline.vector_io.sqlite_vec.config import SQLiteVecImplConfig
|
||||||
from llama_stack.providers.remote.inference.ollama import OllamaImplConfig
|
from llama_stack.providers.remote.inference.ollama import OllamaImplConfig
|
||||||
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
|
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
|
||||||
|
|
||||||
|
@ -49,11 +50,16 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
provider_type="inline::sentence-transformers",
|
provider_type="inline::sentence-transformers",
|
||||||
config=SentenceTransformersInferenceConfig.sample_run_config(),
|
config=SentenceTransformersInferenceConfig.sample_run_config(),
|
||||||
)
|
)
|
||||||
vector_io_provider = Provider(
|
vector_io_provider_faiss = Provider(
|
||||||
provider_id="faiss",
|
provider_id="faiss",
|
||||||
provider_type="inline::faiss",
|
provider_type="inline::faiss",
|
||||||
config=FaissImplConfig.sample_run_config(f"distributions/{name}"),
|
config=FaissImplConfig.sample_run_config(f"distributions/{name}"),
|
||||||
)
|
)
|
||||||
|
vector_io_provider_sqlite = Provider(
|
||||||
|
provider_id="sqlite_vec",
|
||||||
|
provider_type="inline::sqlite_vec",
|
||||||
|
config=SQLiteVecImplConfig.sample_run_config(f"distributions/{name}"),
|
||||||
|
)
|
||||||
|
|
||||||
inference_model = ModelInput(
|
inference_model = ModelInput(
|
||||||
model_id="${env.INFERENCE_MODEL}",
|
model_id="${env.INFERENCE_MODEL}",
|
||||||
|
@ -98,7 +104,7 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
"run.yaml": RunConfigSettings(
|
"run.yaml": RunConfigSettings(
|
||||||
provider_overrides={
|
provider_overrides={
|
||||||
"inference": [inference_provider, embedding_provider],
|
"inference": [inference_provider, embedding_provider],
|
||||||
"vector_io": [vector_io_provider],
|
"vector_io": [vector_io_provider_faiss, vector_io_provider_sqlite],
|
||||||
},
|
},
|
||||||
default_models=[inference_model, embedding_model],
|
default_models=[inference_model, embedding_model],
|
||||||
default_tool_groups=default_tool_groups,
|
default_tool_groups=default_tool_groups,
|
||||||
|
@ -109,7 +115,7 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
inference_provider,
|
inference_provider,
|
||||||
embedding_provider,
|
embedding_provider,
|
||||||
],
|
],
|
||||||
"vector_io": [vector_io_provider],
|
"vector_io": [vector_io_provider_faiss, vector_io_provider_faiss],
|
||||||
"safety": [
|
"safety": [
|
||||||
Provider(
|
Provider(
|
||||||
provider_id="llama-guard",
|
provider_id="llama-guard",
|
||||||
|
|
|
@ -20,13 +20,14 @@ providers:
|
||||||
provider_type: inline::sentence-transformers
|
provider_type: inline::sentence-transformers
|
||||||
config: {}
|
config: {}
|
||||||
vector_io:
|
vector_io:
|
||||||
- provider_id: faiss
|
- provider_id: sqlite_vec
|
||||||
provider_type: inline::faiss
|
provider_type: inline::sqlite_vec
|
||||||
config:
|
config:
|
||||||
kvstore:
|
kvstore:
|
||||||
type: sqlite
|
type: sqlite
|
||||||
namespace: null
|
namespace: null
|
||||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/faiss_store.db
|
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/sqlite_vec.db
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/sqlite_vec.db
|
||||||
safety:
|
safety:
|
||||||
- provider_id: llama-guard
|
- provider_id: llama-guard
|
||||||
provider_type: inline::llama-guard
|
provider_type: inline::llama-guard
|
||||||
|
|
|
@ -80,7 +80,7 @@ def llama_stack_client(provider_data):
|
||||||
skip_logger_removal=True,
|
skip_logger_removal=True,
|
||||||
)
|
)
|
||||||
if not client.initialize():
|
if not client.initialize():
|
||||||
raise RuntimeError("Initialization failed")
|
raise RuntimeError(f"Initialization failed {os.environ.get('LLAMA_STACK_CONFIG')} not found")
|
||||||
|
|
||||||
elif os.environ.get("LLAMA_STACK_BASE_URL"):
|
elif os.environ.get("LLAMA_STACK_BASE_URL"):
|
||||||
client = LlamaStackClient(
|
client = LlamaStackClient(
|
||||||
|
|
|
@ -8,6 +8,9 @@ import random
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
INLINE_VECTOR_DB_PROVIDERS = ["faiss"]
|
||||||
|
# "sqlite_vec"
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function")
|
@pytest.fixture(scope="function")
|
||||||
def empty_vector_db_registry(llama_stack_client):
|
def empty_vector_db_registry(llama_stack_client):
|
||||||
|
@ -17,26 +20,27 @@ def empty_vector_db_registry(llama_stack_client):
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function")
|
@pytest.fixture(scope="function")
|
||||||
def single_entry_vector_db_registry(llama_stack_client, empty_vector_db_registry):
|
def single_entry_vector_db_registry(llama_stack_client, empty_vector_db_registry, provider_id):
|
||||||
vector_db_id = f"test_vector_db_{random.randint(1000, 9999)}"
|
vector_db_id = f"test_vector_db_{random.randint(1000, 9999)}"
|
||||||
llama_stack_client.vector_dbs.register(
|
llama_stack_client.vector_dbs.register(
|
||||||
vector_db_id=vector_db_id,
|
vector_db_id=vector_db_id,
|
||||||
embedding_model="all-MiniLM-L6-v2",
|
embedding_model="all-MiniLM-L6-v2",
|
||||||
embedding_dimension=384,
|
embedding_dimension=384,
|
||||||
provider_id="faiss",
|
provider_id=provider_id,
|
||||||
)
|
)
|
||||||
vector_dbs = [vector_db.identifier for vector_db in llama_stack_client.vector_dbs.list()]
|
vector_dbs = [vector_db.identifier for vector_db in llama_stack_client.vector_dbs.list()]
|
||||||
return vector_dbs
|
return vector_dbs
|
||||||
|
|
||||||
|
|
||||||
def test_vector_db_retrieve(llama_stack_client, embedding_model, empty_vector_db_registry):
|
@pytest.mark.parametrize("provider_id", INLINE_VECTOR_DB_PROVIDERS)
|
||||||
|
def test_vector_db_retrieve(llama_stack_client, embedding_model, empty_vector_db_registry, provider_id):
|
||||||
# Register a memory bank first
|
# Register a memory bank first
|
||||||
vector_db_id = f"test_vector_db_{random.randint(1000, 9999)}"
|
vector_db_id = f"test_vector_db_{random.randint(1000, 9999)}"
|
||||||
llama_stack_client.vector_dbs.register(
|
llama_stack_client.vector_dbs.register(
|
||||||
vector_db_id=vector_db_id,
|
vector_db_id=vector_db_id,
|
||||||
embedding_model=embedding_model,
|
embedding_model=embedding_model,
|
||||||
embedding_dimension=384,
|
embedding_dimension=384,
|
||||||
provider_id="faiss",
|
provider_id=provider_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Retrieve the memory bank and validate its properties
|
# Retrieve the memory bank and validate its properties
|
||||||
|
@ -44,7 +48,7 @@ def test_vector_db_retrieve(llama_stack_client, embedding_model, empty_vector_db
|
||||||
assert response is not None
|
assert response is not None
|
||||||
assert response.identifier == vector_db_id
|
assert response.identifier == vector_db_id
|
||||||
assert response.embedding_model == embedding_model
|
assert response.embedding_model == embedding_model
|
||||||
assert response.provider_id == "faiss"
|
assert response.provider_id == provider_id
|
||||||
assert response.provider_resource_id == vector_db_id
|
assert response.provider_resource_id == vector_db_id
|
||||||
|
|
||||||
|
|
||||||
|
@ -53,20 +57,22 @@ def test_vector_db_list(llama_stack_client, empty_vector_db_registry):
|
||||||
assert len(vector_dbs_after_register) == 0
|
assert len(vector_dbs_after_register) == 0
|
||||||
|
|
||||||
|
|
||||||
def test_vector_db_register(llama_stack_client, embedding_model, empty_vector_db_registry):
|
@pytest.mark.parametrize("provider_id", INLINE_VECTOR_DB_PROVIDERS)
|
||||||
|
def test_vector_db_register(llama_stack_client, embedding_model, empty_vector_db_registry, provider_id):
|
||||||
vector_db_id = f"test_vector_db_{random.randint(1000, 9999)}"
|
vector_db_id = f"test_vector_db_{random.randint(1000, 9999)}"
|
||||||
llama_stack_client.vector_dbs.register(
|
llama_stack_client.vector_dbs.register(
|
||||||
vector_db_id=vector_db_id,
|
vector_db_id=vector_db_id,
|
||||||
embedding_model=embedding_model,
|
embedding_model=embedding_model,
|
||||||
embedding_dimension=384,
|
embedding_dimension=384,
|
||||||
provider_id="faiss",
|
provider_id=provider_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
vector_dbs_after_register = [vector_db.identifier for vector_db in llama_stack_client.vector_dbs.list()]
|
vector_dbs_after_register = [vector_db.identifier for vector_db in llama_stack_client.vector_dbs.list()]
|
||||||
assert vector_dbs_after_register == [vector_db_id]
|
assert vector_dbs_after_register == [vector_db_id]
|
||||||
|
|
||||||
|
|
||||||
def test_vector_db_unregister(llama_stack_client, single_entry_vector_db_registry):
|
@pytest.mark.parametrize("provider_id", INLINE_VECTOR_DB_PROVIDERS)
|
||||||
|
def test_vector_db_unregister(llama_stack_client, single_entry_vector_db_registry, provider_id):
|
||||||
vector_dbs = [vector_db.identifier for vector_db in llama_stack_client.vector_dbs.list()]
|
vector_dbs = [vector_db.identifier for vector_db in llama_stack_client.vector_dbs.list()]
|
||||||
assert len(vector_dbs) == 1
|
assert len(vector_dbs) == 1
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue