This commit is contained in:
Honglin Cao 2025-03-12 22:23:25 -04:00
parent 0cef9adda5
commit 8943755156
3 changed files with 48 additions and 86 deletions

View file

@ -7,19 +7,19 @@
from pathlib import Path
from llama_models.sku_list import all_registered_models
from llama_stack.apis.models.models import ModelType
from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput
from llama_stack.providers.inline.inference.sentence_transformers import (
SentenceTransformersInferenceConfig,
)
from llama_stack.providers.inline.memory.faiss.config import FaissImplConfig
from llama_stack.providers.remote.inference.centml.config import (
CentMLImplConfig,
)
# If your CentML adapter has a MODEL_ALIASES constant with known model mappings:
from llama_stack.providers.remote.inference.centml.centml import MODEL_ALIASES
from llama_stack.providers.remote.inference.centml.config import (
CentMLImplConfig,
)
from llama_stack.templates.template import (
DistributionTemplate,
RunConfigSettings,
@ -68,9 +68,7 @@ def get_distribution_template() -> DistributionTemplate:
)
# Map Llama Models to provider IDs if needed
core_model_to_hf_repo = {
m.descriptor(): m.huggingface_repo for m in all_registered_models()
}
core_model_to_hf_repo = {m.descriptor(): m.huggingface_repo for m in all_registered_models()}
default_models = [
ModelInput(
model_id=core_model_to_hf_repo[m.llama_model],
@ -103,9 +101,7 @@ def get_distribution_template() -> DistributionTemplate:
"memory": [memory_provider],
},
default_models=default_models + [embedding_model],
default_shields=[
ShieldInput(shield_id="meta-llama/Llama-Guard-3-8B")
],
default_shields=[ShieldInput(shield_id="meta-llama/Llama-Guard-3-8B")],
),
},
run_config_env_vars={