diff --git a/llama_stack/providers/inline/post_training/torchtune/post_training.py b/llama_stack/providers/inline/post_training/torchtune/post_training.py index 2c129ef41..cc1a6a5fe 100644 --- a/llama_stack/providers/inline/post_training/torchtune/post_training.py +++ b/llama_stack/providers/inline/post_training/torchtune/post_training.py @@ -3,13 +3,14 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from datetime import datetime, timezone +from enum import Enum from typing import Any, Dict, Optional from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasets import Datasets from llama_stack.apis.post_training import ( AlgorithmConfig, + Checkpoint, DPOAlignmentConfig, JobStatus, ListPostTrainingJobsResponse, @@ -25,9 +26,19 @@ from llama_stack.providers.inline.post_training.torchtune.config import ( from llama_stack.providers.inline.post_training.torchtune.recipes.lora_finetuning_single_device import ( LoraFinetuningSingleDevice, ) +from llama_stack.providers.utils.scheduler import JobArtifact, Scheduler +from llama_stack.providers.utils.scheduler import JobStatus as SchedulerJobStatus from llama_stack.schema_utils import webmethod +class TrainingArtifactType(Enum): + CHECKPOINT = "checkpoint" + RESOURCES_STATS = "resources_stats" + + +_JOB_TYPE_SUPERVISED_FINE_TUNE = "supervised-fine-tune" + + class TorchtunePostTrainingImpl: def __init__( self, @@ -38,13 +49,27 @@ class TorchtunePostTrainingImpl: self.config = config self.datasetio_api = datasetio_api self.datasets_api = datasets + self._scheduler = Scheduler() - # TODO: assume sync job, will need jobs API for async scheduling - self.jobs = {} - self.checkpoints_dict = {} + async def shutdown(self) -> None: + await self._scheduler.shutdown() - async def shutdown(self): - pass + @staticmethod + def _checkpoint_to_artifact(checkpoint: Checkpoint) -> JobArtifact: + return JobArtifact( + type=TrainingArtifactType.CHECKPOINT.value, + name=checkpoint.identifier, + uri=checkpoint.path, + metadata=dict(checkpoint), + ) + + @staticmethod + def _resources_stats_to_artifact(resources_stats: Dict[str, Any]) -> JobArtifact: + return JobArtifact( + type=TrainingArtifactType.RESOURCES_STATS.value, + name=TrainingArtifactType.RESOURCES_STATS.value, + metadata=resources_stats, + ) async def supervised_fine_tune( self, @@ -56,20 +81,11 @@ class TorchtunePostTrainingImpl: checkpoint_dir: Optional[str], algorithm_config: Optional[AlgorithmConfig], ) -> PostTrainingJob: - if job_uuid in self.jobs: - raise ValueError(f"Job {job_uuid} already exists") - - post_training_job = PostTrainingJob(job_uuid=job_uuid) - - job_status_response = PostTrainingJobStatusResponse( - job_uuid=job_uuid, - status=JobStatus.scheduled, - scheduled_at=datetime.now(timezone.utc), - ) - self.jobs[job_uuid] = job_status_response - if isinstance(algorithm_config, LoraFinetuningConfig): - try: + + async def handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb): + on_log_message_cb("Starting Lora finetuning") + recipe = LoraFinetuningSingleDevice( self.config, job_uuid, @@ -82,26 +98,22 @@ class TorchtunePostTrainingImpl: self.datasetio_api, self.datasets_api, ) - - job_status_response.status = JobStatus.in_progress - job_status_response.started_at = datetime.now(timezone.utc) - await recipe.setup() + resources_allocated, checkpoints = await recipe.train() - self.checkpoints_dict[job_uuid] = checkpoints - job_status_response.resources_allocated = resources_allocated - job_status_response.checkpoints = checkpoints - job_status_response.status = JobStatus.completed - job_status_response.completed_at = datetime.now(timezone.utc) + on_artifact_collected_cb(self._resources_stats_to_artifact(resources_allocated)) + for checkpoint in checkpoints: + artifact = self._checkpoint_to_artifact(checkpoint) + on_artifact_collected_cb(artifact) - except Exception: - job_status_response.status = JobStatus.failed - raise + on_status_change_cb(SchedulerJobStatus.completed) + on_log_message_cb("Lora finetuning completed") else: raise NotImplementedError() - return post_training_job + job_uuid = self._scheduler.schedule(_JOB_TYPE_SUPERVISED_FINE_TUNE, job_uuid, handler) + return PostTrainingJob(job_uuid=job_uuid) async def preference_optimize( self, @@ -114,19 +126,55 @@ class TorchtunePostTrainingImpl: ) -> PostTrainingJob: ... async def get_training_jobs(self) -> ListPostTrainingJobsResponse: - return ListPostTrainingJobsResponse(data=[PostTrainingJob(job_uuid=uuid_) for uuid_ in self.jobs]) + return ListPostTrainingJobsResponse( + data=[PostTrainingJob(job_uuid=job.id) for job in self._scheduler.get_jobs()] + ) + + @staticmethod + def _get_artifacts_metadata_by_type(job, artifact_type): + return [artifact.metadata for artifact in job.artifacts if artifact.type == artifact_type] + + @classmethod + def _get_checkpoints(cls, job): + return cls._get_artifacts_metadata_by_type(job, TrainingArtifactType.CHECKPOINT.value) + + @classmethod + def _get_resources_allocated(cls, job): + data = cls._get_artifacts_metadata_by_type(job, TrainingArtifactType.RESOURCES_STATS.value) + return data[0] if data else None @webmethod(route="/post-training/job/status") async def get_training_job_status(self, job_uuid: str) -> Optional[PostTrainingJobStatusResponse]: - return self.jobs.get(job_uuid, None) + job = self._scheduler.get_job(job_uuid) + + match job.status: + # TODO: Add support for other statuses to API + case SchedulerJobStatus.new | SchedulerJobStatus.scheduled: + status = JobStatus.scheduled + case SchedulerJobStatus.running: + status = JobStatus.in_progress + case SchedulerJobStatus.completed: + status = JobStatus.completed + case SchedulerJobStatus.failed: + status = JobStatus.failed + case _: + raise NotImplementedError() + + return PostTrainingJobStatusResponse( + job_uuid=job_uuid, + status=status, + scheduled_at=job.scheduled_at, + started_at=job.started_at, + completed_at=job.completed_at, + checkpoints=self._get_checkpoints(job), + resources_allocated=self._get_resources_allocated(job), + ) @webmethod(route="/post-training/job/cancel") async def cancel_training_job(self, job_uuid: str) -> None: - raise NotImplementedError("Job cancel is not implemented yet") + self._scheduler.cancel(job_uuid) @webmethod(route="/post-training/job/artifacts") async def get_training_job_artifacts(self, job_uuid: str) -> Optional[PostTrainingJobArtifactsResponse]: - if job_uuid in self.checkpoints_dict: - checkpoints = self.checkpoints_dict.get(job_uuid, []) - return PostTrainingJobArtifactsResponse(job_uuid=job_uuid, checkpoints=checkpoints) - return None + job = self._scheduler.get_job(job_uuid) + return PostTrainingJobArtifactsResponse(job_uuid=job_uuid, checkpoints=self._get_checkpoints(job)) diff --git a/llama_stack/providers/utils/scheduler.py b/llama_stack/providers/utils/scheduler.py new file mode 100644 index 000000000..d4cffe605 --- /dev/null +++ b/llama_stack/providers/utils/scheduler.py @@ -0,0 +1,265 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import abc +import asyncio +import functools +import threading +from datetime import datetime, timezone +from enum import Enum +from typing import Any, Callable, Coroutine, Dict, Iterable, Tuple, TypeAlias + +from pydantic import BaseModel + +from llama_stack.log import get_logger + +logger = get_logger(name=__name__, category="scheduler") + + +# TODO: revisit the list of possible statuses when defining a more coherent +# Jobs API for all API flows; e.g. do we need new vs scheduled? +class JobStatus(Enum): + new = "new" + scheduled = "scheduled" + running = "running" + failed = "failed" + completed = "completed" + + +JobID: TypeAlias = str +JobType: TypeAlias = str + + +class JobArtifact(BaseModel): + type: JobType + name: str + # TODO: uri should be a reference to /files API; revisit when /files is implemented + uri: str | None = None + metadata: Dict[str, Any] + + +JobHandler = Callable[ + [Callable[[str], None], Callable[[JobStatus], None], Callable[[JobArtifact], None]], Coroutine[Any, Any, None] +] + + +LogMessage: TypeAlias = Tuple[datetime, str] + + +_COMPLETED_STATUSES = {JobStatus.completed, JobStatus.failed} + + +class Job: + def __init__(self, job_type: JobType, job_id: JobID, handler: JobHandler): + super().__init__() + self.id = job_id + self._type = job_type + self._handler = handler + self._artifacts: list[JobArtifact] = [] + self._logs: list[LogMessage] = [] + self._state_transitions: list[Tuple[datetime, JobStatus]] = [(datetime.now(timezone.utc), JobStatus.new)] + + @property + def handler(self) -> JobHandler: + return self._handler + + @property + def status(self) -> JobStatus: + return self._state_transitions[-1][1] + + @status.setter + def status(self, status: JobStatus): + if status in _COMPLETED_STATUSES and self.status in _COMPLETED_STATUSES: + raise ValueError(f"Job is already in a completed state ({self.status})") + if self.status == status: + return + self._state_transitions.append((datetime.now(timezone.utc), status)) + + @property + def artifacts(self) -> list[JobArtifact]: + return self._artifacts + + def register_artifact(self, artifact: JobArtifact) -> None: + self._artifacts.append(artifact) + + def _find_state_transition_date(self, status: Iterable[JobStatus]) -> datetime | None: + for date, s in reversed(self._state_transitions): + if s in status: + return date + return None + + @property + def scheduled_at(self) -> datetime | None: + return self._find_state_transition_date([JobStatus.scheduled]) + + @property + def started_at(self) -> datetime | None: + return self._find_state_transition_date([JobStatus.running]) + + @property + def completed_at(self) -> datetime | None: + return self._find_state_transition_date(_COMPLETED_STATUSES) + + @property + def logs(self) -> list[LogMessage]: + return self._logs[:] + + def append_log(self, message: LogMessage) -> None: + self._logs.append(message) + + # TODO: implement + def cancel(self) -> None: + raise NotImplementedError + + +class _SchedulerBackend(abc.ABC): + @abc.abstractmethod + def on_log_message_cb(self, job: Job, message: LogMessage) -> None: + raise NotImplementedError + + @abc.abstractmethod + def on_status_change_cb(self, job: Job, status: JobStatus) -> None: + raise NotImplementedError + + @abc.abstractmethod + def on_artifact_collected_cb(self, job: Job, artifact: JobArtifact) -> None: + raise NotImplementedError + + @abc.abstractmethod + async def shutdown(self) -> None: + raise NotImplementedError + + @abc.abstractmethod + def schedule( + self, + job: Job, + on_log_message_cb: Callable[[str], None], + on_status_change_cb: Callable[[JobStatus], None], + on_artifact_collected_cb: Callable[[JobArtifact], None], + ) -> None: + raise NotImplementedError + + +class _NaiveSchedulerBackend(_SchedulerBackend): + def __init__(self, timeout: int = 5): + self._timeout = timeout + self._loop = asyncio.new_event_loop() + # There may be performance implications of using threads due to Python + # GIL; may need to measure if it's a real problem though + self._thread = threading.Thread(target=self._run_loop, daemon=True) + self._thread.start() + + def _run_loop(self) -> None: + asyncio.set_event_loop(self._loop) + self._loop.run_forever() + + # When stopping the loop, give tasks a chance to finish + # TODO: should we explicitly inform jobs of pending stoppage? + for task in asyncio.all_tasks(self._loop): + self._loop.run_until_complete(task) + self._loop.close() + + async def shutdown(self) -> None: + self._loop.call_soon_threadsafe(self._loop.stop) + self._thread.join() + + # TODO: decouple scheduling and running the job + def schedule( + self, + job: Job, + on_log_message_cb: Callable[[str], None], + on_status_change_cb: Callable[[JobStatus], None], + on_artifact_collected_cb: Callable[[JobArtifact], None], + ) -> None: + async def do(): + try: + job.status = JobStatus.running + await job.handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb) + except Exception as e: + on_log_message_cb(str(e)) + job.status = JobStatus.failed + logger.exception(f"Job {job.id} failed.") + + asyncio.run_coroutine_threadsafe(do(), self._loop) + + def on_log_message_cb(self, job: Job, message: LogMessage) -> None: + pass + + def on_status_change_cb(self, job: Job, status: JobStatus) -> None: + pass + + def on_artifact_collected_cb(self, job: Job, artifact: JobArtifact) -> None: + pass + + +_BACKENDS = { + "naive": _NaiveSchedulerBackend, +} + + +def _get_backend_impl(backend: str) -> _SchedulerBackend: + try: + return _BACKENDS[backend]() + except KeyError as e: + raise ValueError(f"Unknown backend {backend}") from e + + +class Scheduler: + def __init__(self, backend: str = "naive"): + # TODO: if server crashes, job states are lost; we need to persist jobs on disc + self._jobs: dict[JobID, Job] = {} + self._backend = _get_backend_impl(backend) + + def _on_log_message_cb(self, job: Job, message: str) -> None: + msg = (datetime.now(timezone.utc), message) + # At least for the time being, until there's a better way to expose + # logs to users, log messages on console + logger.info(f"Job {job.id}: {message}") + job.append_log(msg) + self._backend.on_log_message_cb(job, msg) + + def _on_status_change_cb(self, job: Job, status: JobStatus) -> None: + job.status = status + self._backend.on_status_change_cb(job, status) + + def _on_artifact_collected_cb(self, job: Job, artifact: JobArtifact) -> None: + job.register_artifact(artifact) + self._backend.on_artifact_collected_cb(job, artifact) + + def schedule(self, type_: JobType, job_id: JobID, handler: JobHandler) -> JobID: + job = Job(type_, job_id, handler) + if job.id in self._jobs: + raise ValueError(f"Job {job.id} already exists") + + self._jobs[job.id] = job + job.status = JobStatus.scheduled + self._backend.schedule( + job, + functools.partial(self._on_log_message_cb, job), + functools.partial(self._on_status_change_cb, job), + functools.partial(self._on_artifact_collected_cb, job), + ) + + return job.id + + def cancel(self, job_id: JobID) -> None: + self.get_job(job_id).cancel() + + def get_job(self, job_id: JobID) -> Job: + try: + return self._jobs[job_id] + except KeyError as e: + raise ValueError(f"Job {job_id} not found") from e + + def get_jobs(self, type_: JobType | None = None) -> list[Job]: + jobs = list(self._jobs.values()) + if type_: + jobs = [job for job in jobs if job._type == type_] + return jobs + + async def shutdown(self): + # TODO: also cancel jobs once implemented + await self._backend.shutdown() diff --git a/tests/unit/providers/utils/test_scheduler.py b/tests/unit/providers/utils/test_scheduler.py new file mode 100644 index 000000000..76f0da8ce --- /dev/null +++ b/tests/unit/providers/utils/test_scheduler.py @@ -0,0 +1,120 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import asyncio + +import pytest + +from llama_stack.providers.utils.scheduler import JobStatus, Scheduler + + +@pytest.mark.asyncio +async def test_scheduler_unknown_backend(): + with pytest.raises(ValueError): + Scheduler(backend="unknown") + + +@pytest.mark.asyncio +async def test_scheduler_naive(): + sched = Scheduler() + + # make sure the scheduler starts empty + with pytest.raises(ValueError): + sched.get_job("unknown") + assert sched.get_jobs() == [] + + called = False + + # schedule a job that will exercise the handlers + async def job_handler(on_log, on_status, on_artifact): + nonlocal called + called = True + # exercise the handlers + on_log("test log1") + on_log("test log2") + on_artifact({"type": "type1", "path": "path1"}) + on_artifact({"type": "type2", "path": "path2"}) + on_status(JobStatus.completed) + + job_id = "test_job_id" + job_type = "test_job_type" + sched.schedule(job_type, job_id, job_handler) + + # make sure the job was properly registered + with pytest.raises(ValueError): + sched.get_job("unknown") + assert sched.get_job(job_id) is not None + assert sched.get_jobs() == [sched.get_job(job_id)] + + assert sched.get_jobs("unknown") == [] + assert sched.get_jobs(job_type) == [sched.get_job(job_id)] + + # now shut the scheduler down and make sure the job ran + await sched.shutdown() + + assert called + + job = sched.get_job(job_id) + assert job is not None + + assert job.status == JobStatus.completed + + assert job.scheduled_at is not None + assert job.started_at is not None + assert job.completed_at is not None + assert job.scheduled_at < job.started_at < job.completed_at + + assert job.artifacts == [ + {"type": "type1", "path": "path1"}, + {"type": "type2", "path": "path2"}, + ] + assert [msg[1] for msg in job.logs] == ["test log1", "test log2"] + assert job.logs[0][0] < job.logs[1][0] + + +@pytest.mark.asyncio +async def test_scheduler_naive_handler_raises(): + sched = Scheduler() + + async def failing_job_handler(on_log, on_status, on_artifact): + on_status(JobStatus.running) + raise ValueError("test error") + + job_id = "test_job_id1" + job_type = "test_job_type" + sched.schedule(job_type, job_id, failing_job_handler) + + job = sched.get_job(job_id) + assert job is not None + + # confirm the exception made the job transition to failed state, even + # though it was set to `running` before the error + for _ in range(10): + if job.status == JobStatus.failed: + break + await asyncio.sleep(0.1) + assert job.status == JobStatus.failed + + # confirm that the raised error got registered in log + assert job.logs[0][1] == "test error" + + # even after failed job, we can schedule another one + called = False + + async def successful_job_handler(on_log, on_status, on_artifact): + nonlocal called + called = True + on_status(JobStatus.completed) + + job_id = "test_job_id2" + sched.schedule(job_type, job_id, successful_job_handler) + + await sched.shutdown() + + assert called + job = sched.get_job(job_id) + assert job is not None + assert job.status == JobStatus.completed