add datasetio to distribution

This commit is contained in:
raspawar 2025-03-26 15:21:38 +05:30 committed by Rashmi Pawar
parent ae973c9595
commit 1e77873a02
7 changed files with 37 additions and 35 deletions

View file

@ -7,6 +7,7 @@
from pathlib import Path
from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput, ToolGroupInput
from llama_stack.providers.remote.datasetio.nvidia import NvidiaDatasetIOConfig
from llama_stack.providers.remote.eval.nvidia import NVIDIAEvalConfig
from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig
from llama_stack.providers.remote.inference.nvidia.models import MODEL_ENTRIES
@ -23,7 +24,7 @@ def get_distribution_template() -> DistributionTemplate:
"telemetry": ["inline::meta-reference"],
"eval": ["remote::nvidia"],
"post_training": ["remote::nvidia"],
"datasetio": ["inline::localfs"],
"datasetio": ["inline::localfs", "remote::nvidia"],
"scoring": ["inline::basic"],
"tool_runtime": ["inline::rag-runtime"],
}
@ -51,6 +52,11 @@ def get_distribution_template() -> DistributionTemplate:
model_id="${env.SAFETY_MODEL}",
provider_id="nvidia",
)
datasetio_provider = Provider(
provider_id="nvidia",
provider_type="remote::nvidia",
config=NvidiaDatasetIOConfig.sample_run_config(),
)
available_models = {
"nvidia": MODEL_ENTRIES,
@ -75,6 +81,7 @@ def get_distribution_template() -> DistributionTemplate:
"run.yaml": RunConfigSettings(
provider_overrides={
"inference": [inference_provider],
"datasetio": [datasetio_provider],
"eval": [eval_provider],
},
default_models=default_models,