From 8bd0a332066bf74b4fbd4ea9cc4217ec8f2f844f Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Mon, 18 Nov 2024 23:31:04 -0800 Subject: [PATCH] support adding alias for models without hf repo/sku entry --- .../providers/remote/inference/ollama/ollama.py | 15 ++++++++------- .../providers/utils/inference/model_registry.py | 10 ++++++++++ 2 files changed, 18 insertions(+), 7 deletions(-) diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index 70a091b77..f06d9fad7 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -16,6 +16,7 @@ from ollama import AsyncClient from llama_stack.providers.utils.inference.model_registry import ( build_model_alias, + build_model_alias_with_just_llama_model, ModelRegistryHelper, ) @@ -44,7 +45,7 @@ model_aliases = [ "llama3.1:8b-instruct-fp16", CoreModelId.llama3_1_8b_instruct.value, ), - build_model_alias( + build_model_alias_with_just_llama_model( "llama3.1:8b", CoreModelId.llama3_1_8b_instruct.value, ), @@ -52,7 +53,7 @@ model_aliases = [ "llama3.1:70b-instruct-fp16", CoreModelId.llama3_1_70b_instruct.value, ), - build_model_alias( + build_model_alias_with_just_llama_model( "llama3.1:70b", CoreModelId.llama3_1_70b_instruct.value, ), @@ -64,19 +65,19 @@ model_aliases = [ "llama3.2:3b-instruct-fp16", CoreModelId.llama3_2_3b_instruct.value, ), - build_model_alias( + build_model_alias_with_just_llama_model( "llama3.2:1b", CoreModelId.llama3_2_1b_instruct.value, ), - build_model_alias( + build_model_alias_with_just_llama_model( "llama3.2:3b", CoreModelId.llama3_2_3b_instruct.value, ), - build_model_alias( + build_model_alias_with_just_llama_model( "llama-guard3:8b", CoreModelId.llama_guard_3_8b.value, ), - build_model_alias( + build_model_alias_with_just_llama_model( "llama-guard3:1b", CoreModelId.llama_guard_3_1b.value, ), @@ -84,7 +85,7 @@ model_aliases = [ "x/llama3.2-vision:11b-instruct-fp16", CoreModelId.llama3_2_11b_vision_instruct.value, ), - build_model_alias( + build_model_alias_with_just_llama_model( "llama3.2-vision", CoreModelId.llama3_2_11b_vision_instruct.value, ), diff --git a/llama_stack/providers/utils/inference/model_registry.py b/llama_stack/providers/utils/inference/model_registry.py index 3834946f5..5ac738f47 100644 --- a/llama_stack/providers/utils/inference/model_registry.py +++ b/llama_stack/providers/utils/inference/model_registry.py @@ -36,6 +36,16 @@ def build_model_alias(provider_model_id: str, model_descriptor: str) -> ModelAli ) +def build_model_alias_with_just_llama_model( + provider_model_id: str, model_descriptor: str +) -> ModelAlias: + return ModelAlias( + provider_model_id=provider_model_id, + aliases=[], + llama_model=model_descriptor, + ) + + class ModelRegistryHelper(ModelsProtocolPrivate): def __init__(self, model_aliases: List[ModelAlias]): self.alias_to_provider_id_map = {}