mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-03 17:29:01 +00:00
refine
This commit is contained in:
parent
f39dcdec9d
commit
49e1a25343
7 changed files with 134 additions and 110 deletions
|
@ -29,10 +29,8 @@ class OptimizerType(Enum):
|
|||
|
||||
@json_schema_type
|
||||
class DatasetFormat(Enum):
|
||||
alpaca = "alpaca"
|
||||
instruct = "instruct"
|
||||
chat_sharegpt = "chat_sharegpt"
|
||||
chat_openai = "chat_openai"
|
||||
chat = "chat"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -44,7 +42,6 @@ class DataConfig(BaseModel):
|
|||
validation_dataset_id: Optional[str] = None
|
||||
packed: Optional[bool] = False
|
||||
train_on_input: Optional[bool] = False
|
||||
column_map: Optional[Dict[str, str]] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
|
@ -0,0 +1,48 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, IAny, nc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from llama_stack.apis.common.type_system import StringType
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.providers.utils.common.data_schema_validator import (
|
||||
ColumnName,
|
||||
validate_dataset_schema,
|
||||
)
|
||||
|
||||
EXPECTED_DATASET_SCHEMA = {
|
||||
"instruct": [
|
||||
{
|
||||
ColumnName.chat_completion_input.value: StringType(),
|
||||
ColumnName.expected_answer.value: StringType(),
|
||||
}
|
||||
],
|
||||
"chat": [
|
||||
{
|
||||
ColumnName.dialog.value: StringType(),
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
async def validate_input_dataset_schema(
|
||||
datasets_api: Datasets,
|
||||
dataset_id: str,
|
||||
dataset_type: str,
|
||||
) -> None:
|
||||
dataset_def = await datasets_api.get_dataset(dataset_id=dataset_id)
|
||||
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
|
||||
raise ValueError(f"Dataset {dataset_id} does not have a schema defined.")
|
||||
|
||||
if dataset_type not in EXPECTED_DATASET_SCHEMA:
|
||||
raise ValueError(f"Dataset type {dataset_type} is not supported.")
|
||||
|
||||
validate_dataset_schema(
|
||||
dataset_def.dataset_schema, EXPECTED_DATASET_SCHEMA[dataset_type]
|
||||
)
|
|
@ -10,27 +10,16 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
from typing import Any, Callable, Dict
|
||||
|
||||
import torch
|
||||
from llama_models.datatypes import Model
|
||||
from llama_models.sku_list import resolve_model
|
||||
|
||||
from llama_stack.apis.common.type_system import ParamType, StringType
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.post_training import DatasetFormat
|
||||
from llama_stack.providers.utils.common.data_schema_validator import (
|
||||
ColumnName,
|
||||
validate_dataset_schema,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel
|
||||
from torchtune.data._messages import (
|
||||
AlpacaToMessages,
|
||||
InputOutputToMessages,
|
||||
OpenAIToMessages,
|
||||
ShareGPTToMessages,
|
||||
)
|
||||
from torchtune.data._messages import InputOutputToMessages, ShareGPTToMessages
|
||||
|
||||
from torchtune.models.llama3 import llama3_tokenizer
|
||||
from torchtune.models.llama3._tokenizer import Llama3Tokenizer
|
||||
|
@ -45,13 +34,6 @@ class ModelConfig(BaseModel):
|
|||
checkpoint_type: str
|
||||
|
||||
|
||||
class DatasetSchema(BaseModel):
|
||||
alpaca: List[Dict[str, ParamType]]
|
||||
instruct: List[Dict[str, ParamType]]
|
||||
chat_sharegpt: List[Dict[str, ParamType]]
|
||||
chat_openai: List[Dict[str, ParamType]]
|
||||
|
||||
|
||||
MODEL_CONFIGS: Dict[str, ModelConfig] = {
|
||||
"Llama3.2-3B-Instruct": ModelConfig(
|
||||
model_definition=lora_llama3_2_3b,
|
||||
|
@ -66,49 +48,11 @@ MODEL_CONFIGS: Dict[str, ModelConfig] = {
|
|||
}
|
||||
|
||||
DATA_FORMATS: Dict[str, Transform] = {
|
||||
"alpaca": AlpacaToMessages,
|
||||
"instruct": InputOutputToMessages,
|
||||
"chat_sharegpt": ShareGPTToMessages,
|
||||
"chat_openai": OpenAIToMessages,
|
||||
"chat": ShareGPTToMessages,
|
||||
}
|
||||
|
||||
|
||||
EXPECTED_DATASET_SCHEMA = DatasetSchema(
|
||||
alpaca=[
|
||||
{
|
||||
ColumnName.instruction.value: StringType(),
|
||||
ColumnName.input.value: StringType(),
|
||||
ColumnName.output.value: StringType(),
|
||||
ColumnName.text.value: StringType(),
|
||||
},
|
||||
{
|
||||
ColumnName.instruction.value: StringType(),
|
||||
ColumnName.input.value: StringType(),
|
||||
ColumnName.output.value: StringType(),
|
||||
},
|
||||
{
|
||||
ColumnName.instruction.value: StringType(),
|
||||
ColumnName.output.value: StringType(),
|
||||
},
|
||||
],
|
||||
instruct=[
|
||||
{
|
||||
ColumnName.input.value: StringType(),
|
||||
ColumnName.output.value: StringType(),
|
||||
}
|
||||
],
|
||||
chat_sharegpt=[
|
||||
{
|
||||
ColumnName.conversations.value: StringType(),
|
||||
}
|
||||
],
|
||||
chat_openai=[
|
||||
{
|
||||
ColumnName.messages.value: StringType(),
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
BuildLoraModelCallable = Callable[..., torch.nn.Module]
|
||||
BuildTokenizerCallable = Callable[..., Llama3Tokenizer]
|
||||
|
||||
|
@ -156,35 +100,3 @@ async def get_checkpointer_model_type(
|
|||
|
||||
async def get_data_transform(data_format: DatasetFormat) -> Transform:
|
||||
return DATA_FORMATS[data_format.value]
|
||||
|
||||
|
||||
async def validate_input_dataset_schema(
|
||||
datasets_api: Datasets,
|
||||
dataset_id: str,
|
||||
dataset_type: str,
|
||||
column_map: Optional[Dict[str, str]] = None,
|
||||
) -> None:
|
||||
dataset_def = await datasets_api.get_dataset(dataset_id=dataset_id)
|
||||
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
|
||||
raise ValueError(f"Dataset {dataset_id} does not have a schema defined.")
|
||||
|
||||
if not hasattr(EXPECTED_DATASET_SCHEMA, dataset_type):
|
||||
raise ValueError(f"Dataset type {dataset_type} is not supported.")
|
||||
|
||||
dataset_schema = {}
|
||||
|
||||
if column_map:
|
||||
for old_col_name in dataset_def.dataset_schema.keys():
|
||||
if old_col_name in column_map.values():
|
||||
new_col_name = next(
|
||||
k for k, v in column_map.items() if v == old_col_name
|
||||
)
|
||||
dataset_schema[new_col_name] = dataset_def.dataset_schema[old_col_name]
|
||||
else:
|
||||
dataset_schema[old_col_name] = dataset_def.dataset_schema[old_col_name]
|
||||
else:
|
||||
dataset_schema = dataset_def.dataset_schema
|
||||
|
||||
validate_dataset_schema(
|
||||
dataset_schema, getattr(EXPECTED_DATASET_SCHEMA, dataset_type)
|
||||
)
|
||||
|
|
|
@ -0,0 +1,61 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import Any, Mapping
|
||||
|
||||
from llama_stack.providers.utils.common.data_schema_validator import ColumnName
|
||||
|
||||
|
||||
def llama_stack_instruct_to_torchtune_instruct(
|
||||
sample: Mapping[str, Any]
|
||||
) -> Mapping[str, Any]:
|
||||
assert (
|
||||
ColumnName.chat_completion_input.value in sample
|
||||
and ColumnName.expected_answer.value in sample
|
||||
), "Invalid input row"
|
||||
input_messages = eval(str(sample[ColumnName.chat_completion_input.value]))
|
||||
|
||||
assert (
|
||||
len(input_messages) == 1
|
||||
), "llama stack intruct dataset format only supports 1 user message"
|
||||
input_message = input_messages[0]
|
||||
|
||||
assert "content" in input_message, "content not found in input message"
|
||||
output = sample[ColumnName.expected_answer.value]
|
||||
|
||||
return {
|
||||
"input": input,
|
||||
"output": output,
|
||||
}
|
||||
|
||||
|
||||
def llama_stack_chat_to_torchtune_chat(sample: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||
assert ColumnName.dialog.value in sample, "Invalid input row"
|
||||
role_map = {"user": "human", "assistant": "gpt"}
|
||||
dialog = eval(str(sample[ColumnName.dialog.value]))
|
||||
|
||||
assert len(dialog) > 1, "dialog must have at least 2 messagse"
|
||||
roles = []
|
||||
conversations = []
|
||||
for message in dialog:
|
||||
assert (
|
||||
"role" in message and "content" in message
|
||||
), "role and content must in message"
|
||||
roles.append(message["role"])
|
||||
conversations.append(
|
||||
{"from": role_map[message["role"]], "value": message["content"]}
|
||||
)
|
||||
|
||||
assert roles[0] == "user", "first message must be from user"
|
||||
assert "assistant" in roles, "at least 1 message should be from assistant"
|
||||
|
||||
return {"conversations": conversations}
|
|
@ -13,6 +13,10 @@
|
|||
from typing import Any, Dict, List, Mapping
|
||||
|
||||
import numpy as np
|
||||
from llama_stack.providers.inline.post_training.torchtune.datasets.format_adapter import (
|
||||
llama_stack_chat_to_torchtune_chat,
|
||||
llama_stack_instruct_to_torchtune_instruct,
|
||||
)
|
||||
|
||||
from torch.utils.data import Dataset
|
||||
from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX
|
||||
|
@ -26,10 +30,12 @@ class SFTDataset(Dataset):
|
|||
rows: List[Dict[str, Any]],
|
||||
message_transform: Transform,
|
||||
model_transform: Transform,
|
||||
dataset_type: str,
|
||||
) -> None:
|
||||
self._rows = rows
|
||||
self._message_transform = message_transform
|
||||
self._model_transform = model_transform
|
||||
self._dataset_type = dataset_type
|
||||
|
||||
def __len__(self):
|
||||
return len(self._rows)
|
||||
|
@ -39,6 +45,13 @@ class SFTDataset(Dataset):
|
|||
return self._prepare_sample(sample)
|
||||
|
||||
def _prepare_sample(self, sample: Mapping[str, Any]) -> Dict[str, Any]:
|
||||
if self._dataset_type == "instruct":
|
||||
sample = llama_stack_instruct_to_torchtune_instruct(sample)
|
||||
elif self._dataset_type == "chat":
|
||||
sample = llama_stack_chat_to_torchtune_chat(sample)
|
||||
else:
|
||||
raise ValueError(f"Invalid dataset type: {self._dataset_type}")
|
||||
|
||||
transformed_sample = self._message_transform(sample)
|
||||
if "messages" in transformed_sample:
|
||||
validate_messages(transformed_sample["messages"])
|
||||
|
|
|
@ -29,6 +29,9 @@ from llama_stack.apis.post_training import (
|
|||
from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
|
||||
|
||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||
from llama_stack.providers.inline.post_training.common.validator import (
|
||||
validate_input_dataset_schema,
|
||||
)
|
||||
|
||||
from llama_stack.providers.inline.post_training.torchtune.common import utils
|
||||
from llama_stack.providers.inline.post_training.torchtune.common.checkpointer import (
|
||||
|
@ -132,7 +135,7 @@ class LoraFinetuningSingleDevice:
|
|||
self._data_format = training_config.data_config.data_format
|
||||
self._shuffle = training_config.data_config.shuffle
|
||||
self._batch_size = training_config.data_config.batch_size
|
||||
self._column_map = training_config.data_config.column_map
|
||||
self._training_on_input = training_config.data_config.training_on_input
|
||||
|
||||
# this is important for debugging purpose
|
||||
self.max_steps_per_epoch = training_config.max_steps_per_epoch
|
||||
|
@ -356,22 +359,17 @@ class LoraFinetuningSingleDevice:
|
|||
all_rows = await fetch_rows(dataset_id)
|
||||
rows = all_rows.rows
|
||||
|
||||
# Curretly only support alpaca instruct dataset
|
||||
# TODO @SLR722 make the message_transform swappable and support more dataset types
|
||||
# TODO @SLR722 make the input dataset schema more flexible by exposing column_map
|
||||
await utils.validate_input_dataset_schema(
|
||||
await validate_input_dataset_schema(
|
||||
datasets_api=self.datasets_api,
|
||||
dataset_id=dataset_id,
|
||||
dataset_type=self._data_format.value,
|
||||
column_map=self._column_map,
|
||||
)
|
||||
data_transform = await utils.get_data_transform(self._data_format)
|
||||
ds = SFTDataset(
|
||||
rows,
|
||||
message_transform=data_transform(
|
||||
train_on_input=False, column_map=self._column_map
|
||||
),
|
||||
message_transform=data_transform(train_on_input=self._training_on_input),
|
||||
model_transform=tokenizer,
|
||||
dataset_type=self._data_format.value,
|
||||
)
|
||||
|
||||
sampler = DistributedSampler(
|
||||
|
|
|
@ -23,12 +23,7 @@ class ColumnName(Enum):
|
|||
completion_input = "completion_input"
|
||||
generated_answer = "generated_answer"
|
||||
context = "context"
|
||||
instruction = "instruction"
|
||||
input = "input"
|
||||
output = "output"
|
||||
text = "text"
|
||||
conversations = "conversations"
|
||||
messages = "messages"
|
||||
dialog = "dialog"
|
||||
|
||||
|
||||
VALID_SCHEMAS_FOR_SCORING = [
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue