mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
# What does this PR do? Scheduler: cancel tasks on shutdown. Otherwise the currently running tasks will never exit (before they actually complete), which means the process can't be properly shut down (only with SIGKILL). Ideally, we let tasks know that they are about to shutdown and give them some time to do so; but in the lack of the mechanism, it's better to cancel than linger forever. [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) ## Test Plan Start a long running task (e.g. torchtune or external kfp-provider training). Ctr-C the process in TTY. Confirm it exits in reasonable time. ``` ^CINFO: Shutting down INFO: Waiting for application shutdown. 13:32:26.187 - INFO - Shutting down 13:32:26.187 - INFO - Shutting down DatasetsRoutingTable 13:32:26.187 - INFO - Shutting down DatasetIORouter 13:32:26.187 - INFO - Shutting down TorchtuneKFPPostTrainingImpl Traceback (most recent call last): File "/opt/homebrew/Cellar/python@3.12/3.12.4/Frameworks/Python.framework/Versions/3.12/lib/python3.12/asyncio/runners.py", line 118, in run return self._loop.run_until_complete(task) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/homebrew/Cellar/python@3.12/3.12.4/Frameworks/Python.framework/Versions/3.12/lib/python3.12/asyncio/base_events.py", line 687, in run_until_complete return future.result() ^^^^^^^^^^^^^^^ asyncio.exceptions.CancelledError During handling of the above exception, another exception occurred: Traceback (most recent call last): File "<frozen runpy>", line 198, in _run_module_as_main File "<frozen runpy>", line 88, in _run_code File "/Users/ihrachys/src/llama-stack-provider-kfp-trainer/.venv/lib/python3.12/site-packages/kfp/dsl/executor_main.py", line 109, in <module> executor_main() File "/Users/ihrachys/src/llama-stack-provider-kfp-trainer/.venv/lib/python3.12/site-packages/kfp/dsl/executor_main.py", line 101, in executor_main output_file = executor.execute() ^^^^^^^^^^^^^^^^^^ File "/Users/ihrachys/src/llama-stack-provider-kfp-trainer/.venv/lib/python3.12/site-packages/kfp/dsl/executor.py", line 361, in execute result = self.func(**func_kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^ File "/var/folders/45/1q1rx6cn7jbcn2ty852w0g_r0000gn/T/tmp.RKpPrvTWDD/ephemeral_component.py", line 118, in component asyncio.run(recipe.setup()) File "/opt/homebrew/Cellar/python@3.12/3.12.4/Frameworks/Python.framework/Versions/3.12/lib/python3.12/asyncio/runners.py", line 194, in run return runner.run(main) ^^^^^^^^^^^^^^^^ File "/opt/homebrew/Cellar/python@3.12/3.12.4/Frameworks/Python.framework/Versions/3.12/lib/python3.12/asyncio/runners.py", line 123, in run raise KeyboardInterrupt() KeyboardInterrupt 13:32:31.219 - ERROR - Task 'component' finished with status FAILURE ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ INFO 2025-05-09 13:32:31,221 llama_stack.providers.utils.scheduler:221 scheduler: Job test-jobc3c2e1e4-859c-4852-a41d-ef29e55e3efa: Pipeline [1m[95m'test-jobc3c2e1e4-859c-4852-a41d-ef29e55e3efa'[1m[0m finished with status [1m[91mFAILURE[1m[0m. Inner task failed: [1m[96m'component'[1m[0m. ERROR 2025-05-09 13:32:31,223 llama_stack_provider_kfp_trainer.scheduler:54 scheduler: Job test-jobc3c2e1e4-859c-4852-a41d-ef29e55e3efa failed. ╭───────────────────────────────────── Traceback (most recent call last) ─────────────────────────────────────╮ │ /Users/ihrachys/src/llama-stack-provider-kfp-trainer/src/llama_stack_provider_kfp_trainer/scheduler.py:45 │ │ in do │ │ │ │ 42 │ │ │ │ │ 43 │ │ │ job.status = JobStatus.running │ │ 44 │ │ │ try: │ │ ❱ 45 │ │ │ │ artifacts = self._to_artifacts(job.handler().output) │ │ 46 │ │ │ │ for artifact in artifacts: │ │ 47 │ │ │ │ │ on_artifact_collected_cb(artifact) │ │ 48 │ │ │ │ /Users/ihrachys/src/llama-stack-provider-kfp-trainer/.venv/lib/python3.12/site-packages/kfp/dsl/base_compon │ │ ent.py:101 in __call__ │ │ │ │ 98 │ │ │ │ f'{self.name}() missing {len(missing_arguments)} required ' │ │ 99 │ │ │ │ f'{argument_or_arguments}: {arguments}.') │ │ 100 │ │ │ │ ❱ 101 │ │ return pipeline_task.PipelineTask( │ │ 102 │ │ │ component_spec=self.component_spec, │ │ 103 │ │ │ args=task_inputs, │ │ 104 │ │ │ execute_locally=pipeline_context.Pipeline.get_default_pipeline() is │ │ │ │ /Users/ihrachys/src/llama-stack-provider-kfp-trainer/.venv/lib/python3.12/site-packages/kfp/dsl/pipeline_ta │ │ sk.py:187 in __init__ │ │ │ │ 184 │ │ ]) │ │ 185 │ │ │ │ 186 │ │ if execute_locally: │ │ ❱ 187 │ │ │ self._execute_locally(args=args) │ │ 188 │ │ │ 189 │ def _execute_locally(self, args: Dict[str, Any]) -> None: │ │ 190 │ │ """Execute the pipeline task locally. │ │ │ │ /Users/ihrachys/src/llama-stack-provider-kfp-trainer/.venv/lib/python3.12/site-packages/kfp/dsl/pipeline_ta │ │ sk.py:197 in _execute_locally │ │ │ │ 194 │ │ from kfp.local import task_dispatcher │ │ 195 │ │ │ │ 196 │ │ if self.pipeline_spec is not None: │ │ ❱ 197 │ │ │ self._outputs = pipeline_orchestrator.run_local_pipeline( │ │ 198 │ │ │ │ pipeline_spec=self.pipeline_spec, │ │ 199 │ │ │ │ arguments=args, │ │ 200 │ │ │ ) │ │ │ │ /Users/ihrachys/src/llama-stack-provider-kfp-trainer/.venv/lib/python3.12/site-packages/kfp/local/pipeline_ │ │ orchestrator.py:43 in run_local_pipeline │ │ │ │ 40 │ │ │ 41 │ # validate and access all global state in this function, not downstream │ │ 42 │ config.LocalExecutionConfig.validate() │ │ ❱ 43 │ return _run_local_pipeline_implementation( │ │ 44 │ │ pipeline_spec=pipeline_spec, │ │ 45 │ │ arguments=arguments, │ │ 46 │ │ raise_on_error=config.LocalExecutionConfig.instance.raise_on_error, │ │ │ │ /Users/ihrachys/src/llama-stack-provider-kfp-trainer/.venv/lib/python3.12/site-packages/kfp/local/pipeline_ │ │ orchestrator.py:108 in _run_local_pipeline_implementation │ │ │ │ 105 │ │ │ ) │ │ 106 │ │ return outputs │ │ 107 │ elif dag_status == status.Status.FAILURE: │ │ ❱ 108 │ │ log_and_maybe_raise_for_failure( │ │ 109 │ │ │ pipeline_name=pipeline_name, │ │ 110 │ │ │ fail_stack=fail_stack, │ │ 111 │ │ │ raise_on_error=raise_on_error, │ │ │ │ /Users/ihrachys/src/llama-stack-provider-kfp-trainer/.venv/lib/python3.12/site-packages/kfp/local/pipeline_ │ │ orchestrator.py:137 in log_and_maybe_raise_for_failure │ │ │ │ 134 │ │ logging_utils.format_task_name(task_name) for task_name in fail_stack) │ │ 135 │ msg = f'Pipeline {pipeline_name_with_color} finished with status │ │ {status_with_color}. Inner task failed: {task_chain_with_color}.' │ │ 136 │ if raise_on_error: │ │ ❱ 137 │ │ raise RuntimeError(msg) │ │ 138 │ with logging_utils.local_logger_context(): │ │ 139 │ │ logging.error(msg) │ │ 140 │ ╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ RuntimeError: Pipeline [1m[95m'test-jobc3c2e1e4-859c-4852-a41d-ef29e55e3efa'[1m[0m finished with status [1m[91mFAILURE[1m[0m. Inner task failed: [1m[96m'component'[1m[0m. INFO 2025-05-09 13:32:31,266 llama_stack.distribution.server.server:136 server: Shutting down DistributionInspectImpl INFO 2025-05-09 13:32:31,266 llama_stack.distribution.server.server:136 server: Shutting down ProviderImpl INFO: Application shutdown complete. INFO: Finished server process [26648] ``` [//]: # (## Documentation) Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
270 lines
8.3 KiB
Python
270 lines
8.3 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import abc
|
|
import asyncio
|
|
import functools
|
|
import threading
|
|
from collections.abc import Callable, Coroutine, Iterable
|
|
from datetime import UTC, datetime
|
|
from enum import Enum
|
|
from typing import Any, TypeAlias
|
|
|
|
from pydantic import BaseModel
|
|
|
|
from llama_stack.log import get_logger
|
|
|
|
logger = get_logger(name=__name__, category="scheduler")
|
|
|
|
|
|
# TODO: revisit the list of possible statuses when defining a more coherent
|
|
# Jobs API for all API flows; e.g. do we need new vs scheduled?
|
|
class JobStatus(Enum):
|
|
new = "new"
|
|
scheduled = "scheduled"
|
|
running = "running"
|
|
failed = "failed"
|
|
completed = "completed"
|
|
|
|
|
|
JobID: TypeAlias = str
|
|
JobType: TypeAlias = str
|
|
|
|
|
|
class JobArtifact(BaseModel):
|
|
type: JobType
|
|
name: str
|
|
# TODO: uri should be a reference to /files API; revisit when /files is implemented
|
|
uri: str | None = None
|
|
metadata: dict[str, Any]
|
|
|
|
|
|
JobHandler = Callable[
|
|
[Callable[[str], None], Callable[[JobStatus], None], Callable[[JobArtifact], None]], Coroutine[Any, Any, None]
|
|
]
|
|
|
|
|
|
LogMessage: TypeAlias = tuple[datetime, str]
|
|
|
|
|
|
_COMPLETED_STATUSES = {JobStatus.completed, JobStatus.failed}
|
|
|
|
|
|
class Job:
|
|
def __init__(self, job_type: JobType, job_id: JobID, handler: JobHandler):
|
|
super().__init__()
|
|
self.id = job_id
|
|
self._type = job_type
|
|
self._handler = handler
|
|
self._artifacts: list[JobArtifact] = []
|
|
self._logs: list[LogMessage] = []
|
|
self._state_transitions: list[tuple[datetime, JobStatus]] = [(datetime.now(UTC), JobStatus.new)]
|
|
|
|
@property
|
|
def handler(self) -> JobHandler:
|
|
return self._handler
|
|
|
|
@property
|
|
def status(self) -> JobStatus:
|
|
return self._state_transitions[-1][1]
|
|
|
|
@status.setter
|
|
def status(self, status: JobStatus):
|
|
if status in _COMPLETED_STATUSES and self.status in _COMPLETED_STATUSES:
|
|
raise ValueError(f"Job is already in a completed state ({self.status})")
|
|
if self.status == status:
|
|
return
|
|
self._state_transitions.append((datetime.now(UTC), status))
|
|
|
|
@property
|
|
def artifacts(self) -> list[JobArtifact]:
|
|
return self._artifacts
|
|
|
|
def register_artifact(self, artifact: JobArtifact) -> None:
|
|
self._artifacts.append(artifact)
|
|
|
|
def _find_state_transition_date(self, status: Iterable[JobStatus]) -> datetime | None:
|
|
for date, s in reversed(self._state_transitions):
|
|
if s in status:
|
|
return date
|
|
return None
|
|
|
|
@property
|
|
def scheduled_at(self) -> datetime | None:
|
|
return self._find_state_transition_date([JobStatus.scheduled])
|
|
|
|
@property
|
|
def started_at(self) -> datetime | None:
|
|
return self._find_state_transition_date([JobStatus.running])
|
|
|
|
@property
|
|
def completed_at(self) -> datetime | None:
|
|
return self._find_state_transition_date(_COMPLETED_STATUSES)
|
|
|
|
@property
|
|
def logs(self) -> list[LogMessage]:
|
|
return self._logs[:]
|
|
|
|
def append_log(self, message: LogMessage) -> None:
|
|
self._logs.append(message)
|
|
|
|
# TODO: implement
|
|
def cancel(self) -> None:
|
|
raise NotImplementedError
|
|
|
|
|
|
class _SchedulerBackend(abc.ABC):
|
|
@abc.abstractmethod
|
|
def on_log_message_cb(self, job: Job, message: LogMessage) -> None:
|
|
raise NotImplementedError
|
|
|
|
@abc.abstractmethod
|
|
def on_status_change_cb(self, job: Job, status: JobStatus) -> None:
|
|
raise NotImplementedError
|
|
|
|
@abc.abstractmethod
|
|
def on_artifact_collected_cb(self, job: Job, artifact: JobArtifact) -> None:
|
|
raise NotImplementedError
|
|
|
|
@abc.abstractmethod
|
|
async def shutdown(self) -> None:
|
|
raise NotImplementedError
|
|
|
|
@abc.abstractmethod
|
|
def schedule(
|
|
self,
|
|
job: Job,
|
|
on_log_message_cb: Callable[[str], None],
|
|
on_status_change_cb: Callable[[JobStatus], None],
|
|
on_artifact_collected_cb: Callable[[JobArtifact], None],
|
|
) -> None:
|
|
raise NotImplementedError
|
|
|
|
|
|
class _NaiveSchedulerBackend(_SchedulerBackend):
|
|
def __init__(self, timeout: int = 5):
|
|
self._timeout = timeout
|
|
self._loop = asyncio.new_event_loop()
|
|
# There may be performance implications of using threads due to Python
|
|
# GIL; may need to measure if it's a real problem though
|
|
self._thread = threading.Thread(target=self._run_loop, daemon=True)
|
|
self._thread.start()
|
|
|
|
def _run_loop(self) -> None:
|
|
asyncio.set_event_loop(self._loop)
|
|
self._loop.run_forever()
|
|
|
|
# TODO: When stopping the loop, give tasks a chance to finish
|
|
# TODO: should we explicitly inform jobs of pending stoppage?
|
|
|
|
# cancel all tasks
|
|
for task in asyncio.all_tasks(self._loop):
|
|
if not task.done():
|
|
task.cancel()
|
|
|
|
self._loop.close()
|
|
|
|
async def shutdown(self) -> None:
|
|
self._loop.call_soon_threadsafe(self._loop.stop)
|
|
self._thread.join()
|
|
|
|
# TODO: decouple scheduling and running the job
|
|
def schedule(
|
|
self,
|
|
job: Job,
|
|
on_log_message_cb: Callable[[str], None],
|
|
on_status_change_cb: Callable[[JobStatus], None],
|
|
on_artifact_collected_cb: Callable[[JobArtifact], None],
|
|
) -> None:
|
|
async def do():
|
|
try:
|
|
job.status = JobStatus.running
|
|
await job.handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb)
|
|
except Exception as e:
|
|
on_log_message_cb(str(e))
|
|
job.status = JobStatus.failed
|
|
logger.exception(f"Job {job.id} failed.")
|
|
|
|
asyncio.run_coroutine_threadsafe(do(), self._loop)
|
|
|
|
def on_log_message_cb(self, job: Job, message: LogMessage) -> None:
|
|
pass
|
|
|
|
def on_status_change_cb(self, job: Job, status: JobStatus) -> None:
|
|
pass
|
|
|
|
def on_artifact_collected_cb(self, job: Job, artifact: JobArtifact) -> None:
|
|
pass
|
|
|
|
|
|
_BACKENDS = {
|
|
"naive": _NaiveSchedulerBackend,
|
|
}
|
|
|
|
|
|
def _get_backend_impl(backend: str) -> _SchedulerBackend:
|
|
try:
|
|
return _BACKENDS[backend]()
|
|
except KeyError as e:
|
|
raise ValueError(f"Unknown backend {backend}") from e
|
|
|
|
|
|
class Scheduler:
|
|
def __init__(self, backend: str = "naive"):
|
|
# TODO: if server crashes, job states are lost; we need to persist jobs on disc
|
|
self._jobs: dict[JobID, Job] = {}
|
|
self._backend = _get_backend_impl(backend)
|
|
|
|
def _on_log_message_cb(self, job: Job, message: str) -> None:
|
|
msg = (datetime.now(UTC), message)
|
|
# At least for the time being, until there's a better way to expose
|
|
# logs to users, log messages on console
|
|
logger.info(f"Job {job.id}: {message}")
|
|
job.append_log(msg)
|
|
self._backend.on_log_message_cb(job, msg)
|
|
|
|
def _on_status_change_cb(self, job: Job, status: JobStatus) -> None:
|
|
job.status = status
|
|
self._backend.on_status_change_cb(job, status)
|
|
|
|
def _on_artifact_collected_cb(self, job: Job, artifact: JobArtifact) -> None:
|
|
job.register_artifact(artifact)
|
|
self._backend.on_artifact_collected_cb(job, artifact)
|
|
|
|
def schedule(self, type_: JobType, job_id: JobID, handler: JobHandler) -> JobID:
|
|
job = Job(type_, job_id, handler)
|
|
if job.id in self._jobs:
|
|
raise ValueError(f"Job {job.id} already exists")
|
|
|
|
self._jobs[job.id] = job
|
|
job.status = JobStatus.scheduled
|
|
self._backend.schedule(
|
|
job,
|
|
functools.partial(self._on_log_message_cb, job),
|
|
functools.partial(self._on_status_change_cb, job),
|
|
functools.partial(self._on_artifact_collected_cb, job),
|
|
)
|
|
|
|
return job.id
|
|
|
|
def cancel(self, job_id: JobID) -> None:
|
|
self.get_job(job_id).cancel()
|
|
|
|
def get_job(self, job_id: JobID) -> Job:
|
|
try:
|
|
return self._jobs[job_id]
|
|
except KeyError as e:
|
|
raise ValueError(f"Job {job_id} not found") from e
|
|
|
|
def get_jobs(self, type_: JobType | None = None) -> list[Job]:
|
|
jobs = list(self._jobs.values())
|
|
if type_:
|
|
jobs = [job for job in jobs if job._type == type_]
|
|
return jobs
|
|
|
|
async def shutdown(self):
|
|
# TODO: also cancel jobs once implemented
|
|
await self._backend.shutdown()
|