feat(api): define a more coherent jobs api across different flows

Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
Ihar Hrachyshka 2025-03-24 20:54:04 -04:00
parent 71ed47ea76
commit 0f50cfa561
15 changed files with 1864 additions and 1670 deletions

View file

@ -20,9 +20,10 @@ from llama_stack.providers.inline.agents.meta_reference.agent_instance import (
)
from llama_stack.providers.utils.common.data_schema_validator import ColumnName
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.schema_utils import webmethod
from .....apis.common.job_types import Job, JobStatus
from .....apis.eval.eval import BenchmarkConfig, Eval, EvaluateResponse
from .....apis.common.job_types import JobArtifact, JobStatus
from .....apis.eval.eval import BenchmarkConfig, Eval, EvaluateJob, ListEvaluateJobsResponse
from .config import MetaReferenceEvalConfig
EVAL_TASKS_PREFIX = "benchmarks:"
@ -75,11 +76,11 @@ class MetaReferenceEvalImpl(
)
self.benchmarks[task_def.identifier] = task_def
async def run_eval(
async def evaluate(
self,
benchmark_id: str,
benchmark_config: BenchmarkConfig,
) -> Job:
) -> EvaluateJob:
task_def = self.benchmarks[benchmark_id]
dataset_id = task_def.dataset_id
scoring_functions = task_def.scoring_functions
@ -91,18 +92,35 @@ class MetaReferenceEvalImpl(
dataset_id=dataset_id,
limit=(-1 if benchmark_config.num_examples is None else benchmark_config.num_examples),
)
res = await self.evaluate_rows(
generations, scoring_results = await self._evaluate_rows(
benchmark_id=benchmark_id,
input_rows=all_rows.data,
scoring_functions=scoring_functions,
benchmark_config=benchmark_config,
)
artifacts = [
JobArtifact(
type="generation",
name=f"generation-{i}",
metadata=generation,
)
for i, generation in enumerate(generations)
] + [
JobArtifact(
type="scoring_results",
name="scoring_results",
metadata=scoring_results,
)
]
# TODO: currently needs to wait for generation before returning
# need job scheduler queue (ray/celery) w/ jobs api
job_id = str(len(self.jobs))
self.jobs[job_id] = res
return Job(job_id=job_id, status=JobStatus.completed)
job = EvaluateJob(id=job_id, artifacts=artifacts)
job.update_status(JobStatus.completed)
self.jobs[job_id] = job
return job
async def _run_agent_generation(
self, input_rows: List[Dict[str, Any]], benchmark_config: BenchmarkConfig
@ -182,13 +200,13 @@ class MetaReferenceEvalImpl(
return generations
async def evaluate_rows(
async def _evaluate_rows(
self,
benchmark_id: str,
input_rows: List[Dict[str, Any]],
scoring_functions: List[str],
benchmark_config: BenchmarkConfig,
) -> EvaluateResponse:
) -> tuple[list[dict[str, Any]], dict[str, Any]]:
candidate = benchmark_config.eval_candidate
if candidate.type == "agent":
generations = await self._run_agent_generation(input_rows, benchmark_config)
@ -214,21 +232,26 @@ class MetaReferenceEvalImpl(
input_rows=score_input_rows, scoring_functions=scoring_functions_dict
)
return EvaluateResponse(generations=generations, scores=score_response.results)
async def job_status(self, benchmark_id: str, job_id: str) -> Job:
if job_id in self.jobs:
return Job(job_id=job_id, status=JobStatus.completed)
raise ValueError(f"Job {job_id} not found")
async def job_cancel(self, benchmark_id: str, job_id: str) -> None:
raise NotImplementedError("Job cancel is not implemented yet")
async def job_result(self, benchmark_id: str, job_id: str) -> EvaluateResponse:
job = await self.job_status(benchmark_id, job_id)
status = job.status
if not status or status != JobStatus.completed:
raise ValueError(f"Job is not completed, Status: {status.value}")
return generations, score_response.results
# CRUD operations on running jobs
@webmethod(route="/evaluate/jobs/{job_id:path}", method="GET")
async def get_evaluate_job(self, job_id: str) -> EvaluateJob:
return self.jobs[job_id]
@webmethod(route="/evaluate/jobs", method="GET")
async def list_evaluate_jobs(self) -> ListEvaluateJobsResponse:
return ListEvaluateJobsResponse(data=list(self.jobs.values()))
@webmethod(route="/evaluate/jobs/{job_id:path}", method="POST")
async def update_evaluate_job(self, job: EvaluateJob) -> EvaluateJob:
raise NotImplementedError
@webmethod(route="/evaluate/job/{job_id:path}", method="DELETE")
async def delete_evaluate_job(self, job_id: str) -> None:
raise NotImplementedError
# Note: pause/resume/cancel are achieved as follows:
# - POST with status=paused
# - POST with status=resuming
# - POST with status=cancelled

View file

@ -10,14 +10,10 @@ from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.post_training import (
AlgorithmConfig,
Checkpoint,
DPOAlignmentConfig,
JobStatus,
ListPostTrainingJobsResponse,
LoraFinetuningConfig,
PostTrainingJob,
PostTrainingJobArtifactsResponse,
PostTrainingJobStatusResponse,
TrainingConfig,
)
from llama_stack.providers.inline.post_training.torchtune.config import (
@ -54,15 +50,6 @@ class TorchtunePostTrainingImpl:
async def shutdown(self) -> None:
await self._scheduler.shutdown()
@staticmethod
def _checkpoint_to_artifact(checkpoint: Checkpoint) -> JobArtifact:
return JobArtifact(
type=TrainingArtifactType.CHECKPOINT.value,
name=checkpoint.identifier,
uri=checkpoint.path,
metadata=dict(checkpoint),
)
@staticmethod
def _resources_stats_to_artifact(resources_stats: Dict[str, Any]) -> JobArtifact:
return JobArtifact(
@ -98,14 +85,14 @@ class TorchtunePostTrainingImpl:
self.datasetio_api,
self.datasets_api,
)
await recipe.setup()
resources_allocated, checkpoints = await recipe.train()
on_artifact_collected_cb(self._resources_stats_to_artifact(resources_allocated))
for checkpoint in checkpoints:
artifact = self._checkpoint_to_artifact(checkpoint)
on_artifact_collected_cb(artifact)
on_artifact_collected_cb(checkpoint)
on_status_change_cb(SchedulerJobStatus.completed)
on_log_message_cb("Lora finetuning completed")
@ -113,6 +100,8 @@ class TorchtunePostTrainingImpl:
raise NotImplementedError()
job_uuid = self._scheduler.schedule(_JOB_TYPE_SUPERVISED_FINE_TUNE, job_uuid, handler)
# TODO: initialize with more data from scheduler
return PostTrainingJob(job_uuid=job_uuid)
async def preference_optimize(
@ -125,56 +114,31 @@ class TorchtunePostTrainingImpl:
logger_config: Dict[str, Any],
) -> PostTrainingJob: ...
async def get_training_jobs(self) -> ListPostTrainingJobsResponse:
# TODO: should these be under post-training/supervised-fine-tune/?
# CRUD operations on running jobs
@webmethod(route="/post-training/jobs/{job_id:path}", method="GET")
async def get_post_training_job(self, job_id: str) -> PostTrainingJob:
# TODO: implement
raise NotImplementedError
@webmethod(route="/post-training/jobs", method="GET")
async def list_post_training_jobs(self) -> ListPostTrainingJobsResponse:
# TODO: populate other data
return ListPostTrainingJobsResponse(
data=[PostTrainingJob(job_uuid=job.id) for job in self._scheduler.get_jobs()]
)
@staticmethod
def _get_artifacts_metadata_by_type(job, artifact_type):
return [artifact.metadata for artifact in job.artifacts if artifact.type == artifact_type]
@webmethod(route="/post-training/jobs/{job_id:path}", method="POST")
async def update_post_training_job(self, job: PostTrainingJob) -> PostTrainingJob:
# TODO: implement
raise NotImplementedError
@classmethod
def _get_checkpoints(cls, job):
return cls._get_artifacts_metadata_by_type(job, TrainingArtifactType.CHECKPOINT.value)
@webmethod(route="/post-training/job/{job_id:path}", method="DELETE")
async def delete_post_training_job(self, job_id: str) -> None:
# TODO: implement
raise NotImplementedError
@classmethod
def _get_resources_allocated(cls, job):
data = cls._get_artifacts_metadata_by_type(job, TrainingArtifactType.RESOURCES_STATS.value)
return data[0] if data else None
@webmethod(route="/post-training/job/status")
async def get_training_job_status(self, job_uuid: str) -> Optional[PostTrainingJobStatusResponse]:
job = self._scheduler.get_job(job_uuid)
match job.status:
# TODO: Add support for other statuses to API
case SchedulerJobStatus.new | SchedulerJobStatus.scheduled:
status = JobStatus.scheduled
case SchedulerJobStatus.running:
status = JobStatus.in_progress
case SchedulerJobStatus.completed:
status = JobStatus.completed
case SchedulerJobStatus.failed:
status = JobStatus.failed
case _:
raise NotImplementedError()
return PostTrainingJobStatusResponse(
job_uuid=job_uuid,
status=status,
scheduled_at=job.scheduled_at,
started_at=job.started_at,
completed_at=job.completed_at,
checkpoints=self._get_checkpoints(job),
resources_allocated=self._get_resources_allocated(job),
)
@webmethod(route="/post-training/job/cancel")
async def cancel_training_job(self, job_uuid: str) -> None:
self._scheduler.cancel(job_uuid)
@webmethod(route="/post-training/job/artifacts")
async def get_training_job_artifacts(self, job_uuid: str) -> Optional[PostTrainingJobArtifactsResponse]:
job = self._scheduler.get_job(job_uuid)
return PostTrainingJobArtifactsResponse(job_uuid=job_uuid, checkpoints=self._get_checkpoints(job))
# Note: pause/resume/cancel are achieved as follows:
# - POST with status=paused
# - POST with status=resuming
# - POST with status=cancelled

View file

@ -33,11 +33,11 @@ from torchtune.training.lr_schedulers import get_cosine_schedule_with_warmup
from torchtune.training.metric_logging import DiskLogger
from tqdm import tqdm
from llama_stack.apis.common.job_types import JobArtifact
from llama_stack.apis.common.training_types import PostTrainingMetric
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.post_training import (
Checkpoint,
DataConfig,
EfficiencyConfig,
LoraFinetuningConfig,
@ -457,7 +457,7 @@ class LoraFinetuningSingleDevice:
return loss
async def train(self) -> Tuple[Dict[str, Any], List[Checkpoint]]:
async def train(self) -> Tuple[Dict[str, Any], List[JobArtifact]]:
"""
The core training loop.
"""
@ -543,13 +543,18 @@ class LoraFinetuningSingleDevice:
self.epochs_run += 1
log.info("Starting checkpoint save...")
checkpoint_path = await self.save_checkpoint(epoch=curr_epoch)
checkpoint = Checkpoint(
identifier=f"{self.model_id}-sft-{curr_epoch}",
created_at=datetime.now(timezone.utc),
epoch=curr_epoch,
post_training_job_id=self.job_uuid,
path=checkpoint_path,
checkpoint = JobArtifact(
name=f"{self.model_id}-sft-{curr_epoch}",
type="checkpoint",
# TODO: this should be exposed via /files instead
uri=checkpoint_path,
)
metadata = {
"created_at": datetime.now(timezone.utc),
"epoch": curr_epoch,
}
if self.training_config.data_config.validation_dataset_id:
validation_loss, perplexity = await self.validation()
training_metrics = PostTrainingMetric(
@ -558,7 +563,9 @@ class LoraFinetuningSingleDevice:
validation_loss=validation_loss,
perplexity=perplexity,
)
checkpoint.training_metrics = training_metrics
metadata["training_metrics"] = training_metrics
checkpoint.metadata = metadata
checkpoints.append(checkpoint)
# clean up the memory after training finishes

View file

@ -4,19 +4,16 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import warnings
from datetime import datetime
from typing import Any, Dict, List, Literal, Optional
from typing import Any, Dict, Optional
import aiohttp
from pydantic import BaseModel, ConfigDict
from llama_stack.apis.common.job_types import JobStatus
from llama_stack.apis.post_training import (
AlgorithmConfig,
DPOAlignmentConfig,
JobStatus,
ListPostTrainingJobsResponse,
PostTrainingJob,
PostTrainingJobArtifactsResponse,
PostTrainingJobStatusResponse,
TrainingConfig,
)
from llama_stack.providers.remote.post_training.nvidia.config import NvidiaPostTrainingConfig
@ -25,36 +22,6 @@ from llama_stack.providers.utils.inference.model_registry import ModelRegistryHe
from .models import _MODEL_ENTRIES
# Map API status to JobStatus enum
STATUS_MAPPING = {
"running": "in_progress",
"completed": "completed",
"failed": "failed",
"cancelled": "cancelled",
"pending": "scheduled",
}
class NvidiaPostTrainingJob(PostTrainingJob):
"""Parse the response from the Customizer API.
Inherits job_uuid from PostTrainingJob.
Adds status, created_at, updated_at parameters.
Passes through all other parameters from data field in the response.
"""
model_config = ConfigDict(extra="allow")
status: JobStatus
created_at: datetime
updated_at: datetime
class ListNvidiaPostTrainingJobs(BaseModel):
data: List[NvidiaPostTrainingJob]
class NvidiaPostTrainingJobStatusResponse(PostTrainingJobStatusResponse):
model_config = ConfigDict(extra="allow")
class NvidiaPostTrainingAdapter(ModelRegistryHelper):
def __init__(self, config: NvidiaPostTrainingConfig):
@ -100,102 +67,54 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
raise Exception(f"API request failed: {error_data}")
return await response.json()
async def get_training_jobs(
self,
page: Optional[int] = 1,
page_size: Optional[int] = 10,
sort: Optional[Literal["created_at", "-created_at"]] = "created_at",
) -> ListNvidiaPostTrainingJobs:
"""Get all customization jobs.
Updated the base class return type from ListPostTrainingJobsResponse to ListNvidiaPostTrainingJobs.
raise Exception(f"API request failed after {self.config.max_retries} retries")
Returns a ListNvidiaPostTrainingJobs object with the following fields:
- data: List[NvidiaPostTrainingJob] - List of NvidiaPostTrainingJob objects
@staticmethod
def _get_job_status(job: Dict[str, Any]) -> JobStatus:
job_status = job.get("status", "unknown").lower()
try:
return JobStatus(job_status)
except ValueError:
return JobStatus.unknown
# TODO: fetch just the necessary job from remote
async def get_post_training_job(self, job_id: str) -> PostTrainingJob:
jobs = await self.list_post_training_jobs()
for job in jobs.data:
if job.id == job_id:
return job
raise ValueError(f"Job with ID {job_id} not found")
async def list_post_training_jobs(self) -> ListPostTrainingJobsResponse:
"""Get all customization jobs.
ToDo: Support for schema input for filtering.
"""
params = {"page": page, "page_size": page_size, "sort": sort}
# TODO: don't hardcode pagination params
params = {"page": 1, "page_size": 10, "sort": "created_at"}
response = await self._make_request("GET", "/v1/customization/jobs", params=params)
jobs = []
for job in response.get("data", []):
job_id = job.pop("id")
job_status = job.pop("status", "unknown").lower()
mapped_status = STATUS_MAPPING.get(job_status, "unknown")
for job_dict in response.get("data", []):
# TODO: expose artifacts
job = PostTrainingJob(**job_dict)
job.update_status(self._get_job_status(job_dict))
jobs.append(job)
# Convert string timestamps to datetime objects
created_at = (
datetime.fromisoformat(job.pop("created_at"))
if "created_at" in job
else datetime.now(tz=datetime.timezone.utc)
)
updated_at = (
datetime.fromisoformat(job.pop("updated_at"))
if "updated_at" in job
else datetime.now(tz=datetime.timezone.utc)
)
return ListPostTrainingJobsResponse(data=jobs)
# Create NvidiaPostTrainingJob instance
jobs.append(
NvidiaPostTrainingJob(
job_uuid=job_id,
status=JobStatus(mapped_status),
created_at=created_at,
updated_at=updated_at,
**job,
)
)
return ListNvidiaPostTrainingJobs(data=jobs)
async def get_training_job_status(self, job_uuid: str) -> NvidiaPostTrainingJobStatusResponse:
"""Get the status of a customization job.
Updated the base class return type from PostTrainingJobResponse to NvidiaPostTrainingJob.
Returns a NvidiaPostTrainingJob object with the following fields:
- job_uuid: str - Unique identifier for the job
- status: JobStatus - Current status of the job (in_progress, completed, failed, cancelled, scheduled)
- created_at: datetime - The time when the job was created
- updated_at: datetime - The last time the job status was updated
Additional fields that may be included:
- steps_completed: Optional[int] - Number of training steps completed
- epochs_completed: Optional[int] - Number of epochs completed
- percentage_done: Optional[float] - Percentage of training completed (0-100)
- best_epoch: Optional[int] - The epoch with the best performance
- train_loss: Optional[float] - Training loss of the best checkpoint
- val_loss: Optional[float] - Validation loss of the best checkpoint
- metrics: Optional[Dict] - Additional training metrics
- status_logs: Optional[List] - Detailed logs of status changes
"""
response = await self._make_request(
"GET",
f"/v1/customization/jobs/{job_uuid}/status",
params={"job_id": job_uuid},
)
api_status = response.pop("status").lower()
mapped_status = STATUS_MAPPING.get(api_status, "unknown")
return NvidiaPostTrainingJobStatusResponse(
status=JobStatus(mapped_status),
job_uuid=job_uuid,
started_at=datetime.fromisoformat(response.pop("created_at")),
updated_at=datetime.fromisoformat(response.pop("updated_at")),
**response,
)
async def cancel_training_job(self, job_uuid: str) -> None:
async def update_post_training_job(self, job_id: str, status: JobStatus | None = None) -> PostTrainingJob:
if status is None:
raise ValueError("Status must be provided")
if status not in {JobStatus.cancelled}:
raise ValueError(f"Unsupported status: {status}")
await self._make_request(
method="POST", path=f"/v1/customization/jobs/{job_uuid}/cancel", params={"job_id": job_uuid}
method="POST", path=f"/v1/customization/jobs/{job_id}/cancel", params={"job_id": job_id}
)
return await self.get_post_training_job(job_id)
async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse:
raise NotImplementedError("Job artifacts are not implemented yet")
async def get_post_training_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse:
raise NotImplementedError("Job artifacts are not implemented yet")
async def delete_post_training_job(self, job_id: str) -> None:
raise NotImplementedError("Delete job is not implemented yet")
async def supervised_fine_tune(
self,
@ -206,7 +125,7 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
model: str,
checkpoint_dir: Optional[str],
algorithm_config: Optional[AlgorithmConfig] = None,
) -> NvidiaPostTrainingJob:
) -> PostTrainingJob:
"""
Fine-tunes a model on a dataset.
Currently only supports Lora finetuning for standlone docker container.
@ -409,15 +328,12 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
headers={"Accept": "application/json"},
json=job_config,
)
job_uuid = response["id"]
response.pop("status")
created_at = datetime.fromisoformat(response.pop("created_at"))
updated_at = datetime.fromisoformat(response.pop("updated_at"))
return NvidiaPostTrainingJob(
job_uuid=job_uuid, status=JobStatus.in_progress, created_at=created_at, updated_at=updated_at, **response
)
# TODO: expose artifacts
job = PostTrainingJob(**response)
job.update_status(JobStatus.running)
return job
async def preference_optimize(
self,
@ -430,6 +346,3 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
) -> PostTrainingJob:
"""Optimize a model based on preference data."""
raise NotImplementedError("Preference optimization is not implemented yet")
async def get_training_job_container_logs(self, job_uuid: str) -> PostTrainingJobStatusResponse:
raise NotImplementedError("Job logs are not implemented yet")