diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 2cddc3970..7543256fc 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -236,8 +236,6 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): metadata = {} if model_type is None: model_type = ModelType.llm - if "embedding_dimension" not in metadata and model_type == ModelType.embedding: - raise ValueError("Embedding model must have an embedding dimension in its metadata") model = Model( identifier=model_id, provider_resource_id=provider_model_id, diff --git a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py index 6a83836e6..f605553ab 100644 --- a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py +++ b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py @@ -19,6 +19,7 @@ from llama_stack.apis.inference import ( ToolDefinition, ToolPromptFormat, ) +from llama_stack.apis.models import ModelType from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate from llama_stack.providers.utils.inference.embedding_mixin import ( SentenceTransformerEmbeddingMixin, @@ -44,6 +45,8 @@ class SentenceTransformersInferenceImpl( pass async def register_model(self, model: Model) -> None: + if "embedding_dimension" not in model.metadata and model.model_type == ModelType.embedding: + raise ValueError("Embedding model must have an embedding dimension in its metadata") _ = self._load_sentence_transformer_model(model.provider_resource_id) return model