feat(api): define a more coherent jobs api across different flows

Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
Ihar Hrachyshka 2025-03-24 20:54:04 -04:00
parent 71ed47ea76
commit 0f50cfa561
15 changed files with 1864 additions and 1670 deletions

View file

@ -4,19 +4,16 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import warnings
from datetime import datetime
from typing import Any, Dict, List, Literal, Optional
from typing import Any, Dict, Optional
import aiohttp
from pydantic import BaseModel, ConfigDict
from llama_stack.apis.common.job_types import JobStatus
from llama_stack.apis.post_training import (
AlgorithmConfig,
DPOAlignmentConfig,
JobStatus,
ListPostTrainingJobsResponse,
PostTrainingJob,
PostTrainingJobArtifactsResponse,
PostTrainingJobStatusResponse,
TrainingConfig,
)
from llama_stack.providers.remote.post_training.nvidia.config import NvidiaPostTrainingConfig
@ -25,36 +22,6 @@ from llama_stack.providers.utils.inference.model_registry import ModelRegistryHe
from .models import _MODEL_ENTRIES
# Map API status to JobStatus enum
STATUS_MAPPING = {
"running": "in_progress",
"completed": "completed",
"failed": "failed",
"cancelled": "cancelled",
"pending": "scheduled",
}
class NvidiaPostTrainingJob(PostTrainingJob):
"""Parse the response from the Customizer API.
Inherits job_uuid from PostTrainingJob.
Adds status, created_at, updated_at parameters.
Passes through all other parameters from data field in the response.
"""
model_config = ConfigDict(extra="allow")
status: JobStatus
created_at: datetime
updated_at: datetime
class ListNvidiaPostTrainingJobs(BaseModel):
data: List[NvidiaPostTrainingJob]
class NvidiaPostTrainingJobStatusResponse(PostTrainingJobStatusResponse):
model_config = ConfigDict(extra="allow")
class NvidiaPostTrainingAdapter(ModelRegistryHelper):
def __init__(self, config: NvidiaPostTrainingConfig):
@ -100,102 +67,54 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
raise Exception(f"API request failed: {error_data}")
return await response.json()
async def get_training_jobs(
self,
page: Optional[int] = 1,
page_size: Optional[int] = 10,
sort: Optional[Literal["created_at", "-created_at"]] = "created_at",
) -> ListNvidiaPostTrainingJobs:
"""Get all customization jobs.
Updated the base class return type from ListPostTrainingJobsResponse to ListNvidiaPostTrainingJobs.
raise Exception(f"API request failed after {self.config.max_retries} retries")
Returns a ListNvidiaPostTrainingJobs object with the following fields:
- data: List[NvidiaPostTrainingJob] - List of NvidiaPostTrainingJob objects
@staticmethod
def _get_job_status(job: Dict[str, Any]) -> JobStatus:
job_status = job.get("status", "unknown").lower()
try:
return JobStatus(job_status)
except ValueError:
return JobStatus.unknown
# TODO: fetch just the necessary job from remote
async def get_post_training_job(self, job_id: str) -> PostTrainingJob:
jobs = await self.list_post_training_jobs()
for job in jobs.data:
if job.id == job_id:
return job
raise ValueError(f"Job with ID {job_id} not found")
async def list_post_training_jobs(self) -> ListPostTrainingJobsResponse:
"""Get all customization jobs.
ToDo: Support for schema input for filtering.
"""
params = {"page": page, "page_size": page_size, "sort": sort}
# TODO: don't hardcode pagination params
params = {"page": 1, "page_size": 10, "sort": "created_at"}
response = await self._make_request("GET", "/v1/customization/jobs", params=params)
jobs = []
for job in response.get("data", []):
job_id = job.pop("id")
job_status = job.pop("status", "unknown").lower()
mapped_status = STATUS_MAPPING.get(job_status, "unknown")
for job_dict in response.get("data", []):
# TODO: expose artifacts
job = PostTrainingJob(**job_dict)
job.update_status(self._get_job_status(job_dict))
jobs.append(job)
# Convert string timestamps to datetime objects
created_at = (
datetime.fromisoformat(job.pop("created_at"))
if "created_at" in job
else datetime.now(tz=datetime.timezone.utc)
)
updated_at = (
datetime.fromisoformat(job.pop("updated_at"))
if "updated_at" in job
else datetime.now(tz=datetime.timezone.utc)
)
return ListPostTrainingJobsResponse(data=jobs)
# Create NvidiaPostTrainingJob instance
jobs.append(
NvidiaPostTrainingJob(
job_uuid=job_id,
status=JobStatus(mapped_status),
created_at=created_at,
updated_at=updated_at,
**job,
)
)
return ListNvidiaPostTrainingJobs(data=jobs)
async def get_training_job_status(self, job_uuid: str) -> NvidiaPostTrainingJobStatusResponse:
"""Get the status of a customization job.
Updated the base class return type from PostTrainingJobResponse to NvidiaPostTrainingJob.
Returns a NvidiaPostTrainingJob object with the following fields:
- job_uuid: str - Unique identifier for the job
- status: JobStatus - Current status of the job (in_progress, completed, failed, cancelled, scheduled)
- created_at: datetime - The time when the job was created
- updated_at: datetime - The last time the job status was updated
Additional fields that may be included:
- steps_completed: Optional[int] - Number of training steps completed
- epochs_completed: Optional[int] - Number of epochs completed
- percentage_done: Optional[float] - Percentage of training completed (0-100)
- best_epoch: Optional[int] - The epoch with the best performance
- train_loss: Optional[float] - Training loss of the best checkpoint
- val_loss: Optional[float] - Validation loss of the best checkpoint
- metrics: Optional[Dict] - Additional training metrics
- status_logs: Optional[List] - Detailed logs of status changes
"""
response = await self._make_request(
"GET",
f"/v1/customization/jobs/{job_uuid}/status",
params={"job_id": job_uuid},
)
api_status = response.pop("status").lower()
mapped_status = STATUS_MAPPING.get(api_status, "unknown")
return NvidiaPostTrainingJobStatusResponse(
status=JobStatus(mapped_status),
job_uuid=job_uuid,
started_at=datetime.fromisoformat(response.pop("created_at")),
updated_at=datetime.fromisoformat(response.pop("updated_at")),
**response,
)
async def cancel_training_job(self, job_uuid: str) -> None:
async def update_post_training_job(self, job_id: str, status: JobStatus | None = None) -> PostTrainingJob:
if status is None:
raise ValueError("Status must be provided")
if status not in {JobStatus.cancelled}:
raise ValueError(f"Unsupported status: {status}")
await self._make_request(
method="POST", path=f"/v1/customization/jobs/{job_uuid}/cancel", params={"job_id": job_uuid}
method="POST", path=f"/v1/customization/jobs/{job_id}/cancel", params={"job_id": job_id}
)
return await self.get_post_training_job(job_id)
async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse:
raise NotImplementedError("Job artifacts are not implemented yet")
async def get_post_training_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse:
raise NotImplementedError("Job artifacts are not implemented yet")
async def delete_post_training_job(self, job_id: str) -> None:
raise NotImplementedError("Delete job is not implemented yet")
async def supervised_fine_tune(
self,
@ -206,7 +125,7 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
model: str,
checkpoint_dir: Optional[str],
algorithm_config: Optional[AlgorithmConfig] = None,
) -> NvidiaPostTrainingJob:
) -> PostTrainingJob:
"""
Fine-tunes a model on a dataset.
Currently only supports Lora finetuning for standlone docker container.
@ -409,15 +328,12 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
headers={"Accept": "application/json"},
json=job_config,
)
job_uuid = response["id"]
response.pop("status")
created_at = datetime.fromisoformat(response.pop("created_at"))
updated_at = datetime.fromisoformat(response.pop("updated_at"))
return NvidiaPostTrainingJob(
job_uuid=job_uuid, status=JobStatus.in_progress, created_at=created_at, updated_at=updated_at, **response
)
# TODO: expose artifacts
job = PostTrainingJob(**response)
job.update_status(JobStatus.running)
return job
async def preference_optimize(
self,
@ -430,6 +346,3 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
) -> PostTrainingJob:
"""Optimize a model based on preference data."""
raise NotImplementedError("Preference optimization is not implemented yet")
async def get_training_job_container_logs(self, job_uuid: str) -> PostTrainingJobStatusResponse:
raise NotImplementedError("Job logs are not implemented yet")