mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-05 18:22:41 +00:00
refine
This commit is contained in:
parent
cc21a99fbf
commit
1e915d87fb
1 changed files with 1 additions and 2 deletions
|
@ -47,11 +47,10 @@ class SFTDataset(Dataset):
|
||||||
def _prepare_sample(self, sample: Mapping[str, Any]) -> Dict[str, Any]:
|
def _prepare_sample(self, sample: Mapping[str, Any]) -> Dict[str, Any]:
|
||||||
if self._dataset_type == "instruct":
|
if self._dataset_type == "instruct":
|
||||||
sample = llama_stack_instruct_to_torchtune_instruct(sample)
|
sample = llama_stack_instruct_to_torchtune_instruct(sample)
|
||||||
elif self._dataset_type == "chat":
|
elif self._dataset_type == "dialog":
|
||||||
sample = llama_stack_chat_to_torchtune_chat(sample)
|
sample = llama_stack_chat_to_torchtune_chat(sample)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid dataset type: {self._dataset_type}")
|
raise ValueError(f"Invalid dataset type: {self._dataset_type}")
|
||||||
|
|
||||||
transformed_sample = self._message_transform(sample)
|
transformed_sample = self._message_transform(sample)
|
||||||
if "messages" in transformed_sample:
|
if "messages" in transformed_sample:
|
||||||
validate_messages(transformed_sample["messages"])
|
validate_messages(transformed_sample["messages"])
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue