mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-03 17:29:01 +00:00
refine
This commit is contained in:
parent
49e1a25343
commit
929e7c0e69
6 changed files with 9 additions and 25 deletions
|
@ -30,7 +30,7 @@ class OptimizerType(Enum):
|
|||
@json_schema_type
|
||||
class DatasetFormat(Enum):
|
||||
instruct = "instruct"
|
||||
chat = "chat"
|
||||
chat = "dialog"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
|
@ -23,7 +23,7 @@ EXPECTED_DATASET_SCHEMA = {
|
|||
ColumnName.expected_answer.value: StringType(),
|
||||
}
|
||||
],
|
||||
"chat": [
|
||||
"dialog": [
|
||||
{
|
||||
ColumnName.dialog.value: StringType(),
|
||||
}
|
||||
|
|
|
@ -49,7 +49,7 @@ MODEL_CONFIGS: Dict[str, ModelConfig] = {
|
|||
|
||||
DATA_FORMATS: Dict[str, Transform] = {
|
||||
"instruct": InputOutputToMessages,
|
||||
"chat": ShareGPTToMessages,
|
||||
"dialog": ShareGPTToMessages,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -30,6 +30,7 @@ def llama_stack_instruct_to_torchtune_instruct(
|
|||
input_message = input_messages[0]
|
||||
|
||||
assert "content" in input_message, "content not found in input message"
|
||||
input = input_message["content"]
|
||||
output = sample[ColumnName.expected_answer.value]
|
||||
|
||||
return {
|
||||
|
|
|
@ -135,7 +135,7 @@ class LoraFinetuningSingleDevice:
|
|||
self._data_format = training_config.data_config.data_format
|
||||
self._shuffle = training_config.data_config.shuffle
|
||||
self._batch_size = training_config.data_config.batch_size
|
||||
self._training_on_input = training_config.data_config.training_on_input
|
||||
self._train_on_input = training_config.data_config.train_on_input
|
||||
|
||||
# this is important for debugging purpose
|
||||
self.max_steps_per_epoch = training_config.max_steps_per_epoch
|
||||
|
@ -367,7 +367,7 @@ class LoraFinetuningSingleDevice:
|
|||
data_transform = await utils.get_data_transform(self._data_format)
|
||||
ds = SFTDataset(
|
||||
rows,
|
||||
message_transform=data_transform(train_on_input=self._training_on_input),
|
||||
message_transform=data_transform(train_on_input=self._train_on_input),
|
||||
model_transform=tokenizer,
|
||||
dataset_type=self._data_format.value,
|
||||
)
|
||||
|
|
|
@ -29,8 +29,8 @@ providers:
|
|||
provider_type: inline::basic
|
||||
config: {}
|
||||
datasetio:
|
||||
- provider_id: huggingface-0
|
||||
provider_type: remote::huggingface
|
||||
- provider_id: localfs
|
||||
provider_type: inline::localfs
|
||||
config: {}
|
||||
telemetry:
|
||||
- provider_id: meta-reference
|
||||
|
@ -68,23 +68,6 @@ metadata_store:
|
|||
models: []
|
||||
shields: []
|
||||
memory_banks: []
|
||||
datasets:
|
||||
- dataset_id: alpaca
|
||||
provider_id: huggingface-0
|
||||
url:
|
||||
uri: https://huggingface.co/datasets/tatsu-lab/alpaca
|
||||
metadata:
|
||||
path: tatsu-lab/alpaca
|
||||
name:
|
||||
split: train
|
||||
dataset_schema:
|
||||
instruction:
|
||||
type: string
|
||||
input:
|
||||
type: string
|
||||
output:
|
||||
type: string
|
||||
text:
|
||||
type: string
|
||||
datasets: []
|
||||
scoring_fns: []
|
||||
eval_tasks: []
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue