support adding alias for models without hf repo/sku entry

This commit is contained in:
Dinesh Yeduguru 2024-11-18 23:31:04 -08:00
parent fcc2132e6f
commit 8bd0a33206
2 changed files with 18 additions and 7 deletions

View file

@ -16,6 +16,7 @@ from ollama import AsyncClient
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
build_model_alias, build_model_alias,
build_model_alias_with_just_llama_model,
ModelRegistryHelper, ModelRegistryHelper,
) )
@ -44,7 +45,7 @@ model_aliases = [
"llama3.1:8b-instruct-fp16", "llama3.1:8b-instruct-fp16",
CoreModelId.llama3_1_8b_instruct.value, CoreModelId.llama3_1_8b_instruct.value,
), ),
build_model_alias( build_model_alias_with_just_llama_model(
"llama3.1:8b", "llama3.1:8b",
CoreModelId.llama3_1_8b_instruct.value, CoreModelId.llama3_1_8b_instruct.value,
), ),
@ -52,7 +53,7 @@ model_aliases = [
"llama3.1:70b-instruct-fp16", "llama3.1:70b-instruct-fp16",
CoreModelId.llama3_1_70b_instruct.value, CoreModelId.llama3_1_70b_instruct.value,
), ),
build_model_alias( build_model_alias_with_just_llama_model(
"llama3.1:70b", "llama3.1:70b",
CoreModelId.llama3_1_70b_instruct.value, CoreModelId.llama3_1_70b_instruct.value,
), ),
@ -64,19 +65,19 @@ model_aliases = [
"llama3.2:3b-instruct-fp16", "llama3.2:3b-instruct-fp16",
CoreModelId.llama3_2_3b_instruct.value, CoreModelId.llama3_2_3b_instruct.value,
), ),
build_model_alias( build_model_alias_with_just_llama_model(
"llama3.2:1b", "llama3.2:1b",
CoreModelId.llama3_2_1b_instruct.value, CoreModelId.llama3_2_1b_instruct.value,
), ),
build_model_alias( build_model_alias_with_just_llama_model(
"llama3.2:3b", "llama3.2:3b",
CoreModelId.llama3_2_3b_instruct.value, CoreModelId.llama3_2_3b_instruct.value,
), ),
build_model_alias( build_model_alias_with_just_llama_model(
"llama-guard3:8b", "llama-guard3:8b",
CoreModelId.llama_guard_3_8b.value, CoreModelId.llama_guard_3_8b.value,
), ),
build_model_alias( build_model_alias_with_just_llama_model(
"llama-guard3:1b", "llama-guard3:1b",
CoreModelId.llama_guard_3_1b.value, CoreModelId.llama_guard_3_1b.value,
), ),
@ -84,7 +85,7 @@ model_aliases = [
"x/llama3.2-vision:11b-instruct-fp16", "x/llama3.2-vision:11b-instruct-fp16",
CoreModelId.llama3_2_11b_vision_instruct.value, CoreModelId.llama3_2_11b_vision_instruct.value,
), ),
build_model_alias( build_model_alias_with_just_llama_model(
"llama3.2-vision", "llama3.2-vision",
CoreModelId.llama3_2_11b_vision_instruct.value, CoreModelId.llama3_2_11b_vision_instruct.value,
), ),

View file

@ -36,6 +36,16 @@ def build_model_alias(provider_model_id: str, model_descriptor: str) -> ModelAli
) )
def build_model_alias_with_just_llama_model(
provider_model_id: str, model_descriptor: str
) -> ModelAlias:
return ModelAlias(
provider_model_id=provider_model_id,
aliases=[],
llama_model=model_descriptor,
)
class ModelRegistryHelper(ModelsProtocolPrivate): class ModelRegistryHelper(ModelsProtocolPrivate):
def __init__(self, model_aliases: List[ModelAlias]): def __init__(self, model_aliases: List[ModelAlias]):
self.alias_to_provider_id_map = {} self.alias_to_provider_id_map = {}