mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-23 00:27:26 +00:00
chore: Updating how default embedding model is set in stack (#3818)
# What does this PR do? Refactor setting default vector store provider and embedding model to use an optional `vector_stores` config in the `StackRunConfig` and clean up code to do so (had to add back in some pieces of VectorDB). Also added remote Qdrant and Weaviate to starter distro (based on other PR where inference providers were added for UX). New config is simply (default for Starter distro): ```yaml vector_stores: default_provider_id: faiss default_embedding_model: provider_id: sentence-transformers model_id: nomic-ai/nomic-embed-text-v1.5 ``` ## Test Plan CI and Unit tests. --------- Signed-off-by: Francisco Javier Arceo <farceo@redhat.com> Co-authored-by: Ashwin Bharambe <ashwin.bharambe@gmail.com>
This commit is contained in:
parent
2c43285e22
commit
48581bf651
48 changed files with 973 additions and 818 deletions
|
@ -59,7 +59,6 @@ class SentenceTransformersInferenceImpl(
|
|||
provider_id=self.__provider_id__,
|
||||
metadata={
|
||||
"embedding_dimension": 768,
|
||||
"default_configured": True,
|
||||
},
|
||||
model_type=ModelType.embedding,
|
||||
),
|
||||
|
|
|
@ -12,15 +12,8 @@ from .config import ChromaVectorIOConfig
|
|||
|
||||
|
||||
async def get_provider_impl(config: ChromaVectorIOConfig, deps: dict[Api, Any]):
|
||||
from llama_stack.providers.remote.vector_io.chroma.chroma import (
|
||||
ChromaVectorIOAdapter,
|
||||
)
|
||||
from llama_stack.providers.remote.vector_io.chroma.chroma import ChromaVectorIOAdapter
|
||||
|
||||
impl = ChromaVectorIOAdapter(
|
||||
config,
|
||||
deps[Api.inference],
|
||||
deps[Api.models],
|
||||
deps.get(Api.files),
|
||||
)
|
||||
impl = ChromaVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
@ -16,11 +16,6 @@ async def get_provider_impl(config: FaissVectorIOConfig, deps: dict[Api, Any]):
|
|||
|
||||
assert isinstance(config, FaissVectorIOConfig), f"Unexpected config type: {type(config)}"
|
||||
|
||||
impl = FaissVectorIOAdapter(
|
||||
config,
|
||||
deps[Api.inference],
|
||||
deps[Api.models],
|
||||
deps.get(Api.files),
|
||||
)
|
||||
impl = FaissVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
@ -17,27 +17,14 @@ from numpy.typing import NDArray
|
|||
from llama_stack.apis.common.errors import VectorStoreNotFoundError
|
||||
from llama_stack.apis.files import Files
|
||||
from llama_stack.apis.inference import Inference, InterleavedContent
|
||||
from llama_stack.apis.models import Models
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import (
|
||||
Chunk,
|
||||
QueryChunksResponse,
|
||||
VectorIO,
|
||||
)
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import (
|
||||
HealthResponse,
|
||||
HealthStatus,
|
||||
VectorDBsProtocolPrivate,
|
||||
)
|
||||
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
ChunkForDeletion,
|
||||
EmbeddingIndex,
|
||||
VectorDBWithIndex,
|
||||
)
|
||||
from llama_stack.providers.utils.memory.vector_store import ChunkForDeletion, EmbeddingIndex, VectorDBWithIndex
|
||||
|
||||
from .config import FaissVectorIOConfig
|
||||
|
||||
|
@ -155,12 +142,7 @@ class FaissIndex(EmbeddingIndex):
|
|||
|
||||
await self._save_index()
|
||||
|
||||
async def query_vector(
|
||||
self,
|
||||
embedding: NDArray,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
) -> QueryChunksResponse:
|
||||
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
distances, indices = await asyncio.to_thread(self.index.search, embedding.reshape(1, -1).astype(np.float32), k)
|
||||
chunks = []
|
||||
scores = []
|
||||
|
@ -175,12 +157,7 @@ class FaissIndex(EmbeddingIndex):
|
|||
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
async def query_keyword(
|
||||
self,
|
||||
query_string: str,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
) -> QueryChunksResponse:
|
||||
async def query_keyword(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
raise NotImplementedError(
|
||||
"Keyword search is not supported - underlying DB FAISS does not support this search mode"
|
||||
)
|
||||
|
@ -200,17 +177,10 @@ class FaissIndex(EmbeddingIndex):
|
|||
|
||||
|
||||
class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
|
||||
def __init__(
|
||||
self,
|
||||
config: FaissVectorIOConfig,
|
||||
inference_api: Inference,
|
||||
models_api: Models,
|
||||
files_api: Files | None,
|
||||
) -> None:
|
||||
def __init__(self, config: FaissVectorIOConfig, inference_api: Inference, files_api: Files | None) -> None:
|
||||
super().__init__(files_api=files_api, kvstore=None)
|
||||
self.config = config
|
||||
self.inference_api = inference_api
|
||||
self.models_api = models_api
|
||||
self.cache: dict[str, VectorDBWithIndex] = {}
|
||||
|
||||
async def initialize(self) -> None:
|
||||
|
@ -252,17 +222,11 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr
|
|||
except Exception as e:
|
||||
return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}")
|
||||
|
||||
async def register_vector_db(
|
||||
self,
|
||||
vector_db: VectorDB,
|
||||
) -> None:
|
||||
async def register_vector_db(self, vector_db: VectorDB) -> None:
|
||||
assert self.kvstore is not None
|
||||
|
||||
key = f"{VECTOR_DBS_PREFIX}{vector_db.identifier}"
|
||||
await self.kvstore.set(
|
||||
key=key,
|
||||
value=vector_db.model_dump_json(),
|
||||
)
|
||||
await self.kvstore.set(key=key, value=vector_db.model_dump_json())
|
||||
|
||||
# Store in cache
|
||||
self.cache[vector_db.identifier] = VectorDBWithIndex(
|
||||
|
@ -285,12 +249,7 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr
|
|||
del self.cache[vector_db_id]
|
||||
await self.kvstore.delete(f"{VECTOR_DBS_PREFIX}{vector_db_id}")
|
||||
|
||||
async def insert_chunks(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
chunks: list[Chunk],
|
||||
ttl_seconds: int | None = None,
|
||||
) -> None:
|
||||
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
|
||||
index = self.cache.get(vector_db_id)
|
||||
if index is None:
|
||||
raise ValueError(f"Vector DB {vector_db_id} not found. found: {self.cache.keys()}")
|
||||
|
@ -298,10 +257,7 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr
|
|||
await index.insert_chunks(chunks)
|
||||
|
||||
async def query_chunks(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
query: InterleavedContent,
|
||||
params: dict[str, Any] | None = None,
|
||||
self, vector_db_id: str, query: InterleavedContent, params: dict[str, Any] | None = None
|
||||
) -> QueryChunksResponse:
|
||||
index = self.cache.get(vector_db_id)
|
||||
if index is None:
|
||||
|
|
|
@ -14,11 +14,6 @@ from .config import MilvusVectorIOConfig
|
|||
async def get_provider_impl(config: MilvusVectorIOConfig, deps: dict[Api, Any]):
|
||||
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusVectorIOAdapter
|
||||
|
||||
impl = MilvusVectorIOAdapter(
|
||||
config,
|
||||
deps[Api.inference],
|
||||
deps[Api.models],
|
||||
deps.get(Api.files),
|
||||
)
|
||||
impl = MilvusVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
@ -15,11 +15,6 @@ async def get_provider_impl(config: QdrantVectorIOConfig, deps: dict[Api, Any]):
|
|||
from llama_stack.providers.remote.vector_io.qdrant.qdrant import QdrantVectorIOAdapter
|
||||
|
||||
assert isinstance(config, QdrantVectorIOConfig), f"Unexpected config type: {type(config)}"
|
||||
impl = QdrantVectorIOAdapter(
|
||||
config,
|
||||
deps[Api.inference],
|
||||
deps[Api.models],
|
||||
deps.get(Api.files),
|
||||
)
|
||||
impl = QdrantVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
@ -15,11 +15,6 @@ async def get_provider_impl(config: SQLiteVectorIOConfig, deps: dict[Api, Any]):
|
|||
from .sqlite_vec import SQLiteVecVectorIOAdapter
|
||||
|
||||
assert isinstance(config, SQLiteVectorIOConfig), f"Unexpected config type: {type(config)}"
|
||||
impl = SQLiteVecVectorIOAdapter(
|
||||
config,
|
||||
deps[Api.inference],
|
||||
deps[Api.models],
|
||||
deps.get(Api.files),
|
||||
)
|
||||
impl = SQLiteVecVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
@ -17,13 +17,8 @@ from numpy.typing import NDArray
|
|||
from llama_stack.apis.common.errors import VectorStoreNotFoundError
|
||||
from llama_stack.apis.files import Files
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.models import Models
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import (
|
||||
Chunk,
|
||||
QueryChunksResponse,
|
||||
VectorIO,
|
||||
)
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
|
@ -175,32 +170,18 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
|
||||
# Insert vector embeddings
|
||||
embedding_data = [
|
||||
(
|
||||
(
|
||||
chunk.chunk_id,
|
||||
serialize_vector(emb.tolist()),
|
||||
)
|
||||
)
|
||||
((chunk.chunk_id, serialize_vector(emb.tolist())))
|
||||
for chunk, emb in zip(batch_chunks, batch_embeddings, strict=True)
|
||||
]
|
||||
cur.executemany(
|
||||
f"INSERT INTO [{self.vector_table}] (id, embedding) VALUES (?, ?);",
|
||||
embedding_data,
|
||||
)
|
||||
cur.executemany(f"INSERT INTO [{self.vector_table}] (id, embedding) VALUES (?, ?);", embedding_data)
|
||||
|
||||
# Insert FTS content
|
||||
fts_data = [(chunk.chunk_id, chunk.content) for chunk in batch_chunks]
|
||||
# DELETE existing entries with same IDs (FTS5 doesn't support ON CONFLICT)
|
||||
cur.executemany(
|
||||
f"DELETE FROM [{self.fts_table}] WHERE id = ?;",
|
||||
[(row[0],) for row in fts_data],
|
||||
)
|
||||
cur.executemany(f"DELETE FROM [{self.fts_table}] WHERE id = ?;", [(row[0],) for row in fts_data])
|
||||
|
||||
# INSERT new entries
|
||||
cur.executemany(
|
||||
f"INSERT INTO [{self.fts_table}] (id, content) VALUES (?, ?);",
|
||||
fts_data,
|
||||
)
|
||||
cur.executemany(f"INSERT INTO [{self.fts_table}] (id, content) VALUES (?, ?);", fts_data)
|
||||
|
||||
connection.commit()
|
||||
|
||||
|
@ -216,12 +197,7 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
# Run batch insertion in a background thread
|
||||
await asyncio.to_thread(_execute_all_batch_inserts)
|
||||
|
||||
async def query_vector(
|
||||
self,
|
||||
embedding: NDArray,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
) -> QueryChunksResponse:
|
||||
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
"""
|
||||
Performs vector-based search using a virtual table for vector similarity.
|
||||
"""
|
||||
|
@ -261,12 +237,7 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
scores.append(score)
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
async def query_keyword(
|
||||
self,
|
||||
query_string: str,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
) -> QueryChunksResponse:
|
||||
async def query_keyword(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
"""
|
||||
Performs keyword-based search using SQLite FTS5 for relevance-ranked full-text search.
|
||||
"""
|
||||
|
@ -410,17 +381,10 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
|
|||
and creates a cache of VectorDBWithIndex instances (each wrapping a SQLiteVecIndex).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
inference_api: Inference,
|
||||
models_api: Models,
|
||||
files_api: Files | None,
|
||||
) -> None:
|
||||
def __init__(self, config, inference_api: Inference, files_api: Files | None) -> None:
|
||||
super().__init__(files_api=files_api, kvstore=None)
|
||||
self.config = config
|
||||
self.inference_api = inference_api
|
||||
self.models_api = models_api
|
||||
self.cache: dict[str, VectorDBWithIndex] = {}
|
||||
self.vector_db_store = None
|
||||
|
||||
|
@ -433,9 +397,7 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
|
|||
for db_json in stored_vector_dbs:
|
||||
vector_db = VectorDB.model_validate_json(db_json)
|
||||
index = await SQLiteVecIndex.create(
|
||||
vector_db.embedding_dimension,
|
||||
self.config.db_path,
|
||||
vector_db.identifier,
|
||||
vector_db.embedding_dimension, self.config.db_path, vector_db.identifier
|
||||
)
|
||||
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
|
||||
|
||||
|
@ -450,11 +412,7 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
|
|||
return [v.vector_db for v in self.cache.values()]
|
||||
|
||||
async def register_vector_db(self, vector_db: VectorDB) -> None:
|
||||
index = await SQLiteVecIndex.create(
|
||||
vector_db.embedding_dimension,
|
||||
self.config.db_path,
|
||||
vector_db.identifier,
|
||||
)
|
||||
index = await SQLiteVecIndex.create(vector_db.embedding_dimension, self.config.db_path, vector_db.identifier)
|
||||
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
|
||||
|
||||
async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> VectorDBWithIndex | None:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue