From 3d891fc9ba764baf50fbb7d4ecc194a3a7b680ba Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Thu, 20 Feb 2025 11:21:13 -0800 Subject: [PATCH] ModelAlias cleanup --- .../providers/utils/inference/model_registry.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/llama_stack/providers/utils/inference/model_registry.py b/llama_stack/providers/utils/inference/model_registry.py index c5f6cd6b5..5cb785843 100644 --- a/llama_stack/providers/utils/inference/model_registry.py +++ b/llama_stack/providers/utils/inference/model_registry.py @@ -4,9 +4,10 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from collections import namedtuple from typing import List, Optional +from pydantic import BaseModel, Field + from llama_stack.apis.models.models import ModelType from llama_stack.models.llama.sku_list import all_registered_models from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate @@ -14,7 +15,14 @@ from llama_stack.providers.utils.inference import ( ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR, ) -ModelAlias = namedtuple("ModelAlias", ["provider_model_id", "aliases", "llama_model"]) + +# TODO: this class is more confusing than useful right now. We need to make it +# more closer to the Model class. +class ModelAlias(BaseModel): + provider_model_id: str + aliases: List[str] = Field(default_factory=list) + llama_model: Optional[str] = None + model_type: ModelType = ModelType.llm def get_huggingface_repo(model_descriptor: str) -> Optional[str]: