chore: Updating how default embedding model is set in stack (#3818)

# What does this PR do?

Refactor setting default vector store provider and embedding model to
use an optional `vector_stores` config in the `StackRunConfig` and clean
up code to do so (had to add back in some pieces of VectorDB). Also
added remote Qdrant and Weaviate to starter distro (based on other PR
where inference providers were added for UX).

New config is simply (default for Starter distro):

```yaml
vector_stores:
  default_provider_id: faiss
  default_embedding_model:
    provider_id: sentence-transformers
    model_id: nomic-ai/nomic-embed-text-v1.5
```

## Test Plan
CI and Unit tests.

---------

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
Co-authored-by: Ashwin Bharambe <ashwin.bharambe@gmail.com>
This commit is contained in:
Francisco Arceo 2025-10-20 17:22:45 -04:00 committed by GitHub
parent 2c43285e22
commit 48581bf651
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
48 changed files with 973 additions and 818 deletions

View file

@ -12,11 +12,6 @@ from .config import QdrantVectorIOConfig
async def get_adapter_impl(config: QdrantVectorIOConfig, deps: dict[Api, ProviderSpec]):
from .qdrant import QdrantVectorIOAdapter
impl = QdrantVectorIOAdapter(
config,
deps[Api.inference],
deps[Api.models],
deps.get(Api.files),
)
impl = QdrantVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
await impl.initialize()
return impl

View file

@ -16,7 +16,6 @@ from qdrant_client.models import PointStruct
from llama_stack.apis.common.errors import VectorStoreNotFoundError
from llama_stack.apis.files import Files
from llama_stack.apis.inference import Inference, InterleavedContent
from llama_stack.apis.models import Models
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import (
Chunk,
@ -30,11 +29,7 @@ from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig as InlineQdrantVectorIOConfig
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
from llama_stack.providers.utils.memory.vector_store import (
ChunkForDeletion,
EmbeddingIndex,
VectorDBWithIndex,
)
from llama_stack.providers.utils.memory.vector_store import ChunkForDeletion, EmbeddingIndex, VectorDBWithIndex
from .config import QdrantVectorIOConfig as RemoteQdrantVectorIOConfig
@ -99,8 +94,7 @@ class QdrantIndex(EmbeddingIndex):
chunk_ids = [convert_id(c.chunk_id) for c in chunks_for_deletion]
try:
await self.client.delete(
collection_name=self.collection_name,
points_selector=models.PointIdsList(points=chunk_ids),
collection_name=self.collection_name, points_selector=models.PointIdsList(points=chunk_ids)
)
except Exception as e:
log.error(f"Error deleting chunks from Qdrant collection {self.collection_name}: {e}")
@ -133,12 +127,7 @@ class QdrantIndex(EmbeddingIndex):
return QueryChunksResponse(chunks=chunks, scores=scores)
async def query_keyword(
self,
query_string: str,
k: int,
score_threshold: float,
) -> QueryChunksResponse:
async def query_keyword(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse:
raise NotImplementedError("Keyword search is not supported in Qdrant")
async def query_hybrid(
@ -161,7 +150,6 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
self,
config: RemoteQdrantVectorIOConfig | InlineQdrantVectorIOConfig,
inference_api: Inference,
models_api: Models,
files_api: Files | None = None,
) -> None:
super().__init__(files_api=files_api, kvstore=None)
@ -169,7 +157,6 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
self.client: AsyncQdrantClient = None
self.cache = {}
self.inference_api = inference_api
self.models_api = models_api
self.vector_db_store = None
self._qdrant_lock = asyncio.Lock()
@ -184,11 +171,7 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
for vector_db_data in stored_vector_dbs:
vector_db = VectorDB.model_validate_json(vector_db_data)
index = VectorDBWithIndex(
vector_db,
QdrantIndex(self.client, vector_db.identifier),
self.inference_api,
)
index = VectorDBWithIndex(vector_db, QdrantIndex(self.client, vector_db.identifier), self.inference_api)
self.cache[vector_db.identifier] = index
self.openai_vector_stores = await self._load_openai_vector_stores()
@ -197,18 +180,13 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
# Clean up mixin resources (file batch tasks)
await super().shutdown()
async def register_vector_db(
self,
vector_db: VectorDB,
) -> None:
async def register_vector_db(self, vector_db: VectorDB) -> None:
assert self.kvstore is not None
key = f"{VECTOR_DBS_PREFIX}{vector_db.identifier}"
await self.kvstore.set(key=key, value=vector_db.model_dump_json())
index = VectorDBWithIndex(
vector_db=vector_db,
index=QdrantIndex(self.client, vector_db.identifier),
inference_api=self.inference_api,
vector_db=vector_db, index=QdrantIndex(self.client, vector_db.identifier), inference_api=self.inference_api
)
self.cache[vector_db.identifier] = index
@ -240,12 +218,7 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
self.cache[vector_db_id] = index
return index
async def insert_chunks(
self,
vector_db_id: str,
chunks: list[Chunk],
ttl_seconds: int | None = None,
) -> None:
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
index = await self._get_and_cache_vector_db_index(vector_db_id)
if not index:
raise VectorStoreNotFoundError(vector_db_id)
@ -253,10 +226,7 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
await index.insert_chunks(chunks)
async def query_chunks(
self,
vector_db_id: str,
query: InterleavedContent,
params: dict[str, Any] | None = None,
self, vector_db_id: str, query: InterleavedContent, params: dict[str, Any] | None = None
) -> QueryChunksResponse:
index = await self._get_and_cache_vector_db_index(vector_db_id)
if not index: