chore: Updating how default embedding model is set in stack (#3818)

# What does this PR do?

Refactor setting default vector store provider and embedding model to
use an optional `vector_stores` config in the `StackRunConfig` and clean
up code to do so (had to add back in some pieces of VectorDB). Also
added remote Qdrant and Weaviate to starter distro (based on other PR
where inference providers were added for UX).

New config is simply (default for Starter distro):

```yaml
vector_stores:
  default_provider_id: faiss
  default_embedding_model:
    provider_id: sentence-transformers
    model_id: nomic-ai/nomic-embed-text-v1.5
```

## Test Plan
CI and Unit tests.

---------

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
Co-authored-by: Ashwin Bharambe <ashwin.bharambe@gmail.com>
This commit is contained in:
Francisco Arceo 2025-10-20 17:22:45 -04:00 committed by GitHub
parent 2c43285e22
commit 48581bf651
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
48 changed files with 973 additions and 818 deletions

View file

@ -29,6 +29,7 @@ async def get_routing_table_impl(
from ..routing_tables.scoring_functions import ScoringFunctionsRoutingTable
from ..routing_tables.shields import ShieldsRoutingTable
from ..routing_tables.toolgroups import ToolGroupsRoutingTable
from ..routing_tables.vector_dbs import VectorDBsRoutingTable
api_to_tables = {
"models": ModelsRoutingTable,
@ -37,6 +38,7 @@ async def get_routing_table_impl(
"scoring_functions": ScoringFunctionsRoutingTable,
"benchmarks": BenchmarksRoutingTable,
"tool_groups": ToolGroupsRoutingTable,
"vector_dbs": VectorDBsRoutingTable,
}
if api.value not in api_to_tables:
@ -91,6 +93,9 @@ async def get_auto_router_impl(
await inference_store.initialize()
api_to_dep_impl["store"] = inference_store
elif api == Api.vector_io:
api_to_dep_impl["vector_stores_config"] = run_config.vector_stores
impl = api_to_routers[api.value](routing_table, **api_to_dep_impl)
await impl.initialize()
return impl

View file

@ -31,6 +31,7 @@ from llama_stack.apis.vector_io import (
VectorStoreObject,
VectorStoreSearchResponsePage,
)
from llama_stack.core.datatypes import VectorStoresConfig
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
@ -43,9 +44,11 @@ class VectorIORouter(VectorIO):
def __init__(
self,
routing_table: RoutingTable,
vector_stores_config: VectorStoresConfig | None = None,
) -> None:
logger.debug("Initializing VectorIORouter")
self.routing_table = routing_table
self.vector_stores_config = vector_stores_config
async def initialize(self) -> None:
logger.debug("VectorIORouter.initialize")
@ -122,6 +125,17 @@ class VectorIORouter(VectorIO):
embedding_dimension = extra.get("embedding_dimension")
provider_id = extra.get("provider_id")
# Use default embedding model if not specified
if (
embedding_model is None
and self.vector_stores_config
and self.vector_stores_config.default_embedding_model is not None
):
# Construct the full model ID with provider prefix
embedding_provider_id = self.vector_stores_config.default_embedding_model.provider_id
model_id = self.vector_stores_config.default_embedding_model.model_id
embedding_model = f"{embedding_provider_id}/{model_id}"
if embedding_model is not None and embedding_dimension is None:
embedding_dimension = await self._get_embedding_model_dimension(embedding_model)
@ -132,11 +146,24 @@ class VectorIORouter(VectorIO):
raise ValueError("No vector_io providers available")
if num_providers > 1:
available_providers = list(self.routing_table.impls_by_provider_id.keys())
raise ValueError(
f"Multiple vector_io providers available. Please specify provider_id in extra_body. "
f"Available providers: {available_providers}"
)
provider_id = list(self.routing_table.impls_by_provider_id.keys())[0]
# Use default configured provider
if self.vector_stores_config and self.vector_stores_config.default_provider_id:
default_provider = self.vector_stores_config.default_provider_id
if default_provider in available_providers:
provider_id = default_provider
logger.debug(f"Using configured default vector store provider: {provider_id}")
else:
raise ValueError(
f"Configured default vector store provider '{default_provider}' not found. "
f"Available providers: {available_providers}"
)
else:
raise ValueError(
f"Multiple vector_io providers available. Please specify provider_id in extra_body. "
f"Available providers: {available_providers}"
)
else:
provider_id = list(self.routing_table.impls_by_provider_id.keys())[0]
vector_db_id = f"vs_{uuid.uuid4()}"
registered_vector_db = await self.routing_table.register_vector_db(
@ -243,8 +270,7 @@ class VectorIORouter(VectorIO):
vector_store_id: str,
) -> VectorStoreDeleteResponse:
logger.debug(f"VectorIORouter.openai_delete_vector_store: {vector_store_id}")
provider = await self.routing_table.get_provider_impl(vector_store_id)
return await provider.openai_delete_vector_store(vector_store_id)
return await self.routing_table.openai_delete_vector_store(vector_store_id)
async def openai_search_vector_store(
self,