chore: Updating how default embedding model is set in stack (#3818)

# What does this PR do?

Refactor setting default vector store provider and embedding model to
use an optional `vector_stores` config in the `StackRunConfig` and clean
up code to do so (had to add back in some pieces of VectorDB). Also
added remote Qdrant and Weaviate to starter distro (based on other PR
where inference providers were added for UX).

New config is simply (default for Starter distro):

```yaml
vector_stores:
  default_provider_id: faiss
  default_embedding_model:
    provider_id: sentence-transformers
    model_id: nomic-ai/nomic-embed-text-v1.5
```

## Test Plan
CI and Unit tests.

---------

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
Co-authored-by: Ashwin Bharambe <ashwin.bharambe@gmail.com>
This commit is contained in:
Francisco Arceo 2025-10-20 17:22:45 -04:00 committed by GitHub
parent 2c43285e22
commit 48581bf651
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
48 changed files with 973 additions and 818 deletions

View file

@ -35,7 +35,7 @@ from llama_stack.apis.telemetry import Telemetry
from llama_stack.apis.tools import RAGToolRuntime, ToolGroups, ToolRuntime
from llama_stack.apis.vector_io import VectorIO
from llama_stack.core.conversations.conversations import ConversationServiceConfig, ConversationServiceImpl
from llama_stack.core.datatypes import Provider, StackRunConfig
from llama_stack.core.datatypes import Provider, StackRunConfig, VectorStoresConfig
from llama_stack.core.distribution import get_provider_registry
from llama_stack.core.inspect import DistributionInspectConfig, DistributionInspectImpl
from llama_stack.core.prompts.prompts import PromptServiceConfig, PromptServiceImpl
@ -108,30 +108,6 @@ REGISTRY_REFRESH_TASK = None
TEST_RECORDING_CONTEXT = None
async def validate_default_embedding_model(impls: dict[Api, Any]):
"""Validate that at most one embedding model is marked as default."""
if Api.models not in impls:
return
models_impl = impls[Api.models]
response = await models_impl.list_models()
models_list = response.data if hasattr(response, "data") else response
default_embedding_models = []
for model in models_list:
if model.model_type == "embedding" and model.metadata.get("default_configured") is True:
default_embedding_models.append(model.identifier)
if len(default_embedding_models) > 1:
raise ValueError(
f"Multiple embedding models marked as default_configured=True: {default_embedding_models}. "
"Only one embedding model can be marked as default."
)
if default_embedding_models:
logger.info(f"Default embedding model configured: {default_embedding_models[0]}")
async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
for rsrc, api, register_method, list_method in RESOURCES:
objects = getattr(run_config, rsrc)
@ -162,7 +138,41 @@ async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
f"{rsrc.capitalize()}: {obj.identifier} served by {obj.provider_id}",
)
await validate_default_embedding_model(impls)
async def validate_vector_stores_config(vector_stores_config: VectorStoresConfig | None, impls: dict[Api, Any]):
"""Validate vector stores configuration."""
if vector_stores_config is None:
return
default_embedding_model = vector_stores_config.default_embedding_model
if default_embedding_model is None:
return
provider_id = default_embedding_model.provider_id
model_id = default_embedding_model.model_id
default_model_id = f"{provider_id}/{model_id}"
if Api.models not in impls:
raise ValueError(f"Models API is not available but vector_stores config requires model '{default_model_id}'")
models_impl = impls[Api.models]
response = await models_impl.list_models()
models_list = {m.identifier: m for m in response.data if m.model_type == "embedding"}
default_model = models_list.get(default_model_id)
if default_model is None:
raise ValueError(f"Embedding model '{default_model_id}' not found. Available embedding models: {models_list}")
embedding_dimension = default_model.metadata.get("embedding_dimension")
if embedding_dimension is None:
raise ValueError(f"Embedding model '{default_model_id}' is missing 'embedding_dimension' in metadata")
try:
int(embedding_dimension)
except ValueError as err:
raise ValueError(f"Embedding dimension '{embedding_dimension}' cannot be converted to an integer") from err
logger.debug(f"Validated default embedding model: {default_model_id} (dimension: {embedding_dimension})")
class EnvVarError(Exception):
@ -400,8 +410,8 @@ class Stack:
await impls[Api.conversations].initialize()
await register_resources(self.run_config, impls)
await refresh_registry_once(impls)
await validate_vector_stores_config(self.run_config.vector_stores, impls)
self.impls = impls
def create_registry_refresh_task(self):