address comments

This commit is contained in:
Botao Chen 2024-12-12 14:13:17 -08:00
parent 3378c100f6
commit aad0dedc85

View file

@ -43,6 +43,10 @@ class ModelConfigs(BaseModel):
Llama_3_8B_Instruct: ModelConfig
class DatasetSchema(BaseModel):
alpaca: List[Dict[str, ParamType]]
MODEL_CONFIGS = ModelConfigs(
Llama3_2_3B_Instruct=ModelConfig(
model_definition=lora_llama3_2_3b,
@ -56,8 +60,9 @@ MODEL_CONFIGS = ModelConfigs(
),
)
EXPECTED_DATASET_SCHEMA: Dict[str, List[Dict[str, ParamType]]] = {
"alpaca": [
EXPECTED_DATASET_SCHEMA = DatasetSchema(
alpaca=[
{
ColumnName.instruction.value: StringType(),
ColumnName.input.value: StringType(),
@ -74,7 +79,7 @@ EXPECTED_DATASET_SCHEMA: Dict[str, List[Dict[str, ParamType]]] = {
ColumnName.output.value: StringType(),
},
]
}
)
BuildLoraModelCallable = Callable[..., torch.nn.Module]
BuildTokenizerCallable = Callable[..., Llama3Tokenizer]
@ -138,7 +143,10 @@ async def validate_input_dataset_schema(
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
raise ValueError(f"Dataset {dataset_id} does not have a schema defined.")
if dataset_def.dataset_schema not in EXPECTED_DATASET_SCHEMA[dataset_type]:
if not hasattr(EXPECTED_DATASET_SCHEMA, dataset_type):
raise ValueError(f"Dataset type {dataset_type} is not supported.")
if dataset_def.dataset_schema not in getattr(EXPECTED_DATASET_SCHEMA, dataset_type):
raise ValueError(
f"Dataset {dataset_id} does not have a correct input schema in {EXPECTED_DATASET_SCHEMA[dataset_type]}"
f"Dataset {dataset_id} does not have a correct input schema in {EXPECTED_DATASET_SCHEMA.dataset_type}"
)