From f39dcdec9dfa7848a235a97274b7e9790b8ad8be Mon Sep 17 00:00:00 2001 From: Botao Chen Date: Mon, 6 Jan 2025 14:35:48 -0800 Subject: [PATCH] address comments --- .../post_training/torchtune/common/utils.py | 21 +++++++------------ .../utils/common/data_schema_validator.py | 6 ++++++ 2 files changed, 13 insertions(+), 14 deletions(-) diff --git a/llama_stack/providers/inline/post_training/torchtune/common/utils.py b/llama_stack/providers/inline/post_training/torchtune/common/utils.py index c4d230e2d..51d60ac3a 100644 --- a/llama_stack/providers/inline/post_training/torchtune/common/utils.py +++ b/llama_stack/providers/inline/post_training/torchtune/common/utils.py @@ -10,7 +10,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from enum import Enum from typing import Any, Callable, Dict, List, Optional import torch @@ -20,6 +19,10 @@ from llama_models.sku_list import resolve_model from llama_stack.apis.common.type_system import ParamType, StringType from llama_stack.apis.datasets import Datasets from llama_stack.apis.post_training import DatasetFormat +from llama_stack.providers.utils.common.data_schema_validator import ( + ColumnName, + validate_dataset_schema, +) from pydantic import BaseModel from torchtune.data._messages import ( @@ -36,15 +39,6 @@ from torchtune.models.llama3_2 import lora_llama3_2_3b from torchtune.modules.transforms import Transform -class ColumnName(Enum): - instruction = "instruction" - input = "input" - output = "output" - text = "text" - conversations = "conversations" - messages = "messages" - - class ModelConfig(BaseModel): model_definition: Any tokenizer_type: Any @@ -191,7 +185,6 @@ async def validate_input_dataset_schema( else: dataset_schema = dataset_def.dataset_schema - if dataset_schema not in getattr(EXPECTED_DATASET_SCHEMA, dataset_type): - raise ValueError( - f"Dataset {dataset_id} does not have a correct input schema in {getattr(EXPECTED_DATASET_SCHEMA, dataset_type)}" - ) + validate_dataset_schema( + dataset_schema, getattr(EXPECTED_DATASET_SCHEMA, dataset_type) + ) diff --git a/llama_stack/providers/utils/common/data_schema_validator.py b/llama_stack/providers/utils/common/data_schema_validator.py index af58a4592..0322602b7 100644 --- a/llama_stack/providers/utils/common/data_schema_validator.py +++ b/llama_stack/providers/utils/common/data_schema_validator.py @@ -23,6 +23,12 @@ class ColumnName(Enum): completion_input = "completion_input" generated_answer = "generated_answer" context = "context" + instruction = "instruction" + input = "input" + output = "output" + text = "text" + conversations = "conversations" + messages = "messages" VALID_SCHEMAS_FOR_SCORING = [