diff --git a/llama_stack/providers/inline/post_training/common/validator.py b/llama_stack/providers/inline/post_training/common/validator.py index b0aec6187..950b75f86 100644 --- a/llama_stack/providers/inline/post_training/common/validator.py +++ b/llama_stack/providers/inline/post_training/common/validator.py @@ -17,10 +17,8 @@ from llama_stack.apis.common.type_system import ( DialogType, StringType, ) -from llama_stack.apis.datasets import Datasets from llama_stack.providers.utils.common.data_schema_validator import ( ColumnName, - validate_dataset_schema, ) EXPECTED_DATASET_SCHEMA: dict[str, list[dict[str, Any]]] = { @@ -36,21 +34,3 @@ EXPECTED_DATASET_SCHEMA: dict[str, list[dict[str, Any]]] = { } ], } - - -async def validate_input_dataset_schema( - datasets_api: Datasets, - dataset_id: str, - dataset_type: str, -) -> None: - dataset_def = await datasets_api.get_dataset(dataset_id=dataset_id) - if not dataset_def: - raise ValueError(f"Dataset {dataset_id} does not exist.") - - if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0: - raise ValueError(f"Dataset {dataset_id} does not have a schema defined.") - - if dataset_type not in EXPECTED_DATASET_SCHEMA: - raise ValueError(f"Dataset type {dataset_type} is not supported.") - - validate_dataset_schema(dataset_def.dataset_schema, EXPECTED_DATASET_SCHEMA[dataset_type]) diff --git a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py index 04bf86b97..5cf15824d 100644 --- a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +++ b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py @@ -48,9 +48,6 @@ from llama_stack.apis.post_training import ( from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR from llama_stack.distribution.utils.model_utils import model_local_dir from llama_stack.models.llama.sku_list import resolve_model -from llama_stack.providers.inline.post_training.common.validator import ( - validate_input_dataset_schema, -) from llama_stack.providers.inline.post_training.torchtune.common import utils from llama_stack.providers.inline.post_training.torchtune.common.checkpointer import ( TorchtuneCheckpointer, @@ -348,11 +345,9 @@ class LoraFinetuningSingleDevice: all_rows = await fetch_rows(dataset_id) rows = all_rows.data - await validate_input_dataset_schema( - datasets_api=self.datasets_api, - dataset_id=dataset_id, - dataset_type=self._data_format.value, - ) + # TODO (xiyan): validate dataset schema + # dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id) + data_transform = await utils.get_data_transform(self._data_format) ds = SFTDataset( rows,