mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-05 18:22:41 +00:00
refine
This commit is contained in:
parent
f39dcdec9d
commit
49e1a25343
7 changed files with 134 additions and 110 deletions
|
@ -29,10 +29,8 @@ class OptimizerType(Enum):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class DatasetFormat(Enum):
|
class DatasetFormat(Enum):
|
||||||
alpaca = "alpaca"
|
|
||||||
instruct = "instruct"
|
instruct = "instruct"
|
||||||
chat_sharegpt = "chat_sharegpt"
|
chat = "chat"
|
||||||
chat_openai = "chat_openai"
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -44,7 +42,6 @@ class DataConfig(BaseModel):
|
||||||
validation_dataset_id: Optional[str] = None
|
validation_dataset_id: Optional[str] = None
|
||||||
packed: Optional[bool] = False
|
packed: Optional[bool] = False
|
||||||
train_on_input: Optional[bool] = False
|
train_on_input: Optional[bool] = False
|
||||||
column_map: Optional[Dict[str, str]] = None
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
|
|
@ -0,0 +1,48 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
# Copyright (c) Meta Platforms, IAny, nc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
from llama_stack.apis.common.type_system import StringType
|
||||||
|
from llama_stack.apis.datasets import Datasets
|
||||||
|
from llama_stack.providers.utils.common.data_schema_validator import (
|
||||||
|
ColumnName,
|
||||||
|
validate_dataset_schema,
|
||||||
|
)
|
||||||
|
|
||||||
|
EXPECTED_DATASET_SCHEMA = {
|
||||||
|
"instruct": [
|
||||||
|
{
|
||||||
|
ColumnName.chat_completion_input.value: StringType(),
|
||||||
|
ColumnName.expected_answer.value: StringType(),
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"chat": [
|
||||||
|
{
|
||||||
|
ColumnName.dialog.value: StringType(),
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def validate_input_dataset_schema(
|
||||||
|
datasets_api: Datasets,
|
||||||
|
dataset_id: str,
|
||||||
|
dataset_type: str,
|
||||||
|
) -> None:
|
||||||
|
dataset_def = await datasets_api.get_dataset(dataset_id=dataset_id)
|
||||||
|
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
|
||||||
|
raise ValueError(f"Dataset {dataset_id} does not have a schema defined.")
|
||||||
|
|
||||||
|
if dataset_type not in EXPECTED_DATASET_SCHEMA:
|
||||||
|
raise ValueError(f"Dataset type {dataset_type} is not supported.")
|
||||||
|
|
||||||
|
validate_dataset_schema(
|
||||||
|
dataset_def.dataset_schema, EXPECTED_DATASET_SCHEMA[dataset_type]
|
||||||
|
)
|
|
@ -10,27 +10,16 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Any, Callable, Dict, List, Optional
|
from typing import Any, Callable, Dict
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from llama_models.datatypes import Model
|
from llama_models.datatypes import Model
|
||||||
from llama_models.sku_list import resolve_model
|
from llama_models.sku_list import resolve_model
|
||||||
|
|
||||||
from llama_stack.apis.common.type_system import ParamType, StringType
|
|
||||||
from llama_stack.apis.datasets import Datasets
|
|
||||||
from llama_stack.apis.post_training import DatasetFormat
|
from llama_stack.apis.post_training import DatasetFormat
|
||||||
from llama_stack.providers.utils.common.data_schema_validator import (
|
|
||||||
ColumnName,
|
|
||||||
validate_dataset_schema,
|
|
||||||
)
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from torchtune.data._messages import (
|
from torchtune.data._messages import InputOutputToMessages, ShareGPTToMessages
|
||||||
AlpacaToMessages,
|
|
||||||
InputOutputToMessages,
|
|
||||||
OpenAIToMessages,
|
|
||||||
ShareGPTToMessages,
|
|
||||||
)
|
|
||||||
|
|
||||||
from torchtune.models.llama3 import llama3_tokenizer
|
from torchtune.models.llama3 import llama3_tokenizer
|
||||||
from torchtune.models.llama3._tokenizer import Llama3Tokenizer
|
from torchtune.models.llama3._tokenizer import Llama3Tokenizer
|
||||||
|
@ -45,13 +34,6 @@ class ModelConfig(BaseModel):
|
||||||
checkpoint_type: str
|
checkpoint_type: str
|
||||||
|
|
||||||
|
|
||||||
class DatasetSchema(BaseModel):
|
|
||||||
alpaca: List[Dict[str, ParamType]]
|
|
||||||
instruct: List[Dict[str, ParamType]]
|
|
||||||
chat_sharegpt: List[Dict[str, ParamType]]
|
|
||||||
chat_openai: List[Dict[str, ParamType]]
|
|
||||||
|
|
||||||
|
|
||||||
MODEL_CONFIGS: Dict[str, ModelConfig] = {
|
MODEL_CONFIGS: Dict[str, ModelConfig] = {
|
||||||
"Llama3.2-3B-Instruct": ModelConfig(
|
"Llama3.2-3B-Instruct": ModelConfig(
|
||||||
model_definition=lora_llama3_2_3b,
|
model_definition=lora_llama3_2_3b,
|
||||||
|
@ -66,49 +48,11 @@ MODEL_CONFIGS: Dict[str, ModelConfig] = {
|
||||||
}
|
}
|
||||||
|
|
||||||
DATA_FORMATS: Dict[str, Transform] = {
|
DATA_FORMATS: Dict[str, Transform] = {
|
||||||
"alpaca": AlpacaToMessages,
|
|
||||||
"instruct": InputOutputToMessages,
|
"instruct": InputOutputToMessages,
|
||||||
"chat_sharegpt": ShareGPTToMessages,
|
"chat": ShareGPTToMessages,
|
||||||
"chat_openai": OpenAIToMessages,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
EXPECTED_DATASET_SCHEMA = DatasetSchema(
|
|
||||||
alpaca=[
|
|
||||||
{
|
|
||||||
ColumnName.instruction.value: StringType(),
|
|
||||||
ColumnName.input.value: StringType(),
|
|
||||||
ColumnName.output.value: StringType(),
|
|
||||||
ColumnName.text.value: StringType(),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
ColumnName.instruction.value: StringType(),
|
|
||||||
ColumnName.input.value: StringType(),
|
|
||||||
ColumnName.output.value: StringType(),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
ColumnName.instruction.value: StringType(),
|
|
||||||
ColumnName.output.value: StringType(),
|
|
||||||
},
|
|
||||||
],
|
|
||||||
instruct=[
|
|
||||||
{
|
|
||||||
ColumnName.input.value: StringType(),
|
|
||||||
ColumnName.output.value: StringType(),
|
|
||||||
}
|
|
||||||
],
|
|
||||||
chat_sharegpt=[
|
|
||||||
{
|
|
||||||
ColumnName.conversations.value: StringType(),
|
|
||||||
}
|
|
||||||
],
|
|
||||||
chat_openai=[
|
|
||||||
{
|
|
||||||
ColumnName.messages.value: StringType(),
|
|
||||||
}
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
BuildLoraModelCallable = Callable[..., torch.nn.Module]
|
BuildLoraModelCallable = Callable[..., torch.nn.Module]
|
||||||
BuildTokenizerCallable = Callable[..., Llama3Tokenizer]
|
BuildTokenizerCallable = Callable[..., Llama3Tokenizer]
|
||||||
|
|
||||||
|
@ -156,35 +100,3 @@ async def get_checkpointer_model_type(
|
||||||
|
|
||||||
async def get_data_transform(data_format: DatasetFormat) -> Transform:
|
async def get_data_transform(data_format: DatasetFormat) -> Transform:
|
||||||
return DATA_FORMATS[data_format.value]
|
return DATA_FORMATS[data_format.value]
|
||||||
|
|
||||||
|
|
||||||
async def validate_input_dataset_schema(
|
|
||||||
datasets_api: Datasets,
|
|
||||||
dataset_id: str,
|
|
||||||
dataset_type: str,
|
|
||||||
column_map: Optional[Dict[str, str]] = None,
|
|
||||||
) -> None:
|
|
||||||
dataset_def = await datasets_api.get_dataset(dataset_id=dataset_id)
|
|
||||||
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
|
|
||||||
raise ValueError(f"Dataset {dataset_id} does not have a schema defined.")
|
|
||||||
|
|
||||||
if not hasattr(EXPECTED_DATASET_SCHEMA, dataset_type):
|
|
||||||
raise ValueError(f"Dataset type {dataset_type} is not supported.")
|
|
||||||
|
|
||||||
dataset_schema = {}
|
|
||||||
|
|
||||||
if column_map:
|
|
||||||
for old_col_name in dataset_def.dataset_schema.keys():
|
|
||||||
if old_col_name in column_map.values():
|
|
||||||
new_col_name = next(
|
|
||||||
k for k, v in column_map.items() if v == old_col_name
|
|
||||||
)
|
|
||||||
dataset_schema[new_col_name] = dataset_def.dataset_schema[old_col_name]
|
|
||||||
else:
|
|
||||||
dataset_schema[old_col_name] = dataset_def.dataset_schema[old_col_name]
|
|
||||||
else:
|
|
||||||
dataset_schema = dataset_def.dataset_schema
|
|
||||||
|
|
||||||
validate_dataset_schema(
|
|
||||||
dataset_schema, getattr(EXPECTED_DATASET_SCHEMA, dataset_type)
|
|
||||||
)
|
|
||||||
|
|
|
@ -0,0 +1,61 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the BSD-style license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any, Mapping
|
||||||
|
|
||||||
|
from llama_stack.providers.utils.common.data_schema_validator import ColumnName
|
||||||
|
|
||||||
|
|
||||||
|
def llama_stack_instruct_to_torchtune_instruct(
|
||||||
|
sample: Mapping[str, Any]
|
||||||
|
) -> Mapping[str, Any]:
|
||||||
|
assert (
|
||||||
|
ColumnName.chat_completion_input.value in sample
|
||||||
|
and ColumnName.expected_answer.value in sample
|
||||||
|
), "Invalid input row"
|
||||||
|
input_messages = eval(str(sample[ColumnName.chat_completion_input.value]))
|
||||||
|
|
||||||
|
assert (
|
||||||
|
len(input_messages) == 1
|
||||||
|
), "llama stack intruct dataset format only supports 1 user message"
|
||||||
|
input_message = input_messages[0]
|
||||||
|
|
||||||
|
assert "content" in input_message, "content not found in input message"
|
||||||
|
output = sample[ColumnName.expected_answer.value]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"input": input,
|
||||||
|
"output": output,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def llama_stack_chat_to_torchtune_chat(sample: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||||
|
assert ColumnName.dialog.value in sample, "Invalid input row"
|
||||||
|
role_map = {"user": "human", "assistant": "gpt"}
|
||||||
|
dialog = eval(str(sample[ColumnName.dialog.value]))
|
||||||
|
|
||||||
|
assert len(dialog) > 1, "dialog must have at least 2 messagse"
|
||||||
|
roles = []
|
||||||
|
conversations = []
|
||||||
|
for message in dialog:
|
||||||
|
assert (
|
||||||
|
"role" in message and "content" in message
|
||||||
|
), "role and content must in message"
|
||||||
|
roles.append(message["role"])
|
||||||
|
conversations.append(
|
||||||
|
{"from": role_map[message["role"]], "value": message["content"]}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert roles[0] == "user", "first message must be from user"
|
||||||
|
assert "assistant" in roles, "at least 1 message should be from assistant"
|
||||||
|
|
||||||
|
return {"conversations": conversations}
|
|
@ -13,6 +13,10 @@
|
||||||
from typing import Any, Dict, List, Mapping
|
from typing import Any, Dict, List, Mapping
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from llama_stack.providers.inline.post_training.torchtune.datasets.format_adapter import (
|
||||||
|
llama_stack_chat_to_torchtune_chat,
|
||||||
|
llama_stack_instruct_to_torchtune_instruct,
|
||||||
|
)
|
||||||
|
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX
|
from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX
|
||||||
|
@ -26,10 +30,12 @@ class SFTDataset(Dataset):
|
||||||
rows: List[Dict[str, Any]],
|
rows: List[Dict[str, Any]],
|
||||||
message_transform: Transform,
|
message_transform: Transform,
|
||||||
model_transform: Transform,
|
model_transform: Transform,
|
||||||
|
dataset_type: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
self._rows = rows
|
self._rows = rows
|
||||||
self._message_transform = message_transform
|
self._message_transform = message_transform
|
||||||
self._model_transform = model_transform
|
self._model_transform = model_transform
|
||||||
|
self._dataset_type = dataset_type
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self._rows)
|
return len(self._rows)
|
||||||
|
@ -39,6 +45,13 @@ class SFTDataset(Dataset):
|
||||||
return self._prepare_sample(sample)
|
return self._prepare_sample(sample)
|
||||||
|
|
||||||
def _prepare_sample(self, sample: Mapping[str, Any]) -> Dict[str, Any]:
|
def _prepare_sample(self, sample: Mapping[str, Any]) -> Dict[str, Any]:
|
||||||
|
if self._dataset_type == "instruct":
|
||||||
|
sample = llama_stack_instruct_to_torchtune_instruct(sample)
|
||||||
|
elif self._dataset_type == "chat":
|
||||||
|
sample = llama_stack_chat_to_torchtune_chat(sample)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid dataset type: {self._dataset_type}")
|
||||||
|
|
||||||
transformed_sample = self._message_transform(sample)
|
transformed_sample = self._message_transform(sample)
|
||||||
if "messages" in transformed_sample:
|
if "messages" in transformed_sample:
|
||||||
validate_messages(transformed_sample["messages"])
|
validate_messages(transformed_sample["messages"])
|
||||||
|
|
|
@ -29,6 +29,9 @@ from llama_stack.apis.post_training import (
|
||||||
from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
|
from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
|
||||||
|
|
||||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||||
|
from llama_stack.providers.inline.post_training.common.validator import (
|
||||||
|
validate_input_dataset_schema,
|
||||||
|
)
|
||||||
|
|
||||||
from llama_stack.providers.inline.post_training.torchtune.common import utils
|
from llama_stack.providers.inline.post_training.torchtune.common import utils
|
||||||
from llama_stack.providers.inline.post_training.torchtune.common.checkpointer import (
|
from llama_stack.providers.inline.post_training.torchtune.common.checkpointer import (
|
||||||
|
@ -132,7 +135,7 @@ class LoraFinetuningSingleDevice:
|
||||||
self._data_format = training_config.data_config.data_format
|
self._data_format = training_config.data_config.data_format
|
||||||
self._shuffle = training_config.data_config.shuffle
|
self._shuffle = training_config.data_config.shuffle
|
||||||
self._batch_size = training_config.data_config.batch_size
|
self._batch_size = training_config.data_config.batch_size
|
||||||
self._column_map = training_config.data_config.column_map
|
self._training_on_input = training_config.data_config.training_on_input
|
||||||
|
|
||||||
# this is important for debugging purpose
|
# this is important for debugging purpose
|
||||||
self.max_steps_per_epoch = training_config.max_steps_per_epoch
|
self.max_steps_per_epoch = training_config.max_steps_per_epoch
|
||||||
|
@ -356,22 +359,17 @@ class LoraFinetuningSingleDevice:
|
||||||
all_rows = await fetch_rows(dataset_id)
|
all_rows = await fetch_rows(dataset_id)
|
||||||
rows = all_rows.rows
|
rows = all_rows.rows
|
||||||
|
|
||||||
# Curretly only support alpaca instruct dataset
|
await validate_input_dataset_schema(
|
||||||
# TODO @SLR722 make the message_transform swappable and support more dataset types
|
|
||||||
# TODO @SLR722 make the input dataset schema more flexible by exposing column_map
|
|
||||||
await utils.validate_input_dataset_schema(
|
|
||||||
datasets_api=self.datasets_api,
|
datasets_api=self.datasets_api,
|
||||||
dataset_id=dataset_id,
|
dataset_id=dataset_id,
|
||||||
dataset_type=self._data_format.value,
|
dataset_type=self._data_format.value,
|
||||||
column_map=self._column_map,
|
|
||||||
)
|
)
|
||||||
data_transform = await utils.get_data_transform(self._data_format)
|
data_transform = await utils.get_data_transform(self._data_format)
|
||||||
ds = SFTDataset(
|
ds = SFTDataset(
|
||||||
rows,
|
rows,
|
||||||
message_transform=data_transform(
|
message_transform=data_transform(train_on_input=self._training_on_input),
|
||||||
train_on_input=False, column_map=self._column_map
|
|
||||||
),
|
|
||||||
model_transform=tokenizer,
|
model_transform=tokenizer,
|
||||||
|
dataset_type=self._data_format.value,
|
||||||
)
|
)
|
||||||
|
|
||||||
sampler = DistributedSampler(
|
sampler = DistributedSampler(
|
||||||
|
|
|
@ -23,12 +23,7 @@ class ColumnName(Enum):
|
||||||
completion_input = "completion_input"
|
completion_input = "completion_input"
|
||||||
generated_answer = "generated_answer"
|
generated_answer = "generated_answer"
|
||||||
context = "context"
|
context = "context"
|
||||||
instruction = "instruction"
|
dialog = "dialog"
|
||||||
input = "input"
|
|
||||||
output = "output"
|
|
||||||
text = "text"
|
|
||||||
conversations = "conversations"
|
|
||||||
messages = "messages"
|
|
||||||
|
|
||||||
|
|
||||||
VALID_SCHEMAS_FOR_SCORING = [
|
VALID_SCHEMAS_FOR_SCORING = [
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue