llama-stack-mirror/llama_stack/providers/utils/scheduler.py
Charlie Doern d12f195f56
feat: drop python 3.10 support (#2469)
# What does this PR do?

dropped python3.10, updated pyproject and dependencies, and also removed
some blocks of code with special handling for enum.StrEnum

Closes #2458

Signed-off-by: Charlie Doern <cdoern@redhat.com>
2025-06-19 12:07:14 +05:30

266 lines
8.2 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import abc
import asyncio
import functools
import threading
from collections.abc import Callable, Coroutine, Iterable
from datetime import UTC, datetime
from enum import Enum
from typing import Any, TypeAlias
from pydantic import BaseModel
from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="scheduler")
# TODO: revisit the list of possible statuses when defining a more coherent
# Jobs API for all API flows; e.g. do we need new vs scheduled?
class JobStatus(Enum):
new = "new"
scheduled = "scheduled"
running = "running"
failed = "failed"
completed = "completed"
JobID: TypeAlias = str
JobType: TypeAlias = str
class JobArtifact(BaseModel):
type: JobType
name: str
# TODO: uri should be a reference to /files API; revisit when /files is implemented
uri: str | None = None
metadata: dict[str, Any]
JobHandler = Callable[
[Callable[[str], None], Callable[[JobStatus], None], Callable[[JobArtifact], None]], Coroutine[Any, Any, None]
]
LogMessage: TypeAlias = tuple[datetime, str]
_COMPLETED_STATUSES = {JobStatus.completed, JobStatus.failed}
class Job:
def __init__(self, job_type: JobType, job_id: JobID, handler: JobHandler):
super().__init__()
self.id = job_id
self._type = job_type
self._handler = handler
self._artifacts: list[JobArtifact] = []
self._logs: list[LogMessage] = []
self._state_transitions: list[tuple[datetime, JobStatus]] = [(datetime.now(UTC), JobStatus.new)]
@property
def handler(self) -> JobHandler:
return self._handler
@property
def status(self) -> JobStatus:
return self._state_transitions[-1][1]
@status.setter
def status(self, status: JobStatus):
if status in _COMPLETED_STATUSES and self.status in _COMPLETED_STATUSES:
raise ValueError(f"Job is already in a completed state ({self.status})")
if self.status == status:
return
self._state_transitions.append((datetime.now(UTC), status))
@property
def artifacts(self) -> list[JobArtifact]:
return self._artifacts
def register_artifact(self, artifact: JobArtifact) -> None:
self._artifacts.append(artifact)
def _find_state_transition_date(self, status: Iterable[JobStatus]) -> datetime | None:
for date, s in reversed(self._state_transitions):
if s in status:
return date
return None
@property
def scheduled_at(self) -> datetime | None:
return self._find_state_transition_date([JobStatus.scheduled])
@property
def started_at(self) -> datetime | None:
return self._find_state_transition_date([JobStatus.running])
@property
def completed_at(self) -> datetime | None:
return self._find_state_transition_date(_COMPLETED_STATUSES)
@property
def logs(self) -> list[LogMessage]:
return self._logs[:]
def append_log(self, message: LogMessage) -> None:
self._logs.append(message)
# TODO: implement
def cancel(self) -> None:
raise NotImplementedError
class _SchedulerBackend(abc.ABC):
@abc.abstractmethod
def on_log_message_cb(self, job: Job, message: LogMessage) -> None:
raise NotImplementedError
@abc.abstractmethod
def on_status_change_cb(self, job: Job, status: JobStatus) -> None:
raise NotImplementedError
@abc.abstractmethod
def on_artifact_collected_cb(self, job: Job, artifact: JobArtifact) -> None:
raise NotImplementedError
@abc.abstractmethod
async def shutdown(self) -> None:
raise NotImplementedError
@abc.abstractmethod
def schedule(
self,
job: Job,
on_log_message_cb: Callable[[str], None],
on_status_change_cb: Callable[[JobStatus], None],
on_artifact_collected_cb: Callable[[JobArtifact], None],
) -> None:
raise NotImplementedError
class _NaiveSchedulerBackend(_SchedulerBackend):
def __init__(self, timeout: int = 5):
self._timeout = timeout
self._loop = asyncio.new_event_loop()
# There may be performance implications of using threads due to Python
# GIL; may need to measure if it's a real problem though
self._thread = threading.Thread(target=self._run_loop, daemon=True)
self._thread.start()
def _run_loop(self) -> None:
asyncio.set_event_loop(self._loop)
self._loop.run_forever()
# When stopping the loop, give tasks a chance to finish
# TODO: should we explicitly inform jobs of pending stoppage?
for task in asyncio.all_tasks(self._loop):
self._loop.run_until_complete(task)
self._loop.close()
async def shutdown(self) -> None:
self._loop.call_soon_threadsafe(self._loop.stop)
self._thread.join()
# TODO: decouple scheduling and running the job
def schedule(
self,
job: Job,
on_log_message_cb: Callable[[str], None],
on_status_change_cb: Callable[[JobStatus], None],
on_artifact_collected_cb: Callable[[JobArtifact], None],
) -> None:
async def do():
try:
job.status = JobStatus.running
await job.handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb)
except Exception as e:
on_log_message_cb(str(e))
job.status = JobStatus.failed
logger.exception(f"Job {job.id} failed.")
asyncio.run_coroutine_threadsafe(do(), self._loop)
def on_log_message_cb(self, job: Job, message: LogMessage) -> None:
pass
def on_status_change_cb(self, job: Job, status: JobStatus) -> None:
pass
def on_artifact_collected_cb(self, job: Job, artifact: JobArtifact) -> None:
pass
_BACKENDS = {
"naive": _NaiveSchedulerBackend,
}
def _get_backend_impl(backend: str) -> _SchedulerBackend:
try:
return _BACKENDS[backend]()
except KeyError as e:
raise ValueError(f"Unknown backend {backend}") from e
class Scheduler:
def __init__(self, backend: str = "naive"):
# TODO: if server crashes, job states are lost; we need to persist jobs on disc
self._jobs: dict[JobID, Job] = {}
self._backend = _get_backend_impl(backend)
def _on_log_message_cb(self, job: Job, message: str) -> None:
msg = (datetime.now(UTC), message)
# At least for the time being, until there's a better way to expose
# logs to users, log messages on console
logger.info(f"Job {job.id}: {message}")
job.append_log(msg)
self._backend.on_log_message_cb(job, msg)
def _on_status_change_cb(self, job: Job, status: JobStatus) -> None:
job.status = status
self._backend.on_status_change_cb(job, status)
def _on_artifact_collected_cb(self, job: Job, artifact: JobArtifact) -> None:
job.register_artifact(artifact)
self._backend.on_artifact_collected_cb(job, artifact)
def schedule(self, type_: JobType, job_id: JobID, handler: JobHandler) -> JobID:
job = Job(type_, job_id, handler)
if job.id in self._jobs:
raise ValueError(f"Job {job.id} already exists")
self._jobs[job.id] = job
job.status = JobStatus.scheduled
self._backend.schedule(
job,
functools.partial(self._on_log_message_cb, job),
functools.partial(self._on_status_change_cb, job),
functools.partial(self._on_artifact_collected_cb, job),
)
return job.id
def cancel(self, job_id: JobID) -> None:
self.get_job(job_id).cancel()
def get_job(self, job_id: JobID) -> Job:
try:
return self._jobs[job_id]
except KeyError as e:
raise ValueError(f"Job {job_id} not found") from e
def get_jobs(self, type_: JobType | None = None) -> list[Job]:
jobs = list(self._jobs.values())
if type_:
jobs = [job for job in jobs if job._type == type_]
return jobs
async def shutdown(self):
# TODO: also cancel jobs once implemented
await self._backend.shutdown()