diff --git a/llama_stack/core/datatypes.py b/llama_stack/core/datatypes.py index 35592e76f..6d06adb84 100644 --- a/llama_stack/core/datatypes.py +++ b/llama_stack/core/datatypes.py @@ -354,17 +354,11 @@ class AuthenticationRequiredError(Exception): pass -class DefaultEmbeddingModel(BaseModel): - """Configuration for default embedding model.""" +class QualifiedModel(BaseModel): + """A qualified model identifier, consisting of a provider ID and a model ID.""" - provider_id: str = Field( - ..., - description="ID of the inference provider that serves the embedding model (e.g., 'sentence-transformers').", - ) - model_id: str = Field( - ..., - description="ID of the embedding model (e.g., 'nomic-ai/nomic-embed-text-v1.5').", - ) + provider_id: str + model_id: str class VectorStoresConfig(BaseModel): @@ -374,7 +368,7 @@ class VectorStoresConfig(BaseModel): default=None, description="ID of the vector_io provider to use as default when multiple providers are available and none is specified.", ) - default_embedding_model: DefaultEmbeddingModel | None = Field( + default_embedding_model: QualifiedModel | None = Field( default=None, description="Default embedding model configuration for vector stores.", ) diff --git a/llama_stack/core/stack.py b/llama_stack/core/stack.py index a55f26998..a2f7babd2 100644 --- a/llama_stack/core/stack.py +++ b/llama_stack/core/stack.py @@ -35,7 +35,7 @@ from llama_stack.apis.telemetry import Telemetry from llama_stack.apis.tools import RAGToolRuntime, ToolGroups, ToolRuntime from llama_stack.apis.vector_io import VectorIO from llama_stack.core.conversations.conversations import ConversationServiceConfig, ConversationServiceImpl -from llama_stack.core.datatypes import Provider, StackRunConfig +from llama_stack.core.datatypes import Provider, StackRunConfig, VectorStoresConfig from llama_stack.core.distribution import get_provider_registry from llama_stack.core.inspect import DistributionInspectConfig, DistributionInspectImpl from llama_stack.core.prompts.prompts import PromptServiceConfig, PromptServiceImpl @@ -139,58 +139,40 @@ async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]): ) -async def validate_vector_stores_config(run_config: StackRunConfig, impls: dict[Api, Any]): +async def validate_vector_stores_config(vector_stores_config: VectorStoresConfig | None, impls: dict[Api, Any]): """Validate vector stores configuration.""" - if not run_config.vector_stores: + if vector_stores_config is None: return - vector_stores_config = run_config.vector_stores + default_embedding_model = vector_stores_config.default_embedding_model + if default_embedding_model is None: + return - # Validate default embedding model if configured - if vector_stores_config.default_embedding_model: - default_embedding_model = vector_stores_config.default_embedding_model - provider_id = default_embedding_model.provider_id - model_id = default_embedding_model.model_id - # Construct the full model identifier - default_model_id = f"{provider_id}/{model_id}" + provider_id = default_embedding_model.provider_id + model_id = default_embedding_model.model_id + default_model_id = f"{provider_id}/{model_id}" - if Api.models not in impls: - raise ValueError( - f"Models API is not available but vector_stores config requires model '{default_model_id}'" - ) + if Api.models not in impls: + raise ValueError(f"Models API is not available but vector_stores config requires model '{default_model_id}'") - models_impl = impls[Api.models] - response = await models_impl.list_models() - models_list = response.data if hasattr(response, "data") else response + models_impl = impls[Api.models] + response = await models_impl.list_models() + models_list = {m.identifier: m for m in response.data if m.model_type == "embedding"} - # find default embedding model - default_model = None - for model in models_list: - if model.identifier == default_model_id: - default_model = model - break + default_model = models_list.get(default_model_id) + if default_model is None: + raise ValueError(f"Embedding model '{default_model_id}' not found. Available embedding models: {models_list}") - if not default_model: - available_models = [m.identifier for m in models_list if m.model_type == "embedding"] - raise ValueError( - f"Embedding model '{default_model_id}' not found. Available embedding models: {available_models}" - ) + embedding_dimension = default_model.metadata.get("embedding_dimension") + if embedding_dimension is None: + raise ValueError(f"Embedding model '{default_model_id}' is missing 'embedding_dimension' in metadata") - if default_model.model_type != "embedding": - raise ValueError(f"Model '{default_model_id}' is type '{default_model.model_type}', not 'embedding'") + try: + int(embedding_dimension) + except ValueError as err: + raise ValueError(f"Embedding dimension '{embedding_dimension}' cannot be converted to an integer") from err - embedding_dimension = default_model.metadata.get("embedding_dimension") - if embedding_dimension is None: - raise ValueError(f"Embedding model '{default_model_id}' is missing 'embedding_dimension' in metadata") - - try: - int(embedding_dimension) - except ValueError as err: - raise ValueError(f"Embedding dimension '{embedding_dimension}' cannot be converted to an integer") from err - - logger.debug(f"Validated default embedding model: {default_model_id} (dimension: {embedding_dimension})") - - # If no default embedding model is configured, that's fine - validation passes + logger.debug(f"Validated default embedding model: {default_model_id} (dimension: {embedding_dimension})") class EnvVarError(Exception): @@ -429,7 +411,7 @@ class Stack: await register_resources(self.run_config, impls) await refresh_registry_once(impls) - await validate_vector_stores_config(self.run_config, impls) + await validate_vector_stores_config(self.run_config.vector_stores, impls) self.impls = impls def create_registry_refresh_task(self): diff --git a/llama_stack/distributions/starter/starter.py b/llama_stack/distributions/starter/starter.py index 6a3f5e3ac..c8c7101a6 100644 --- a/llama_stack/distributions/starter/starter.py +++ b/llama_stack/distributions/starter/starter.py @@ -9,9 +9,9 @@ from typing import Any from llama_stack.core.datatypes import ( BuildProvider, - DefaultEmbeddingModel, Provider, ProviderSpec, + QualifiedModel, ShieldInput, ToolGroupInput, VectorStoresConfig, @@ -251,7 +251,7 @@ def get_distribution_template(name: str = "starter") -> DistributionTemplate: default_shields=default_shields, vector_stores_config=VectorStoresConfig( default_provider_id="faiss", - default_embedding_model=DefaultEmbeddingModel( + default_embedding_model=QualifiedModel( provider_id="sentence-transformers", model_id="nomic-ai/nomic-embed-text-v1.5", ), diff --git a/tests/unit/core/test_stack_validation.py b/tests/unit/core/test_stack_validation.py index 148d48a12..b50aff559 100644 --- a/tests/unit/core/test_stack_validation.py +++ b/tests/unit/core/test_stack_validation.py @@ -11,7 +11,7 @@ from unittest.mock import AsyncMock import pytest from llama_stack.apis.models import Model, ModelType -from llama_stack.core.datatypes import DefaultEmbeddingModel, StackRunConfig, StorageConfig, VectorStoresConfig +from llama_stack.core.datatypes import QualifiedModel, StackRunConfig, StorageConfig, VectorStoresConfig from llama_stack.core.stack import validate_vector_stores_config from llama_stack.providers.datatypes import Api @@ -25,7 +25,7 @@ class TestVectorStoresValidation: storage=StorageConfig(backends={}, stores={}), vector_stores=VectorStoresConfig( default_provider_id="faiss", - default_embedding_model=DefaultEmbeddingModel( + default_embedding_model=QualifiedModel( provider_id="p", model_id="missing", ), @@ -45,7 +45,7 @@ class TestVectorStoresValidation: storage=StorageConfig(backends={}, stores={}), vector_stores=VectorStoresConfig( default_provider_id="faiss", - default_embedding_model=DefaultEmbeddingModel( + default_embedding_model=QualifiedModel( provider_id="p", model_id="valid", ),