From 9698c14e072f96765f26e2d4f2daac7e26358268 Mon Sep 17 00:00:00 2001 From: James Kunstle Date: Wed, 12 Mar 2025 15:53:47 -0700 Subject: [PATCH] huggingface_ilab set up; removed most of torchtune impl Signed-off-by: James Kunstle --- .../huggingface_ilab/__init__.py | 25 ++++ .../post_training/huggingface_ilab/config.py | 14 ++ .../huggingface_ilab/post_training.py | 125 ++++++++++++++++++ .../providers/registry/post_training.py | 11 ++ 4 files changed, 175 insertions(+) create mode 100644 llama_stack/providers/inline/post_training/huggingface_ilab/__init__.py create mode 100644 llama_stack/providers/inline/post_training/huggingface_ilab/config.py create mode 100644 llama_stack/providers/inline/post_training/huggingface_ilab/post_training.py diff --git a/llama_stack/providers/inline/post_training/huggingface_ilab/__init__.py b/llama_stack/providers/inline/post_training/huggingface_ilab/__init__.py new file mode 100644 index 000000000..6f97ebd63 --- /dev/null +++ b/llama_stack/providers/inline/post_training/huggingface_ilab/__init__.py @@ -0,0 +1,25 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any + +from llama_stack.distribution.datatypes import Api + +from .config import HFilabPostTrainingConfig + + +async def get_provider_impl( + config: HFilabPostTrainingConfig, + deps: dict[Api, Any], +): + from .post_training import HFilabPostTrainingImpl + + impl = HFilabPostTrainingImpl( + config, + deps[Api.datasetio], + deps[Api.datasets], + ) + return impl diff --git a/llama_stack/providers/inline/post_training/huggingface_ilab/config.py b/llama_stack/providers/inline/post_training/huggingface_ilab/config.py new file mode 100644 index 000000000..66c7864ca --- /dev/null +++ b/llama_stack/providers/inline/post_training/huggingface_ilab/config.py @@ -0,0 +1,14 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Literal, Optional + +from pydantic import BaseModel + + +class HFilabPostTrainingConfig(BaseModel): + torch_seed: Optional[int] = None + checkpoint_format: Optional[Literal["meta", "huggingface"]] = "meta" diff --git a/llama_stack/providers/inline/post_training/huggingface_ilab/post_training.py b/llama_stack/providers/inline/post_training/huggingface_ilab/post_training.py new file mode 100644 index 000000000..d15850b29 --- /dev/null +++ b/llama_stack/providers/inline/post_training/huggingface_ilab/post_training.py @@ -0,0 +1,125 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +from datetime import datetime +from typing import Any, Dict, Optional + +from llama_stack.apis.datasetio import DatasetIO +from llama_stack.apis.datasets import Datasets +from llama_stack.apis.post_training import ( + AlgorithmConfig, + DPOAlignmentConfig, + JobStatus, + ListPostTrainingJobsResponse, + LoraFinetuningConfig, + PostTrainingJob, + PostTrainingJobArtifactsResponse, + PostTrainingJobStatusResponse, + TrainingConfig, +) +from llama_stack.providers.inline.post_training.huggingface_ilab.config import HFilabPostTrainingConfig +from llama_stack.schema_utils import webmethod + + +class HFilabPostTrainingImpl: + def __init__( + self, + config: HFilabPostTrainingConfig, + datasetio_api: DatasetIO, + datasets: Datasets, + ) -> None: + self.config = config + self.datasetio_api = datasetio_api + self.datasets_api = datasets + + # TODO: assume sync job, will need jobs API for async scheduling + self.jobs = {} + self.checkpoints_dict = {} + + async def shutdown(self): + pass + + async def supervised_fine_tune( + self, + job_uuid: str, + training_config: TrainingConfig, + hyperparam_search_config: Dict[str, Any], + logger_config: Dict[str, Any], + model: str, + checkpoint_dir: Optional[str], + algorithm_config: Optional[AlgorithmConfig], + ) -> PostTrainingJob: + if job_uuid in self.jobs: + raise ValueError(f"Job {job_uuid} already exists") + + post_training_job = PostTrainingJob(job_uuid=job_uuid) + + job_status_response = PostTrainingJobStatusResponse( + job_uuid=job_uuid, + status=JobStatus.scheduled, + scheduled_at=datetime.now(), + ) + self.jobs[job_uuid] = job_status_response + + if isinstance(algorithm_config, LoraFinetuningConfig): + try: + recipe = LoraFinetuningSingleDevice( + self.config, + job_uuid, + training_config, + hyperparam_search_config, + logger_config, + model, + checkpoint_dir, + algorithm_config, + self.datasetio_api, + self.datasets_api, + ) + + job_status_response.status = JobStatus.in_progress + job_status_response.started_at = datetime.now() + + await recipe.setup() + resources_allocated, checkpoints = await recipe.train() + + self.checkpoints_dict[job_uuid] = checkpoints + job_status_response.resources_allocated = resources_allocated + job_status_response.checkpoints = checkpoints + job_status_response.status = JobStatus.completed + job_status_response.completed_at = datetime.now() + + except Exception: + job_status_response.status = JobStatus.failed + raise + else: + raise NotImplementedError() + + return post_training_job + + async def preference_optimize( + self, + job_uuid: str, + finetuned_model: str, + algorithm_config: DPOAlignmentConfig, + training_config: TrainingConfig, + hyperparam_search_config: Dict[str, Any], + logger_config: Dict[str, Any], + ) -> PostTrainingJob: + raise NotImplementedError("preference optimization is not implemented yet") + + async def get_training_jobs(self) -> ListPostTrainingJobsResponse: + raise NotImplementedError("'get training jobs' ys not implemented yet") + + @webmethod(route="/post-training/job/status") # type: ignore + async def get_training_job_status(self, job_uuid: str) -> Optional[PostTrainingJobStatusResponse]: + raise NotImplementedError("'get training job status' is not implemented yet") + + @webmethod(route="/post-training/job/cancel") # type: ignore + async def cancel_training_job(self, job_uuid: str) -> None: + raise NotImplementedError("Job cancel is not implemented yet") + + @webmethod(route="/post-training/job/artifacts") # type: ignore + async def get_training_job_artifacts(self, job_uuid: str) -> Optional[PostTrainingJobArtifactsResponse]: + raise NotImplementedError("'get training job artifacts' is not implemented yet") diff --git a/llama_stack/providers/registry/post_training.py b/llama_stack/providers/registry/post_training.py index 3bcda6508..4fd7122af 100644 --- a/llama_stack/providers/registry/post_training.py +++ b/llama_stack/providers/registry/post_training.py @@ -22,4 +22,15 @@ def available_providers() -> List[ProviderSpec]: Api.datasets, ], ), + InlineProviderSpec( + api=Api.post_training, + provider_type="inline::huggingface-ilab", + pip_packages=["torch", "transformers", "datasets", "numpy"], + module="llama_stack.providers.inline.post_training.huggingface_ilab", + config_class="llama_stack.providers.inline.post_training.huggingface_ilab.HFilabPostTrainingConfig", + api_dependencies=[ + Api.datasetio, + Api.datasets, + ], + ), ]