mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
feat: Implement async job execution for torchtune training (#1437)
# What does this PR do? Now a separate thread is started to execute training jobs. Training requests now return job ID before the job completes. (Which fixes API timeouts for any jobs that take longer than a minute.) Note: the scheduler code is meant to be spun out in the future into a common provider service that can be reused for different APIs and providers. It is also expected to back the /jobs API proposed here: https://github.com/meta-llama/llama-stack/discussions/1238 Hence its somewhat generalized form which is expected to simplify its adoption elsewhere in the future. Note: this patch doesn't attempt to implement missing APIs (e.g. cancel or job removal). This work will belong to follow-up PRs. [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) ## Test Plan [Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.*] Added unit tests for the scheduler module. For the API coverage, did manual testing and was able to run a training cycle on GPU. The initial call returned job ID before the training completed, as (now) expected. Artifacts are returned as expected. ``` JobArtifactsResponse(checkpoints=[{'identifier': 'meta-llama/Llama-3.2-3B-Instruct-sft-0', 'created_at': '2025-03-07T22:45:19.892714', 'epoch': 0, 'post_training_job_id': 'test-job2ee77104-2fd3-4a4e-84cf-f83f8b8f1f50', 'path': '/home/ec2-user/.llama/checkpoints/meta-llama/Llama-3.2-3B-Instruct-sft-0', 'training_metrics': None}], job_uuid='test-job2ee77104-2fd3-4a4e-84cf-f83f8b8f1f50') ``` The integration test is currently disabled for the provider. I will look into how it can be enabled in a different PR / issue context. [//]: # (## Documentation) Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
parent
7641a5cd0b
commit
3ed4316ed5
3 changed files with 472 additions and 39 deletions
|
@ -3,13 +3,14 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
from datetime import datetime, timezone
|
from enum import Enum
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
from llama_stack.apis.datasetio import DatasetIO
|
from llama_stack.apis.datasetio import DatasetIO
|
||||||
from llama_stack.apis.datasets import Datasets
|
from llama_stack.apis.datasets import Datasets
|
||||||
from llama_stack.apis.post_training import (
|
from llama_stack.apis.post_training import (
|
||||||
AlgorithmConfig,
|
AlgorithmConfig,
|
||||||
|
Checkpoint,
|
||||||
DPOAlignmentConfig,
|
DPOAlignmentConfig,
|
||||||
JobStatus,
|
JobStatus,
|
||||||
ListPostTrainingJobsResponse,
|
ListPostTrainingJobsResponse,
|
||||||
|
@ -25,9 +26,19 @@ from llama_stack.providers.inline.post_training.torchtune.config import (
|
||||||
from llama_stack.providers.inline.post_training.torchtune.recipes.lora_finetuning_single_device import (
|
from llama_stack.providers.inline.post_training.torchtune.recipes.lora_finetuning_single_device import (
|
||||||
LoraFinetuningSingleDevice,
|
LoraFinetuningSingleDevice,
|
||||||
)
|
)
|
||||||
|
from llama_stack.providers.utils.scheduler import JobArtifact, Scheduler
|
||||||
|
from llama_stack.providers.utils.scheduler import JobStatus as SchedulerJobStatus
|
||||||
from llama_stack.schema_utils import webmethod
|
from llama_stack.schema_utils import webmethod
|
||||||
|
|
||||||
|
|
||||||
|
class TrainingArtifactType(Enum):
|
||||||
|
CHECKPOINT = "checkpoint"
|
||||||
|
RESOURCES_STATS = "resources_stats"
|
||||||
|
|
||||||
|
|
||||||
|
_JOB_TYPE_SUPERVISED_FINE_TUNE = "supervised-fine-tune"
|
||||||
|
|
||||||
|
|
||||||
class TorchtunePostTrainingImpl:
|
class TorchtunePostTrainingImpl:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -38,13 +49,27 @@ class TorchtunePostTrainingImpl:
|
||||||
self.config = config
|
self.config = config
|
||||||
self.datasetio_api = datasetio_api
|
self.datasetio_api = datasetio_api
|
||||||
self.datasets_api = datasets
|
self.datasets_api = datasets
|
||||||
|
self._scheduler = Scheduler()
|
||||||
|
|
||||||
# TODO: assume sync job, will need jobs API for async scheduling
|
async def shutdown(self) -> None:
|
||||||
self.jobs = {}
|
await self._scheduler.shutdown()
|
||||||
self.checkpoints_dict = {}
|
|
||||||
|
|
||||||
async def shutdown(self):
|
@staticmethod
|
||||||
pass
|
def _checkpoint_to_artifact(checkpoint: Checkpoint) -> JobArtifact:
|
||||||
|
return JobArtifact(
|
||||||
|
type=TrainingArtifactType.CHECKPOINT.value,
|
||||||
|
name=checkpoint.identifier,
|
||||||
|
uri=checkpoint.path,
|
||||||
|
metadata=dict(checkpoint),
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _resources_stats_to_artifact(resources_stats: Dict[str, Any]) -> JobArtifact:
|
||||||
|
return JobArtifact(
|
||||||
|
type=TrainingArtifactType.RESOURCES_STATS.value,
|
||||||
|
name=TrainingArtifactType.RESOURCES_STATS.value,
|
||||||
|
metadata=resources_stats,
|
||||||
|
)
|
||||||
|
|
||||||
async def supervised_fine_tune(
|
async def supervised_fine_tune(
|
||||||
self,
|
self,
|
||||||
|
@ -56,20 +81,11 @@ class TorchtunePostTrainingImpl:
|
||||||
checkpoint_dir: Optional[str],
|
checkpoint_dir: Optional[str],
|
||||||
algorithm_config: Optional[AlgorithmConfig],
|
algorithm_config: Optional[AlgorithmConfig],
|
||||||
) -> PostTrainingJob:
|
) -> PostTrainingJob:
|
||||||
if job_uuid in self.jobs:
|
|
||||||
raise ValueError(f"Job {job_uuid} already exists")
|
|
||||||
|
|
||||||
post_training_job = PostTrainingJob(job_uuid=job_uuid)
|
|
||||||
|
|
||||||
job_status_response = PostTrainingJobStatusResponse(
|
|
||||||
job_uuid=job_uuid,
|
|
||||||
status=JobStatus.scheduled,
|
|
||||||
scheduled_at=datetime.now(timezone.utc),
|
|
||||||
)
|
|
||||||
self.jobs[job_uuid] = job_status_response
|
|
||||||
|
|
||||||
if isinstance(algorithm_config, LoraFinetuningConfig):
|
if isinstance(algorithm_config, LoraFinetuningConfig):
|
||||||
try:
|
|
||||||
|
async def handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb):
|
||||||
|
on_log_message_cb("Starting Lora finetuning")
|
||||||
|
|
||||||
recipe = LoraFinetuningSingleDevice(
|
recipe = LoraFinetuningSingleDevice(
|
||||||
self.config,
|
self.config,
|
||||||
job_uuid,
|
job_uuid,
|
||||||
|
@ -82,26 +98,22 @@ class TorchtunePostTrainingImpl:
|
||||||
self.datasetio_api,
|
self.datasetio_api,
|
||||||
self.datasets_api,
|
self.datasets_api,
|
||||||
)
|
)
|
||||||
|
|
||||||
job_status_response.status = JobStatus.in_progress
|
|
||||||
job_status_response.started_at = datetime.now(timezone.utc)
|
|
||||||
|
|
||||||
await recipe.setup()
|
await recipe.setup()
|
||||||
|
|
||||||
resources_allocated, checkpoints = await recipe.train()
|
resources_allocated, checkpoints = await recipe.train()
|
||||||
|
|
||||||
self.checkpoints_dict[job_uuid] = checkpoints
|
on_artifact_collected_cb(self._resources_stats_to_artifact(resources_allocated))
|
||||||
job_status_response.resources_allocated = resources_allocated
|
for checkpoint in checkpoints:
|
||||||
job_status_response.checkpoints = checkpoints
|
artifact = self._checkpoint_to_artifact(checkpoint)
|
||||||
job_status_response.status = JobStatus.completed
|
on_artifact_collected_cb(artifact)
|
||||||
job_status_response.completed_at = datetime.now(timezone.utc)
|
|
||||||
|
|
||||||
except Exception:
|
on_status_change_cb(SchedulerJobStatus.completed)
|
||||||
job_status_response.status = JobStatus.failed
|
on_log_message_cb("Lora finetuning completed")
|
||||||
raise
|
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
return post_training_job
|
job_uuid = self._scheduler.schedule(_JOB_TYPE_SUPERVISED_FINE_TUNE, job_uuid, handler)
|
||||||
|
return PostTrainingJob(job_uuid=job_uuid)
|
||||||
|
|
||||||
async def preference_optimize(
|
async def preference_optimize(
|
||||||
self,
|
self,
|
||||||
|
@ -114,19 +126,55 @@ class TorchtunePostTrainingImpl:
|
||||||
) -> PostTrainingJob: ...
|
) -> PostTrainingJob: ...
|
||||||
|
|
||||||
async def get_training_jobs(self) -> ListPostTrainingJobsResponse:
|
async def get_training_jobs(self) -> ListPostTrainingJobsResponse:
|
||||||
return ListPostTrainingJobsResponse(data=[PostTrainingJob(job_uuid=uuid_) for uuid_ in self.jobs])
|
return ListPostTrainingJobsResponse(
|
||||||
|
data=[PostTrainingJob(job_uuid=job.id) for job in self._scheduler.get_jobs()]
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_artifacts_metadata_by_type(job, artifact_type):
|
||||||
|
return [artifact.metadata for artifact in job.artifacts if artifact.type == artifact_type]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_checkpoints(cls, job):
|
||||||
|
return cls._get_artifacts_metadata_by_type(job, TrainingArtifactType.CHECKPOINT.value)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_resources_allocated(cls, job):
|
||||||
|
data = cls._get_artifacts_metadata_by_type(job, TrainingArtifactType.RESOURCES_STATS.value)
|
||||||
|
return data[0] if data else None
|
||||||
|
|
||||||
@webmethod(route="/post-training/job/status")
|
@webmethod(route="/post-training/job/status")
|
||||||
async def get_training_job_status(self, job_uuid: str) -> Optional[PostTrainingJobStatusResponse]:
|
async def get_training_job_status(self, job_uuid: str) -> Optional[PostTrainingJobStatusResponse]:
|
||||||
return self.jobs.get(job_uuid, None)
|
job = self._scheduler.get_job(job_uuid)
|
||||||
|
|
||||||
|
match job.status:
|
||||||
|
# TODO: Add support for other statuses to API
|
||||||
|
case SchedulerJobStatus.new | SchedulerJobStatus.scheduled:
|
||||||
|
status = JobStatus.scheduled
|
||||||
|
case SchedulerJobStatus.running:
|
||||||
|
status = JobStatus.in_progress
|
||||||
|
case SchedulerJobStatus.completed:
|
||||||
|
status = JobStatus.completed
|
||||||
|
case SchedulerJobStatus.failed:
|
||||||
|
status = JobStatus.failed
|
||||||
|
case _:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
return PostTrainingJobStatusResponse(
|
||||||
|
job_uuid=job_uuid,
|
||||||
|
status=status,
|
||||||
|
scheduled_at=job.scheduled_at,
|
||||||
|
started_at=job.started_at,
|
||||||
|
completed_at=job.completed_at,
|
||||||
|
checkpoints=self._get_checkpoints(job),
|
||||||
|
resources_allocated=self._get_resources_allocated(job),
|
||||||
|
)
|
||||||
|
|
||||||
@webmethod(route="/post-training/job/cancel")
|
@webmethod(route="/post-training/job/cancel")
|
||||||
async def cancel_training_job(self, job_uuid: str) -> None:
|
async def cancel_training_job(self, job_uuid: str) -> None:
|
||||||
raise NotImplementedError("Job cancel is not implemented yet")
|
self._scheduler.cancel(job_uuid)
|
||||||
|
|
||||||
@webmethod(route="/post-training/job/artifacts")
|
@webmethod(route="/post-training/job/artifacts")
|
||||||
async def get_training_job_artifacts(self, job_uuid: str) -> Optional[PostTrainingJobArtifactsResponse]:
|
async def get_training_job_artifacts(self, job_uuid: str) -> Optional[PostTrainingJobArtifactsResponse]:
|
||||||
if job_uuid in self.checkpoints_dict:
|
job = self._scheduler.get_job(job_uuid)
|
||||||
checkpoints = self.checkpoints_dict.get(job_uuid, [])
|
return PostTrainingJobArtifactsResponse(job_uuid=job_uuid, checkpoints=self._get_checkpoints(job))
|
||||||
return PostTrainingJobArtifactsResponse(job_uuid=job_uuid, checkpoints=checkpoints)
|
|
||||||
return None
|
|
||||||
|
|
265
llama_stack/providers/utils/scheduler.py
Normal file
265
llama_stack/providers/utils/scheduler.py
Normal file
|
@ -0,0 +1,265 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import abc
|
||||||
|
import asyncio
|
||||||
|
import functools
|
||||||
|
import threading
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Callable, Coroutine, Dict, Iterable, Tuple, TypeAlias
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(name=__name__, category="scheduler")
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: revisit the list of possible statuses when defining a more coherent
|
||||||
|
# Jobs API for all API flows; e.g. do we need new vs scheduled?
|
||||||
|
class JobStatus(Enum):
|
||||||
|
new = "new"
|
||||||
|
scheduled = "scheduled"
|
||||||
|
running = "running"
|
||||||
|
failed = "failed"
|
||||||
|
completed = "completed"
|
||||||
|
|
||||||
|
|
||||||
|
JobID: TypeAlias = str
|
||||||
|
JobType: TypeAlias = str
|
||||||
|
|
||||||
|
|
||||||
|
class JobArtifact(BaseModel):
|
||||||
|
type: JobType
|
||||||
|
name: str
|
||||||
|
# TODO: uri should be a reference to /files API; revisit when /files is implemented
|
||||||
|
uri: str | None = None
|
||||||
|
metadata: Dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
JobHandler = Callable[
|
||||||
|
[Callable[[str], None], Callable[[JobStatus], None], Callable[[JobArtifact], None]], Coroutine[Any, Any, None]
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
LogMessage: TypeAlias = Tuple[datetime, str]
|
||||||
|
|
||||||
|
|
||||||
|
_COMPLETED_STATUSES = {JobStatus.completed, JobStatus.failed}
|
||||||
|
|
||||||
|
|
||||||
|
class Job:
|
||||||
|
def __init__(self, job_type: JobType, job_id: JobID, handler: JobHandler):
|
||||||
|
super().__init__()
|
||||||
|
self.id = job_id
|
||||||
|
self._type = job_type
|
||||||
|
self._handler = handler
|
||||||
|
self._artifacts: list[JobArtifact] = []
|
||||||
|
self._logs: list[LogMessage] = []
|
||||||
|
self._state_transitions: list[Tuple[datetime, JobStatus]] = [(datetime.now(timezone.utc), JobStatus.new)]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def handler(self) -> JobHandler:
|
||||||
|
return self._handler
|
||||||
|
|
||||||
|
@property
|
||||||
|
def status(self) -> JobStatus:
|
||||||
|
return self._state_transitions[-1][1]
|
||||||
|
|
||||||
|
@status.setter
|
||||||
|
def status(self, status: JobStatus):
|
||||||
|
if status in _COMPLETED_STATUSES and self.status in _COMPLETED_STATUSES:
|
||||||
|
raise ValueError(f"Job is already in a completed state ({self.status})")
|
||||||
|
if self.status == status:
|
||||||
|
return
|
||||||
|
self._state_transitions.append((datetime.now(timezone.utc), status))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def artifacts(self) -> list[JobArtifact]:
|
||||||
|
return self._artifacts
|
||||||
|
|
||||||
|
def register_artifact(self, artifact: JobArtifact) -> None:
|
||||||
|
self._artifacts.append(artifact)
|
||||||
|
|
||||||
|
def _find_state_transition_date(self, status: Iterable[JobStatus]) -> datetime | None:
|
||||||
|
for date, s in reversed(self._state_transitions):
|
||||||
|
if s in status:
|
||||||
|
return date
|
||||||
|
return None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def scheduled_at(self) -> datetime | None:
|
||||||
|
return self._find_state_transition_date([JobStatus.scheduled])
|
||||||
|
|
||||||
|
@property
|
||||||
|
def started_at(self) -> datetime | None:
|
||||||
|
return self._find_state_transition_date([JobStatus.running])
|
||||||
|
|
||||||
|
@property
|
||||||
|
def completed_at(self) -> datetime | None:
|
||||||
|
return self._find_state_transition_date(_COMPLETED_STATUSES)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def logs(self) -> list[LogMessage]:
|
||||||
|
return self._logs[:]
|
||||||
|
|
||||||
|
def append_log(self, message: LogMessage) -> None:
|
||||||
|
self._logs.append(message)
|
||||||
|
|
||||||
|
# TODO: implement
|
||||||
|
def cancel(self) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class _SchedulerBackend(abc.ABC):
|
||||||
|
@abc.abstractmethod
|
||||||
|
def on_log_message_cb(self, job: Job, message: LogMessage) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def on_status_change_cb(self, job: Job, status: JobStatus) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def on_artifact_collected_cb(self, job: Job, artifact: JobArtifact) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def schedule(
|
||||||
|
self,
|
||||||
|
job: Job,
|
||||||
|
on_log_message_cb: Callable[[str], None],
|
||||||
|
on_status_change_cb: Callable[[JobStatus], None],
|
||||||
|
on_artifact_collected_cb: Callable[[JobArtifact], None],
|
||||||
|
) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class _NaiveSchedulerBackend(_SchedulerBackend):
|
||||||
|
def __init__(self, timeout: int = 5):
|
||||||
|
self._timeout = timeout
|
||||||
|
self._loop = asyncio.new_event_loop()
|
||||||
|
# There may be performance implications of using threads due to Python
|
||||||
|
# GIL; may need to measure if it's a real problem though
|
||||||
|
self._thread = threading.Thread(target=self._run_loop, daemon=True)
|
||||||
|
self._thread.start()
|
||||||
|
|
||||||
|
def _run_loop(self) -> None:
|
||||||
|
asyncio.set_event_loop(self._loop)
|
||||||
|
self._loop.run_forever()
|
||||||
|
|
||||||
|
# When stopping the loop, give tasks a chance to finish
|
||||||
|
# TODO: should we explicitly inform jobs of pending stoppage?
|
||||||
|
for task in asyncio.all_tasks(self._loop):
|
||||||
|
self._loop.run_until_complete(task)
|
||||||
|
self._loop.close()
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
self._loop.call_soon_threadsafe(self._loop.stop)
|
||||||
|
self._thread.join()
|
||||||
|
|
||||||
|
# TODO: decouple scheduling and running the job
|
||||||
|
def schedule(
|
||||||
|
self,
|
||||||
|
job: Job,
|
||||||
|
on_log_message_cb: Callable[[str], None],
|
||||||
|
on_status_change_cb: Callable[[JobStatus], None],
|
||||||
|
on_artifact_collected_cb: Callable[[JobArtifact], None],
|
||||||
|
) -> None:
|
||||||
|
async def do():
|
||||||
|
try:
|
||||||
|
job.status = JobStatus.running
|
||||||
|
await job.handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb)
|
||||||
|
except Exception as e:
|
||||||
|
on_log_message_cb(str(e))
|
||||||
|
job.status = JobStatus.failed
|
||||||
|
logger.exception(f"Job {job.id} failed.")
|
||||||
|
|
||||||
|
asyncio.run_coroutine_threadsafe(do(), self._loop)
|
||||||
|
|
||||||
|
def on_log_message_cb(self, job: Job, message: LogMessage) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_status_change_cb(self, job: Job, status: JobStatus) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_artifact_collected_cb(self, job: Job, artifact: JobArtifact) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
_BACKENDS = {
|
||||||
|
"naive": _NaiveSchedulerBackend,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _get_backend_impl(backend: str) -> _SchedulerBackend:
|
||||||
|
try:
|
||||||
|
return _BACKENDS[backend]()
|
||||||
|
except KeyError as e:
|
||||||
|
raise ValueError(f"Unknown backend {backend}") from e
|
||||||
|
|
||||||
|
|
||||||
|
class Scheduler:
|
||||||
|
def __init__(self, backend: str = "naive"):
|
||||||
|
# TODO: if server crashes, job states are lost; we need to persist jobs on disc
|
||||||
|
self._jobs: dict[JobID, Job] = {}
|
||||||
|
self._backend = _get_backend_impl(backend)
|
||||||
|
|
||||||
|
def _on_log_message_cb(self, job: Job, message: str) -> None:
|
||||||
|
msg = (datetime.now(timezone.utc), message)
|
||||||
|
# At least for the time being, until there's a better way to expose
|
||||||
|
# logs to users, log messages on console
|
||||||
|
logger.info(f"Job {job.id}: {message}")
|
||||||
|
job.append_log(msg)
|
||||||
|
self._backend.on_log_message_cb(job, msg)
|
||||||
|
|
||||||
|
def _on_status_change_cb(self, job: Job, status: JobStatus) -> None:
|
||||||
|
job.status = status
|
||||||
|
self._backend.on_status_change_cb(job, status)
|
||||||
|
|
||||||
|
def _on_artifact_collected_cb(self, job: Job, artifact: JobArtifact) -> None:
|
||||||
|
job.register_artifact(artifact)
|
||||||
|
self._backend.on_artifact_collected_cb(job, artifact)
|
||||||
|
|
||||||
|
def schedule(self, type_: JobType, job_id: JobID, handler: JobHandler) -> JobID:
|
||||||
|
job = Job(type_, job_id, handler)
|
||||||
|
if job.id in self._jobs:
|
||||||
|
raise ValueError(f"Job {job.id} already exists")
|
||||||
|
|
||||||
|
self._jobs[job.id] = job
|
||||||
|
job.status = JobStatus.scheduled
|
||||||
|
self._backend.schedule(
|
||||||
|
job,
|
||||||
|
functools.partial(self._on_log_message_cb, job),
|
||||||
|
functools.partial(self._on_status_change_cb, job),
|
||||||
|
functools.partial(self._on_artifact_collected_cb, job),
|
||||||
|
)
|
||||||
|
|
||||||
|
return job.id
|
||||||
|
|
||||||
|
def cancel(self, job_id: JobID) -> None:
|
||||||
|
self.get_job(job_id).cancel()
|
||||||
|
|
||||||
|
def get_job(self, job_id: JobID) -> Job:
|
||||||
|
try:
|
||||||
|
return self._jobs[job_id]
|
||||||
|
except KeyError as e:
|
||||||
|
raise ValueError(f"Job {job_id} not found") from e
|
||||||
|
|
||||||
|
def get_jobs(self, type_: JobType | None = None) -> list[Job]:
|
||||||
|
jobs = list(self._jobs.values())
|
||||||
|
if type_:
|
||||||
|
jobs = [job for job in jobs if job._type == type_]
|
||||||
|
return jobs
|
||||||
|
|
||||||
|
async def shutdown(self):
|
||||||
|
# TODO: also cancel jobs once implemented
|
||||||
|
await self._backend.shutdown()
|
120
tests/unit/providers/utils/test_scheduler.py
Normal file
120
tests/unit/providers/utils/test_scheduler.py
Normal file
|
@ -0,0 +1,120 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from llama_stack.providers.utils.scheduler import JobStatus, Scheduler
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_scheduler_unknown_backend():
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
Scheduler(backend="unknown")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_scheduler_naive():
|
||||||
|
sched = Scheduler()
|
||||||
|
|
||||||
|
# make sure the scheduler starts empty
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
sched.get_job("unknown")
|
||||||
|
assert sched.get_jobs() == []
|
||||||
|
|
||||||
|
called = False
|
||||||
|
|
||||||
|
# schedule a job that will exercise the handlers
|
||||||
|
async def job_handler(on_log, on_status, on_artifact):
|
||||||
|
nonlocal called
|
||||||
|
called = True
|
||||||
|
# exercise the handlers
|
||||||
|
on_log("test log1")
|
||||||
|
on_log("test log2")
|
||||||
|
on_artifact({"type": "type1", "path": "path1"})
|
||||||
|
on_artifact({"type": "type2", "path": "path2"})
|
||||||
|
on_status(JobStatus.completed)
|
||||||
|
|
||||||
|
job_id = "test_job_id"
|
||||||
|
job_type = "test_job_type"
|
||||||
|
sched.schedule(job_type, job_id, job_handler)
|
||||||
|
|
||||||
|
# make sure the job was properly registered
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
sched.get_job("unknown")
|
||||||
|
assert sched.get_job(job_id) is not None
|
||||||
|
assert sched.get_jobs() == [sched.get_job(job_id)]
|
||||||
|
|
||||||
|
assert sched.get_jobs("unknown") == []
|
||||||
|
assert sched.get_jobs(job_type) == [sched.get_job(job_id)]
|
||||||
|
|
||||||
|
# now shut the scheduler down and make sure the job ran
|
||||||
|
await sched.shutdown()
|
||||||
|
|
||||||
|
assert called
|
||||||
|
|
||||||
|
job = sched.get_job(job_id)
|
||||||
|
assert job is not None
|
||||||
|
|
||||||
|
assert job.status == JobStatus.completed
|
||||||
|
|
||||||
|
assert job.scheduled_at is not None
|
||||||
|
assert job.started_at is not None
|
||||||
|
assert job.completed_at is not None
|
||||||
|
assert job.scheduled_at < job.started_at < job.completed_at
|
||||||
|
|
||||||
|
assert job.artifacts == [
|
||||||
|
{"type": "type1", "path": "path1"},
|
||||||
|
{"type": "type2", "path": "path2"},
|
||||||
|
]
|
||||||
|
assert [msg[1] for msg in job.logs] == ["test log1", "test log2"]
|
||||||
|
assert job.logs[0][0] < job.logs[1][0]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_scheduler_naive_handler_raises():
|
||||||
|
sched = Scheduler()
|
||||||
|
|
||||||
|
async def failing_job_handler(on_log, on_status, on_artifact):
|
||||||
|
on_status(JobStatus.running)
|
||||||
|
raise ValueError("test error")
|
||||||
|
|
||||||
|
job_id = "test_job_id1"
|
||||||
|
job_type = "test_job_type"
|
||||||
|
sched.schedule(job_type, job_id, failing_job_handler)
|
||||||
|
|
||||||
|
job = sched.get_job(job_id)
|
||||||
|
assert job is not None
|
||||||
|
|
||||||
|
# confirm the exception made the job transition to failed state, even
|
||||||
|
# though it was set to `running` before the error
|
||||||
|
for _ in range(10):
|
||||||
|
if job.status == JobStatus.failed:
|
||||||
|
break
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
assert job.status == JobStatus.failed
|
||||||
|
|
||||||
|
# confirm that the raised error got registered in log
|
||||||
|
assert job.logs[0][1] == "test error"
|
||||||
|
|
||||||
|
# even after failed job, we can schedule another one
|
||||||
|
called = False
|
||||||
|
|
||||||
|
async def successful_job_handler(on_log, on_status, on_artifact):
|
||||||
|
nonlocal called
|
||||||
|
called = True
|
||||||
|
on_status(JobStatus.completed)
|
||||||
|
|
||||||
|
job_id = "test_job_id2"
|
||||||
|
sched.schedule(job_type, job_id, successful_job_handler)
|
||||||
|
|
||||||
|
await sched.shutdown()
|
||||||
|
|
||||||
|
assert called
|
||||||
|
job = sched.get_job(job_id)
|
||||||
|
assert job is not None
|
||||||
|
assert job.status == JobStatus.completed
|
Loading…
Add table
Add a link
Reference in a new issue