forked from phoenix-oss/llama-stack-mirror
parent
7ea14ae62e
commit
a7b929f17e
9 changed files with 53 additions and 56 deletions
|
@ -8,7 +8,12 @@ from pathlib import Path
|
|||
|
||||
from llama_models.sku_list import all_registered_models
|
||||
|
||||
from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput
|
||||
from llama_stack.distribution.datatypes import (
|
||||
ModelInput,
|
||||
Provider,
|
||||
ShieldInput,
|
||||
ToolGroupInput,
|
||||
)
|
||||
from llama_stack.providers.remote.inference.sambanova import SambaNovaImplConfig
|
||||
from llama_stack.providers.remote.inference.sambanova.sambanova import MODEL_ALIASES
|
||||
|
||||
|
@ -29,10 +34,11 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
"inline::rag-runtime",
|
||||
],
|
||||
}
|
||||
name = "sambanova"
|
||||
|
||||
inference_provider = Provider(
|
||||
provider_id="sambanova",
|
||||
provider_type="remote::sambanova",
|
||||
provider_id=name,
|
||||
provider_type=f"remote::{name}",
|
||||
config=SambaNovaImplConfig.sample_run_config(),
|
||||
)
|
||||
|
||||
|
@ -43,12 +49,28 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
ModelInput(
|
||||
model_id=core_model_to_hf_repo[m.llama_model],
|
||||
provider_model_id=m.provider_model_id,
|
||||
provider_id=name,
|
||||
)
|
||||
for m in MODEL_ALIASES
|
||||
]
|
||||
|
||||
default_tool_groups = [
|
||||
ToolGroupInput(
|
||||
toolgroup_id="builtin::websearch",
|
||||
provider_id="tavily-search",
|
||||
),
|
||||
ToolGroupInput(
|
||||
toolgroup_id="builtin::rag",
|
||||
provider_id="rag-runtime",
|
||||
),
|
||||
ToolGroupInput(
|
||||
toolgroup_id="builtin::code_interpreter",
|
||||
provider_id="code-interpreter",
|
||||
),
|
||||
]
|
||||
|
||||
return DistributionTemplate(
|
||||
name="sambanova",
|
||||
name=name,
|
||||
distro_type="self_hosted",
|
||||
description="Use SambaNova.AI for running LLM inference",
|
||||
docker_image=None,
|
||||
|
@ -62,6 +84,7 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
},
|
||||
default_models=default_models,
|
||||
default_shields=[ShieldInput(shield_id="meta-llama/Llama-Guard-3-8B")],
|
||||
default_tool_groups=default_tool_groups,
|
||||
),
|
||||
},
|
||||
run_config_env_vars={
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue