fixes after rebase

This commit is contained in:
Dinesh Yeduguru 2024-11-12 15:37:07 -08:00
parent 948f6ece6e
commit 919d421bcf
11 changed files with 72 additions and 70 deletions

View file

@ -7,7 +7,6 @@
from collections import namedtuple
from typing import List, Optional
from llama_models.datatypes import CoreModelId
from llama_models.sku_list import all_registered_models
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
@ -15,22 +14,22 @@ from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
ModelAlias = namedtuple("ModelAlias", ["provider_model_id", "aliases", "llama_model"])
def get_huggingface_repo(core_model_id: CoreModelId) -> Optional[str]:
def get_huggingface_repo(model_descriptor: str) -> Optional[str]:
"""Get the Hugging Face repository for a given CoreModelId."""
for model in all_registered_models():
if model.core_model_id == core_model_id:
if model.descriptor() == model_descriptor:
return model.huggingface_repo
return None
def build_model_alias(provider_model_id: str, core_model_id: CoreModelId) -> ModelAlias:
def build_model_alias(provider_model_id: str, model_descriptor: str) -> ModelAlias:
return ModelAlias(
provider_model_id=provider_model_id,
aliases=[
core_model_id.value,
get_huggingface_repo(core_model_id),
model_descriptor,
get_huggingface_repo(model_descriptor),
],
llama_model=core_model_id.value,
llama_model=model_descriptor,
)