diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 18b0c891f..d617f7e27 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -438,7 +438,10 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets): provider_dataset_id = dataset_id # infer provider from source - if source.type == DatasetType.rows.value: + if metadata: + if metadata.get("provider"): + provider_id = metadata.get("provider") # pass through from nvidia datasetio + elif source.type == DatasetType.rows.value: provider_id = "localfs" elif source.type == DatasetType.uri.value: # infer provider from uri diff --git a/llama_stack/providers/remote/datasetio/nvidia/datasetio.py b/llama_stack/providers/remote/datasetio/nvidia/datasetio.py index 95bd155a8..5c6f13ac2 100644 --- a/llama_stack/providers/remote/datasetio/nvidia/datasetio.py +++ b/llama_stack/providers/remote/datasetio/nvidia/datasetio.py @@ -10,20 +10,17 @@ import aiohttp from llama_stack.apis.common.content_types import URL from llama_stack.apis.common.type_system import ParamType -from llama_stack.apis.datasetio import IterrowsResponse +from llama_stack.apis.common.responses import PaginatedResponse from llama_stack.schema_utils import webmethod - +from llama_stack.apis.datasets import DatasetPurpose, DataSource, Dataset from .config import NvidiaDatasetIOConfig - class NvidiaDatasetIOAdapter: """Nvidia NeMo DatasetIO API.""" def __init__(self, config: NvidiaDatasetIOConfig): self.config = config self.headers = {} - if config.api_key: - self.headers["Authorization"] = f"Bearer {config.api_key}" async def _make_request( self, @@ -36,46 +33,49 @@ class NvidiaDatasetIOAdapter: ) -> Dict[str, Any]: """Helper method to make HTTP requests to the Customizer API.""" url = f"{self.config.datasets_url}{path}" - request_headers = self.headers.copy() # Create a copy to avoid modifying the original + request_headers = self.headers.copy() if headers: request_headers.update(headers) - # Add content-type header for JSON requests - if json and "Content-Type" not in request_headers: - request_headers["Content-Type"] = "application/json" - async with aiohttp.ClientSession(headers=request_headers) as session: async with session.request(method, url, params=params, json=json, **kwargs) as response: - if response.status >= 400: + if response.status != 200: error_data = await response.json() raise Exception(f"API request failed: {error_data}") return await response.json() - @webmethod(route="/datasets", method="POST") async def register_dataset( self, - dataset_id: str, - dataset_schema: Dict[str, ParamType], - url: URL, - provider_dataset_id: Optional[str] = None, - provider_id: Optional[str] = None, - metadata: Optional[Dict[str, Any]] = None, - ) -> None: + dataset_def: Dataset, + ) -> Dataset: """Register a new dataset. Args: - dataset_id: The ID of the dataset. - dataset_schema: The schema of the dataset. - url: The URL of the dataset. - provider_dataset_id: The ID of the provider dataset. - provider_id: The ID of the provider. - metadata: The metadata of the dataset. - + dataset_def : The dataset definition. + dataset_id: The ID of the dataset. + source: The source of the dataset. + metadata: The metadata of the dataset. + format: The format of the dataset. + description: The description of the dataset. Returns: None """ - ... + ## add warnings for unsupported params + request_body = { + "name": dataset_def.identifier, + "namespace": self.config.dataset_namespace, + "files_url": dataset_def.source.uri, + "project": self.config.project_id, + } + if dataset_def.metadata: + request_body["format"] = dataset_def.metadata.get("format") + request_body["description"] = dataset_def.metadata.get("description") + await self._make_request( + "POST", + "/v1/datasets", + json=request_body, + ) @webmethod(route="/datasets/{dataset_id:path}", method="POST") async def update_dataset( @@ -93,16 +93,19 @@ class NvidiaDatasetIOAdapter: async def unregister_dataset( self, dataset_id: str, - namespace: Optional[str] = "default", ) -> None: - raise NotImplementedError("Not implemented") + await self._make_request( + "DELETE", + f"/v1/datasets/{self.config.dataset_namespace}/{dataset_id}", + headers={"Accept": "application/json", "Content-Type": "application/json"}, + ) async def iterrows( self, dataset_id: str, start_index: Optional[int] = None, limit: Optional[int] = None, - ) -> IterrowsResponse: + ) -> PaginatedResponse: raise NotImplementedError("Not implemented") async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: diff --git a/llama_stack/templates/nvidia/run.yaml b/llama_stack/templates/nvidia/run.yaml index 3b421c0b4..7488997d8 100644 --- a/llama_stack/templates/nvidia/run.yaml +++ b/llama_stack/templates/nvidia/run.yaml @@ -62,6 +62,13 @@ providers: project_id: ${env.NVIDIA_PROJECT_ID:test-project} customizer_url: ${env.NVIDIA_CUSTOMIZER_URL:http://nemo.test} datasetio: + - provider_id: localfs + provider_type: inline::localfs + config: + kvstore: + type: sqlite + namespace: null + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/localfs_datasetio.db - provider_id: nvidia provider_type: remote::nvidia config: