mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-17 09:09:48 +00:00
fixes for all providers
This commit is contained in:
parent
d5874735ea
commit
948f6ece6e
8 changed files with 133 additions and 135 deletions
|
|
@ -5,13 +5,35 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from collections import namedtuple
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
from llama_models.datatypes import CoreModelId
|
||||
from llama_models.sku_list import all_registered_models
|
||||
|
||||
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
|
||||
|
||||
ModelAlias = namedtuple("ModelAlias", ["provider_model_id", "aliases", "llama_model"])
|
||||
|
||||
|
||||
def get_huggingface_repo(core_model_id: CoreModelId) -> Optional[str]:
|
||||
"""Get the Hugging Face repository for a given CoreModelId."""
|
||||
for model in all_registered_models():
|
||||
if model.core_model_id == core_model_id:
|
||||
return model.huggingface_repo
|
||||
return None
|
||||
|
||||
|
||||
def build_model_alias(provider_model_id: str, core_model_id: CoreModelId) -> ModelAlias:
|
||||
return ModelAlias(
|
||||
provider_model_id=provider_model_id,
|
||||
aliases=[
|
||||
core_model_id.value,
|
||||
get_huggingface_repo(core_model_id),
|
||||
],
|
||||
llama_model=core_model_id.value,
|
||||
)
|
||||
|
||||
|
||||
class ModelLookup:
|
||||
def __init__(
|
||||
self,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue