llama-stack/llama_stack/providers/utils/common/data_schema_validator.py
Botao Chen 25c1d9b037
[post training] define llama stack post training dataset format (#717)
## context
In this PR, we defined 2 llama stack dataset formats (instruct, dialog)

- For instruct dataset format, the column schema will be
[chat_completion_input, expected_answer], which is consistent with the
eval data format. This dataset format is the abstract of single turn QA
style post training data
- For dialog dataset format, the column schema will be [dialog], which
is a list of user messages and assistant messages that interleave
together. During training, the whole list will be the model input and
the loss is calculated on assistant messages only. This dataset format
is the abstract of multi turn chat style post training data

## changes
- defined the 2 llama stack dataset formats
- an adapter to convert llama stack dataset format to torchtune dataset
format
- move dataset format validation to post training level instead of
torchtune level since it's not specific to torchtune
- add localfs as datasetio provider


## test 
instruct format
- use https://huggingface.co/datasets/llamastack/evals as dataset and
the training works as expected
<img width="1443" alt="Screenshot 2025-01-09 at 5 15 14 PM"
src="https://github.com/user-attachments/assets/2c37a936-c67a-4726-90e0-23fa0ba7000f"
/>

- use my generated local dataset and the training works as expected

<img width="1617" alt="Screenshot 2025-01-09 at 5 19 11 PM"
src="https://github.com/user-attachments/assets/0bdccbbf-bac2-472a-a365-15213e49bbfa"
/>


dialog format
- use my generated local dataset and the training works as expected
<img width="1588" alt="Screenshot 2025-01-09 at 5 23 16 PM"
src="https://github.com/user-attachments/assets/893915ba-41a3-4d51-948b-e872060ecede"
/>
2025-01-14 12:48:49 -08:00

86 lines
2.4 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import Any, Dict, List
from llama_stack.apis.common.type_system import (
ChatCompletionInputType,
CompletionInputType,
StringType,
)
from llama_stack.distribution.datatypes import Api
class ColumnName(Enum):
input_query = "input_query"
expected_answer = "expected_answer"
chat_completion_input = "chat_completion_input"
completion_input = "completion_input"
generated_answer = "generated_answer"
context = "context"
dialog = "dialog"
VALID_SCHEMAS_FOR_SCORING = [
{
ColumnName.input_query.value: StringType(),
ColumnName.expected_answer.value: StringType(),
ColumnName.generated_answer.value: StringType(),
},
{
ColumnName.input_query.value: StringType(),
ColumnName.expected_answer.value: StringType(),
ColumnName.generated_answer.value: StringType(),
ColumnName.context.value: StringType(),
},
]
VALID_SCHEMAS_FOR_EVAL = [
{
ColumnName.input_query.value: StringType(),
ColumnName.expected_answer.value: StringType(),
ColumnName.chat_completion_input.value: ChatCompletionInputType(),
},
{
ColumnName.input_query.value: StringType(),
ColumnName.expected_answer.value: StringType(),
ColumnName.completion_input.value: CompletionInputType(),
},
]
def get_valid_schemas(api_str: str):
if api_str == Api.scoring.value:
return VALID_SCHEMAS_FOR_SCORING
elif api_str == Api.eval.value:
return VALID_SCHEMAS_FOR_EVAL
else:
raise ValueError(f"Invalid API string: {api_str}")
def validate_dataset_schema(
dataset_schema: Dict[str, Any],
expected_schemas: List[Dict[str, Any]],
):
if dataset_schema not in expected_schemas:
raise ValueError(
f"Dataset {dataset_schema} does not have a correct input schema in {expected_schemas}"
)
def validate_row_schema(
input_row: Dict[str, Any],
expected_schemas: List[Dict[str, Any]],
):
for schema in expected_schemas:
if all(key in input_row for key in schema):
return
raise ValueError(
f"Input row {input_row} does not match any of the expected schemas in {expected_schemas}"
)