diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index d617f7e27..fe52a0abb 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -440,7 +440,7 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets): # infer provider from source if metadata: if metadata.get("provider"): - provider_id = metadata.get("provider") # pass through from nvidia datasetio + provider_id = metadata.get("provider") # pass through from nvidia datasetio elif source.type == DatasetType.rows.value: provider_id = "localfs" elif source.type == DatasetType.uri.value: diff --git a/llama_stack/providers/remote/datasetio/nvidia/datasetio.py b/llama_stack/providers/remote/datasetio/nvidia/datasetio.py index 5c6f13ac2..5d3cf5c03 100644 --- a/llama_stack/providers/remote/datasetio/nvidia/datasetio.py +++ b/llama_stack/providers/remote/datasetio/nvidia/datasetio.py @@ -9,12 +9,14 @@ from typing import Any, Dict, List, Optional import aiohttp from llama_stack.apis.common.content_types import URL -from llama_stack.apis.common.type_system import ParamType from llama_stack.apis.common.responses import PaginatedResponse +from llama_stack.apis.common.type_system import ParamType +from llama_stack.apis.datasets import Dataset from llama_stack.schema_utils import webmethod -from llama_stack.apis.datasets import DatasetPurpose, DataSource, Dataset + from .config import NvidiaDatasetIOConfig + class NvidiaDatasetIOAdapter: """Nvidia NeMo DatasetIO API.""" diff --git a/llama_stack/templates/nvidia/nvidia.py b/llama_stack/templates/nvidia/nvidia.py index ae0ab32b9..c393a292c 100644 --- a/llama_stack/templates/nvidia/nvidia.py +++ b/llama_stack/templates/nvidia/nvidia.py @@ -52,11 +52,6 @@ def get_distribution_template() -> DistributionTemplate: model_id="${env.SAFETY_MODEL}", provider_id="nvidia", ) - datasetio_provider = Provider( - provider_id="nvidia", - provider_type="remote::nvidia", - config=NvidiaDatasetIOConfig.sample_run_config(), - ) available_models = { "nvidia": MODEL_ENTRIES,