diff --git a/llama_stack/apis/common/type_system.py b/llama_stack/apis/common/type_system.py index a653efef9..e76cfde13 100644 --- a/llama_stack/apis/common/type_system.py +++ b/llama_stack/apis/common/type_system.py @@ -54,6 +54,12 @@ class AgentTurnInputType(BaseModel): type: Literal["agent_turn_input"] = "agent_turn_input" +class DialogType(BaseModel): + # expects List[Message] for messages + # this type semantically contains the output label whereas ChatCompletionInputType does not + type: Literal["dialog"] = "dialog" + + ParamType = register_schema( Annotated[ Union[ diff --git a/llama_stack/apis/post_training/post_training.py b/llama_stack/apis/post_training/post_training.py index 8e1edbe87..8841dc1d0 100644 --- a/llama_stack/apis/post_training/post_training.py +++ b/llama_stack/apis/post_training/post_training.py @@ -27,11 +27,18 @@ class OptimizerType(Enum): sgd = "sgd" +@json_schema_type +class DatasetFormat(Enum): + instruct = "instruct" + dialog = "dialog" + + @json_schema_type class DataConfig(BaseModel): dataset_id: str batch_size: int shuffle: bool + data_format: DatasetFormat validation_dataset_id: Optional[str] = None packed: Optional[bool] = False train_on_input: Optional[bool] = False diff --git a/llama_stack/providers/inline/post_training/common/__init__.py b/llama_stack/providers/inline/post_training/common/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/llama_stack/providers/inline/post_training/common/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/llama_stack/providers/inline/post_training/common/validator.py b/llama_stack/providers/inline/post_training/common/validator.py new file mode 100644 index 000000000..836e20c85 --- /dev/null +++ b/llama_stack/providers/inline/post_training/common/validator.py @@ -0,0 +1,52 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, IAny, nc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +from llama_stack.apis.common.type_system import ( + ChatCompletionInputType, + DialogType, + StringType, +) +from llama_stack.apis.datasets import Datasets +from llama_stack.providers.utils.common.data_schema_validator import ( + ColumnName, + validate_dataset_schema, +) + +EXPECTED_DATASET_SCHEMA = { + "instruct": [ + { + ColumnName.chat_completion_input.value: ChatCompletionInputType(), + ColumnName.expected_answer.value: StringType(), + } + ], + "dialog": [ + { + ColumnName.dialog.value: DialogType(), + } + ], +} + + +async def validate_input_dataset_schema( + datasets_api: Datasets, + dataset_id: str, + dataset_type: str, +) -> None: + dataset_def = await datasets_api.get_dataset(dataset_id=dataset_id) + if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0: + raise ValueError(f"Dataset {dataset_id} does not have a schema defined.") + + if dataset_type not in EXPECTED_DATASET_SCHEMA: + raise ValueError(f"Dataset type {dataset_type} is not supported.") + + validate_dataset_schema( + dataset_def.dataset_schema, EXPECTED_DATASET_SCHEMA[dataset_type] + ) diff --git a/llama_stack/providers/inline/post_training/torchtune/common/utils.py b/llama_stack/providers/inline/post_training/torchtune/common/utils.py index b4cd43770..88011ead4 100644 --- a/llama_stack/providers/inline/post_training/torchtune/common/utils.py +++ b/llama_stack/providers/inline/post_training/torchtune/common/utils.py @@ -10,29 +10,22 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from enum import Enum -from typing import Any, Callable, Dict, List +from typing import Any, Callable, Dict import torch from llama_models.datatypes import Model from llama_models.sku_list import resolve_model from pydantic import BaseModel +from torchtune.data._messages import InputOutputToMessages, ShareGPTToMessages from torchtune.models.llama3 import llama3_tokenizer from torchtune.models.llama3._tokenizer import Llama3Tokenizer from torchtune.models.llama3_1 import lora_llama3_1_8b from torchtune.models.llama3_2 import lora_llama3_2_3b +from torchtune.modules.transforms import Transform -from llama_stack.apis.common.type_system import ParamType, StringType -from llama_stack.apis.datasets import Datasets - - -class ColumnName(Enum): - instruction = "instruction" - input = "input" - output = "output" - text = "text" +from llama_stack.apis.post_training import DatasetFormat class ModelConfig(BaseModel): @@ -41,10 +34,6 @@ class ModelConfig(BaseModel): checkpoint_type: str -class DatasetSchema(BaseModel): - alpaca: List[Dict[str, ParamType]] - - MODEL_CONFIGS: Dict[str, ModelConfig] = { "Llama3.2-3B-Instruct": ModelConfig( model_definition=lora_llama3_2_3b, @@ -58,26 +47,11 @@ MODEL_CONFIGS: Dict[str, ModelConfig] = { ), } +DATA_FORMATS: Dict[str, Transform] = { + "instruct": InputOutputToMessages, + "dialog": ShareGPTToMessages, +} -EXPECTED_DATASET_SCHEMA = DatasetSchema( - alpaca=[ - { - ColumnName.instruction.value: StringType(), - ColumnName.input.value: StringType(), - ColumnName.output.value: StringType(), - ColumnName.text.value: StringType(), - }, - { - ColumnName.instruction.value: StringType(), - ColumnName.input.value: StringType(), - ColumnName.output.value: StringType(), - }, - { - ColumnName.instruction.value: StringType(), - ColumnName.output.value: StringType(), - }, - ] -) BuildLoraModelCallable = Callable[..., torch.nn.Module] BuildTokenizerCallable = Callable[..., Llama3Tokenizer] @@ -124,19 +98,5 @@ async def get_checkpointer_model_type( return model_config.checkpoint_type -async def validate_input_dataset_schema( - datasets_api: Datasets, - dataset_id: str, - dataset_type: str, -) -> None: - dataset_def = await datasets_api.get_dataset(dataset_id=dataset_id) - if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0: - raise ValueError(f"Dataset {dataset_id} does not have a schema defined.") - - if not hasattr(EXPECTED_DATASET_SCHEMA, dataset_type): - raise ValueError(f"Dataset type {dataset_type} is not supported.") - - if dataset_def.dataset_schema not in getattr(EXPECTED_DATASET_SCHEMA, dataset_type): - raise ValueError( - f"Dataset {dataset_id} does not have a correct input schema in {getattr(EXPECTED_DATASET_SCHEMA, dataset_type)}" - ) +async def get_data_transform(data_format: DatasetFormat) -> Transform: + return DATA_FORMATS[data_format.value] diff --git a/llama_stack/providers/inline/post_training/torchtune/datasets/format_adapter.py b/llama_stack/providers/inline/post_training/torchtune/datasets/format_adapter.py new file mode 100644 index 000000000..b4dfbb3c1 --- /dev/null +++ b/llama_stack/providers/inline/post_training/torchtune/datasets/format_adapter.py @@ -0,0 +1,62 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Mapping + +from llama_stack.providers.utils.common.data_schema_validator import ColumnName + + +def llama_stack_instruct_to_torchtune_instruct( + sample: Mapping[str, Any] +) -> Mapping[str, Any]: + assert ( + ColumnName.chat_completion_input.value in sample + and ColumnName.expected_answer.value in sample + ), "Invalid input row" + input_messages = eval(str(sample[ColumnName.chat_completion_input.value])) + + assert ( + len(input_messages) == 1 + ), "llama stack intruct dataset format only supports 1 user message" + input_message = input_messages[0] + + assert "content" in input_message, "content not found in input message" + input = input_message["content"] + output = sample[ColumnName.expected_answer.value] + + return { + "input": input, + "output": output, + } + + +def llama_stack_chat_to_torchtune_chat(sample: Mapping[str, Any]) -> Mapping[str, Any]: + assert ColumnName.dialog.value in sample, "Invalid input row" + role_map = {"user": "human", "assistant": "gpt"} + dialog = eval(str(sample[ColumnName.dialog.value])) + + assert len(dialog) > 1, "dialog must have at least 2 messagse" + roles = [] + conversations = [] + for message in dialog: + assert ( + "role" in message and "content" in message + ), "role and content must in message" + roles.append(message["role"]) + conversations.append( + {"from": role_map[message["role"]], "value": message["content"]} + ) + + assert roles[0] == "user", "first message must be from user" + assert "assistant" in roles, "at least 1 message should be from assistant" + + return {"conversations": conversations} diff --git a/llama_stack/providers/inline/post_training/torchtune/datasets/sft.py b/llama_stack/providers/inline/post_training/torchtune/datasets/sft.py index 1f91dc73f..1a5aade09 100644 --- a/llama_stack/providers/inline/post_training/torchtune/datasets/sft.py +++ b/llama_stack/providers/inline/post_training/torchtune/datasets/sft.py @@ -19,6 +19,11 @@ from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX from torchtune.data._messages import validate_messages from torchtune.modules.transforms import Transform +from llama_stack.providers.inline.post_training.torchtune.datasets.format_adapter import ( + llama_stack_chat_to_torchtune_chat, + llama_stack_instruct_to_torchtune_instruct, +) + class SFTDataset(Dataset): def __init__( @@ -26,10 +31,12 @@ class SFTDataset(Dataset): rows: List[Dict[str, Any]], message_transform: Transform, model_transform: Transform, + dataset_type: str, ) -> None: self._rows = rows self._message_transform = message_transform self._model_transform = model_transform + self._dataset_type = dataset_type def __len__(self): return len(self._rows) @@ -39,6 +46,12 @@ class SFTDataset(Dataset): return self._prepare_sample(sample) def _prepare_sample(self, sample: Mapping[str, Any]) -> Dict[str, Any]: + if self._dataset_type == "instruct": + sample = llama_stack_instruct_to_torchtune_instruct(sample) + elif self._dataset_type == "dialog": + sample = llama_stack_chat_to_torchtune_chat(sample) + else: + raise ValueError(f"Invalid dataset type: {self._dataset_type}") transformed_sample = self._message_transform(sample) if "messages" in transformed_sample: validate_messages(transformed_sample["messages"]) diff --git a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py index 6c795d310..7543b1f4e 100644 --- a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +++ b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py @@ -18,7 +18,7 @@ from torch import nn from torch.optim import Optimizer from torch.utils.data import DataLoader, DistributedSampler from torchtune import modules, training, utils as torchtune_utils -from torchtune.data import AlpacaToMessages, padded_collate_sft +from torchtune.data import padded_collate_sft from torchtune.modules.loss import CEWithChunkedOutputLoss from torchtune.modules.peft import ( @@ -47,6 +47,9 @@ from llama_stack.apis.post_training import ( from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR from llama_stack.distribution.utils.model_utils import model_local_dir +from llama_stack.providers.inline.post_training.common.validator import ( + validate_input_dataset_schema, +) from llama_stack.providers.inline.post_training.torchtune.common import utils from llama_stack.providers.inline.post_training.torchtune.common.checkpointer import ( @@ -129,8 +132,10 @@ class LoraFinetuningSingleDevice: self.seed = training.set_seed(seed=config.torch_seed) self.epochs_run = 0 self.total_epochs = training_config.n_epochs + self._data_format = training_config.data_config.data_format self._shuffle = training_config.data_config.shuffle self._batch_size = training_config.data_config.batch_size + self._train_on_input = training_config.data_config.train_on_input # this is important for debugging purpose self.max_steps_per_epoch = training_config.max_steps_per_epoch @@ -354,18 +359,17 @@ class LoraFinetuningSingleDevice: all_rows = await fetch_rows(dataset_id) rows = all_rows.rows - # Curretly only support alpaca instruct dataset - # TODO @SLR722 make the message_transform swappable and support more dataset types - # TODO @SLR722 make the input dataset schema more flexible by exposing column_map - await utils.validate_input_dataset_schema( + await validate_input_dataset_schema( datasets_api=self.datasets_api, dataset_id=dataset_id, - dataset_type="alpaca", + dataset_type=self._data_format.value, ) + data_transform = await utils.get_data_transform(self._data_format) ds = SFTDataset( rows, - message_transform=AlpacaToMessages(train_on_input=False), + message_transform=data_transform(train_on_input=self._train_on_input), model_transform=tokenizer, + dataset_type=self._data_format.value, ) sampler = DistributedSampler( diff --git a/llama_stack/providers/utils/common/data_schema_validator.py b/llama_stack/providers/utils/common/data_schema_validator.py index af58a4592..55f1078a4 100644 --- a/llama_stack/providers/utils/common/data_schema_validator.py +++ b/llama_stack/providers/utils/common/data_schema_validator.py @@ -23,6 +23,7 @@ class ColumnName(Enum): completion_input = "completion_input" generated_answer = "generated_answer" context = "context" + dialog = "dialog" VALID_SCHEMAS_FOR_SCORING = [ diff --git a/llama_stack/templates/experimental-post-training/build.yaml b/llama_stack/templates/experimental-post-training/build.yaml index aa7695bca..e04868199 100644 --- a/llama_stack/templates/experimental-post-training/build.yaml +++ b/llama_stack/templates/experimental-post-training/build.yaml @@ -13,6 +13,7 @@ distribution_spec: post_training: - inline::torchtune datasetio: + - inline::localfs - remote::huggingface telemetry: - inline::meta-reference @@ -22,4 +23,6 @@ distribution_spec: - inline::llama-guard memory: - inline::faiss + tool_runtime: + - remote::brave-search image_type: conda diff --git a/llama_stack/templates/experimental-post-training/run.yaml b/llama_stack/templates/experimental-post-training/run.yaml index a654c375e..4a7bb5c47 100644 --- a/llama_stack/templates/experimental-post-training/run.yaml +++ b/llama_stack/templates/experimental-post-training/run.yaml @@ -12,6 +12,7 @@ apis: - scoring - telemetry - post_training +- tool_runtime providers: inference: - provider_id: meta-reference-inference @@ -32,6 +33,9 @@ providers: - provider_id: huggingface-0 provider_type: remote::huggingface config: {} + - provider_id: localfs + provider_type: inline::localfs + config: {} telemetry: - provider_id: meta-reference provider_type: inline::meta-reference @@ -60,6 +64,13 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/faiss_store.db + tool_runtime: + - provider_id: brave-search + provider_type: remote::brave-search + config: + api_key: ${env.BRAVE_SEARCH_API_KEY:} + max_results: 3 + metadata_store: namespace: null @@ -68,23 +79,6 @@ metadata_store: models: [] shields: [] memory_banks: [] -datasets: - - dataset_id: alpaca - provider_id: huggingface-0 - url: - uri: https://huggingface.co/datasets/tatsu-lab/alpaca - metadata: - path: tatsu-lab/alpaca - name: - split: train - dataset_schema: - instruction: - type: string - input: - type: string - output: - type: string - text: - type: string +datasets: [] scoring_fns: [] eval_tasks: []