feat(api): define a more coherent jobs api across different flows

Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
Ihar Hrachyshka 2025-03-24 20:54:04 -04:00
parent 71ed47ea76
commit 0f50cfa561
15 changed files with 1864 additions and 1670 deletions

View file

@ -20,9 +20,10 @@ from llama_stack.providers.inline.agents.meta_reference.agent_instance import (
)
from llama_stack.providers.utils.common.data_schema_validator import ColumnName
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.schema_utils import webmethod
from .....apis.common.job_types import Job, JobStatus
from .....apis.eval.eval import BenchmarkConfig, Eval, EvaluateResponse
from .....apis.common.job_types import JobArtifact, JobStatus
from .....apis.eval.eval import BenchmarkConfig, Eval, EvaluateJob, ListEvaluateJobsResponse
from .config import MetaReferenceEvalConfig
EVAL_TASKS_PREFIX = "benchmarks:"
@ -75,11 +76,11 @@ class MetaReferenceEvalImpl(
)
self.benchmarks[task_def.identifier] = task_def
async def run_eval(
async def evaluate(
self,
benchmark_id: str,
benchmark_config: BenchmarkConfig,
) -> Job:
) -> EvaluateJob:
task_def = self.benchmarks[benchmark_id]
dataset_id = task_def.dataset_id
scoring_functions = task_def.scoring_functions
@ -91,18 +92,35 @@ class MetaReferenceEvalImpl(
dataset_id=dataset_id,
limit=(-1 if benchmark_config.num_examples is None else benchmark_config.num_examples),
)
res = await self.evaluate_rows(
generations, scoring_results = await self._evaluate_rows(
benchmark_id=benchmark_id,
input_rows=all_rows.data,
scoring_functions=scoring_functions,
benchmark_config=benchmark_config,
)
artifacts = [
JobArtifact(
type="generation",
name=f"generation-{i}",
metadata=generation,
)
for i, generation in enumerate(generations)
] + [
JobArtifact(
type="scoring_results",
name="scoring_results",
metadata=scoring_results,
)
]
# TODO: currently needs to wait for generation before returning
# need job scheduler queue (ray/celery) w/ jobs api
job_id = str(len(self.jobs))
self.jobs[job_id] = res
return Job(job_id=job_id, status=JobStatus.completed)
job = EvaluateJob(id=job_id, artifacts=artifacts)
job.update_status(JobStatus.completed)
self.jobs[job_id] = job
return job
async def _run_agent_generation(
self, input_rows: List[Dict[str, Any]], benchmark_config: BenchmarkConfig
@ -182,13 +200,13 @@ class MetaReferenceEvalImpl(
return generations
async def evaluate_rows(
async def _evaluate_rows(
self,
benchmark_id: str,
input_rows: List[Dict[str, Any]],
scoring_functions: List[str],
benchmark_config: BenchmarkConfig,
) -> EvaluateResponse:
) -> tuple[list[dict[str, Any]], dict[str, Any]]:
candidate = benchmark_config.eval_candidate
if candidate.type == "agent":
generations = await self._run_agent_generation(input_rows, benchmark_config)
@ -214,21 +232,26 @@ class MetaReferenceEvalImpl(
input_rows=score_input_rows, scoring_functions=scoring_functions_dict
)
return EvaluateResponse(generations=generations, scores=score_response.results)
async def job_status(self, benchmark_id: str, job_id: str) -> Job:
if job_id in self.jobs:
return Job(job_id=job_id, status=JobStatus.completed)
raise ValueError(f"Job {job_id} not found")
async def job_cancel(self, benchmark_id: str, job_id: str) -> None:
raise NotImplementedError("Job cancel is not implemented yet")
async def job_result(self, benchmark_id: str, job_id: str) -> EvaluateResponse:
job = await self.job_status(benchmark_id, job_id)
status = job.status
if not status or status != JobStatus.completed:
raise ValueError(f"Job is not completed, Status: {status.value}")
return generations, score_response.results
# CRUD operations on running jobs
@webmethod(route="/evaluate/jobs/{job_id:path}", method="GET")
async def get_evaluate_job(self, job_id: str) -> EvaluateJob:
return self.jobs[job_id]
@webmethod(route="/evaluate/jobs", method="GET")
async def list_evaluate_jobs(self) -> ListEvaluateJobsResponse:
return ListEvaluateJobsResponse(data=list(self.jobs.values()))
@webmethod(route="/evaluate/jobs/{job_id:path}", method="POST")
async def update_evaluate_job(self, job: EvaluateJob) -> EvaluateJob:
raise NotImplementedError
@webmethod(route="/evaluate/job/{job_id:path}", method="DELETE")
async def delete_evaluate_job(self, job_id: str) -> None:
raise NotImplementedError
# Note: pause/resume/cancel are achieved as follows:
# - POST with status=paused
# - POST with status=resuming
# - POST with status=cancelled

View file

@ -10,14 +10,10 @@ from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.post_training import (
AlgorithmConfig,
Checkpoint,
DPOAlignmentConfig,
JobStatus,
ListPostTrainingJobsResponse,
LoraFinetuningConfig,
PostTrainingJob,
PostTrainingJobArtifactsResponse,
PostTrainingJobStatusResponse,
TrainingConfig,
)
from llama_stack.providers.inline.post_training.torchtune.config import (
@ -54,15 +50,6 @@ class TorchtunePostTrainingImpl:
async def shutdown(self) -> None:
await self._scheduler.shutdown()
@staticmethod
def _checkpoint_to_artifact(checkpoint: Checkpoint) -> JobArtifact:
return JobArtifact(
type=TrainingArtifactType.CHECKPOINT.value,
name=checkpoint.identifier,
uri=checkpoint.path,
metadata=dict(checkpoint),
)
@staticmethod
def _resources_stats_to_artifact(resources_stats: Dict[str, Any]) -> JobArtifact:
return JobArtifact(
@ -98,14 +85,14 @@ class TorchtunePostTrainingImpl:
self.datasetio_api,
self.datasets_api,
)
await recipe.setup()
resources_allocated, checkpoints = await recipe.train()
on_artifact_collected_cb(self._resources_stats_to_artifact(resources_allocated))
for checkpoint in checkpoints:
artifact = self._checkpoint_to_artifact(checkpoint)
on_artifact_collected_cb(artifact)
on_artifact_collected_cb(checkpoint)
on_status_change_cb(SchedulerJobStatus.completed)
on_log_message_cb("Lora finetuning completed")
@ -113,6 +100,8 @@ class TorchtunePostTrainingImpl:
raise NotImplementedError()
job_uuid = self._scheduler.schedule(_JOB_TYPE_SUPERVISED_FINE_TUNE, job_uuid, handler)
# TODO: initialize with more data from scheduler
return PostTrainingJob(job_uuid=job_uuid)
async def preference_optimize(
@ -125,56 +114,31 @@ class TorchtunePostTrainingImpl:
logger_config: Dict[str, Any],
) -> PostTrainingJob: ...
async def get_training_jobs(self) -> ListPostTrainingJobsResponse:
# TODO: should these be under post-training/supervised-fine-tune/?
# CRUD operations on running jobs
@webmethod(route="/post-training/jobs/{job_id:path}", method="GET")
async def get_post_training_job(self, job_id: str) -> PostTrainingJob:
# TODO: implement
raise NotImplementedError
@webmethod(route="/post-training/jobs", method="GET")
async def list_post_training_jobs(self) -> ListPostTrainingJobsResponse:
# TODO: populate other data
return ListPostTrainingJobsResponse(
data=[PostTrainingJob(job_uuid=job.id) for job in self._scheduler.get_jobs()]
)
@staticmethod
def _get_artifacts_metadata_by_type(job, artifact_type):
return [artifact.metadata for artifact in job.artifacts if artifact.type == artifact_type]
@webmethod(route="/post-training/jobs/{job_id:path}", method="POST")
async def update_post_training_job(self, job: PostTrainingJob) -> PostTrainingJob:
# TODO: implement
raise NotImplementedError
@classmethod
def _get_checkpoints(cls, job):
return cls._get_artifacts_metadata_by_type(job, TrainingArtifactType.CHECKPOINT.value)
@webmethod(route="/post-training/job/{job_id:path}", method="DELETE")
async def delete_post_training_job(self, job_id: str) -> None:
# TODO: implement
raise NotImplementedError
@classmethod
def _get_resources_allocated(cls, job):
data = cls._get_artifacts_metadata_by_type(job, TrainingArtifactType.RESOURCES_STATS.value)
return data[0] if data else None
@webmethod(route="/post-training/job/status")
async def get_training_job_status(self, job_uuid: str) -> Optional[PostTrainingJobStatusResponse]:
job = self._scheduler.get_job(job_uuid)
match job.status:
# TODO: Add support for other statuses to API
case SchedulerJobStatus.new | SchedulerJobStatus.scheduled:
status = JobStatus.scheduled
case SchedulerJobStatus.running:
status = JobStatus.in_progress
case SchedulerJobStatus.completed:
status = JobStatus.completed
case SchedulerJobStatus.failed:
status = JobStatus.failed
case _:
raise NotImplementedError()
return PostTrainingJobStatusResponse(
job_uuid=job_uuid,
status=status,
scheduled_at=job.scheduled_at,
started_at=job.started_at,
completed_at=job.completed_at,
checkpoints=self._get_checkpoints(job),
resources_allocated=self._get_resources_allocated(job),
)
@webmethod(route="/post-training/job/cancel")
async def cancel_training_job(self, job_uuid: str) -> None:
self._scheduler.cancel(job_uuid)
@webmethod(route="/post-training/job/artifacts")
async def get_training_job_artifacts(self, job_uuid: str) -> Optional[PostTrainingJobArtifactsResponse]:
job = self._scheduler.get_job(job_uuid)
return PostTrainingJobArtifactsResponse(job_uuid=job_uuid, checkpoints=self._get_checkpoints(job))
# Note: pause/resume/cancel are achieved as follows:
# - POST with status=paused
# - POST with status=resuming
# - POST with status=cancelled

View file

@ -33,11 +33,11 @@ from torchtune.training.lr_schedulers import get_cosine_schedule_with_warmup
from torchtune.training.metric_logging import DiskLogger
from tqdm import tqdm
from llama_stack.apis.common.job_types import JobArtifact
from llama_stack.apis.common.training_types import PostTrainingMetric
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.post_training import (
Checkpoint,
DataConfig,
EfficiencyConfig,
LoraFinetuningConfig,
@ -457,7 +457,7 @@ class LoraFinetuningSingleDevice:
return loss
async def train(self) -> Tuple[Dict[str, Any], List[Checkpoint]]:
async def train(self) -> Tuple[Dict[str, Any], List[JobArtifact]]:
"""
The core training loop.
"""
@ -543,13 +543,18 @@ class LoraFinetuningSingleDevice:
self.epochs_run += 1
log.info("Starting checkpoint save...")
checkpoint_path = await self.save_checkpoint(epoch=curr_epoch)
checkpoint = Checkpoint(
identifier=f"{self.model_id}-sft-{curr_epoch}",
created_at=datetime.now(timezone.utc),
epoch=curr_epoch,
post_training_job_id=self.job_uuid,
path=checkpoint_path,
checkpoint = JobArtifact(
name=f"{self.model_id}-sft-{curr_epoch}",
type="checkpoint",
# TODO: this should be exposed via /files instead
uri=checkpoint_path,
)
metadata = {
"created_at": datetime.now(timezone.utc),
"epoch": curr_epoch,
}
if self.training_config.data_config.validation_dataset_id:
validation_loss, perplexity = await self.validation()
training_metrics = PostTrainingMetric(
@ -558,7 +563,9 @@ class LoraFinetuningSingleDevice:
validation_loss=validation_loss,
perplexity=perplexity,
)
checkpoint.training_metrics = training_metrics
metadata["training_metrics"] = training_metrics
checkpoint.metadata = metadata
checkpoints.append(checkpoint)
# clean up the memory after training finishes