forked from phoenix-oss/llama-stack-mirror
add model type to APIs (#588)
# What does this PR do? This PR adds a new model type field to support embedding models to be registered. Summary of changes: 1) Each registered model by default is an llm model. 2) User can specify an embedding model type, while registering.If specified, the model bypass the llama model checks since embedding models can by of any type and based on llama. 3) User needs to include the required embedding dimension in metadata. This will be used by embedding generation to generate the requried size of embeddings. ## Test Plan This PR will go together will need to be merged with two follow up PRs that will include test plans.
This commit is contained in:
parent
7e1d628864
commit
8e33db6015
6 changed files with 77 additions and 13 deletions
|
@ -9,6 +9,7 @@ from typing import List, Optional
|
|||
|
||||
from llama_models.sku_list import all_registered_models
|
||||
|
||||
from llama_stack.apis.models.models import ModelType
|
||||
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
|
||||
|
||||
from llama_stack.providers.utils.inference import (
|
||||
|
@ -77,7 +78,13 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
|
|||
return None
|
||||
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
provider_resource_id = self.get_provider_model_id(model.provider_resource_id)
|
||||
if model.model_type == ModelType.embedding_model:
|
||||
# embedding models are always registered by their provider model id and does not need to be mapped to a llama model
|
||||
provider_resource_id = model.provider_resource_id
|
||||
else:
|
||||
provider_resource_id = self.get_provider_model_id(
|
||||
model.provider_resource_id
|
||||
)
|
||||
if provider_resource_id:
|
||||
model.provider_resource_id = provider_resource_id
|
||||
else:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue