mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-06 02:32:40 +00:00
implemented data loading and preflight of FullPrecisionFineTuning
Signed-off-by: James Kunstle <jkunstle@redhat.com>
This commit is contained in:
parent
9698c14e07
commit
ddea2aa74f
3 changed files with 246 additions and 47 deletions
|
@ -3,9 +3,17 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from asyncio.subprocess import Process
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import fastapi
|
||||
import fastapi.concurrency
|
||||
import pydantic
|
||||
from starlette.background import BackgroundTasks
|
||||
from starlette.responses import JSONResponse
|
||||
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.post_training import (
|
||||
|
@ -13,16 +21,28 @@ from llama_stack.apis.post_training import (
|
|||
DPOAlignmentConfig,
|
||||
JobStatus,
|
||||
ListPostTrainingJobsResponse,
|
||||
LoraFinetuningConfig,
|
||||
PostTrainingJob,
|
||||
PostTrainingJobArtifactsResponse,
|
||||
PostTrainingJobStatusResponse,
|
||||
TrainingConfig,
|
||||
)
|
||||
from llama_stack.providers.inline.post_training.huggingface_ilab.config import HFilabPostTrainingConfig
|
||||
from llama_stack.providers.inline.post_training.huggingface_ilab.recipes import FullPrecisionFineTuning
|
||||
from llama_stack.schema_utils import webmethod
|
||||
|
||||
|
||||
class TuningJob(pydantic.BaseModel):
|
||||
model_config = pydantic.ConfigDict(validate_assignment=True, arbitrary_types_allowed=True)
|
||||
job_uuid: str
|
||||
status: list[JobStatus] = []
|
||||
|
||||
created_at: datetime | None = None
|
||||
scheduled_at: datetime | None = None
|
||||
completed_at: datetime | None = None
|
||||
|
||||
background_proc_pid: Process | None = None
|
||||
|
||||
|
||||
class HFilabPostTrainingImpl:
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -34,13 +54,25 @@ class HFilabPostTrainingImpl:
|
|||
self.datasetio_api = datasetio_api
|
||||
self.datasets_api = datasets
|
||||
|
||||
# TODO: assume sync job, will need jobs API for async scheduling
|
||||
self.jobs = {}
|
||||
self.checkpoints_dict = {}
|
||||
self.current_job: TuningJob | None = None
|
||||
|
||||
async def shutdown(self):
|
||||
pass
|
||||
|
||||
async def can_schedule_new_job(self) -> bool:
|
||||
if self.current_job is None:
|
||||
return True
|
||||
|
||||
finalized_job_states = [JobStatus.completed.value, JobStatus.failed.value]
|
||||
if self.current_job.status in finalized_job_states:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def __set_status_callback(self, new_status: JobStatus):
|
||||
if self.current_job is not None:
|
||||
self.current_job.status.append(new_status)
|
||||
|
||||
async def supervised_fine_tune(
|
||||
self,
|
||||
job_uuid: str,
|
||||
|
@ -50,53 +82,46 @@ class HFilabPostTrainingImpl:
|
|||
model: str,
|
||||
checkpoint_dir: Optional[str],
|
||||
algorithm_config: Optional[AlgorithmConfig],
|
||||
) -> PostTrainingJob:
|
||||
if job_uuid in self.jobs:
|
||||
raise ValueError(f"Job {job_uuid} already exists")
|
||||
) -> JSONResponse:
|
||||
if not self.can_schedule_new_job():
|
||||
raise fastapi.HTTPException(
|
||||
status_code=503, # service unavailable, try again later.
|
||||
detail="A tuning job is currently running; this could take a while.",
|
||||
headers={"Retry-After": "3600"}, # 60sec * 60min = 3600 seconds
|
||||
)
|
||||
|
||||
post_training_job = PostTrainingJob(job_uuid=job_uuid)
|
||||
|
||||
job_status_response = PostTrainingJobStatusResponse(
|
||||
job_uuid=job_uuid,
|
||||
status=JobStatus.scheduled,
|
||||
scheduled_at=datetime.now(),
|
||||
recipe = FullPrecisionFineTuning(
|
||||
model=model,
|
||||
training_config=training_config,
|
||||
logger_config=logger_config,
|
||||
storage_dir=Path(checkpoint_dir) if checkpoint_dir else None,
|
||||
algorithm_config=algorithm_config,
|
||||
datasets_api=self.datasets_api,
|
||||
datasetsio_api=self.datasetio_api,
|
||||
)
|
||||
self.jobs[job_uuid] = job_status_response
|
||||
|
||||
if isinstance(algorithm_config, LoraFinetuningConfig):
|
||||
try:
|
||||
recipe = LoraFinetuningSingleDevice(
|
||||
self.config,
|
||||
job_uuid,
|
||||
training_config,
|
||||
hyperparam_search_config,
|
||||
logger_config,
|
||||
model,
|
||||
checkpoint_dir,
|
||||
algorithm_config,
|
||||
self.datasetio_api,
|
||||
self.datasets_api,
|
||||
)
|
||||
tasks = BackgroundTasks()
|
||||
tasks.add_task(
|
||||
recipe.load_dataset_from_datasetsio, # asynchronous request
|
||||
)
|
||||
tasks.add_task(
|
||||
recipe.preflight, # synchronous request
|
||||
set_status_callback=self.__set_status_callback,
|
||||
)
|
||||
tasks.add_task(
|
||||
recipe.setup, # synchronous request
|
||||
)
|
||||
tasks.add_task(
|
||||
recipe.train, # asynchronous request
|
||||
set_status_callback=self.__set_status_callback,
|
||||
)
|
||||
|
||||
job_status_response.status = JobStatus.in_progress
|
||||
job_status_response.started_at = datetime.now()
|
||||
|
||||
await recipe.setup()
|
||||
resources_allocated, checkpoints = await recipe.train()
|
||||
|
||||
self.checkpoints_dict[job_uuid] = checkpoints
|
||||
job_status_response.resources_allocated = resources_allocated
|
||||
job_status_response.checkpoints = checkpoints
|
||||
job_status_response.status = JobStatus.completed
|
||||
job_status_response.completed_at = datetime.now()
|
||||
|
||||
except Exception:
|
||||
job_status_response.status = JobStatus.failed
|
||||
raise
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
return post_training_job
|
||||
self.current_job = TuningJob(job_uuid=job_uuid, status=[JobStatus.scheduled])
|
||||
resp_object = PostTrainingJob(job_uuid=job_uuid)
|
||||
return JSONResponse(
|
||||
content=resp_object.model_dump(),
|
||||
background=tasks,
|
||||
)
|
||||
|
||||
async def preference_optimize(
|
||||
self,
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
from .fullprecision_finetuning_multi_device import FullPrecisionFineTuning
|
|
@ -0,0 +1,173 @@
|
|||
import tempfile
|
||||
import typing
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
import transformers
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
|
||||
|
||||
from llama_stack.apis.common.job_types import JobStatus
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.post_training.post_training import AlgorithmConfig, TrainingConfig
|
||||
|
||||
STORAGE_SUBDIRS = ["checkpoints", "data", "logs", "hf_cache"]
|
||||
VALIDATED_MODEL_ARCHS = ["LlamaForCausalLM", "GraniteForCausalLM"]
|
||||
|
||||
SomePretrainedTokenizer: typing.TypeAlias = PreTrainedTokenizer | PreTrainedTokenizerFast
|
||||
|
||||
|
||||
class FullPrecisionFineTuning:
|
||||
# TODO: set HF storage in HF utilities to this object's storage so that we can clean it up automatically
|
||||
# TODO: set up logging utils, replace print statements with logging with the right level
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
training_config: TrainingConfig,
|
||||
logger_config: dict[str, typing.Any],
|
||||
storage_dir: Path | None,
|
||||
algorithm_config: AlgorithmConfig, # type: ignore
|
||||
datasets_api: Datasets,
|
||||
datasetsio_api: DatasetIO,
|
||||
):
|
||||
self.model_name_or_path = model
|
||||
self.training_config = training_config
|
||||
self.logger_config = logger_config
|
||||
if storage_dir:
|
||||
self.storage_dir = storage_dir
|
||||
else:
|
||||
self.storage_dir = Path(tempfile.mkdtemp())
|
||||
self.__setup_storage()
|
||||
|
||||
self.datasets_api = datasets_api
|
||||
self.datasetio_api = datasetsio_api
|
||||
self.loaded_dataset: typing.Any = None # should be a list of dicts but shape can be weird
|
||||
|
||||
def __setup_storage(self):
|
||||
for subdir in STORAGE_SUBDIRS:
|
||||
new_subdir = self.storage_dir / subdir
|
||||
new_subdir.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
@property
|
||||
def checkpoint_dir(self):
|
||||
return self.storage_dir / "checkpoints"
|
||||
|
||||
@property
|
||||
def data_dir(self):
|
||||
return self.storage_dir / "data"
|
||||
|
||||
@property
|
||||
def logs_dir(self):
|
||||
return self.storage_dir / "logs"
|
||||
|
||||
@staticmethod
|
||||
def check_model_arch_validated(model_config: PretrainedConfig) -> bool:
|
||||
if model_config.architectures is None:
|
||||
return False
|
||||
|
||||
for arch in model_config.architectures:
|
||||
if arch in VALIDATED_MODEL_ARCHS:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def __try_load_config(self) -> PretrainedConfig:
|
||||
try:
|
||||
model_config: PretrainedConfig = transformers.AutoConfig.from_pretrained(self.model_name_or_path)
|
||||
except OSError:
|
||||
print(
|
||||
f"Attempted to load model config for ({self.model_name_or_path}) but failed. Model config will be loaded by `AutoConfig.from_pretrained()`"
|
||||
)
|
||||
raise
|
||||
|
||||
return model_config
|
||||
|
||||
def __try_load_tokenizer(self) -> SomePretrainedTokenizer:
|
||||
try:
|
||||
tokenizer: SomePretrainedTokenizer = transformers.AutoTokenizer.from_pretrained(
|
||||
self.model_name_or_path, use_fast=True
|
||||
)
|
||||
except OSError:
|
||||
print(
|
||||
f"Attempted to load model tokenizer for ({self.model_name_or_path}) but failed. Model tokenizer will be loaded by `AutoTokenizer.from_pretrained()`"
|
||||
)
|
||||
raise
|
||||
|
||||
return tokenizer
|
||||
|
||||
@staticmethod
|
||||
def check_tokenizer_has_chat_template(tokenizer: SomePretrainedTokenizer) -> bool:
|
||||
if not hasattr(tokenizer, "chat_template"):
|
||||
return False
|
||||
|
||||
if tokenizer.chat_template is None:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
async def load_dataset_from_datasetsio(self):
|
||||
dataset = await self.datasetio_api.get_rows_paginated(
|
||||
dataset_id=self.training_config.data_config.dataset_id, rows_in_page=-1
|
||||
)
|
||||
self.loaded_dataset = dataset.rows
|
||||
|
||||
def preflight(self, set_status_callback: Callable[[JobStatus], None]):
|
||||
"""
|
||||
A synchronous "preflight" operation from the recipe runs and does the following checks:
|
||||
1. (future) validates that the host has access to sufficient hardware. For now, assume that an administrator has "cleared" the deployment for any requests it could get. In the future, check:
|
||||
1. Cards exist
|
||||
2. Cards have enough memory
|
||||
3. Cards are idle
|
||||
4. Cards have silicon for functions (bfloat16 tensor cores, support FA)
|
||||
5. Cards are functional
|
||||
2. Validates that model is available from HF.
|
||||
3. Validates that model is a verified architecture (warns user if not).
|
||||
4. Validates that model's tokenizer exists.
|
||||
5. Validates that the model's tokenizer has a chat template.
|
||||
6. Validates that the model's chat template can render a sample from the dataset.
|
||||
"""
|
||||
|
||||
model_config = self.__try_load_config()
|
||||
if not self.check_model_arch_validated(model_config=model_config):
|
||||
# could raise Error if we need a strong check against this.
|
||||
print(
|
||||
f"Input model ({self.model_name_or_path}) architecture ({model_config.architectures}) is not among validated architectures."
|
||||
)
|
||||
|
||||
model_tokenizer = self.__try_load_tokenizer()
|
||||
if not self.check_tokenizer_has_chat_template(model_tokenizer):
|
||||
raise RuntimeError(
|
||||
f"Input model ({self.model_name_or_path})'s tokenizer ({model_tokenizer.__name__}) has no chat template from associated `tokenizer_config.json`"
|
||||
)
|
||||
|
||||
try:
|
||||
rendered_sample = model_tokenizer.apply_chat_template(self.loaded_dataset[0]["messages"])
|
||||
except Exception:
|
||||
# catching / raising bare exception because 'apply_chat_template' can raise ValueError or TypeError; want to report the same thing regardless.
|
||||
print(
|
||||
f"Input model ({self.model_name_or_path})'s tokenizer ({model_tokenizer.__name__}) could not tokenize dataset sample. Please make sure that sample is OpenAI 'chat' formatted."
|
||||
)
|
||||
raise
|
||||
|
||||
# Success! Preflight checks haven't caught any immediate problems.
|
||||
set_status_callback(JobStatus.scheduled)
|
||||
|
||||
def setup(self):
|
||||
"""
|
||||
A synchronous data preprocessing operation that runs in a kernel-scheduled background thread and does the following:
|
||||
1. Requests all rows of data from datasetsio API
|
||||
2. Ports data into a `huggingface.datasets` object
|
||||
3. Instantiates the model tokenizer
|
||||
4. `dataset.map`'s the input data into chat template format
|
||||
5. generates labels, masks for each sample
|
||||
6. writes dataset to temporary storage
|
||||
"""
|
||||
pass
|
||||
|
||||
async def train(self, set_status_callback: Callable[[JobStatus], None]):
|
||||
"""
|
||||
An asynchronous instance method that creates and watches a `torchrun` subprocess that's training the input model.
|
||||
"""
|
||||
set_status_callback(JobStatus.completed)
|
Loading…
Add table
Add a link
Reference in a new issue