implemented data loading, preprocessing, and docstrings for FullPrecisionFineTuning

Signed-off-by: James Kunstle <jkunstle@redhat.com>
This commit is contained in:
James Kunstle 2025-03-13 01:21:27 -07:00
parent ddea2aa74f
commit 06465441f2
2 changed files with 134 additions and 26 deletions

View file

@ -3,7 +3,7 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from asyncio.subprocess import Process
from asyncio import subprocess
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, Optional
@ -40,7 +40,7 @@ class TuningJob(pydantic.BaseModel):
scheduled_at: datetime | None = None
completed_at: datetime | None = None
background_proc_pid: Process | None = None
subproc_ref: subprocess.Process | None = None
class HFilabPostTrainingImpl:
@ -73,6 +73,10 @@ class HFilabPostTrainingImpl:
if self.current_job is not None:
self.current_job.status.append(new_status)
def __set_subproc_ref_callback(self, subproc_ref: subprocess.Process):
if self.current_job is not None:
self.current_job.subproc_ref = subproc_ref
async def supervised_fine_tune(
self,
job_uuid: str,
@ -100,6 +104,7 @@ class HFilabPostTrainingImpl:
datasetsio_api=self.datasetio_api,
)
# This is not a reliable or tidy way to implement the behavior that we want.
tasks = BackgroundTasks()
tasks.add_task(
recipe.load_dataset_from_datasetsio, # asynchronous request
@ -114,6 +119,7 @@ class HFilabPostTrainingImpl:
tasks.add_task(
recipe.train, # asynchronous request
set_status_callback=self.__set_status_callback,
set_subproc_ref_callback=self.__set_subproc_ref_callback,
)
self.current_job = TuningJob(job_uuid=job_uuid, status=[JobStatus.scheduled])

View file

@ -1,8 +1,11 @@
import asyncio
import tempfile
import typing
from asyncio import subprocess
from pathlib import Path
from typing import Callable
import datasets
import transformers
from transformers.configuration_utils import PretrainedConfig
from transformers.tokenization_utils import PreTrainedTokenizer
@ -15,13 +18,20 @@ from llama_stack.apis.post_training.post_training import AlgorithmConfig, Traini
STORAGE_SUBDIRS = ["checkpoints", "data", "logs", "hf_cache"]
VALIDATED_MODEL_ARCHS = ["LlamaForCausalLM", "GraniteForCausalLM"]
TMP_DATA_FILE_NAME = "data.jsonl"
SomePretrainedTokenizer: typing.TypeAlias = PreTrainedTokenizer | PreTrainedTokenizerFast
# TODO: set HF storage in HF utilities to this object's storage so that we can clean it up automatically
# TODO: set up logging utils, replace print statements with logging with the right level
class FullPrecisionFineTuning:
# TODO: set HF storage in HF utilities to this object's storage so that we can clean it up automatically
# TODO: set up logging utils, replace print statements with logging with the right level
"""Implement full-precision (bfloat16) training.
Uses subprocessing to launch `torchrun` processes for model tuning, SPMD.
"""
def __init__(
self,
model: str,
@ -64,6 +74,14 @@ class FullPrecisionFineTuning:
@staticmethod
def check_model_arch_validated(model_config: PretrainedConfig) -> bool:
"""Check whether input model architecture from config is among the pre-validated architectures.
Args:
model_config (PretrainedConfig): input model config object
Returns:
bool: whether the model architecture is known to work with this training implementation.
"""
if model_config.architectures is None:
return False
@ -74,6 +92,11 @@ class FullPrecisionFineTuning:
return False
def __try_load_config(self) -> PretrainedConfig:
"""Attempt to load model config via model's name or path.
Returns:
PretrainedConfig: model config associated with model.
"""
try:
model_config: PretrainedConfig = transformers.AutoConfig.from_pretrained(self.model_name_or_path)
except OSError:
@ -85,6 +108,11 @@ class FullPrecisionFineTuning:
return model_config
def __try_load_tokenizer(self) -> SomePretrainedTokenizer:
"""Attempt to load tokenizer via model's name or path.
Returns:
SomePretrainedTokenizer: tokenizer associated with input model name.
"""
try:
tokenizer: SomePretrainedTokenizer = transformers.AutoTokenizer.from_pretrained(
self.model_name_or_path, use_fast=True
@ -99,6 +127,14 @@ class FullPrecisionFineTuning:
@staticmethod
def check_tokenizer_has_chat_template(tokenizer: SomePretrainedTokenizer) -> bool:
"""Checks for existence of chat template on tokenizer object.
Args:
tokenizer (SomePretrainedTokenizer): Model tokenizer
Returns:
bool: Whether 'chat_template' instance member exists and is not None.
"""
if not hasattr(tokenizer, "chat_template"):
return False
@ -108,25 +144,36 @@ class FullPrecisionFineTuning:
return True
async def load_dataset_from_datasetsio(self):
"""Loads all dataset rows from datasetio API. Sets 'loaded_dataset' in object."""
dataset = await self.datasetio_api.get_rows_paginated(
dataset_id=self.training_config.data_config.dataset_id, rows_in_page=-1
)
self.loaded_dataset = dataset.rows
def preflight(self, set_status_callback: Callable[[JobStatus], None]):
"""
A synchronous "preflight" operation from the recipe runs and does the following checks:
1. (future) validates that the host has access to sufficient hardware. For now, assume that an administrator has "cleared" the deployment for any requests it could get. In the future, check:
"""Set of checks that should run before any heavier-weight preprocessing runs to validate starting state.
Checks the following:
1. Model config is available from Huggingface by the model's name.
2. Model's architecture (from config) is among "validated" architectures.
3. Model's tokenizer can be downloaded from Huggingface by the model's name.
4. Tokenizer has a 'chat_template' available (we don't currently support BYO chat template).
5. A single data sample can successfully be rendered without raising an error.
In the future, it's be great for this method to also do some system-checking further up the call stack, like:
1. Cards exist
2. Cards have enough memory
3. Cards are idle
4. Cards have silicon for functions (bfloat16 tensor cores, support FA)
5. Cards are functional
2. Validates that model is available from HF.
3. Validates that model is a verified architecture (warns user if not).
4. Validates that model's tokenizer exists.
5. Validates that the model's tokenizer has a chat template.
6. Validates that the model's chat template can render a sample from the dataset.
Args:
set_status_callback (Callable[[JobStatus], None]): Sets job status in calling 'Impl' class' ref to this job.
Raises:
RuntimeError: If tokenizer doesn't have chat template available.
OSError: Can be raised via this function if config or tokenizer not available via model's name.
"""
model_config = self.__try_load_config()
@ -143,7 +190,7 @@ class FullPrecisionFineTuning:
)
try:
rendered_sample = model_tokenizer.apply_chat_template(self.loaded_dataset[0]["messages"])
_ = model_tokenizer.apply_chat_template(self.loaded_dataset[0]["messages"])
except Exception:
# catching / raising bare exception because 'apply_chat_template' can raise ValueError or TypeError; want to report the same thing regardless.
print(
@ -154,20 +201,75 @@ class FullPrecisionFineTuning:
# Success! Preflight checks haven't caught any immediate problems.
set_status_callback(JobStatus.scheduled)
def setup(self):
"""
A synchronous data preprocessing operation that runs in a kernel-scheduled background thread and does the following:
1. Requests all rows of data from datasetsio API
2. Ports data into a `huggingface.datasets` object
3. Instantiates the model tokenizer
4. `dataset.map`'s the input data into chat template format
5. generates labels, masks for each sample
6. writes dataset to temporary storage
"""
pass
@staticmethod
def __tokenize_and_generate_labels_and_mask(
tokenizer: SomePretrainedTokenizer,
sample: list[dict[typing.Any, typing.Any]], # TODO: type dict correctly.
):
"""Helper method for preparing a single chat sample for model training.
async def train(self, set_status_callback: Callable[[JobStatus], None]):
Assumed (but not required) to have been called from a `dataset.map()` call.
Tokenizes sample using `tokenizer.apply_chat_template()` and uses that output
for the associated labels.
Creates 'attention_mask' and 'loss_mask' of ones (doesn't mask out non-assistant messages).
Args:
tokenizer (SomePretrainedTokenizer): Tokenizer associated with model
sample (list[dict[typing.Any, typing.Any]]): Input OpenAI chat conversation dataset sample.
Returns:
dict[str, list[int]]: Of shape {input_ids, labels, attention_mask, loss_mask}
"""
An asynchronous instance method that creates and watches a `torchrun` subprocess that's training the input model.
input_ids = tokenizer.apply_chat_template(conversation=sample, tokenize=True)
input_ids = typing.cast(
list[int], input_ids
) # I know what the output will be, and this makes the typing system happy.
labels = input_ids[:]
attention_mask = [1] * len(labels) # == [1 for _ in range(len(labels))]
loss_mask = [1] * len(labels)
return {"input_ids": input_ids, "labels": labels, "attention_mask": attention_mask, "loss_mask": loss_mask}
def __setup_data(self):
"""Helper method for loading and tokenizing dataset from .get_rows_paginated() API.
Tokenizes data using special tokens and chat template from model tokenizer.
Doesn't do specialized sample preprocessing re: chat message types for loss calculation.
Returns:
datasets.Dataset: Dataset object w/ data prepared for training.
"""
dataset = datasets.Dataset.from_list(self.loaded_dataset)
model_tok = self.__try_load_tokenizer()
# NOTE: not implementing as batched for the moment; need to know how batching impacts memory usage on machine.
dataset = dataset.map(lambda x: self.__tokenize_and_generate_labels_and_mask(tokenizer=model_tok, sample=x))
return dataset
def setup(self):
"""Data preprocessing to prepare for model training. Writes data to local cache dir to be read by SPMD training processes later."""
dataset = self.__setup_data()
dataset.to_json(path_or_buf=self.data_dir / TMP_DATA_FILE_NAME)
async def train(
self,
set_status_callback: Callable[[JobStatus], None],
set_subproc_ref_callback: Callable[[subprocess.Process], None],
):
"""Subprocesses `torchrun` as async function and updates state of calling `Impl` class.
Args:
set_status_callback (Callable[[JobStatus], None]): Sets job status in calling 'Impl' class' ref to this job.
set_subproc_ref_callback (Callable[[subprocess.Process], None]): Sets subprocess reference in 'Impl' class' ref to this job
"""
training_subproc = await asyncio.create_subprocess_exec(
"echo 'yay Im running in a subprocess: $$'; sleep 30; echo 'exiting process $$'"
)
set_subproc_ref_callback(training_subproc)
await training_subproc.wait()
set_status_callback(JobStatus.completed)