change to llm type

This commit is contained in:
Dinesh Yeduguru 2024-12-10 16:57:00 -08:00
parent 62890b3171
commit 2b6aa71a21
3 changed files with 5 additions and 5 deletions

View file

@ -22,7 +22,7 @@ class CommonModelFields(BaseModel):
class ModelType(Enum): class ModelType(Enum):
llm_model = "llm_model" llm = "llm"
embedding_model = "embedding_model" embedding_model = "embedding_model"
@ -40,14 +40,14 @@ class Model(CommonModelFields, Resource):
model_config = ConfigDict(protected_namespaces=()) model_config = ConfigDict(protected_namespaces=())
model_type: ModelType = Field(default=ModelType.llm_model) model_type: ModelType = Field(default=ModelType.llm)
class ModelInput(CommonModelFields): class ModelInput(CommonModelFields):
model_id: str model_id: str
provider_id: Optional[str] = None provider_id: Optional[str] = None
provider_model_id: Optional[str] = None provider_model_id: Optional[str] = None
model_type: Optional[ModelType] = ModelType.llm_model model_type: Optional[ModelType] = ModelType.llm
model_config = ConfigDict(protected_namespaces=()) model_config = ConfigDict(protected_namespaces=())

View file

@ -168,7 +168,7 @@ class InferenceRouter(Inference):
model = await self.routing_table.get_model(model_id) model = await self.routing_table.get_model(model_id)
if model is None: if model is None:
raise ValueError(f"Model '{model_id}' not found") raise ValueError(f"Model '{model_id}' not found")
if model.model_type == ModelType.llm_model: if model.model_type == ModelType.llm:
raise ValueError( raise ValueError(
f"Model '{model_id}' is an LLM model and does not support embeddings" f"Model '{model_id}' is an LLM model and does not support embeddings"
) )

View file

@ -224,7 +224,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
if metadata is None: if metadata is None:
metadata = {} metadata = {}
if model_type is None: if model_type is None:
model_type = ModelType.llm_model model_type = ModelType.llm
if ( if (
"embedding_dimension" not in metadata "embedding_dimension" not in metadata
and model_type == ModelType.embedding_model and model_type == ModelType.embedding_model