forked from phoenix-oss/llama-stack-mirror
# What does this PR do? adds an inline HF SFTTrainer provider. Alongside touchtune -- this is a super popular option for running training jobs. The config allows a user to specify some key fields such as a model, chat_template, device, etc the provider comes with one recipe `finetune_single_device` which works both with and without LoRA. any model that is a valid HF identifier can be given and the model will be pulled. this has been tested so far with CPU and MPS device types, but should be compatible with CUDA out of the box The provider processes the given dataset into the proper format, establishes the various steps per epoch, steps per save, steps per eval, sets a sane SFTConfig, and runs n_epochs of training if checkpoint_dir is none, no model is saved. If there is a checkpoint dir, a model is saved every `save_steps` and at the end of training. ## Test Plan re-enabled post_training integration test suite with a singular test that loads the simpleqa dataset: https://huggingface.co/datasets/llamastack/simpleqa and a tiny granite model: https://huggingface.co/ibm-granite/granite-3.3-2b-instruct. The test now uses the llama stack client and the proper post_training API runs one step with a batch_size of 1. This test runs on CPU on the Ubuntu runner so it needs to be a small batch and a single step. [//]: # (## Documentation) --------- Signed-off-by: Charlie Doern <cdoern@redhat.com>
176 lines
6.4 KiB
Python
176 lines
6.4 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
from enum import Enum
|
|
from typing import Any
|
|
|
|
from llama_stack.apis.datasetio import DatasetIO
|
|
from llama_stack.apis.datasets import Datasets
|
|
from llama_stack.apis.post_training import (
|
|
AlgorithmConfig,
|
|
Checkpoint,
|
|
DPOAlignmentConfig,
|
|
JobStatus,
|
|
ListPostTrainingJobsResponse,
|
|
PostTrainingJob,
|
|
PostTrainingJobArtifactsResponse,
|
|
PostTrainingJobStatusResponse,
|
|
TrainingConfig,
|
|
)
|
|
from llama_stack.providers.inline.post_training.huggingface.config import (
|
|
HuggingFacePostTrainingConfig,
|
|
)
|
|
from llama_stack.providers.inline.post_training.huggingface.recipes.finetune_single_device import (
|
|
HFFinetuningSingleDevice,
|
|
)
|
|
from llama_stack.providers.utils.scheduler import JobArtifact, Scheduler
|
|
from llama_stack.providers.utils.scheduler import JobStatus as SchedulerJobStatus
|
|
from llama_stack.schema_utils import webmethod
|
|
|
|
|
|
class TrainingArtifactType(Enum):
|
|
CHECKPOINT = "checkpoint"
|
|
RESOURCES_STATS = "resources_stats"
|
|
|
|
|
|
_JOB_TYPE_SUPERVISED_FINE_TUNE = "supervised-fine-tune"
|
|
|
|
|
|
class HuggingFacePostTrainingImpl:
|
|
def __init__(
|
|
self,
|
|
config: HuggingFacePostTrainingConfig,
|
|
datasetio_api: DatasetIO,
|
|
datasets: Datasets,
|
|
) -> None:
|
|
self.config = config
|
|
self.datasetio_api = datasetio_api
|
|
self.datasets_api = datasets
|
|
self._scheduler = Scheduler()
|
|
|
|
async def shutdown(self) -> None:
|
|
await self._scheduler.shutdown()
|
|
|
|
@staticmethod
|
|
def _checkpoint_to_artifact(checkpoint: Checkpoint) -> JobArtifact:
|
|
return JobArtifact(
|
|
type=TrainingArtifactType.CHECKPOINT.value,
|
|
name=checkpoint.identifier,
|
|
uri=checkpoint.path,
|
|
metadata=dict(checkpoint),
|
|
)
|
|
|
|
@staticmethod
|
|
def _resources_stats_to_artifact(resources_stats: dict[str, Any]) -> JobArtifact:
|
|
return JobArtifact(
|
|
type=TrainingArtifactType.RESOURCES_STATS.value,
|
|
name=TrainingArtifactType.RESOURCES_STATS.value,
|
|
metadata=resources_stats,
|
|
)
|
|
|
|
async def supervised_fine_tune(
|
|
self,
|
|
job_uuid: str,
|
|
training_config: TrainingConfig,
|
|
hyperparam_search_config: dict[str, Any],
|
|
logger_config: dict[str, Any],
|
|
model: str,
|
|
checkpoint_dir: str | None = None,
|
|
algorithm_config: AlgorithmConfig | None = None,
|
|
) -> PostTrainingJob:
|
|
async def handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb):
|
|
on_log_message_cb("Starting HF finetuning")
|
|
|
|
recipe = HFFinetuningSingleDevice(
|
|
job_uuid=job_uuid,
|
|
datasetio_api=self.datasetio_api,
|
|
datasets_api=self.datasets_api,
|
|
)
|
|
|
|
resources_allocated, checkpoints = await recipe.train(
|
|
model=model,
|
|
output_dir=checkpoint_dir,
|
|
job_uuid=job_uuid,
|
|
lora_config=algorithm_config,
|
|
config=training_config,
|
|
provider_config=self.config,
|
|
)
|
|
|
|
on_artifact_collected_cb(self._resources_stats_to_artifact(resources_allocated))
|
|
if checkpoints:
|
|
for checkpoint in checkpoints:
|
|
artifact = self._checkpoint_to_artifact(checkpoint)
|
|
on_artifact_collected_cb(artifact)
|
|
|
|
on_status_change_cb(SchedulerJobStatus.completed)
|
|
on_log_message_cb("HF finetuning completed")
|
|
|
|
job_uuid = self._scheduler.schedule(_JOB_TYPE_SUPERVISED_FINE_TUNE, job_uuid, handler)
|
|
return PostTrainingJob(job_uuid=job_uuid)
|
|
|
|
async def preference_optimize(
|
|
self,
|
|
job_uuid: str,
|
|
finetuned_model: str,
|
|
algorithm_config: DPOAlignmentConfig,
|
|
training_config: TrainingConfig,
|
|
hyperparam_search_config: dict[str, Any],
|
|
logger_config: dict[str, Any],
|
|
) -> PostTrainingJob:
|
|
raise NotImplementedError("DPO alignment is not implemented yet")
|
|
|
|
async def get_training_jobs(self) -> ListPostTrainingJobsResponse:
|
|
return ListPostTrainingJobsResponse(
|
|
data=[PostTrainingJob(job_uuid=job.id) for job in self._scheduler.get_jobs()]
|
|
)
|
|
|
|
@staticmethod
|
|
def _get_artifacts_metadata_by_type(job, artifact_type):
|
|
return [artifact.metadata for artifact in job.artifacts if artifact.type == artifact_type]
|
|
|
|
@classmethod
|
|
def _get_checkpoints(cls, job):
|
|
return cls._get_artifacts_metadata_by_type(job, TrainingArtifactType.CHECKPOINT.value)
|
|
|
|
@classmethod
|
|
def _get_resources_allocated(cls, job):
|
|
data = cls._get_artifacts_metadata_by_type(job, TrainingArtifactType.RESOURCES_STATS.value)
|
|
return data[0] if data else None
|
|
|
|
@webmethod(route="/post-training/job/status")
|
|
async def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse | None:
|
|
job = self._scheduler.get_job(job_uuid)
|
|
|
|
match job.status:
|
|
# TODO: Add support for other statuses to API
|
|
case SchedulerJobStatus.new | SchedulerJobStatus.scheduled:
|
|
status = JobStatus.scheduled
|
|
case SchedulerJobStatus.running:
|
|
status = JobStatus.in_progress
|
|
case SchedulerJobStatus.completed:
|
|
status = JobStatus.completed
|
|
case SchedulerJobStatus.failed:
|
|
status = JobStatus.failed
|
|
case _:
|
|
raise NotImplementedError()
|
|
|
|
return PostTrainingJobStatusResponse(
|
|
job_uuid=job_uuid,
|
|
status=status,
|
|
scheduled_at=job.scheduled_at,
|
|
started_at=job.started_at,
|
|
completed_at=job.completed_at,
|
|
checkpoints=self._get_checkpoints(job),
|
|
resources_allocated=self._get_resources_allocated(job),
|
|
)
|
|
|
|
@webmethod(route="/post-training/job/cancel")
|
|
async def cancel_training_job(self, job_uuid: str) -> None:
|
|
self._scheduler.cancel(job_uuid)
|
|
|
|
@webmethod(route="/post-training/job/artifacts")
|
|
async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse | None:
|
|
job = self._scheduler.get_job(job_uuid)
|
|
return PostTrainingJobArtifactsResponse(job_uuid=job_uuid, checkpoints=self._get_checkpoints(job))
|