forked from phoenix-oss/llama-stack-mirror
## context In this PR, we defined 2 llama stack dataset formats (instruct, dialog) - For instruct dataset format, the column schema will be [chat_completion_input, expected_answer], which is consistent with the eval data format. This dataset format is the abstract of single turn QA style post training data - For dialog dataset format, the column schema will be [dialog], which is a list of user messages and assistant messages that interleave together. During training, the whole list will be the model input and the loss is calculated on assistant messages only. This dataset format is the abstract of multi turn chat style post training data ## changes - defined the 2 llama stack dataset formats - an adapter to convert llama stack dataset format to torchtune dataset format - move dataset format validation to post training level instead of torchtune level since it's not specific to torchtune - add localfs as datasetio provider ## test instruct format - use https://huggingface.co/datasets/llamastack/evals as dataset and the training works as expected <img width="1443" alt="Screenshot 2025-01-09 at 5 15 14 PM" src="https://github.com/user-attachments/assets/2c37a936-c67a-4726-90e0-23fa0ba7000f" /> - use my generated local dataset and the training works as expected <img width="1617" alt="Screenshot 2025-01-09 at 5 19 11 PM" src="https://github.com/user-attachments/assets/0bdccbbf-bac2-472a-a365-15213e49bbfa" /> dialog format - use my generated local dataset and the training works as expected <img width="1588" alt="Screenshot 2025-01-09 at 5 23 16 PM" src="https://github.com/user-attachments/assets/893915ba-41a3-4d51-948b-e872060ecede" />
52 lines
1.6 KiB
Python
52 lines
1.6 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
# Copyright (c) Meta Platforms, IAny, nc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
from llama_stack.apis.common.type_system import (
|
|
ChatCompletionInputType,
|
|
DialogType,
|
|
StringType,
|
|
)
|
|
from llama_stack.apis.datasets import Datasets
|
|
from llama_stack.providers.utils.common.data_schema_validator import (
|
|
ColumnName,
|
|
validate_dataset_schema,
|
|
)
|
|
|
|
EXPECTED_DATASET_SCHEMA = {
|
|
"instruct": [
|
|
{
|
|
ColumnName.chat_completion_input.value: ChatCompletionInputType(),
|
|
ColumnName.expected_answer.value: StringType(),
|
|
}
|
|
],
|
|
"dialog": [
|
|
{
|
|
ColumnName.dialog.value: DialogType(),
|
|
}
|
|
],
|
|
}
|
|
|
|
|
|
async def validate_input_dataset_schema(
|
|
datasets_api: Datasets,
|
|
dataset_id: str,
|
|
dataset_type: str,
|
|
) -> None:
|
|
dataset_def = await datasets_api.get_dataset(dataset_id=dataset_id)
|
|
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
|
|
raise ValueError(f"Dataset {dataset_id} does not have a schema defined.")
|
|
|
|
if dataset_type not in EXPECTED_DATASET_SCHEMA:
|
|
raise ValueError(f"Dataset type {dataset_type} is not supported.")
|
|
|
|
validate_dataset_schema(
|
|
dataset_def.dataset_schema, EXPECTED_DATASET_SCHEMA[dataset_type]
|
|
)
|