forked from phoenix-oss/llama-stack-mirror
		
	# What does this PR do? Now a separate thread is started to execute training jobs. Training requests now return job ID before the job completes. (Which fixes API timeouts for any jobs that take longer than a minute.) Note: the scheduler code is meant to be spun out in the future into a common provider service that can be reused for different APIs and providers. It is also expected to back the /jobs API proposed here: https://github.com/meta-llama/llama-stack/discussions/1238 Hence its somewhat generalized form which is expected to simplify its adoption elsewhere in the future. Note: this patch doesn't attempt to implement missing APIs (e.g. cancel or job removal). This work will belong to follow-up PRs. [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) ## Test Plan [Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.*] Added unit tests for the scheduler module. For the API coverage, did manual testing and was able to run a training cycle on GPU. The initial call returned job ID before the training completed, as (now) expected. Artifacts are returned as expected. ``` JobArtifactsResponse(checkpoints=[{'identifier': 'meta-llama/Llama-3.2-3B-Instruct-sft-0', 'created_at': '2025-03-07T22:45:19.892714', 'epoch': 0, 'post_training_job_id': 'test-job2ee77104-2fd3-4a4e-84cf-f83f8b8f1f50', 'path': '/home/ec2-user/.llama/checkpoints/meta-llama/Llama-3.2-3B-Instruct-sft-0', 'training_metrics': None}], job_uuid='test-job2ee77104-2fd3-4a4e-84cf-f83f8b8f1f50') ``` The integration test is currently disabled for the provider. I will look into how it can be enabled in a different PR / issue context. [//]: # (## Documentation) Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
		
			
				
	
	
		
			265 lines
		
	
	
	
		
			8.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			265 lines
		
	
	
	
		
			8.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright (c) Meta Platforms, Inc. and affiliates.
 | |
| # All rights reserved.
 | |
| #
 | |
| # This source code is licensed under the terms described in the LICENSE file in
 | |
| # the root directory of this source tree.
 | |
| 
 | |
| import abc
 | |
| import asyncio
 | |
| import functools
 | |
| import threading
 | |
| from datetime import datetime, timezone
 | |
| from enum import Enum
 | |
| from typing import Any, Callable, Coroutine, Dict, Iterable, Tuple, TypeAlias
 | |
| 
 | |
| from pydantic import BaseModel
 | |
| 
 | |
| from llama_stack.log import get_logger
 | |
| 
 | |
| logger = get_logger(name=__name__, category="scheduler")
 | |
| 
 | |
| 
 | |
| # TODO: revisit the list of possible statuses when defining a more coherent
 | |
| # Jobs API for all API flows; e.g. do we need new vs scheduled?
 | |
| class JobStatus(Enum):
 | |
|     new = "new"
 | |
|     scheduled = "scheduled"
 | |
|     running = "running"
 | |
|     failed = "failed"
 | |
|     completed = "completed"
 | |
| 
 | |
| 
 | |
| JobID: TypeAlias = str
 | |
| JobType: TypeAlias = str
 | |
| 
 | |
| 
 | |
| class JobArtifact(BaseModel):
 | |
|     type: JobType
 | |
|     name: str
 | |
|     # TODO: uri should be a reference to /files API; revisit when /files is implemented
 | |
|     uri: str | None = None
 | |
|     metadata: Dict[str, Any]
 | |
| 
 | |
| 
 | |
| JobHandler = Callable[
 | |
|     [Callable[[str], None], Callable[[JobStatus], None], Callable[[JobArtifact], None]], Coroutine[Any, Any, None]
 | |
| ]
 | |
| 
 | |
| 
 | |
| LogMessage: TypeAlias = Tuple[datetime, str]
 | |
| 
 | |
| 
 | |
| _COMPLETED_STATUSES = {JobStatus.completed, JobStatus.failed}
 | |
| 
 | |
| 
 | |
| class Job:
 | |
|     def __init__(self, job_type: JobType, job_id: JobID, handler: JobHandler):
 | |
|         super().__init__()
 | |
|         self.id = job_id
 | |
|         self._type = job_type
 | |
|         self._handler = handler
 | |
|         self._artifacts: list[JobArtifact] = []
 | |
|         self._logs: list[LogMessage] = []
 | |
|         self._state_transitions: list[Tuple[datetime, JobStatus]] = [(datetime.now(timezone.utc), JobStatus.new)]
 | |
| 
 | |
|     @property
 | |
|     def handler(self) -> JobHandler:
 | |
|         return self._handler
 | |
| 
 | |
|     @property
 | |
|     def status(self) -> JobStatus:
 | |
|         return self._state_transitions[-1][1]
 | |
| 
 | |
|     @status.setter
 | |
|     def status(self, status: JobStatus):
 | |
|         if status in _COMPLETED_STATUSES and self.status in _COMPLETED_STATUSES:
 | |
|             raise ValueError(f"Job is already in a completed state ({self.status})")
 | |
|         if self.status == status:
 | |
|             return
 | |
|         self._state_transitions.append((datetime.now(timezone.utc), status))
 | |
| 
 | |
|     @property
 | |
|     def artifacts(self) -> list[JobArtifact]:
 | |
|         return self._artifacts
 | |
| 
 | |
|     def register_artifact(self, artifact: JobArtifact) -> None:
 | |
|         self._artifacts.append(artifact)
 | |
| 
 | |
|     def _find_state_transition_date(self, status: Iterable[JobStatus]) -> datetime | None:
 | |
|         for date, s in reversed(self._state_transitions):
 | |
|             if s in status:
 | |
|                 return date
 | |
|         return None
 | |
| 
 | |
|     @property
 | |
|     def scheduled_at(self) -> datetime | None:
 | |
|         return self._find_state_transition_date([JobStatus.scheduled])
 | |
| 
 | |
|     @property
 | |
|     def started_at(self) -> datetime | None:
 | |
|         return self._find_state_transition_date([JobStatus.running])
 | |
| 
 | |
|     @property
 | |
|     def completed_at(self) -> datetime | None:
 | |
|         return self._find_state_transition_date(_COMPLETED_STATUSES)
 | |
| 
 | |
|     @property
 | |
|     def logs(self) -> list[LogMessage]:
 | |
|         return self._logs[:]
 | |
| 
 | |
|     def append_log(self, message: LogMessage) -> None:
 | |
|         self._logs.append(message)
 | |
| 
 | |
|     # TODO: implement
 | |
|     def cancel(self) -> None:
 | |
|         raise NotImplementedError
 | |
| 
 | |
| 
 | |
| class _SchedulerBackend(abc.ABC):
 | |
|     @abc.abstractmethod
 | |
|     def on_log_message_cb(self, job: Job, message: LogMessage) -> None:
 | |
|         raise NotImplementedError
 | |
| 
 | |
|     @abc.abstractmethod
 | |
|     def on_status_change_cb(self, job: Job, status: JobStatus) -> None:
 | |
|         raise NotImplementedError
 | |
| 
 | |
|     @abc.abstractmethod
 | |
|     def on_artifact_collected_cb(self, job: Job, artifact: JobArtifact) -> None:
 | |
|         raise NotImplementedError
 | |
| 
 | |
|     @abc.abstractmethod
 | |
|     async def shutdown(self) -> None:
 | |
|         raise NotImplementedError
 | |
| 
 | |
|     @abc.abstractmethod
 | |
|     def schedule(
 | |
|         self,
 | |
|         job: Job,
 | |
|         on_log_message_cb: Callable[[str], None],
 | |
|         on_status_change_cb: Callable[[JobStatus], None],
 | |
|         on_artifact_collected_cb: Callable[[JobArtifact], None],
 | |
|     ) -> None:
 | |
|         raise NotImplementedError
 | |
| 
 | |
| 
 | |
| class _NaiveSchedulerBackend(_SchedulerBackend):
 | |
|     def __init__(self, timeout: int = 5):
 | |
|         self._timeout = timeout
 | |
|         self._loop = asyncio.new_event_loop()
 | |
|         # There may be performance implications of using threads due to Python
 | |
|         # GIL; may need to measure if it's a real problem though
 | |
|         self._thread = threading.Thread(target=self._run_loop, daemon=True)
 | |
|         self._thread.start()
 | |
| 
 | |
|     def _run_loop(self) -> None:
 | |
|         asyncio.set_event_loop(self._loop)
 | |
|         self._loop.run_forever()
 | |
| 
 | |
|         # When stopping the loop, give tasks a chance to finish
 | |
|         # TODO: should we explicitly inform jobs of pending stoppage?
 | |
|         for task in asyncio.all_tasks(self._loop):
 | |
|             self._loop.run_until_complete(task)
 | |
|         self._loop.close()
 | |
| 
 | |
|     async def shutdown(self) -> None:
 | |
|         self._loop.call_soon_threadsafe(self._loop.stop)
 | |
|         self._thread.join()
 | |
| 
 | |
|     # TODO: decouple scheduling and running the job
 | |
|     def schedule(
 | |
|         self,
 | |
|         job: Job,
 | |
|         on_log_message_cb: Callable[[str], None],
 | |
|         on_status_change_cb: Callable[[JobStatus], None],
 | |
|         on_artifact_collected_cb: Callable[[JobArtifact], None],
 | |
|     ) -> None:
 | |
|         async def do():
 | |
|             try:
 | |
|                 job.status = JobStatus.running
 | |
|                 await job.handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb)
 | |
|             except Exception as e:
 | |
|                 on_log_message_cb(str(e))
 | |
|                 job.status = JobStatus.failed
 | |
|                 logger.exception(f"Job {job.id} failed.")
 | |
| 
 | |
|         asyncio.run_coroutine_threadsafe(do(), self._loop)
 | |
| 
 | |
|     def on_log_message_cb(self, job: Job, message: LogMessage) -> None:
 | |
|         pass
 | |
| 
 | |
|     def on_status_change_cb(self, job: Job, status: JobStatus) -> None:
 | |
|         pass
 | |
| 
 | |
|     def on_artifact_collected_cb(self, job: Job, artifact: JobArtifact) -> None:
 | |
|         pass
 | |
| 
 | |
| 
 | |
| _BACKENDS = {
 | |
|     "naive": _NaiveSchedulerBackend,
 | |
| }
 | |
| 
 | |
| 
 | |
| def _get_backend_impl(backend: str) -> _SchedulerBackend:
 | |
|     try:
 | |
|         return _BACKENDS[backend]()
 | |
|     except KeyError as e:
 | |
|         raise ValueError(f"Unknown backend {backend}") from e
 | |
| 
 | |
| 
 | |
| class Scheduler:
 | |
|     def __init__(self, backend: str = "naive"):
 | |
|         # TODO: if server crashes, job states are lost; we need to persist jobs on disc
 | |
|         self._jobs: dict[JobID, Job] = {}
 | |
|         self._backend = _get_backend_impl(backend)
 | |
| 
 | |
|     def _on_log_message_cb(self, job: Job, message: str) -> None:
 | |
|         msg = (datetime.now(timezone.utc), message)
 | |
|         # At least for the time being, until there's a better way to expose
 | |
|         # logs to users, log messages on console
 | |
|         logger.info(f"Job {job.id}: {message}")
 | |
|         job.append_log(msg)
 | |
|         self._backend.on_log_message_cb(job, msg)
 | |
| 
 | |
|     def _on_status_change_cb(self, job: Job, status: JobStatus) -> None:
 | |
|         job.status = status
 | |
|         self._backend.on_status_change_cb(job, status)
 | |
| 
 | |
|     def _on_artifact_collected_cb(self, job: Job, artifact: JobArtifact) -> None:
 | |
|         job.register_artifact(artifact)
 | |
|         self._backend.on_artifact_collected_cb(job, artifact)
 | |
| 
 | |
|     def schedule(self, type_: JobType, job_id: JobID, handler: JobHandler) -> JobID:
 | |
|         job = Job(type_, job_id, handler)
 | |
|         if job.id in self._jobs:
 | |
|             raise ValueError(f"Job {job.id} already exists")
 | |
| 
 | |
|         self._jobs[job.id] = job
 | |
|         job.status = JobStatus.scheduled
 | |
|         self._backend.schedule(
 | |
|             job,
 | |
|             functools.partial(self._on_log_message_cb, job),
 | |
|             functools.partial(self._on_status_change_cb, job),
 | |
|             functools.partial(self._on_artifact_collected_cb, job),
 | |
|         )
 | |
| 
 | |
|         return job.id
 | |
| 
 | |
|     def cancel(self, job_id: JobID) -> None:
 | |
|         self.get_job(job_id).cancel()
 | |
| 
 | |
|     def get_job(self, job_id: JobID) -> Job:
 | |
|         try:
 | |
|             return self._jobs[job_id]
 | |
|         except KeyError as e:
 | |
|             raise ValueError(f"Job {job_id} not found") from e
 | |
| 
 | |
|     def get_jobs(self, type_: JobType | None = None) -> list[Job]:
 | |
|         jobs = list(self._jobs.values())
 | |
|         if type_:
 | |
|             jobs = [job for job in jobs if job._type == type_]
 | |
|         return jobs
 | |
| 
 | |
|     async def shutdown(self):
 | |
|         # TODO: also cancel jobs once implemented
 | |
|         await self._backend.shutdown()
 |