mirror of
https://github.com/meta-llama/llama-stack.git
synced 2026-01-03 16:52:17 +00:00
refine
This commit is contained in:
parent
49e1a25343
commit
929e7c0e69
6 changed files with 9 additions and 25 deletions
|
|
@ -49,7 +49,7 @@ MODEL_CONFIGS: Dict[str, ModelConfig] = {
|
|||
|
||||
DATA_FORMATS: Dict[str, Transform] = {
|
||||
"instruct": InputOutputToMessages,
|
||||
"chat": ShareGPTToMessages,
|
||||
"dialog": ShareGPTToMessages,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -30,6 +30,7 @@ def llama_stack_instruct_to_torchtune_instruct(
|
|||
input_message = input_messages[0]
|
||||
|
||||
assert "content" in input_message, "content not found in input message"
|
||||
input = input_message["content"]
|
||||
output = sample[ColumnName.expected_answer.value]
|
||||
|
||||
return {
|
||||
|
|
|
|||
|
|
@ -135,7 +135,7 @@ class LoraFinetuningSingleDevice:
|
|||
self._data_format = training_config.data_config.data_format
|
||||
self._shuffle = training_config.data_config.shuffle
|
||||
self._batch_size = training_config.data_config.batch_size
|
||||
self._training_on_input = training_config.data_config.training_on_input
|
||||
self._train_on_input = training_config.data_config.train_on_input
|
||||
|
||||
# this is important for debugging purpose
|
||||
self.max_steps_per_epoch = training_config.max_steps_per_epoch
|
||||
|
|
@ -367,7 +367,7 @@ class LoraFinetuningSingleDevice:
|
|||
data_transform = await utils.get_data_transform(self._data_format)
|
||||
ds = SFTDataset(
|
||||
rows,
|
||||
message_transform=data_transform(train_on_input=self._training_on_input),
|
||||
message_transform=data_transform(train_on_input=self._train_on_input),
|
||||
model_transform=tokenizer,
|
||||
dataset_type=self._data_format.value,
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue