diff --git a/llama_stack/apis/models/models.py b/llama_stack/apis/models/models.py index ba1af944c..713dfa377 100644 --- a/llama_stack/apis/models/models.py +++ b/llama_stack/apis/models/models.py @@ -22,7 +22,7 @@ class CommonModelFields(BaseModel): class ModelType(Enum): - llm_model = "llm_model" + llm = "llm" embedding_model = "embedding_model" @@ -40,14 +40,14 @@ class Model(CommonModelFields, Resource): model_config = ConfigDict(protected_namespaces=()) - model_type: ModelType = Field(default=ModelType.llm_model) + model_type: ModelType = Field(default=ModelType.llm) class ModelInput(CommonModelFields): model_id: str provider_id: Optional[str] = None provider_model_id: Optional[str] = None - model_type: Optional[ModelType] = ModelType.llm_model + model_type: Optional[ModelType] = ModelType.llm model_config = ConfigDict(protected_namespaces=()) diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 2e182a6b8..51be318cb 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -168,7 +168,7 @@ class InferenceRouter(Inference): model = await self.routing_table.get_model(model_id) if model is None: raise ValueError(f"Model '{model_id}' not found") - if model.model_type == ModelType.llm_model: + if model.model_type == ModelType.llm: raise ValueError( f"Model '{model_id}' is an LLM model and does not support embeddings" ) diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 9c24faf8d..bc3de8be0 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -224,7 +224,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): if metadata is None: metadata = {} if model_type is None: - model_type = ModelType.llm_model + model_type = ModelType.llm if ( "embedding_dimension" not in metadata and model_type == ModelType.embedding_model