mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-31 05:53:53 +00:00
add nvidia distribution
This commit is contained in:
parent
b5c6a80b2e
commit
103a3b1a4f
7 changed files with 67 additions and 12 deletions
|
|
@ -10,6 +10,7 @@ from llama_stack.distribution.datatypes import ModelInput, Provider, ToolGroupIn
|
|||
from llama_stack.models.llama.sku_list import all_registered_models
|
||||
from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig
|
||||
from llama_stack.providers.remote.inference.nvidia.models import _MODEL_ENTRIES
|
||||
from llama_stack.providers.remote.post_training.nvidia import NvidiaPostTrainingConfig
|
||||
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
|
||||
|
||||
|
||||
|
|
@ -18,6 +19,7 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
"inference": ["remote::nvidia"],
|
||||
"vector_io": ["inline::faiss"],
|
||||
"safety": ["inline::llama-guard"],
|
||||
"post_training": ["remote::nvidia"],
|
||||
"agents": ["inline::meta-reference"],
|
||||
"telemetry": ["inline::meta-reference"],
|
||||
"eval": ["inline::meta-reference"],
|
||||
|
|
@ -38,6 +40,12 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
config=NVIDIAConfig.sample_run_config(),
|
||||
)
|
||||
|
||||
post_training_provider = Provider(
|
||||
provider_id="nvidia",
|
||||
provider_type="remote::nvidia",
|
||||
config=NvidiaPostTrainingConfig.sample_run_config(),
|
||||
)
|
||||
|
||||
core_model_to_hf_repo = {m.descriptor(): m.huggingface_repo for m in all_registered_models()}
|
||||
default_models = [
|
||||
ModelInput(
|
||||
|
|
@ -90,5 +98,30 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
"",
|
||||
"NVIDIA API Key",
|
||||
),
|
||||
## Nemo Customizer related variables
|
||||
"NVIDIA_USER_ID": (
|
||||
"llama-stack-user",
|
||||
"NVIDIA User ID",
|
||||
),
|
||||
"NVIDIA_DATASET_NAMESPACE": (
|
||||
"default",
|
||||
"NVIDIA Dataset Namespace",
|
||||
),
|
||||
"NVIDIA_ACCESS_POLICIES": (
|
||||
"{}",
|
||||
"NVIDIA Access Policies",
|
||||
),
|
||||
"NVIDIA_PROJECT_ID": (
|
||||
"test-project",
|
||||
"NVIDIA Project ID",
|
||||
),
|
||||
"NVIDIA_CUSTOMIZER_URL": (
|
||||
"https://customizer.api.nvidia.com",
|
||||
"NVIDIA Customizer URL",
|
||||
),
|
||||
"NVIDIA_OUTPUT_MODEL_DIR": (
|
||||
"test-example-model@v1",
|
||||
"NVIDIA Output Model Directory",
|
||||
),
|
||||
},
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue