mirror of
https://github.com/meta-llama/llama-stack.git
synced 2026-01-04 16:51:59 +00:00
feat(api): define a more coherent jobs api across different flows
Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
parent
71ed47ea76
commit
0f50cfa561
15 changed files with 1864 additions and 1670 deletions
|
|
@ -10,14 +10,10 @@ from llama_stack.apis.datasetio import DatasetIO
|
|||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.post_training import (
|
||||
AlgorithmConfig,
|
||||
Checkpoint,
|
||||
DPOAlignmentConfig,
|
||||
JobStatus,
|
||||
ListPostTrainingJobsResponse,
|
||||
LoraFinetuningConfig,
|
||||
PostTrainingJob,
|
||||
PostTrainingJobArtifactsResponse,
|
||||
PostTrainingJobStatusResponse,
|
||||
TrainingConfig,
|
||||
)
|
||||
from llama_stack.providers.inline.post_training.torchtune.config import (
|
||||
|
|
@ -54,15 +50,6 @@ class TorchtunePostTrainingImpl:
|
|||
async def shutdown(self) -> None:
|
||||
await self._scheduler.shutdown()
|
||||
|
||||
@staticmethod
|
||||
def _checkpoint_to_artifact(checkpoint: Checkpoint) -> JobArtifact:
|
||||
return JobArtifact(
|
||||
type=TrainingArtifactType.CHECKPOINT.value,
|
||||
name=checkpoint.identifier,
|
||||
uri=checkpoint.path,
|
||||
metadata=dict(checkpoint),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _resources_stats_to_artifact(resources_stats: Dict[str, Any]) -> JobArtifact:
|
||||
return JobArtifact(
|
||||
|
|
@ -98,14 +85,14 @@ class TorchtunePostTrainingImpl:
|
|||
self.datasetio_api,
|
||||
self.datasets_api,
|
||||
)
|
||||
|
||||
await recipe.setup()
|
||||
|
||||
resources_allocated, checkpoints = await recipe.train()
|
||||
|
||||
on_artifact_collected_cb(self._resources_stats_to_artifact(resources_allocated))
|
||||
for checkpoint in checkpoints:
|
||||
artifact = self._checkpoint_to_artifact(checkpoint)
|
||||
on_artifact_collected_cb(artifact)
|
||||
on_artifact_collected_cb(checkpoint)
|
||||
|
||||
on_status_change_cb(SchedulerJobStatus.completed)
|
||||
on_log_message_cb("Lora finetuning completed")
|
||||
|
|
@ -113,6 +100,8 @@ class TorchtunePostTrainingImpl:
|
|||
raise NotImplementedError()
|
||||
|
||||
job_uuid = self._scheduler.schedule(_JOB_TYPE_SUPERVISED_FINE_TUNE, job_uuid, handler)
|
||||
|
||||
# TODO: initialize with more data from scheduler
|
||||
return PostTrainingJob(job_uuid=job_uuid)
|
||||
|
||||
async def preference_optimize(
|
||||
|
|
@ -125,56 +114,31 @@ class TorchtunePostTrainingImpl:
|
|||
logger_config: Dict[str, Any],
|
||||
) -> PostTrainingJob: ...
|
||||
|
||||
async def get_training_jobs(self) -> ListPostTrainingJobsResponse:
|
||||
# TODO: should these be under post-training/supervised-fine-tune/?
|
||||
# CRUD operations on running jobs
|
||||
@webmethod(route="/post-training/jobs/{job_id:path}", method="GET")
|
||||
async def get_post_training_job(self, job_id: str) -> PostTrainingJob:
|
||||
# TODO: implement
|
||||
raise NotImplementedError
|
||||
|
||||
@webmethod(route="/post-training/jobs", method="GET")
|
||||
async def list_post_training_jobs(self) -> ListPostTrainingJobsResponse:
|
||||
# TODO: populate other data
|
||||
return ListPostTrainingJobsResponse(
|
||||
data=[PostTrainingJob(job_uuid=job.id) for job in self._scheduler.get_jobs()]
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_artifacts_metadata_by_type(job, artifact_type):
|
||||
return [artifact.metadata for artifact in job.artifacts if artifact.type == artifact_type]
|
||||
@webmethod(route="/post-training/jobs/{job_id:path}", method="POST")
|
||||
async def update_post_training_job(self, job: PostTrainingJob) -> PostTrainingJob:
|
||||
# TODO: implement
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def _get_checkpoints(cls, job):
|
||||
return cls._get_artifacts_metadata_by_type(job, TrainingArtifactType.CHECKPOINT.value)
|
||||
@webmethod(route="/post-training/job/{job_id:path}", method="DELETE")
|
||||
async def delete_post_training_job(self, job_id: str) -> None:
|
||||
# TODO: implement
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def _get_resources_allocated(cls, job):
|
||||
data = cls._get_artifacts_metadata_by_type(job, TrainingArtifactType.RESOURCES_STATS.value)
|
||||
return data[0] if data else None
|
||||
|
||||
@webmethod(route="/post-training/job/status")
|
||||
async def get_training_job_status(self, job_uuid: str) -> Optional[PostTrainingJobStatusResponse]:
|
||||
job = self._scheduler.get_job(job_uuid)
|
||||
|
||||
match job.status:
|
||||
# TODO: Add support for other statuses to API
|
||||
case SchedulerJobStatus.new | SchedulerJobStatus.scheduled:
|
||||
status = JobStatus.scheduled
|
||||
case SchedulerJobStatus.running:
|
||||
status = JobStatus.in_progress
|
||||
case SchedulerJobStatus.completed:
|
||||
status = JobStatus.completed
|
||||
case SchedulerJobStatus.failed:
|
||||
status = JobStatus.failed
|
||||
case _:
|
||||
raise NotImplementedError()
|
||||
|
||||
return PostTrainingJobStatusResponse(
|
||||
job_uuid=job_uuid,
|
||||
status=status,
|
||||
scheduled_at=job.scheduled_at,
|
||||
started_at=job.started_at,
|
||||
completed_at=job.completed_at,
|
||||
checkpoints=self._get_checkpoints(job),
|
||||
resources_allocated=self._get_resources_allocated(job),
|
||||
)
|
||||
|
||||
@webmethod(route="/post-training/job/cancel")
|
||||
async def cancel_training_job(self, job_uuid: str) -> None:
|
||||
self._scheduler.cancel(job_uuid)
|
||||
|
||||
@webmethod(route="/post-training/job/artifacts")
|
||||
async def get_training_job_artifacts(self, job_uuid: str) -> Optional[PostTrainingJobArtifactsResponse]:
|
||||
job = self._scheduler.get_job(job_uuid)
|
||||
return PostTrainingJobArtifactsResponse(job_uuid=job_uuid, checkpoints=self._get_checkpoints(job))
|
||||
# Note: pause/resume/cancel are achieved as follows:
|
||||
# - POST with status=paused
|
||||
# - POST with status=resuming
|
||||
# - POST with status=cancelled
|
||||
|
|
|
|||
|
|
@ -33,11 +33,11 @@ from torchtune.training.lr_schedulers import get_cosine_schedule_with_warmup
|
|||
from torchtune.training.metric_logging import DiskLogger
|
||||
from tqdm import tqdm
|
||||
|
||||
from llama_stack.apis.common.job_types import JobArtifact
|
||||
from llama_stack.apis.common.training_types import PostTrainingMetric
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.post_training import (
|
||||
Checkpoint,
|
||||
DataConfig,
|
||||
EfficiencyConfig,
|
||||
LoraFinetuningConfig,
|
||||
|
|
@ -457,7 +457,7 @@ class LoraFinetuningSingleDevice:
|
|||
|
||||
return loss
|
||||
|
||||
async def train(self) -> Tuple[Dict[str, Any], List[Checkpoint]]:
|
||||
async def train(self) -> Tuple[Dict[str, Any], List[JobArtifact]]:
|
||||
"""
|
||||
The core training loop.
|
||||
"""
|
||||
|
|
@ -543,13 +543,18 @@ class LoraFinetuningSingleDevice:
|
|||
self.epochs_run += 1
|
||||
log.info("Starting checkpoint save...")
|
||||
checkpoint_path = await self.save_checkpoint(epoch=curr_epoch)
|
||||
checkpoint = Checkpoint(
|
||||
identifier=f"{self.model_id}-sft-{curr_epoch}",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
epoch=curr_epoch,
|
||||
post_training_job_id=self.job_uuid,
|
||||
path=checkpoint_path,
|
||||
|
||||
checkpoint = JobArtifact(
|
||||
name=f"{self.model_id}-sft-{curr_epoch}",
|
||||
type="checkpoint",
|
||||
# TODO: this should be exposed via /files instead
|
||||
uri=checkpoint_path,
|
||||
)
|
||||
|
||||
metadata = {
|
||||
"created_at": datetime.now(timezone.utc),
|
||||
"epoch": curr_epoch,
|
||||
}
|
||||
if self.training_config.data_config.validation_dataset_id:
|
||||
validation_loss, perplexity = await self.validation()
|
||||
training_metrics = PostTrainingMetric(
|
||||
|
|
@ -558,7 +563,9 @@ class LoraFinetuningSingleDevice:
|
|||
validation_loss=validation_loss,
|
||||
perplexity=perplexity,
|
||||
)
|
||||
checkpoint.training_metrics = training_metrics
|
||||
metadata["training_metrics"] = training_metrics
|
||||
checkpoint.metadata = metadata
|
||||
|
||||
checkpoints.append(checkpoint)
|
||||
|
||||
# clean up the memory after training finishes
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue