fixes for all providers

This commit is contained in:
Dinesh Yeduguru 2024-11-12 14:25:28 -08:00
parent d5874735ea
commit 948f6ece6e
8 changed files with 133 additions and 135 deletions

View file

@ -5,13 +5,35 @@
# the root directory of this source tree.
from collections import namedtuple
from typing import List
from typing import List, Optional
from llama_models.datatypes import CoreModelId
from llama_models.sku_list import all_registered_models
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
ModelAlias = namedtuple("ModelAlias", ["provider_model_id", "aliases", "llama_model"])
def get_huggingface_repo(core_model_id: CoreModelId) -> Optional[str]:
"""Get the Hugging Face repository for a given CoreModelId."""
for model in all_registered_models():
if model.core_model_id == core_model_id:
return model.huggingface_repo
return None
def build_model_alias(provider_model_id: str, core_model_id: CoreModelId) -> ModelAlias:
return ModelAlias(
provider_model_id=provider_model_id,
aliases=[
core_model_id.value,
get_huggingface_repo(core_model_id),
],
llama_model=core_model_id.value,
)
class ModelLookup:
def __init__(
self,