huggingface_ilab set up; removed most of torchtune impl

Signed-off-by: James Kunstle <jkunstle@redhat.com>
This commit is contained in:
James Kunstle 2025-03-12 15:53:47 -07:00
parent 5b9c366614
commit 9698c14e07
4 changed files with 175 additions and 0 deletions

View file

@ -0,0 +1,25 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from llama_stack.distribution.datatypes import Api
from .config import HFilabPostTrainingConfig
async def get_provider_impl(
config: HFilabPostTrainingConfig,
deps: dict[Api, Any],
):
from .post_training import HFilabPostTrainingImpl
impl = HFilabPostTrainingImpl(
config,
deps[Api.datasetio],
deps[Api.datasets],
)
return impl

View file

@ -0,0 +1,14 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Literal, Optional
from pydantic import BaseModel
class HFilabPostTrainingConfig(BaseModel):
torch_seed: Optional[int] = None
checkpoint_format: Optional[Literal["meta", "huggingface"]] = "meta"

View file

@ -0,0 +1,125 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from datetime import datetime
from typing import Any, Dict, Optional
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.post_training import (
AlgorithmConfig,
DPOAlignmentConfig,
JobStatus,
ListPostTrainingJobsResponse,
LoraFinetuningConfig,
PostTrainingJob,
PostTrainingJobArtifactsResponse,
PostTrainingJobStatusResponse,
TrainingConfig,
)
from llama_stack.providers.inline.post_training.huggingface_ilab.config import HFilabPostTrainingConfig
from llama_stack.schema_utils import webmethod
class HFilabPostTrainingImpl:
def __init__(
self,
config: HFilabPostTrainingConfig,
datasetio_api: DatasetIO,
datasets: Datasets,
) -> None:
self.config = config
self.datasetio_api = datasetio_api
self.datasets_api = datasets
# TODO: assume sync job, will need jobs API for async scheduling
self.jobs = {}
self.checkpoints_dict = {}
async def shutdown(self):
pass
async def supervised_fine_tune(
self,
job_uuid: str,
training_config: TrainingConfig,
hyperparam_search_config: Dict[str, Any],
logger_config: Dict[str, Any],
model: str,
checkpoint_dir: Optional[str],
algorithm_config: Optional[AlgorithmConfig],
) -> PostTrainingJob:
if job_uuid in self.jobs:
raise ValueError(f"Job {job_uuid} already exists")
post_training_job = PostTrainingJob(job_uuid=job_uuid)
job_status_response = PostTrainingJobStatusResponse(
job_uuid=job_uuid,
status=JobStatus.scheduled,
scheduled_at=datetime.now(),
)
self.jobs[job_uuid] = job_status_response
if isinstance(algorithm_config, LoraFinetuningConfig):
try:
recipe = LoraFinetuningSingleDevice(
self.config,
job_uuid,
training_config,
hyperparam_search_config,
logger_config,
model,
checkpoint_dir,
algorithm_config,
self.datasetio_api,
self.datasets_api,
)
job_status_response.status = JobStatus.in_progress
job_status_response.started_at = datetime.now()
await recipe.setup()
resources_allocated, checkpoints = await recipe.train()
self.checkpoints_dict[job_uuid] = checkpoints
job_status_response.resources_allocated = resources_allocated
job_status_response.checkpoints = checkpoints
job_status_response.status = JobStatus.completed
job_status_response.completed_at = datetime.now()
except Exception:
job_status_response.status = JobStatus.failed
raise
else:
raise NotImplementedError()
return post_training_job
async def preference_optimize(
self,
job_uuid: str,
finetuned_model: str,
algorithm_config: DPOAlignmentConfig,
training_config: TrainingConfig,
hyperparam_search_config: Dict[str, Any],
logger_config: Dict[str, Any],
) -> PostTrainingJob:
raise NotImplementedError("preference optimization is not implemented yet")
async def get_training_jobs(self) -> ListPostTrainingJobsResponse:
raise NotImplementedError("'get training jobs' ys not implemented yet")
@webmethod(route="/post-training/job/status") # type: ignore
async def get_training_job_status(self, job_uuid: str) -> Optional[PostTrainingJobStatusResponse]:
raise NotImplementedError("'get training job status' is not implemented yet")
@webmethod(route="/post-training/job/cancel") # type: ignore
async def cancel_training_job(self, job_uuid: str) -> None:
raise NotImplementedError("Job cancel is not implemented yet")
@webmethod(route="/post-training/job/artifacts") # type: ignore
async def get_training_job_artifacts(self, job_uuid: str) -> Optional[PostTrainingJobArtifactsResponse]:
raise NotImplementedError("'get training job artifacts' is not implemented yet")

View file

@ -22,4 +22,15 @@ def available_providers() -> List[ProviderSpec]:
Api.datasets,
],
),
InlineProviderSpec(
api=Api.post_training,
provider_type="inline::huggingface-ilab",
pip_packages=["torch", "transformers", "datasets", "numpy"],
module="llama_stack.providers.inline.post_training.huggingface_ilab",
config_class="llama_stack.providers.inline.post_training.huggingface_ilab.HFilabPostTrainingConfig",
api_dependencies=[
Api.datasetio,
Api.datasets,
],
),
]