mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-06 02:32:40 +00:00
implemented data loading, preprocessing, and docstrings for FullPrecisionFineTuning
Signed-off-by: James Kunstle <jkunstle@redhat.com>
This commit is contained in:
parent
ddea2aa74f
commit
06465441f2
2 changed files with 134 additions and 26 deletions
|
@ -3,7 +3,7 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from asyncio.subprocess import Process
|
||||
from asyncio import subprocess
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
@ -40,7 +40,7 @@ class TuningJob(pydantic.BaseModel):
|
|||
scheduled_at: datetime | None = None
|
||||
completed_at: datetime | None = None
|
||||
|
||||
background_proc_pid: Process | None = None
|
||||
subproc_ref: subprocess.Process | None = None
|
||||
|
||||
|
||||
class HFilabPostTrainingImpl:
|
||||
|
@ -73,6 +73,10 @@ class HFilabPostTrainingImpl:
|
|||
if self.current_job is not None:
|
||||
self.current_job.status.append(new_status)
|
||||
|
||||
def __set_subproc_ref_callback(self, subproc_ref: subprocess.Process):
|
||||
if self.current_job is not None:
|
||||
self.current_job.subproc_ref = subproc_ref
|
||||
|
||||
async def supervised_fine_tune(
|
||||
self,
|
||||
job_uuid: str,
|
||||
|
@ -100,6 +104,7 @@ class HFilabPostTrainingImpl:
|
|||
datasetsio_api=self.datasetio_api,
|
||||
)
|
||||
|
||||
# This is not a reliable or tidy way to implement the behavior that we want.
|
||||
tasks = BackgroundTasks()
|
||||
tasks.add_task(
|
||||
recipe.load_dataset_from_datasetsio, # asynchronous request
|
||||
|
@ -114,6 +119,7 @@ class HFilabPostTrainingImpl:
|
|||
tasks.add_task(
|
||||
recipe.train, # asynchronous request
|
||||
set_status_callback=self.__set_status_callback,
|
||||
set_subproc_ref_callback=self.__set_subproc_ref_callback,
|
||||
)
|
||||
|
||||
self.current_job = TuningJob(job_uuid=job_uuid, status=[JobStatus.scheduled])
|
||||
|
|
|
@ -1,8 +1,11 @@
|
|||
import asyncio
|
||||
import tempfile
|
||||
import typing
|
||||
from asyncio import subprocess
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
import datasets
|
||||
import transformers
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
|
@ -15,13 +18,20 @@ from llama_stack.apis.post_training.post_training import AlgorithmConfig, Traini
|
|||
|
||||
STORAGE_SUBDIRS = ["checkpoints", "data", "logs", "hf_cache"]
|
||||
VALIDATED_MODEL_ARCHS = ["LlamaForCausalLM", "GraniteForCausalLM"]
|
||||
TMP_DATA_FILE_NAME = "data.jsonl"
|
||||
|
||||
SomePretrainedTokenizer: typing.TypeAlias = PreTrainedTokenizer | PreTrainedTokenizerFast
|
||||
|
||||
# TODO: set HF storage in HF utilities to this object's storage so that we can clean it up automatically
|
||||
# TODO: set up logging utils, replace print statements with logging with the right level
|
||||
|
||||
|
||||
class FullPrecisionFineTuning:
|
||||
# TODO: set HF storage in HF utilities to this object's storage so that we can clean it up automatically
|
||||
# TODO: set up logging utils, replace print statements with logging with the right level
|
||||
"""Implement full-precision (bfloat16) training.
|
||||
|
||||
Uses subprocessing to launch `torchrun` processes for model tuning, SPMD.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
|
@ -64,6 +74,14 @@ class FullPrecisionFineTuning:
|
|||
|
||||
@staticmethod
|
||||
def check_model_arch_validated(model_config: PretrainedConfig) -> bool:
|
||||
"""Check whether input model architecture from config is among the pre-validated architectures.
|
||||
|
||||
Args:
|
||||
model_config (PretrainedConfig): input model config object
|
||||
|
||||
Returns:
|
||||
bool: whether the model architecture is known to work with this training implementation.
|
||||
"""
|
||||
if model_config.architectures is None:
|
||||
return False
|
||||
|
||||
|
@ -74,6 +92,11 @@ class FullPrecisionFineTuning:
|
|||
return False
|
||||
|
||||
def __try_load_config(self) -> PretrainedConfig:
|
||||
"""Attempt to load model config via model's name or path.
|
||||
|
||||
Returns:
|
||||
PretrainedConfig: model config associated with model.
|
||||
"""
|
||||
try:
|
||||
model_config: PretrainedConfig = transformers.AutoConfig.from_pretrained(self.model_name_or_path)
|
||||
except OSError:
|
||||
|
@ -85,6 +108,11 @@ class FullPrecisionFineTuning:
|
|||
return model_config
|
||||
|
||||
def __try_load_tokenizer(self) -> SomePretrainedTokenizer:
|
||||
"""Attempt to load tokenizer via model's name or path.
|
||||
|
||||
Returns:
|
||||
SomePretrainedTokenizer: tokenizer associated with input model name.
|
||||
"""
|
||||
try:
|
||||
tokenizer: SomePretrainedTokenizer = transformers.AutoTokenizer.from_pretrained(
|
||||
self.model_name_or_path, use_fast=True
|
||||
|
@ -99,6 +127,14 @@ class FullPrecisionFineTuning:
|
|||
|
||||
@staticmethod
|
||||
def check_tokenizer_has_chat_template(tokenizer: SomePretrainedTokenizer) -> bool:
|
||||
"""Checks for existence of chat template on tokenizer object.
|
||||
|
||||
Args:
|
||||
tokenizer (SomePretrainedTokenizer): Model tokenizer
|
||||
|
||||
Returns:
|
||||
bool: Whether 'chat_template' instance member exists and is not None.
|
||||
"""
|
||||
if not hasattr(tokenizer, "chat_template"):
|
||||
return False
|
||||
|
||||
|
@ -108,25 +144,36 @@ class FullPrecisionFineTuning:
|
|||
return True
|
||||
|
||||
async def load_dataset_from_datasetsio(self):
|
||||
"""Loads all dataset rows from datasetio API. Sets 'loaded_dataset' in object."""
|
||||
|
||||
dataset = await self.datasetio_api.get_rows_paginated(
|
||||
dataset_id=self.training_config.data_config.dataset_id, rows_in_page=-1
|
||||
)
|
||||
self.loaded_dataset = dataset.rows
|
||||
|
||||
def preflight(self, set_status_callback: Callable[[JobStatus], None]):
|
||||
"""
|
||||
A synchronous "preflight" operation from the recipe runs and does the following checks:
|
||||
1. (future) validates that the host has access to sufficient hardware. For now, assume that an administrator has "cleared" the deployment for any requests it could get. In the future, check:
|
||||
"""Set of checks that should run before any heavier-weight preprocessing runs to validate starting state.
|
||||
|
||||
Checks the following:
|
||||
1. Model config is available from Huggingface by the model's name.
|
||||
2. Model's architecture (from config) is among "validated" architectures.
|
||||
3. Model's tokenizer can be downloaded from Huggingface by the model's name.
|
||||
4. Tokenizer has a 'chat_template' available (we don't currently support BYO chat template).
|
||||
5. A single data sample can successfully be rendered without raising an error.
|
||||
|
||||
In the future, it's be great for this method to also do some system-checking further up the call stack, like:
|
||||
1. Cards exist
|
||||
2. Cards have enough memory
|
||||
3. Cards are idle
|
||||
4. Cards have silicon for functions (bfloat16 tensor cores, support FA)
|
||||
5. Cards are functional
|
||||
2. Validates that model is available from HF.
|
||||
3. Validates that model is a verified architecture (warns user if not).
|
||||
4. Validates that model's tokenizer exists.
|
||||
5. Validates that the model's tokenizer has a chat template.
|
||||
6. Validates that the model's chat template can render a sample from the dataset.
|
||||
|
||||
Args:
|
||||
set_status_callback (Callable[[JobStatus], None]): Sets job status in calling 'Impl' class' ref to this job.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If tokenizer doesn't have chat template available.
|
||||
OSError: Can be raised via this function if config or tokenizer not available via model's name.
|
||||
"""
|
||||
|
||||
model_config = self.__try_load_config()
|
||||
|
@ -143,7 +190,7 @@ class FullPrecisionFineTuning:
|
|||
)
|
||||
|
||||
try:
|
||||
rendered_sample = model_tokenizer.apply_chat_template(self.loaded_dataset[0]["messages"])
|
||||
_ = model_tokenizer.apply_chat_template(self.loaded_dataset[0]["messages"])
|
||||
except Exception:
|
||||
# catching / raising bare exception because 'apply_chat_template' can raise ValueError or TypeError; want to report the same thing regardless.
|
||||
print(
|
||||
|
@ -154,20 +201,75 @@ class FullPrecisionFineTuning:
|
|||
# Success! Preflight checks haven't caught any immediate problems.
|
||||
set_status_callback(JobStatus.scheduled)
|
||||
|
||||
def setup(self):
|
||||
"""
|
||||
A synchronous data preprocessing operation that runs in a kernel-scheduled background thread and does the following:
|
||||
1. Requests all rows of data from datasetsio API
|
||||
2. Ports data into a `huggingface.datasets` object
|
||||
3. Instantiates the model tokenizer
|
||||
4. `dataset.map`'s the input data into chat template format
|
||||
5. generates labels, masks for each sample
|
||||
6. writes dataset to temporary storage
|
||||
"""
|
||||
pass
|
||||
@staticmethod
|
||||
def __tokenize_and_generate_labels_and_mask(
|
||||
tokenizer: SomePretrainedTokenizer,
|
||||
sample: list[dict[typing.Any, typing.Any]], # TODO: type dict correctly.
|
||||
):
|
||||
"""Helper method for preparing a single chat sample for model training.
|
||||
|
||||
async def train(self, set_status_callback: Callable[[JobStatus], None]):
|
||||
Assumed (but not required) to have been called from a `dataset.map()` call.
|
||||
Tokenizes sample using `tokenizer.apply_chat_template()` and uses that output
|
||||
for the associated labels.
|
||||
|
||||
Creates 'attention_mask' and 'loss_mask' of ones (doesn't mask out non-assistant messages).
|
||||
|
||||
Args:
|
||||
tokenizer (SomePretrainedTokenizer): Tokenizer associated with model
|
||||
sample (list[dict[typing.Any, typing.Any]]): Input OpenAI chat conversation dataset sample.
|
||||
|
||||
Returns:
|
||||
dict[str, list[int]]: Of shape {input_ids, labels, attention_mask, loss_mask}
|
||||
"""
|
||||
An asynchronous instance method that creates and watches a `torchrun` subprocess that's training the input model.
|
||||
|
||||
input_ids = tokenizer.apply_chat_template(conversation=sample, tokenize=True)
|
||||
input_ids = typing.cast(
|
||||
list[int], input_ids
|
||||
) # I know what the output will be, and this makes the typing system happy.
|
||||
labels = input_ids[:]
|
||||
attention_mask = [1] * len(labels) # == [1 for _ in range(len(labels))]
|
||||
loss_mask = [1] * len(labels)
|
||||
|
||||
return {"input_ids": input_ids, "labels": labels, "attention_mask": attention_mask, "loss_mask": loss_mask}
|
||||
|
||||
def __setup_data(self):
|
||||
"""Helper method for loading and tokenizing dataset from .get_rows_paginated() API.
|
||||
|
||||
Tokenizes data using special tokens and chat template from model tokenizer.
|
||||
Doesn't do specialized sample preprocessing re: chat message types for loss calculation.
|
||||
|
||||
Returns:
|
||||
datasets.Dataset: Dataset object w/ data prepared for training.
|
||||
"""
|
||||
|
||||
dataset = datasets.Dataset.from_list(self.loaded_dataset)
|
||||
model_tok = self.__try_load_tokenizer()
|
||||
|
||||
# NOTE: not implementing as batched for the moment; need to know how batching impacts memory usage on machine.
|
||||
dataset = dataset.map(lambda x: self.__tokenize_and_generate_labels_and_mask(tokenizer=model_tok, sample=x))
|
||||
return dataset
|
||||
|
||||
def setup(self):
|
||||
"""Data preprocessing to prepare for model training. Writes data to local cache dir to be read by SPMD training processes later."""
|
||||
|
||||
dataset = self.__setup_data()
|
||||
dataset.to_json(path_or_buf=self.data_dir / TMP_DATA_FILE_NAME)
|
||||
|
||||
async def train(
|
||||
self,
|
||||
set_status_callback: Callable[[JobStatus], None],
|
||||
set_subproc_ref_callback: Callable[[subprocess.Process], None],
|
||||
):
|
||||
"""Subprocesses `torchrun` as async function and updates state of calling `Impl` class.
|
||||
|
||||
Args:
|
||||
set_status_callback (Callable[[JobStatus], None]): Sets job status in calling 'Impl' class' ref to this job.
|
||||
set_subproc_ref_callback (Callable[[subprocess.Process], None]): Sets subprocess reference in 'Impl' class' ref to this job
|
||||
"""
|
||||
|
||||
training_subproc = await asyncio.create_subprocess_exec(
|
||||
"echo 'yay Im running in a subprocess: $$'; sleep 30; echo 'exiting process $$'"
|
||||
)
|
||||
set_subproc_ref_callback(training_subproc)
|
||||
await training_subproc.wait()
|
||||
set_status_callback(JobStatus.completed)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue