From 06465441f23a3a6959ef8ec80013a6557744d703 Mon Sep 17 00:00:00 2001 From: James Kunstle Date: Thu, 13 Mar 2025 01:21:27 -0700 Subject: [PATCH] implemented data loading, preprocessing, and docstrings for FullPrecisionFineTuning Signed-off-by: James Kunstle --- .../huggingface_ilab/post_training.py | 10 +- .../fullprecision_finetuning_multi_device.py | 150 +++++++++++++++--- 2 files changed, 134 insertions(+), 26 deletions(-) diff --git a/llama_stack/providers/inline/post_training/huggingface_ilab/post_training.py b/llama_stack/providers/inline/post_training/huggingface_ilab/post_training.py index 00045d7c1..6f8f39e4b 100644 --- a/llama_stack/providers/inline/post_training/huggingface_ilab/post_training.py +++ b/llama_stack/providers/inline/post_training/huggingface_ilab/post_training.py @@ -3,7 +3,7 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from asyncio.subprocess import Process +from asyncio import subprocess from datetime import datetime from pathlib import Path from typing import Any, Dict, Optional @@ -40,7 +40,7 @@ class TuningJob(pydantic.BaseModel): scheduled_at: datetime | None = None completed_at: datetime | None = None - background_proc_pid: Process | None = None + subproc_ref: subprocess.Process | None = None class HFilabPostTrainingImpl: @@ -73,6 +73,10 @@ class HFilabPostTrainingImpl: if self.current_job is not None: self.current_job.status.append(new_status) + def __set_subproc_ref_callback(self, subproc_ref: subprocess.Process): + if self.current_job is not None: + self.current_job.subproc_ref = subproc_ref + async def supervised_fine_tune( self, job_uuid: str, @@ -100,6 +104,7 @@ class HFilabPostTrainingImpl: datasetsio_api=self.datasetio_api, ) + # This is not a reliable or tidy way to implement the behavior that we want. tasks = BackgroundTasks() tasks.add_task( recipe.load_dataset_from_datasetsio, # asynchronous request @@ -114,6 +119,7 @@ class HFilabPostTrainingImpl: tasks.add_task( recipe.train, # asynchronous request set_status_callback=self.__set_status_callback, + set_subproc_ref_callback=self.__set_subproc_ref_callback, ) self.current_job = TuningJob(job_uuid=job_uuid, status=[JobStatus.scheduled]) diff --git a/llama_stack/providers/inline/post_training/huggingface_ilab/recipes/fullprecision_finetuning_multi_device.py b/llama_stack/providers/inline/post_training/huggingface_ilab/recipes/fullprecision_finetuning_multi_device.py index 665236b0c..c810ae5b7 100644 --- a/llama_stack/providers/inline/post_training/huggingface_ilab/recipes/fullprecision_finetuning_multi_device.py +++ b/llama_stack/providers/inline/post_training/huggingface_ilab/recipes/fullprecision_finetuning_multi_device.py @@ -1,8 +1,11 @@ +import asyncio import tempfile import typing +from asyncio import subprocess from pathlib import Path from typing import Callable +import datasets import transformers from transformers.configuration_utils import PretrainedConfig from transformers.tokenization_utils import PreTrainedTokenizer @@ -15,13 +18,20 @@ from llama_stack.apis.post_training.post_training import AlgorithmConfig, Traini STORAGE_SUBDIRS = ["checkpoints", "data", "logs", "hf_cache"] VALIDATED_MODEL_ARCHS = ["LlamaForCausalLM", "GraniteForCausalLM"] +TMP_DATA_FILE_NAME = "data.jsonl" SomePretrainedTokenizer: typing.TypeAlias = PreTrainedTokenizer | PreTrainedTokenizerFast +# TODO: set HF storage in HF utilities to this object's storage so that we can clean it up automatically +# TODO: set up logging utils, replace print statements with logging with the right level + class FullPrecisionFineTuning: - # TODO: set HF storage in HF utilities to this object's storage so that we can clean it up automatically - # TODO: set up logging utils, replace print statements with logging with the right level + """Implement full-precision (bfloat16) training. + + Uses subprocessing to launch `torchrun` processes for model tuning, SPMD. + """ + def __init__( self, model: str, @@ -64,6 +74,14 @@ class FullPrecisionFineTuning: @staticmethod def check_model_arch_validated(model_config: PretrainedConfig) -> bool: + """Check whether input model architecture from config is among the pre-validated architectures. + + Args: + model_config (PretrainedConfig): input model config object + + Returns: + bool: whether the model architecture is known to work with this training implementation. + """ if model_config.architectures is None: return False @@ -74,6 +92,11 @@ class FullPrecisionFineTuning: return False def __try_load_config(self) -> PretrainedConfig: + """Attempt to load model config via model's name or path. + + Returns: + PretrainedConfig: model config associated with model. + """ try: model_config: PretrainedConfig = transformers.AutoConfig.from_pretrained(self.model_name_or_path) except OSError: @@ -85,6 +108,11 @@ class FullPrecisionFineTuning: return model_config def __try_load_tokenizer(self) -> SomePretrainedTokenizer: + """Attempt to load tokenizer via model's name or path. + + Returns: + SomePretrainedTokenizer: tokenizer associated with input model name. + """ try: tokenizer: SomePretrainedTokenizer = transformers.AutoTokenizer.from_pretrained( self.model_name_or_path, use_fast=True @@ -99,6 +127,14 @@ class FullPrecisionFineTuning: @staticmethod def check_tokenizer_has_chat_template(tokenizer: SomePretrainedTokenizer) -> bool: + """Checks for existence of chat template on tokenizer object. + + Args: + tokenizer (SomePretrainedTokenizer): Model tokenizer + + Returns: + bool: Whether 'chat_template' instance member exists and is not None. + """ if not hasattr(tokenizer, "chat_template"): return False @@ -108,25 +144,36 @@ class FullPrecisionFineTuning: return True async def load_dataset_from_datasetsio(self): + """Loads all dataset rows from datasetio API. Sets 'loaded_dataset' in object.""" + dataset = await self.datasetio_api.get_rows_paginated( dataset_id=self.training_config.data_config.dataset_id, rows_in_page=-1 ) self.loaded_dataset = dataset.rows def preflight(self, set_status_callback: Callable[[JobStatus], None]): - """ - A synchronous "preflight" operation from the recipe runs and does the following checks: - 1. (future) validates that the host has access to sufficient hardware. For now, assume that an administrator has "cleared" the deployment for any requests it could get. In the future, check: + """Set of checks that should run before any heavier-weight preprocessing runs to validate starting state. + + Checks the following: + 1. Model config is available from Huggingface by the model's name. + 2. Model's architecture (from config) is among "validated" architectures. + 3. Model's tokenizer can be downloaded from Huggingface by the model's name. + 4. Tokenizer has a 'chat_template' available (we don't currently support BYO chat template). + 5. A single data sample can successfully be rendered without raising an error. + + In the future, it's be great for this method to also do some system-checking further up the call stack, like: 1. Cards exist 2. Cards have enough memory 3. Cards are idle 4. Cards have silicon for functions (bfloat16 tensor cores, support FA) 5. Cards are functional - 2. Validates that model is available from HF. - 3. Validates that model is a verified architecture (warns user if not). - 4. Validates that model's tokenizer exists. - 5. Validates that the model's tokenizer has a chat template. - 6. Validates that the model's chat template can render a sample from the dataset. + + Args: + set_status_callback (Callable[[JobStatus], None]): Sets job status in calling 'Impl' class' ref to this job. + + Raises: + RuntimeError: If tokenizer doesn't have chat template available. + OSError: Can be raised via this function if config or tokenizer not available via model's name. """ model_config = self.__try_load_config() @@ -143,7 +190,7 @@ class FullPrecisionFineTuning: ) try: - rendered_sample = model_tokenizer.apply_chat_template(self.loaded_dataset[0]["messages"]) + _ = model_tokenizer.apply_chat_template(self.loaded_dataset[0]["messages"]) except Exception: # catching / raising bare exception because 'apply_chat_template' can raise ValueError or TypeError; want to report the same thing regardless. print( @@ -154,20 +201,75 @@ class FullPrecisionFineTuning: # Success! Preflight checks haven't caught any immediate problems. set_status_callback(JobStatus.scheduled) - def setup(self): - """ - A synchronous data preprocessing operation that runs in a kernel-scheduled background thread and does the following: - 1. Requests all rows of data from datasetsio API - 2. Ports data into a `huggingface.datasets` object - 3. Instantiates the model tokenizer - 4. `dataset.map`'s the input data into chat template format - 5. generates labels, masks for each sample - 6. writes dataset to temporary storage - """ - pass + @staticmethod + def __tokenize_and_generate_labels_and_mask( + tokenizer: SomePretrainedTokenizer, + sample: list[dict[typing.Any, typing.Any]], # TODO: type dict correctly. + ): + """Helper method for preparing a single chat sample for model training. - async def train(self, set_status_callback: Callable[[JobStatus], None]): + Assumed (but not required) to have been called from a `dataset.map()` call. + Tokenizes sample using `tokenizer.apply_chat_template()` and uses that output + for the associated labels. + + Creates 'attention_mask' and 'loss_mask' of ones (doesn't mask out non-assistant messages). + + Args: + tokenizer (SomePretrainedTokenizer): Tokenizer associated with model + sample (list[dict[typing.Any, typing.Any]]): Input OpenAI chat conversation dataset sample. + + Returns: + dict[str, list[int]]: Of shape {input_ids, labels, attention_mask, loss_mask} """ - An asynchronous instance method that creates and watches a `torchrun` subprocess that's training the input model. + + input_ids = tokenizer.apply_chat_template(conversation=sample, tokenize=True) + input_ids = typing.cast( + list[int], input_ids + ) # I know what the output will be, and this makes the typing system happy. + labels = input_ids[:] + attention_mask = [1] * len(labels) # == [1 for _ in range(len(labels))] + loss_mask = [1] * len(labels) + + return {"input_ids": input_ids, "labels": labels, "attention_mask": attention_mask, "loss_mask": loss_mask} + + def __setup_data(self): + """Helper method for loading and tokenizing dataset from .get_rows_paginated() API. + + Tokenizes data using special tokens and chat template from model tokenizer. + Doesn't do specialized sample preprocessing re: chat message types for loss calculation. + + Returns: + datasets.Dataset: Dataset object w/ data prepared for training. """ + + dataset = datasets.Dataset.from_list(self.loaded_dataset) + model_tok = self.__try_load_tokenizer() + + # NOTE: not implementing as batched for the moment; need to know how batching impacts memory usage on machine. + dataset = dataset.map(lambda x: self.__tokenize_and_generate_labels_and_mask(tokenizer=model_tok, sample=x)) + return dataset + + def setup(self): + """Data preprocessing to prepare for model training. Writes data to local cache dir to be read by SPMD training processes later.""" + + dataset = self.__setup_data() + dataset.to_json(path_or_buf=self.data_dir / TMP_DATA_FILE_NAME) + + async def train( + self, + set_status_callback: Callable[[JobStatus], None], + set_subproc_ref_callback: Callable[[subprocess.Process], None], + ): + """Subprocesses `torchrun` as async function and updates state of calling `Impl` class. + + Args: + set_status_callback (Callable[[JobStatus], None]): Sets job status in calling 'Impl' class' ref to this job. + set_subproc_ref_callback (Callable[[subprocess.Process], None]): Sets subprocess reference in 'Impl' class' ref to this job + """ + + training_subproc = await asyncio.create_subprocess_exec( + "echo 'yay Im running in a subprocess: $$'; sleep 30; echo 'exiting process $$'" + ) + set_subproc_ref_callback(training_subproc) + await training_subproc.wait() set_status_callback(JobStatus.completed)