diff --git a/llama_stack/providers/inline/post_training/torchtune/datasets/sft.py b/llama_stack/providers/inline/post_training/torchtune/datasets/sft.py index 5501044c4..e4e4e9ebb 100644 --- a/llama_stack/providers/inline/post_training/torchtune/datasets/sft.py +++ b/llama_stack/providers/inline/post_training/torchtune/datasets/sft.py @@ -47,11 +47,10 @@ class SFTDataset(Dataset): def _prepare_sample(self, sample: Mapping[str, Any]) -> Dict[str, Any]: if self._dataset_type == "instruct": sample = llama_stack_instruct_to_torchtune_instruct(sample) - elif self._dataset_type == "chat": + elif self._dataset_type == "dialog": sample = llama_stack_chat_to_torchtune_chat(sample) else: raise ValueError(f"Invalid dataset type: {self._dataset_type}") - transformed_sample = self._message_transform(sample) if "messages" in transformed_sample: validate_messages(transformed_sample["messages"])