diff --git a/llama_stack/core/routers/__init__.py b/llama_stack/core/routers/__init__.py index f851473c1..df4df0463 100644 --- a/llama_stack/core/routers/__init__.py +++ b/llama_stack/core/routers/__init__.py @@ -93,7 +93,7 @@ async def get_auto_router_impl( await inference_store.initialize() api_to_dep_impl["store"] = inference_store - if api == Api.vector_io and run_config.vector_stores: + elif api == Api.vector_io: api_to_dep_impl["vector_stores_config"] = run_config.vector_stores impl = api_to_routers[api.value](routing_table, **api_to_dep_impl) diff --git a/llama_stack/core/routers/vector_io.py b/llama_stack/core/routers/vector_io.py index e06d1d45c..bfc5f7164 100644 --- a/llama_stack/core/routers/vector_io.py +++ b/llama_stack/core/routers/vector_io.py @@ -31,6 +31,7 @@ from llama_stack.apis.vector_io import ( VectorStoreObject, VectorStoreSearchResponsePage, ) +from llama_stack.core.datatypes import VectorStoresConfig from llama_stack.log import get_logger from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable @@ -43,7 +44,7 @@ class VectorIORouter(VectorIO): def __init__( self, routing_table: RoutingTable, - vector_stores_config=None, + vector_stores_config: VectorStoresConfig | None = None, ) -> None: logger.debug("Initializing VectorIORouter") self.routing_table = routing_table @@ -125,12 +126,15 @@ class VectorIORouter(VectorIO): provider_id = extra.get("provider_id") # Use default embedding model if not specified - if embedding_model is None and self.vector_stores_config is not None: - if self.vector_stores_config.default_embedding_model is not None: - # Construct the full model ID with provider prefix - embedding_provider_id = self.vector_stores_config.default_embedding_model.provider_id - model_id = self.vector_stores_config.default_embedding_model.model_id - embedding_model = f"{embedding_provider_id}/{model_id}" + if ( + embedding_model is None + and self.vector_stores_config + and self.vector_stores_config.default_embedding_model is not None + ): + # Construct the full model ID with provider prefix + embedding_provider_id = self.vector_stores_config.default_embedding_model.provider_id + model_id = self.vector_stores_config.default_embedding_model.model_id + embedding_model = f"{embedding_provider_id}/{model_id}" if embedding_model is not None and embedding_dimension is None: embedding_dimension = await self._get_embedding_model_dimension(embedding_model)