Revert "Revert "add model type to APIs" (#605)"

This reverts commit 47b2dc8ae3.
This commit is contained in:
Dinesh Yeduguru 2024-12-11 10:18:00 -08:00 committed by GitHub
parent 47b2dc8ae3
commit 310c15bada
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 77 additions and 13 deletions

View file

@ -9,6 +9,7 @@ from typing import List, Optional
from llama_models.sku_list import all_registered_models
from llama_stack.apis.models.models import ModelType
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
from llama_stack.providers.utils.inference import (
@ -77,7 +78,13 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
return None
async def register_model(self, model: Model) -> Model:
provider_resource_id = self.get_provider_model_id(model.provider_resource_id)
if model.model_type == ModelType.embedding_model:
# embedding models are always registered by their provider model id and does not need to be mapped to a llama model
provider_resource_id = model.provider_resource_id
else:
provider_resource_id = self.get_provider_model_id(
model.provider_resource_id
)
if provider_resource_id:
model.provider_resource_id = provider_resource_id
else: