diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index 70a091b77..1c5d26a84 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -16,6 +16,7 @@ from ollama import AsyncClient from llama_stack.providers.utils.inference.model_registry import ( build_model_alias, + build_model_alias_with_just_provider_model_id, ModelRegistryHelper, ) @@ -44,7 +45,7 @@ model_aliases = [ "llama3.1:8b-instruct-fp16", CoreModelId.llama3_1_8b_instruct.value, ), - build_model_alias( + build_model_alias_with_just_provider_model_id( "llama3.1:8b", CoreModelId.llama3_1_8b_instruct.value, ), @@ -52,7 +53,7 @@ model_aliases = [ "llama3.1:70b-instruct-fp16", CoreModelId.llama3_1_70b_instruct.value, ), - build_model_alias( + build_model_alias_with_just_provider_model_id( "llama3.1:70b", CoreModelId.llama3_1_70b_instruct.value, ), @@ -64,27 +65,27 @@ model_aliases = [ "llama3.2:3b-instruct-fp16", CoreModelId.llama3_2_3b_instruct.value, ), - build_model_alias( + build_model_alias_with_just_provider_model_id( "llama3.2:1b", CoreModelId.llama3_2_1b_instruct.value, ), - build_model_alias( + build_model_alias_with_just_provider_model_id( "llama3.2:3b", CoreModelId.llama3_2_3b_instruct.value, ), - build_model_alias( + build_model_alias_with_just_provider_model_id( "llama-guard3:8b", CoreModelId.llama_guard_3_8b.value, ), - build_model_alias( + build_model_alias_with_just_provider_model_id( "llama-guard3:1b", CoreModelId.llama_guard_3_1b.value, ), build_model_alias( - "x/llama3.2-vision:11b-instruct-fp16", + "llama3.2-vision:11b-instruct-fp16", CoreModelId.llama3_2_11b_vision_instruct.value, ), - build_model_alias( + build_model_alias_with_just_provider_model_id( "llama3.2-vision", CoreModelId.llama3_2_11b_vision_instruct.value, ), diff --git a/llama_stack/providers/utils/inference/model_registry.py b/llama_stack/providers/utils/inference/model_registry.py index 3834946f5..07225fac0 100644 --- a/llama_stack/providers/utils/inference/model_registry.py +++ b/llama_stack/providers/utils/inference/model_registry.py @@ -36,6 +36,16 @@ def build_model_alias(provider_model_id: str, model_descriptor: str) -> ModelAli ) +def build_model_alias_with_just_provider_model_id( + provider_model_id: str, model_descriptor: str +) -> ModelAlias: + return ModelAlias( + provider_model_id=provider_model_id, + aliases=[], + llama_model=model_descriptor, + ) + + class ModelRegistryHelper(ModelsProtocolPrivate): def __init__(self, model_aliases: List[ModelAlias]): self.alias_to_provider_id_map = {}