address comments

This commit is contained in:
Botao Chen 2025-01-06 14:35:48 -08:00
parent 2a992d4f05
commit f39dcdec9d
2 changed files with 13 additions and 14 deletions

View file

@ -10,7 +10,6 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from enum import Enum
from typing import Any, Callable, Dict, List, Optional from typing import Any, Callable, Dict, List, Optional
import torch import torch
@ -20,6 +19,10 @@ from llama_models.sku_list import resolve_model
from llama_stack.apis.common.type_system import ParamType, StringType from llama_stack.apis.common.type_system import ParamType, StringType
from llama_stack.apis.datasets import Datasets from llama_stack.apis.datasets import Datasets
from llama_stack.apis.post_training import DatasetFormat from llama_stack.apis.post_training import DatasetFormat
from llama_stack.providers.utils.common.data_schema_validator import (
ColumnName,
validate_dataset_schema,
)
from pydantic import BaseModel from pydantic import BaseModel
from torchtune.data._messages import ( from torchtune.data._messages import (
@ -36,15 +39,6 @@ from torchtune.models.llama3_2 import lora_llama3_2_3b
from torchtune.modules.transforms import Transform from torchtune.modules.transforms import Transform
class ColumnName(Enum):
instruction = "instruction"
input = "input"
output = "output"
text = "text"
conversations = "conversations"
messages = "messages"
class ModelConfig(BaseModel): class ModelConfig(BaseModel):
model_definition: Any model_definition: Any
tokenizer_type: Any tokenizer_type: Any
@ -191,7 +185,6 @@ async def validate_input_dataset_schema(
else: else:
dataset_schema = dataset_def.dataset_schema dataset_schema = dataset_def.dataset_schema
if dataset_schema not in getattr(EXPECTED_DATASET_SCHEMA, dataset_type): validate_dataset_schema(
raise ValueError( dataset_schema, getattr(EXPECTED_DATASET_SCHEMA, dataset_type)
f"Dataset {dataset_id} does not have a correct input schema in {getattr(EXPECTED_DATASET_SCHEMA, dataset_type)}" )
)

View file

@ -23,6 +23,12 @@ class ColumnName(Enum):
completion_input = "completion_input" completion_input = "completion_input"
generated_answer = "generated_answer" generated_answer = "generated_answer"
context = "context" context = "context"
instruction = "instruction"
input = "input"
output = "output"
text = "text"
conversations = "conversations"
messages = "messages"
VALID_SCHEMAS_FOR_SCORING = [ VALID_SCHEMAS_FOR_SCORING = [