From 02f1c47416f68f5dbe7d7e4878f1eddfbe9f124e Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Mon, 18 Nov 2024 23:50:18 -0800 Subject: [PATCH] support adding alias for models without hf repo/sku entry (#481) # What does this PR do? adds a new method build_model_alias_with_just_llama_model which is needed for cases like ollama's quantized models which do not really have a repo in hf and an entry in SKU list. ## Test Plan pytest -v -s -m "ollama" llama_stack/providers/tests/inference/test_text_inference.py --------- Co-authored-by: Dinesh Yeduguru --- .../providers/remote/inference/ollama/ollama.py | 17 +++++++++-------- .../providers/utils/inference/model_registry.py | 10 ++++++++++ 2 files changed, 19 insertions(+), 8 deletions(-) diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index 70a091b77..1c5d26a84 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -16,6 +16,7 @@ from ollama import AsyncClient from llama_stack.providers.utils.inference.model_registry import ( build_model_alias, + build_model_alias_with_just_provider_model_id, ModelRegistryHelper, ) @@ -44,7 +45,7 @@ model_aliases = [ "llama3.1:8b-instruct-fp16", CoreModelId.llama3_1_8b_instruct.value, ), - build_model_alias( + build_model_alias_with_just_provider_model_id( "llama3.1:8b", CoreModelId.llama3_1_8b_instruct.value, ), @@ -52,7 +53,7 @@ model_aliases = [ "llama3.1:70b-instruct-fp16", CoreModelId.llama3_1_70b_instruct.value, ), - build_model_alias( + build_model_alias_with_just_provider_model_id( "llama3.1:70b", CoreModelId.llama3_1_70b_instruct.value, ), @@ -64,27 +65,27 @@ model_aliases = [ "llama3.2:3b-instruct-fp16", CoreModelId.llama3_2_3b_instruct.value, ), - build_model_alias( + build_model_alias_with_just_provider_model_id( "llama3.2:1b", CoreModelId.llama3_2_1b_instruct.value, ), - build_model_alias( + build_model_alias_with_just_provider_model_id( "llama3.2:3b", CoreModelId.llama3_2_3b_instruct.value, ), - build_model_alias( + build_model_alias_with_just_provider_model_id( "llama-guard3:8b", CoreModelId.llama_guard_3_8b.value, ), - build_model_alias( + build_model_alias_with_just_provider_model_id( "llama-guard3:1b", CoreModelId.llama_guard_3_1b.value, ), build_model_alias( - "x/llama3.2-vision:11b-instruct-fp16", + "llama3.2-vision:11b-instruct-fp16", CoreModelId.llama3_2_11b_vision_instruct.value, ), - build_model_alias( + build_model_alias_with_just_provider_model_id( "llama3.2-vision", CoreModelId.llama3_2_11b_vision_instruct.value, ), diff --git a/llama_stack/providers/utils/inference/model_registry.py b/llama_stack/providers/utils/inference/model_registry.py index 3834946f5..07225fac0 100644 --- a/llama_stack/providers/utils/inference/model_registry.py +++ b/llama_stack/providers/utils/inference/model_registry.py @@ -36,6 +36,16 @@ def build_model_alias(provider_model_id: str, model_descriptor: str) -> ModelAli ) +def build_model_alias_with_just_provider_model_id( + provider_model_id: str, model_descriptor: str +) -> ModelAlias: + return ModelAlias( + provider_model_id=provider_model_id, + aliases=[], + llama_model=model_descriptor, + ) + + class ModelRegistryHelper(ModelsProtocolPrivate): def __init__(self, model_aliases: List[ModelAlias]): self.alias_to_provider_id_map = {}