mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-06 02:32:40 +00:00
huggingface_ilab set up; removed most of torchtune impl
Signed-off-by: James Kunstle <jkunstle@redhat.com>
This commit is contained in:
parent
5b9c366614
commit
9698c14e07
4 changed files with 175 additions and 0 deletions
|
@ -0,0 +1,25 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
|
||||
from .config import HFilabPostTrainingConfig
|
||||
|
||||
|
||||
async def get_provider_impl(
|
||||
config: HFilabPostTrainingConfig,
|
||||
deps: dict[Api, Any],
|
||||
):
|
||||
from .post_training import HFilabPostTrainingImpl
|
||||
|
||||
impl = HFilabPostTrainingImpl(
|
||||
config,
|
||||
deps[Api.datasetio],
|
||||
deps[Api.datasets],
|
||||
)
|
||||
return impl
|
|
@ -0,0 +1,14 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Literal, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class HFilabPostTrainingConfig(BaseModel):
|
||||
torch_seed: Optional[int] = None
|
||||
checkpoint_format: Optional[Literal["meta", "huggingface"]] = "meta"
|
|
@ -0,0 +1,125 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.post_training import (
|
||||
AlgorithmConfig,
|
||||
DPOAlignmentConfig,
|
||||
JobStatus,
|
||||
ListPostTrainingJobsResponse,
|
||||
LoraFinetuningConfig,
|
||||
PostTrainingJob,
|
||||
PostTrainingJobArtifactsResponse,
|
||||
PostTrainingJobStatusResponse,
|
||||
TrainingConfig,
|
||||
)
|
||||
from llama_stack.providers.inline.post_training.huggingface_ilab.config import HFilabPostTrainingConfig
|
||||
from llama_stack.schema_utils import webmethod
|
||||
|
||||
|
||||
class HFilabPostTrainingImpl:
|
||||
def __init__(
|
||||
self,
|
||||
config: HFilabPostTrainingConfig,
|
||||
datasetio_api: DatasetIO,
|
||||
datasets: Datasets,
|
||||
) -> None:
|
||||
self.config = config
|
||||
self.datasetio_api = datasetio_api
|
||||
self.datasets_api = datasets
|
||||
|
||||
# TODO: assume sync job, will need jobs API for async scheduling
|
||||
self.jobs = {}
|
||||
self.checkpoints_dict = {}
|
||||
|
||||
async def shutdown(self):
|
||||
pass
|
||||
|
||||
async def supervised_fine_tune(
|
||||
self,
|
||||
job_uuid: str,
|
||||
training_config: TrainingConfig,
|
||||
hyperparam_search_config: Dict[str, Any],
|
||||
logger_config: Dict[str, Any],
|
||||
model: str,
|
||||
checkpoint_dir: Optional[str],
|
||||
algorithm_config: Optional[AlgorithmConfig],
|
||||
) -> PostTrainingJob:
|
||||
if job_uuid in self.jobs:
|
||||
raise ValueError(f"Job {job_uuid} already exists")
|
||||
|
||||
post_training_job = PostTrainingJob(job_uuid=job_uuid)
|
||||
|
||||
job_status_response = PostTrainingJobStatusResponse(
|
||||
job_uuid=job_uuid,
|
||||
status=JobStatus.scheduled,
|
||||
scheduled_at=datetime.now(),
|
||||
)
|
||||
self.jobs[job_uuid] = job_status_response
|
||||
|
||||
if isinstance(algorithm_config, LoraFinetuningConfig):
|
||||
try:
|
||||
recipe = LoraFinetuningSingleDevice(
|
||||
self.config,
|
||||
job_uuid,
|
||||
training_config,
|
||||
hyperparam_search_config,
|
||||
logger_config,
|
||||
model,
|
||||
checkpoint_dir,
|
||||
algorithm_config,
|
||||
self.datasetio_api,
|
||||
self.datasets_api,
|
||||
)
|
||||
|
||||
job_status_response.status = JobStatus.in_progress
|
||||
job_status_response.started_at = datetime.now()
|
||||
|
||||
await recipe.setup()
|
||||
resources_allocated, checkpoints = await recipe.train()
|
||||
|
||||
self.checkpoints_dict[job_uuid] = checkpoints
|
||||
job_status_response.resources_allocated = resources_allocated
|
||||
job_status_response.checkpoints = checkpoints
|
||||
job_status_response.status = JobStatus.completed
|
||||
job_status_response.completed_at = datetime.now()
|
||||
|
||||
except Exception:
|
||||
job_status_response.status = JobStatus.failed
|
||||
raise
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
return post_training_job
|
||||
|
||||
async def preference_optimize(
|
||||
self,
|
||||
job_uuid: str,
|
||||
finetuned_model: str,
|
||||
algorithm_config: DPOAlignmentConfig,
|
||||
training_config: TrainingConfig,
|
||||
hyperparam_search_config: Dict[str, Any],
|
||||
logger_config: Dict[str, Any],
|
||||
) -> PostTrainingJob:
|
||||
raise NotImplementedError("preference optimization is not implemented yet")
|
||||
|
||||
async def get_training_jobs(self) -> ListPostTrainingJobsResponse:
|
||||
raise NotImplementedError("'get training jobs' ys not implemented yet")
|
||||
|
||||
@webmethod(route="/post-training/job/status") # type: ignore
|
||||
async def get_training_job_status(self, job_uuid: str) -> Optional[PostTrainingJobStatusResponse]:
|
||||
raise NotImplementedError("'get training job status' is not implemented yet")
|
||||
|
||||
@webmethod(route="/post-training/job/cancel") # type: ignore
|
||||
async def cancel_training_job(self, job_uuid: str) -> None:
|
||||
raise NotImplementedError("Job cancel is not implemented yet")
|
||||
|
||||
@webmethod(route="/post-training/job/artifacts") # type: ignore
|
||||
async def get_training_job_artifacts(self, job_uuid: str) -> Optional[PostTrainingJobArtifactsResponse]:
|
||||
raise NotImplementedError("'get training job artifacts' is not implemented yet")
|
|
@ -22,4 +22,15 @@ def available_providers() -> List[ProviderSpec]:
|
|||
Api.datasets,
|
||||
],
|
||||
),
|
||||
InlineProviderSpec(
|
||||
api=Api.post_training,
|
||||
provider_type="inline::huggingface-ilab",
|
||||
pip_packages=["torch", "transformers", "datasets", "numpy"],
|
||||
module="llama_stack.providers.inline.post_training.huggingface_ilab",
|
||||
config_class="llama_stack.providers.inline.post_training.huggingface_ilab.HFilabPostTrainingConfig",
|
||||
api_dependencies=[
|
||||
Api.datasetio,
|
||||
Api.datasets,
|
||||
],
|
||||
),
|
||||
]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue