From aad0dedc85366004c1ec443aee94ea212724aff4 Mon Sep 17 00:00:00 2001 From: Botao Chen Date: Thu, 12 Dec 2024 14:13:17 -0800 Subject: [PATCH] address comments --- .../inline/post_training/torchtune/utils.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/llama_stack/providers/inline/post_training/torchtune/utils.py b/llama_stack/providers/inline/post_training/torchtune/utils.py index 6f7c18f72..2e22777de 100644 --- a/llama_stack/providers/inline/post_training/torchtune/utils.py +++ b/llama_stack/providers/inline/post_training/torchtune/utils.py @@ -43,6 +43,10 @@ class ModelConfigs(BaseModel): Llama_3_8B_Instruct: ModelConfig +class DatasetSchema(BaseModel): + alpaca: List[Dict[str, ParamType]] + + MODEL_CONFIGS = ModelConfigs( Llama3_2_3B_Instruct=ModelConfig( model_definition=lora_llama3_2_3b, @@ -56,8 +60,9 @@ MODEL_CONFIGS = ModelConfigs( ), ) -EXPECTED_DATASET_SCHEMA: Dict[str, List[Dict[str, ParamType]]] = { - "alpaca": [ + +EXPECTED_DATASET_SCHEMA = DatasetSchema( + alpaca=[ { ColumnName.instruction.value: StringType(), ColumnName.input.value: StringType(), @@ -74,7 +79,7 @@ EXPECTED_DATASET_SCHEMA: Dict[str, List[Dict[str, ParamType]]] = { ColumnName.output.value: StringType(), }, ] -} +) BuildLoraModelCallable = Callable[..., torch.nn.Module] BuildTokenizerCallable = Callable[..., Llama3Tokenizer] @@ -138,7 +143,10 @@ async def validate_input_dataset_schema( if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0: raise ValueError(f"Dataset {dataset_id} does not have a schema defined.") - if dataset_def.dataset_schema not in EXPECTED_DATASET_SCHEMA[dataset_type]: + if not hasattr(EXPECTED_DATASET_SCHEMA, dataset_type): + raise ValueError(f"Dataset type {dataset_type} is not supported.") + + if dataset_def.dataset_schema not in getattr(EXPECTED_DATASET_SCHEMA, dataset_type): raise ValueError( - f"Dataset {dataset_id} does not have a correct input schema in {EXPECTED_DATASET_SCHEMA[dataset_type]}" + f"Dataset {dataset_id} does not have a correct input schema in {EXPECTED_DATASET_SCHEMA.dataset_type}" )