mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-06 18:50:44 +00:00
implemented data loading and preflight of FullPrecisionFineTuning
Signed-off-by: James Kunstle <jkunstle@redhat.com>
This commit is contained in:
parent
9698c14e07
commit
ddea2aa74f
3 changed files with 246 additions and 47 deletions
|
@ -3,9 +3,17 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
from asyncio.subprocess import Process
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
import fastapi
|
||||||
|
import fastapi.concurrency
|
||||||
|
import pydantic
|
||||||
|
from starlette.background import BackgroundTasks
|
||||||
|
from starlette.responses import JSONResponse
|
||||||
|
|
||||||
from llama_stack.apis.datasetio import DatasetIO
|
from llama_stack.apis.datasetio import DatasetIO
|
||||||
from llama_stack.apis.datasets import Datasets
|
from llama_stack.apis.datasets import Datasets
|
||||||
from llama_stack.apis.post_training import (
|
from llama_stack.apis.post_training import (
|
||||||
|
@ -13,16 +21,28 @@ from llama_stack.apis.post_training import (
|
||||||
DPOAlignmentConfig,
|
DPOAlignmentConfig,
|
||||||
JobStatus,
|
JobStatus,
|
||||||
ListPostTrainingJobsResponse,
|
ListPostTrainingJobsResponse,
|
||||||
LoraFinetuningConfig,
|
|
||||||
PostTrainingJob,
|
PostTrainingJob,
|
||||||
PostTrainingJobArtifactsResponse,
|
PostTrainingJobArtifactsResponse,
|
||||||
PostTrainingJobStatusResponse,
|
PostTrainingJobStatusResponse,
|
||||||
TrainingConfig,
|
TrainingConfig,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.inline.post_training.huggingface_ilab.config import HFilabPostTrainingConfig
|
from llama_stack.providers.inline.post_training.huggingface_ilab.config import HFilabPostTrainingConfig
|
||||||
|
from llama_stack.providers.inline.post_training.huggingface_ilab.recipes import FullPrecisionFineTuning
|
||||||
from llama_stack.schema_utils import webmethod
|
from llama_stack.schema_utils import webmethod
|
||||||
|
|
||||||
|
|
||||||
|
class TuningJob(pydantic.BaseModel):
|
||||||
|
model_config = pydantic.ConfigDict(validate_assignment=True, arbitrary_types_allowed=True)
|
||||||
|
job_uuid: str
|
||||||
|
status: list[JobStatus] = []
|
||||||
|
|
||||||
|
created_at: datetime | None = None
|
||||||
|
scheduled_at: datetime | None = None
|
||||||
|
completed_at: datetime | None = None
|
||||||
|
|
||||||
|
background_proc_pid: Process | None = None
|
||||||
|
|
||||||
|
|
||||||
class HFilabPostTrainingImpl:
|
class HFilabPostTrainingImpl:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -34,13 +54,25 @@ class HFilabPostTrainingImpl:
|
||||||
self.datasetio_api = datasetio_api
|
self.datasetio_api = datasetio_api
|
||||||
self.datasets_api = datasets
|
self.datasets_api = datasets
|
||||||
|
|
||||||
# TODO: assume sync job, will need jobs API for async scheduling
|
self.current_job: TuningJob | None = None
|
||||||
self.jobs = {}
|
|
||||||
self.checkpoints_dict = {}
|
|
||||||
|
|
||||||
async def shutdown(self):
|
async def shutdown(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
async def can_schedule_new_job(self) -> bool:
|
||||||
|
if self.current_job is None:
|
||||||
|
return True
|
||||||
|
|
||||||
|
finalized_job_states = [JobStatus.completed.value, JobStatus.failed.value]
|
||||||
|
if self.current_job.status in finalized_job_states:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def __set_status_callback(self, new_status: JobStatus):
|
||||||
|
if self.current_job is not None:
|
||||||
|
self.current_job.status.append(new_status)
|
||||||
|
|
||||||
async def supervised_fine_tune(
|
async def supervised_fine_tune(
|
||||||
self,
|
self,
|
||||||
job_uuid: str,
|
job_uuid: str,
|
||||||
|
@ -50,53 +82,46 @@ class HFilabPostTrainingImpl:
|
||||||
model: str,
|
model: str,
|
||||||
checkpoint_dir: Optional[str],
|
checkpoint_dir: Optional[str],
|
||||||
algorithm_config: Optional[AlgorithmConfig],
|
algorithm_config: Optional[AlgorithmConfig],
|
||||||
) -> PostTrainingJob:
|
) -> JSONResponse:
|
||||||
if job_uuid in self.jobs:
|
if not self.can_schedule_new_job():
|
||||||
raise ValueError(f"Job {job_uuid} already exists")
|
raise fastapi.HTTPException(
|
||||||
|
status_code=503, # service unavailable, try again later.
|
||||||
|
detail="A tuning job is currently running; this could take a while.",
|
||||||
|
headers={"Retry-After": "3600"}, # 60sec * 60min = 3600 seconds
|
||||||
|
)
|
||||||
|
|
||||||
post_training_job = PostTrainingJob(job_uuid=job_uuid)
|
recipe = FullPrecisionFineTuning(
|
||||||
|
model=model,
|
||||||
job_status_response = PostTrainingJobStatusResponse(
|
training_config=training_config,
|
||||||
job_uuid=job_uuid,
|
logger_config=logger_config,
|
||||||
status=JobStatus.scheduled,
|
storage_dir=Path(checkpoint_dir) if checkpoint_dir else None,
|
||||||
scheduled_at=datetime.now(),
|
algorithm_config=algorithm_config,
|
||||||
|
datasets_api=self.datasets_api,
|
||||||
|
datasetsio_api=self.datasetio_api,
|
||||||
)
|
)
|
||||||
self.jobs[job_uuid] = job_status_response
|
|
||||||
|
|
||||||
if isinstance(algorithm_config, LoraFinetuningConfig):
|
tasks = BackgroundTasks()
|
||||||
try:
|
tasks.add_task(
|
||||||
recipe = LoraFinetuningSingleDevice(
|
recipe.load_dataset_from_datasetsio, # asynchronous request
|
||||||
self.config,
|
)
|
||||||
job_uuid,
|
tasks.add_task(
|
||||||
training_config,
|
recipe.preflight, # synchronous request
|
||||||
hyperparam_search_config,
|
set_status_callback=self.__set_status_callback,
|
||||||
logger_config,
|
)
|
||||||
model,
|
tasks.add_task(
|
||||||
checkpoint_dir,
|
recipe.setup, # synchronous request
|
||||||
algorithm_config,
|
)
|
||||||
self.datasetio_api,
|
tasks.add_task(
|
||||||
self.datasets_api,
|
recipe.train, # asynchronous request
|
||||||
)
|
set_status_callback=self.__set_status_callback,
|
||||||
|
)
|
||||||
|
|
||||||
job_status_response.status = JobStatus.in_progress
|
self.current_job = TuningJob(job_uuid=job_uuid, status=[JobStatus.scheduled])
|
||||||
job_status_response.started_at = datetime.now()
|
resp_object = PostTrainingJob(job_uuid=job_uuid)
|
||||||
|
return JSONResponse(
|
||||||
await recipe.setup()
|
content=resp_object.model_dump(),
|
||||||
resources_allocated, checkpoints = await recipe.train()
|
background=tasks,
|
||||||
|
)
|
||||||
self.checkpoints_dict[job_uuid] = checkpoints
|
|
||||||
job_status_response.resources_allocated = resources_allocated
|
|
||||||
job_status_response.checkpoints = checkpoints
|
|
||||||
job_status_response.status = JobStatus.completed
|
|
||||||
job_status_response.completed_at = datetime.now()
|
|
||||||
|
|
||||||
except Exception:
|
|
||||||
job_status_response.status = JobStatus.failed
|
|
||||||
raise
|
|
||||||
else:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
return post_training_job
|
|
||||||
|
|
||||||
async def preference_optimize(
|
async def preference_optimize(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -0,0 +1 @@
|
||||||
|
from .fullprecision_finetuning_multi_device import FullPrecisionFineTuning
|
|
@ -0,0 +1,173 @@
|
||||||
|
import tempfile
|
||||||
|
import typing
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
import transformers
|
||||||
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
|
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||||
|
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
|
||||||
|
|
||||||
|
from llama_stack.apis.common.job_types import JobStatus
|
||||||
|
from llama_stack.apis.datasetio import DatasetIO
|
||||||
|
from llama_stack.apis.datasets import Datasets
|
||||||
|
from llama_stack.apis.post_training.post_training import AlgorithmConfig, TrainingConfig
|
||||||
|
|
||||||
|
STORAGE_SUBDIRS = ["checkpoints", "data", "logs", "hf_cache"]
|
||||||
|
VALIDATED_MODEL_ARCHS = ["LlamaForCausalLM", "GraniteForCausalLM"]
|
||||||
|
|
||||||
|
SomePretrainedTokenizer: typing.TypeAlias = PreTrainedTokenizer | PreTrainedTokenizerFast
|
||||||
|
|
||||||
|
|
||||||
|
class FullPrecisionFineTuning:
|
||||||
|
# TODO: set HF storage in HF utilities to this object's storage so that we can clean it up automatically
|
||||||
|
# TODO: set up logging utils, replace print statements with logging with the right level
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
training_config: TrainingConfig,
|
||||||
|
logger_config: dict[str, typing.Any],
|
||||||
|
storage_dir: Path | None,
|
||||||
|
algorithm_config: AlgorithmConfig, # type: ignore
|
||||||
|
datasets_api: Datasets,
|
||||||
|
datasetsio_api: DatasetIO,
|
||||||
|
):
|
||||||
|
self.model_name_or_path = model
|
||||||
|
self.training_config = training_config
|
||||||
|
self.logger_config = logger_config
|
||||||
|
if storage_dir:
|
||||||
|
self.storage_dir = storage_dir
|
||||||
|
else:
|
||||||
|
self.storage_dir = Path(tempfile.mkdtemp())
|
||||||
|
self.__setup_storage()
|
||||||
|
|
||||||
|
self.datasets_api = datasets_api
|
||||||
|
self.datasetio_api = datasetsio_api
|
||||||
|
self.loaded_dataset: typing.Any = None # should be a list of dicts but shape can be weird
|
||||||
|
|
||||||
|
def __setup_storage(self):
|
||||||
|
for subdir in STORAGE_SUBDIRS:
|
||||||
|
new_subdir = self.storage_dir / subdir
|
||||||
|
new_subdir.mkdir(exist_ok=True, parents=True)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def checkpoint_dir(self):
|
||||||
|
return self.storage_dir / "checkpoints"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def data_dir(self):
|
||||||
|
return self.storage_dir / "data"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def logs_dir(self):
|
||||||
|
return self.storage_dir / "logs"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def check_model_arch_validated(model_config: PretrainedConfig) -> bool:
|
||||||
|
if model_config.architectures is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
for arch in model_config.architectures:
|
||||||
|
if arch in VALIDATED_MODEL_ARCHS:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def __try_load_config(self) -> PretrainedConfig:
|
||||||
|
try:
|
||||||
|
model_config: PretrainedConfig = transformers.AutoConfig.from_pretrained(self.model_name_or_path)
|
||||||
|
except OSError:
|
||||||
|
print(
|
||||||
|
f"Attempted to load model config for ({self.model_name_or_path}) but failed. Model config will be loaded by `AutoConfig.from_pretrained()`"
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
return model_config
|
||||||
|
|
||||||
|
def __try_load_tokenizer(self) -> SomePretrainedTokenizer:
|
||||||
|
try:
|
||||||
|
tokenizer: SomePretrainedTokenizer = transformers.AutoTokenizer.from_pretrained(
|
||||||
|
self.model_name_or_path, use_fast=True
|
||||||
|
)
|
||||||
|
except OSError:
|
||||||
|
print(
|
||||||
|
f"Attempted to load model tokenizer for ({self.model_name_or_path}) but failed. Model tokenizer will be loaded by `AutoTokenizer.from_pretrained()`"
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
return tokenizer
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def check_tokenizer_has_chat_template(tokenizer: SomePretrainedTokenizer) -> bool:
|
||||||
|
if not hasattr(tokenizer, "chat_template"):
|
||||||
|
return False
|
||||||
|
|
||||||
|
if tokenizer.chat_template is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def load_dataset_from_datasetsio(self):
|
||||||
|
dataset = await self.datasetio_api.get_rows_paginated(
|
||||||
|
dataset_id=self.training_config.data_config.dataset_id, rows_in_page=-1
|
||||||
|
)
|
||||||
|
self.loaded_dataset = dataset.rows
|
||||||
|
|
||||||
|
def preflight(self, set_status_callback: Callable[[JobStatus], None]):
|
||||||
|
"""
|
||||||
|
A synchronous "preflight" operation from the recipe runs and does the following checks:
|
||||||
|
1. (future) validates that the host has access to sufficient hardware. For now, assume that an administrator has "cleared" the deployment for any requests it could get. In the future, check:
|
||||||
|
1. Cards exist
|
||||||
|
2. Cards have enough memory
|
||||||
|
3. Cards are idle
|
||||||
|
4. Cards have silicon for functions (bfloat16 tensor cores, support FA)
|
||||||
|
5. Cards are functional
|
||||||
|
2. Validates that model is available from HF.
|
||||||
|
3. Validates that model is a verified architecture (warns user if not).
|
||||||
|
4. Validates that model's tokenizer exists.
|
||||||
|
5. Validates that the model's tokenizer has a chat template.
|
||||||
|
6. Validates that the model's chat template can render a sample from the dataset.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_config = self.__try_load_config()
|
||||||
|
if not self.check_model_arch_validated(model_config=model_config):
|
||||||
|
# could raise Error if we need a strong check against this.
|
||||||
|
print(
|
||||||
|
f"Input model ({self.model_name_or_path}) architecture ({model_config.architectures}) is not among validated architectures."
|
||||||
|
)
|
||||||
|
|
||||||
|
model_tokenizer = self.__try_load_tokenizer()
|
||||||
|
if not self.check_tokenizer_has_chat_template(model_tokenizer):
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Input model ({self.model_name_or_path})'s tokenizer ({model_tokenizer.__name__}) has no chat template from associated `tokenizer_config.json`"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
rendered_sample = model_tokenizer.apply_chat_template(self.loaded_dataset[0]["messages"])
|
||||||
|
except Exception:
|
||||||
|
# catching / raising bare exception because 'apply_chat_template' can raise ValueError or TypeError; want to report the same thing regardless.
|
||||||
|
print(
|
||||||
|
f"Input model ({self.model_name_or_path})'s tokenizer ({model_tokenizer.__name__}) could not tokenize dataset sample. Please make sure that sample is OpenAI 'chat' formatted."
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Success! Preflight checks haven't caught any immediate problems.
|
||||||
|
set_status_callback(JobStatus.scheduled)
|
||||||
|
|
||||||
|
def setup(self):
|
||||||
|
"""
|
||||||
|
A synchronous data preprocessing operation that runs in a kernel-scheduled background thread and does the following:
|
||||||
|
1. Requests all rows of data from datasetsio API
|
||||||
|
2. Ports data into a `huggingface.datasets` object
|
||||||
|
3. Instantiates the model tokenizer
|
||||||
|
4. `dataset.map`'s the input data into chat template format
|
||||||
|
5. generates labels, masks for each sample
|
||||||
|
6. writes dataset to temporary storage
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def train(self, set_status_callback: Callable[[JobStatus], None]):
|
||||||
|
"""
|
||||||
|
An asynchronous instance method that creates and watches a `torchrun` subprocess that's training the input model.
|
||||||
|
"""
|
||||||
|
set_status_callback(JobStatus.completed)
|
Loading…
Add table
Add a link
Reference in a new issue