address comments

This commit is contained in:
Botao Chen 2025-01-06 14:35:48 -08:00
parent 2a992d4f05
commit f39dcdec9d
2 changed files with 13 additions and 14 deletions

View file

@ -10,7 +10,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import Any, Callable, Dict, List, Optional
import torch
@ -20,6 +19,10 @@ from llama_models.sku_list import resolve_model
from llama_stack.apis.common.type_system import ParamType, StringType
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.post_training import DatasetFormat
from llama_stack.providers.utils.common.data_schema_validator import (
ColumnName,
validate_dataset_schema,
)
from pydantic import BaseModel
from torchtune.data._messages import (
@ -36,15 +39,6 @@ from torchtune.models.llama3_2 import lora_llama3_2_3b
from torchtune.modules.transforms import Transform
class ColumnName(Enum):
instruction = "instruction"
input = "input"
output = "output"
text = "text"
conversations = "conversations"
messages = "messages"
class ModelConfig(BaseModel):
model_definition: Any
tokenizer_type: Any
@ -191,7 +185,6 @@ async def validate_input_dataset_schema(
else:
dataset_schema = dataset_def.dataset_schema
if dataset_schema not in getattr(EXPECTED_DATASET_SCHEMA, dataset_type):
raise ValueError(
f"Dataset {dataset_id} does not have a correct input schema in {getattr(EXPECTED_DATASET_SCHEMA, dataset_type)}"
)
validate_dataset_schema(
dataset_schema, getattr(EXPECTED_DATASET_SCHEMA, dataset_type)
)

View file

@ -23,6 +23,12 @@ class ColumnName(Enum):
completion_input = "completion_input"
generated_answer = "generated_answer"
context = "context"
instruction = "instruction"
input = "input"
output = "output"
text = "text"
conversations = "conversations"
messages = "messages"
VALID_SCHEMAS_FOR_SCORING = [