mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-05 18:22:41 +00:00
refine
This commit is contained in:
parent
49e1a25343
commit
929e7c0e69
6 changed files with 9 additions and 25 deletions
|
@ -30,7 +30,7 @@ class OptimizerType(Enum):
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class DatasetFormat(Enum):
|
class DatasetFormat(Enum):
|
||||||
instruct = "instruct"
|
instruct = "instruct"
|
||||||
chat = "chat"
|
chat = "dialog"
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
|
|
@ -23,7 +23,7 @@ EXPECTED_DATASET_SCHEMA = {
|
||||||
ColumnName.expected_answer.value: StringType(),
|
ColumnName.expected_answer.value: StringType(),
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"chat": [
|
"dialog": [
|
||||||
{
|
{
|
||||||
ColumnName.dialog.value: StringType(),
|
ColumnName.dialog.value: StringType(),
|
||||||
}
|
}
|
||||||
|
|
|
@ -49,7 +49,7 @@ MODEL_CONFIGS: Dict[str, ModelConfig] = {
|
||||||
|
|
||||||
DATA_FORMATS: Dict[str, Transform] = {
|
DATA_FORMATS: Dict[str, Transform] = {
|
||||||
"instruct": InputOutputToMessages,
|
"instruct": InputOutputToMessages,
|
||||||
"chat": ShareGPTToMessages,
|
"dialog": ShareGPTToMessages,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -30,6 +30,7 @@ def llama_stack_instruct_to_torchtune_instruct(
|
||||||
input_message = input_messages[0]
|
input_message = input_messages[0]
|
||||||
|
|
||||||
assert "content" in input_message, "content not found in input message"
|
assert "content" in input_message, "content not found in input message"
|
||||||
|
input = input_message["content"]
|
||||||
output = sample[ColumnName.expected_answer.value]
|
output = sample[ColumnName.expected_answer.value]
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
|
|
@ -135,7 +135,7 @@ class LoraFinetuningSingleDevice:
|
||||||
self._data_format = training_config.data_config.data_format
|
self._data_format = training_config.data_config.data_format
|
||||||
self._shuffle = training_config.data_config.shuffle
|
self._shuffle = training_config.data_config.shuffle
|
||||||
self._batch_size = training_config.data_config.batch_size
|
self._batch_size = training_config.data_config.batch_size
|
||||||
self._training_on_input = training_config.data_config.training_on_input
|
self._train_on_input = training_config.data_config.train_on_input
|
||||||
|
|
||||||
# this is important for debugging purpose
|
# this is important for debugging purpose
|
||||||
self.max_steps_per_epoch = training_config.max_steps_per_epoch
|
self.max_steps_per_epoch = training_config.max_steps_per_epoch
|
||||||
|
@ -367,7 +367,7 @@ class LoraFinetuningSingleDevice:
|
||||||
data_transform = await utils.get_data_transform(self._data_format)
|
data_transform = await utils.get_data_transform(self._data_format)
|
||||||
ds = SFTDataset(
|
ds = SFTDataset(
|
||||||
rows,
|
rows,
|
||||||
message_transform=data_transform(train_on_input=self._training_on_input),
|
message_transform=data_transform(train_on_input=self._train_on_input),
|
||||||
model_transform=tokenizer,
|
model_transform=tokenizer,
|
||||||
dataset_type=self._data_format.value,
|
dataset_type=self._data_format.value,
|
||||||
)
|
)
|
||||||
|
|
|
@ -29,8 +29,8 @@ providers:
|
||||||
provider_type: inline::basic
|
provider_type: inline::basic
|
||||||
config: {}
|
config: {}
|
||||||
datasetio:
|
datasetio:
|
||||||
- provider_id: huggingface-0
|
- provider_id: localfs
|
||||||
provider_type: remote::huggingface
|
provider_type: inline::localfs
|
||||||
config: {}
|
config: {}
|
||||||
telemetry:
|
telemetry:
|
||||||
- provider_id: meta-reference
|
- provider_id: meta-reference
|
||||||
|
@ -68,23 +68,6 @@ metadata_store:
|
||||||
models: []
|
models: []
|
||||||
shields: []
|
shields: []
|
||||||
memory_banks: []
|
memory_banks: []
|
||||||
datasets:
|
datasets: []
|
||||||
- dataset_id: alpaca
|
|
||||||
provider_id: huggingface-0
|
|
||||||
url:
|
|
||||||
uri: https://huggingface.co/datasets/tatsu-lab/alpaca
|
|
||||||
metadata:
|
|
||||||
path: tatsu-lab/alpaca
|
|
||||||
name:
|
|
||||||
split: train
|
|
||||||
dataset_schema:
|
|
||||||
instruction:
|
|
||||||
type: string
|
|
||||||
input:
|
|
||||||
type: string
|
|
||||||
output:
|
|
||||||
type: string
|
|
||||||
text:
|
|
||||||
type: string
|
|
||||||
scoring_fns: []
|
scoring_fns: []
|
||||||
eval_tasks: []
|
eval_tasks: []
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue