mirror of
https://github.com/meta-llama/llama-stack.git
synced 2026-01-02 11:54:30 +00:00
feat(api): define a more coherent jobs api across different flows
Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
parent
71ed47ea76
commit
0f50cfa561
15 changed files with 1864 additions and 1670 deletions
|
|
@ -20,9 +20,10 @@ from llama_stack.providers.inline.agents.meta_reference.agent_instance import (
|
|||
)
|
||||
from llama_stack.providers.utils.common.data_schema_validator import ColumnName
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.schema_utils import webmethod
|
||||
|
||||
from .....apis.common.job_types import Job, JobStatus
|
||||
from .....apis.eval.eval import BenchmarkConfig, Eval, EvaluateResponse
|
||||
from .....apis.common.job_types import JobArtifact, JobStatus
|
||||
from .....apis.eval.eval import BenchmarkConfig, Eval, EvaluateJob, ListEvaluateJobsResponse
|
||||
from .config import MetaReferenceEvalConfig
|
||||
|
||||
EVAL_TASKS_PREFIX = "benchmarks:"
|
||||
|
|
@ -75,11 +76,11 @@ class MetaReferenceEvalImpl(
|
|||
)
|
||||
self.benchmarks[task_def.identifier] = task_def
|
||||
|
||||
async def run_eval(
|
||||
async def evaluate(
|
||||
self,
|
||||
benchmark_id: str,
|
||||
benchmark_config: BenchmarkConfig,
|
||||
) -> Job:
|
||||
) -> EvaluateJob:
|
||||
task_def = self.benchmarks[benchmark_id]
|
||||
dataset_id = task_def.dataset_id
|
||||
scoring_functions = task_def.scoring_functions
|
||||
|
|
@ -91,18 +92,35 @@ class MetaReferenceEvalImpl(
|
|||
dataset_id=dataset_id,
|
||||
limit=(-1 if benchmark_config.num_examples is None else benchmark_config.num_examples),
|
||||
)
|
||||
res = await self.evaluate_rows(
|
||||
|
||||
generations, scoring_results = await self._evaluate_rows(
|
||||
benchmark_id=benchmark_id,
|
||||
input_rows=all_rows.data,
|
||||
scoring_functions=scoring_functions,
|
||||
benchmark_config=benchmark_config,
|
||||
)
|
||||
artifacts = [
|
||||
JobArtifact(
|
||||
type="generation",
|
||||
name=f"generation-{i}",
|
||||
metadata=generation,
|
||||
)
|
||||
for i, generation in enumerate(generations)
|
||||
] + [
|
||||
JobArtifact(
|
||||
type="scoring_results",
|
||||
name="scoring_results",
|
||||
metadata=scoring_results,
|
||||
)
|
||||
]
|
||||
|
||||
# TODO: currently needs to wait for generation before returning
|
||||
# need job scheduler queue (ray/celery) w/ jobs api
|
||||
job_id = str(len(self.jobs))
|
||||
self.jobs[job_id] = res
|
||||
return Job(job_id=job_id, status=JobStatus.completed)
|
||||
job = EvaluateJob(id=job_id, artifacts=artifacts)
|
||||
job.update_status(JobStatus.completed)
|
||||
self.jobs[job_id] = job
|
||||
return job
|
||||
|
||||
async def _run_agent_generation(
|
||||
self, input_rows: List[Dict[str, Any]], benchmark_config: BenchmarkConfig
|
||||
|
|
@ -182,13 +200,13 @@ class MetaReferenceEvalImpl(
|
|||
|
||||
return generations
|
||||
|
||||
async def evaluate_rows(
|
||||
async def _evaluate_rows(
|
||||
self,
|
||||
benchmark_id: str,
|
||||
input_rows: List[Dict[str, Any]],
|
||||
scoring_functions: List[str],
|
||||
benchmark_config: BenchmarkConfig,
|
||||
) -> EvaluateResponse:
|
||||
) -> tuple[list[dict[str, Any]], dict[str, Any]]:
|
||||
candidate = benchmark_config.eval_candidate
|
||||
if candidate.type == "agent":
|
||||
generations = await self._run_agent_generation(input_rows, benchmark_config)
|
||||
|
|
@ -214,21 +232,26 @@ class MetaReferenceEvalImpl(
|
|||
input_rows=score_input_rows, scoring_functions=scoring_functions_dict
|
||||
)
|
||||
|
||||
return EvaluateResponse(generations=generations, scores=score_response.results)
|
||||
|
||||
async def job_status(self, benchmark_id: str, job_id: str) -> Job:
|
||||
if job_id in self.jobs:
|
||||
return Job(job_id=job_id, status=JobStatus.completed)
|
||||
|
||||
raise ValueError(f"Job {job_id} not found")
|
||||
|
||||
async def job_cancel(self, benchmark_id: str, job_id: str) -> None:
|
||||
raise NotImplementedError("Job cancel is not implemented yet")
|
||||
|
||||
async def job_result(self, benchmark_id: str, job_id: str) -> EvaluateResponse:
|
||||
job = await self.job_status(benchmark_id, job_id)
|
||||
status = job.status
|
||||
if not status or status != JobStatus.completed:
|
||||
raise ValueError(f"Job is not completed, Status: {status.value}")
|
||||
return generations, score_response.results
|
||||
|
||||
# CRUD operations on running jobs
|
||||
@webmethod(route="/evaluate/jobs/{job_id:path}", method="GET")
|
||||
async def get_evaluate_job(self, job_id: str) -> EvaluateJob:
|
||||
return self.jobs[job_id]
|
||||
|
||||
@webmethod(route="/evaluate/jobs", method="GET")
|
||||
async def list_evaluate_jobs(self) -> ListEvaluateJobsResponse:
|
||||
return ListEvaluateJobsResponse(data=list(self.jobs.values()))
|
||||
|
||||
@webmethod(route="/evaluate/jobs/{job_id:path}", method="POST")
|
||||
async def update_evaluate_job(self, job: EvaluateJob) -> EvaluateJob:
|
||||
raise NotImplementedError
|
||||
|
||||
@webmethod(route="/evaluate/job/{job_id:path}", method="DELETE")
|
||||
async def delete_evaluate_job(self, job_id: str) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
# Note: pause/resume/cancel are achieved as follows:
|
||||
# - POST with status=paused
|
||||
# - POST with status=resuming
|
||||
# - POST with status=cancelled
|
||||
|
|
|
|||
|
|
@ -10,14 +10,10 @@ from llama_stack.apis.datasetio import DatasetIO
|
|||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.post_training import (
|
||||
AlgorithmConfig,
|
||||
Checkpoint,
|
||||
DPOAlignmentConfig,
|
||||
JobStatus,
|
||||
ListPostTrainingJobsResponse,
|
||||
LoraFinetuningConfig,
|
||||
PostTrainingJob,
|
||||
PostTrainingJobArtifactsResponse,
|
||||
PostTrainingJobStatusResponse,
|
||||
TrainingConfig,
|
||||
)
|
||||
from llama_stack.providers.inline.post_training.torchtune.config import (
|
||||
|
|
@ -54,15 +50,6 @@ class TorchtunePostTrainingImpl:
|
|||
async def shutdown(self) -> None:
|
||||
await self._scheduler.shutdown()
|
||||
|
||||
@staticmethod
|
||||
def _checkpoint_to_artifact(checkpoint: Checkpoint) -> JobArtifact:
|
||||
return JobArtifact(
|
||||
type=TrainingArtifactType.CHECKPOINT.value,
|
||||
name=checkpoint.identifier,
|
||||
uri=checkpoint.path,
|
||||
metadata=dict(checkpoint),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _resources_stats_to_artifact(resources_stats: Dict[str, Any]) -> JobArtifact:
|
||||
return JobArtifact(
|
||||
|
|
@ -98,14 +85,14 @@ class TorchtunePostTrainingImpl:
|
|||
self.datasetio_api,
|
||||
self.datasets_api,
|
||||
)
|
||||
|
||||
await recipe.setup()
|
||||
|
||||
resources_allocated, checkpoints = await recipe.train()
|
||||
|
||||
on_artifact_collected_cb(self._resources_stats_to_artifact(resources_allocated))
|
||||
for checkpoint in checkpoints:
|
||||
artifact = self._checkpoint_to_artifact(checkpoint)
|
||||
on_artifact_collected_cb(artifact)
|
||||
on_artifact_collected_cb(checkpoint)
|
||||
|
||||
on_status_change_cb(SchedulerJobStatus.completed)
|
||||
on_log_message_cb("Lora finetuning completed")
|
||||
|
|
@ -113,6 +100,8 @@ class TorchtunePostTrainingImpl:
|
|||
raise NotImplementedError()
|
||||
|
||||
job_uuid = self._scheduler.schedule(_JOB_TYPE_SUPERVISED_FINE_TUNE, job_uuid, handler)
|
||||
|
||||
# TODO: initialize with more data from scheduler
|
||||
return PostTrainingJob(job_uuid=job_uuid)
|
||||
|
||||
async def preference_optimize(
|
||||
|
|
@ -125,56 +114,31 @@ class TorchtunePostTrainingImpl:
|
|||
logger_config: Dict[str, Any],
|
||||
) -> PostTrainingJob: ...
|
||||
|
||||
async def get_training_jobs(self) -> ListPostTrainingJobsResponse:
|
||||
# TODO: should these be under post-training/supervised-fine-tune/?
|
||||
# CRUD operations on running jobs
|
||||
@webmethod(route="/post-training/jobs/{job_id:path}", method="GET")
|
||||
async def get_post_training_job(self, job_id: str) -> PostTrainingJob:
|
||||
# TODO: implement
|
||||
raise NotImplementedError
|
||||
|
||||
@webmethod(route="/post-training/jobs", method="GET")
|
||||
async def list_post_training_jobs(self) -> ListPostTrainingJobsResponse:
|
||||
# TODO: populate other data
|
||||
return ListPostTrainingJobsResponse(
|
||||
data=[PostTrainingJob(job_uuid=job.id) for job in self._scheduler.get_jobs()]
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_artifacts_metadata_by_type(job, artifact_type):
|
||||
return [artifact.metadata for artifact in job.artifacts if artifact.type == artifact_type]
|
||||
@webmethod(route="/post-training/jobs/{job_id:path}", method="POST")
|
||||
async def update_post_training_job(self, job: PostTrainingJob) -> PostTrainingJob:
|
||||
# TODO: implement
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def _get_checkpoints(cls, job):
|
||||
return cls._get_artifacts_metadata_by_type(job, TrainingArtifactType.CHECKPOINT.value)
|
||||
@webmethod(route="/post-training/job/{job_id:path}", method="DELETE")
|
||||
async def delete_post_training_job(self, job_id: str) -> None:
|
||||
# TODO: implement
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def _get_resources_allocated(cls, job):
|
||||
data = cls._get_artifacts_metadata_by_type(job, TrainingArtifactType.RESOURCES_STATS.value)
|
||||
return data[0] if data else None
|
||||
|
||||
@webmethod(route="/post-training/job/status")
|
||||
async def get_training_job_status(self, job_uuid: str) -> Optional[PostTrainingJobStatusResponse]:
|
||||
job = self._scheduler.get_job(job_uuid)
|
||||
|
||||
match job.status:
|
||||
# TODO: Add support for other statuses to API
|
||||
case SchedulerJobStatus.new | SchedulerJobStatus.scheduled:
|
||||
status = JobStatus.scheduled
|
||||
case SchedulerJobStatus.running:
|
||||
status = JobStatus.in_progress
|
||||
case SchedulerJobStatus.completed:
|
||||
status = JobStatus.completed
|
||||
case SchedulerJobStatus.failed:
|
||||
status = JobStatus.failed
|
||||
case _:
|
||||
raise NotImplementedError()
|
||||
|
||||
return PostTrainingJobStatusResponse(
|
||||
job_uuid=job_uuid,
|
||||
status=status,
|
||||
scheduled_at=job.scheduled_at,
|
||||
started_at=job.started_at,
|
||||
completed_at=job.completed_at,
|
||||
checkpoints=self._get_checkpoints(job),
|
||||
resources_allocated=self._get_resources_allocated(job),
|
||||
)
|
||||
|
||||
@webmethod(route="/post-training/job/cancel")
|
||||
async def cancel_training_job(self, job_uuid: str) -> None:
|
||||
self._scheduler.cancel(job_uuid)
|
||||
|
||||
@webmethod(route="/post-training/job/artifacts")
|
||||
async def get_training_job_artifacts(self, job_uuid: str) -> Optional[PostTrainingJobArtifactsResponse]:
|
||||
job = self._scheduler.get_job(job_uuid)
|
||||
return PostTrainingJobArtifactsResponse(job_uuid=job_uuid, checkpoints=self._get_checkpoints(job))
|
||||
# Note: pause/resume/cancel are achieved as follows:
|
||||
# - POST with status=paused
|
||||
# - POST with status=resuming
|
||||
# - POST with status=cancelled
|
||||
|
|
|
|||
|
|
@ -33,11 +33,11 @@ from torchtune.training.lr_schedulers import get_cosine_schedule_with_warmup
|
|||
from torchtune.training.metric_logging import DiskLogger
|
||||
from tqdm import tqdm
|
||||
|
||||
from llama_stack.apis.common.job_types import JobArtifact
|
||||
from llama_stack.apis.common.training_types import PostTrainingMetric
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.post_training import (
|
||||
Checkpoint,
|
||||
DataConfig,
|
||||
EfficiencyConfig,
|
||||
LoraFinetuningConfig,
|
||||
|
|
@ -457,7 +457,7 @@ class LoraFinetuningSingleDevice:
|
|||
|
||||
return loss
|
||||
|
||||
async def train(self) -> Tuple[Dict[str, Any], List[Checkpoint]]:
|
||||
async def train(self) -> Tuple[Dict[str, Any], List[JobArtifact]]:
|
||||
"""
|
||||
The core training loop.
|
||||
"""
|
||||
|
|
@ -543,13 +543,18 @@ class LoraFinetuningSingleDevice:
|
|||
self.epochs_run += 1
|
||||
log.info("Starting checkpoint save...")
|
||||
checkpoint_path = await self.save_checkpoint(epoch=curr_epoch)
|
||||
checkpoint = Checkpoint(
|
||||
identifier=f"{self.model_id}-sft-{curr_epoch}",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
epoch=curr_epoch,
|
||||
post_training_job_id=self.job_uuid,
|
||||
path=checkpoint_path,
|
||||
|
||||
checkpoint = JobArtifact(
|
||||
name=f"{self.model_id}-sft-{curr_epoch}",
|
||||
type="checkpoint",
|
||||
# TODO: this should be exposed via /files instead
|
||||
uri=checkpoint_path,
|
||||
)
|
||||
|
||||
metadata = {
|
||||
"created_at": datetime.now(timezone.utc),
|
||||
"epoch": curr_epoch,
|
||||
}
|
||||
if self.training_config.data_config.validation_dataset_id:
|
||||
validation_loss, perplexity = await self.validation()
|
||||
training_metrics = PostTrainingMetric(
|
||||
|
|
@ -558,7 +563,9 @@ class LoraFinetuningSingleDevice:
|
|||
validation_loss=validation_loss,
|
||||
perplexity=perplexity,
|
||||
)
|
||||
checkpoint.training_metrics = training_metrics
|
||||
metadata["training_metrics"] = training_metrics
|
||||
checkpoint.metadata = metadata
|
||||
|
||||
checkpoints.append(checkpoint)
|
||||
|
||||
# clean up the memory after training finishes
|
||||
|
|
|
|||
|
|
@ -4,19 +4,16 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import warnings
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import aiohttp
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from llama_stack.apis.common.job_types import JobStatus
|
||||
from llama_stack.apis.post_training import (
|
||||
AlgorithmConfig,
|
||||
DPOAlignmentConfig,
|
||||
JobStatus,
|
||||
ListPostTrainingJobsResponse,
|
||||
PostTrainingJob,
|
||||
PostTrainingJobArtifactsResponse,
|
||||
PostTrainingJobStatusResponse,
|
||||
TrainingConfig,
|
||||
)
|
||||
from llama_stack.providers.remote.post_training.nvidia.config import NvidiaPostTrainingConfig
|
||||
|
|
@ -25,36 +22,6 @@ from llama_stack.providers.utils.inference.model_registry import ModelRegistryHe
|
|||
|
||||
from .models import _MODEL_ENTRIES
|
||||
|
||||
# Map API status to JobStatus enum
|
||||
STATUS_MAPPING = {
|
||||
"running": "in_progress",
|
||||
"completed": "completed",
|
||||
"failed": "failed",
|
||||
"cancelled": "cancelled",
|
||||
"pending": "scheduled",
|
||||
}
|
||||
|
||||
|
||||
class NvidiaPostTrainingJob(PostTrainingJob):
|
||||
"""Parse the response from the Customizer API.
|
||||
Inherits job_uuid from PostTrainingJob.
|
||||
Adds status, created_at, updated_at parameters.
|
||||
Passes through all other parameters from data field in the response.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
status: JobStatus
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
class ListNvidiaPostTrainingJobs(BaseModel):
|
||||
data: List[NvidiaPostTrainingJob]
|
||||
|
||||
|
||||
class NvidiaPostTrainingJobStatusResponse(PostTrainingJobStatusResponse):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
||||
def __init__(self, config: NvidiaPostTrainingConfig):
|
||||
|
|
@ -100,102 +67,54 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
|||
raise Exception(f"API request failed: {error_data}")
|
||||
return await response.json()
|
||||
|
||||
async def get_training_jobs(
|
||||
self,
|
||||
page: Optional[int] = 1,
|
||||
page_size: Optional[int] = 10,
|
||||
sort: Optional[Literal["created_at", "-created_at"]] = "created_at",
|
||||
) -> ListNvidiaPostTrainingJobs:
|
||||
"""Get all customization jobs.
|
||||
Updated the base class return type from ListPostTrainingJobsResponse to ListNvidiaPostTrainingJobs.
|
||||
raise Exception(f"API request failed after {self.config.max_retries} retries")
|
||||
|
||||
Returns a ListNvidiaPostTrainingJobs object with the following fields:
|
||||
- data: List[NvidiaPostTrainingJob] - List of NvidiaPostTrainingJob objects
|
||||
@staticmethod
|
||||
def _get_job_status(job: Dict[str, Any]) -> JobStatus:
|
||||
job_status = job.get("status", "unknown").lower()
|
||||
try:
|
||||
return JobStatus(job_status)
|
||||
except ValueError:
|
||||
return JobStatus.unknown
|
||||
|
||||
# TODO: fetch just the necessary job from remote
|
||||
async def get_post_training_job(self, job_id: str) -> PostTrainingJob:
|
||||
jobs = await self.list_post_training_jobs()
|
||||
for job in jobs.data:
|
||||
if job.id == job_id:
|
||||
return job
|
||||
raise ValueError(f"Job with ID {job_id} not found")
|
||||
|
||||
async def list_post_training_jobs(self) -> ListPostTrainingJobsResponse:
|
||||
"""Get all customization jobs.
|
||||
|
||||
ToDo: Support for schema input for filtering.
|
||||
"""
|
||||
params = {"page": page, "page_size": page_size, "sort": sort}
|
||||
|
||||
# TODO: don't hardcode pagination params
|
||||
params = {"page": 1, "page_size": 10, "sort": "created_at"}
|
||||
response = await self._make_request("GET", "/v1/customization/jobs", params=params)
|
||||
|
||||
jobs = []
|
||||
for job in response.get("data", []):
|
||||
job_id = job.pop("id")
|
||||
job_status = job.pop("status", "unknown").lower()
|
||||
mapped_status = STATUS_MAPPING.get(job_status, "unknown")
|
||||
for job_dict in response.get("data", []):
|
||||
# TODO: expose artifacts
|
||||
job = PostTrainingJob(**job_dict)
|
||||
job.update_status(self._get_job_status(job_dict))
|
||||
jobs.append(job)
|
||||
|
||||
# Convert string timestamps to datetime objects
|
||||
created_at = (
|
||||
datetime.fromisoformat(job.pop("created_at"))
|
||||
if "created_at" in job
|
||||
else datetime.now(tz=datetime.timezone.utc)
|
||||
)
|
||||
updated_at = (
|
||||
datetime.fromisoformat(job.pop("updated_at"))
|
||||
if "updated_at" in job
|
||||
else datetime.now(tz=datetime.timezone.utc)
|
||||
)
|
||||
return ListPostTrainingJobsResponse(data=jobs)
|
||||
|
||||
# Create NvidiaPostTrainingJob instance
|
||||
jobs.append(
|
||||
NvidiaPostTrainingJob(
|
||||
job_uuid=job_id,
|
||||
status=JobStatus(mapped_status),
|
||||
created_at=created_at,
|
||||
updated_at=updated_at,
|
||||
**job,
|
||||
)
|
||||
)
|
||||
|
||||
return ListNvidiaPostTrainingJobs(data=jobs)
|
||||
|
||||
async def get_training_job_status(self, job_uuid: str) -> NvidiaPostTrainingJobStatusResponse:
|
||||
"""Get the status of a customization job.
|
||||
Updated the base class return type from PostTrainingJobResponse to NvidiaPostTrainingJob.
|
||||
|
||||
Returns a NvidiaPostTrainingJob object with the following fields:
|
||||
- job_uuid: str - Unique identifier for the job
|
||||
- status: JobStatus - Current status of the job (in_progress, completed, failed, cancelled, scheduled)
|
||||
- created_at: datetime - The time when the job was created
|
||||
- updated_at: datetime - The last time the job status was updated
|
||||
|
||||
Additional fields that may be included:
|
||||
- steps_completed: Optional[int] - Number of training steps completed
|
||||
- epochs_completed: Optional[int] - Number of epochs completed
|
||||
- percentage_done: Optional[float] - Percentage of training completed (0-100)
|
||||
- best_epoch: Optional[int] - The epoch with the best performance
|
||||
- train_loss: Optional[float] - Training loss of the best checkpoint
|
||||
- val_loss: Optional[float] - Validation loss of the best checkpoint
|
||||
- metrics: Optional[Dict] - Additional training metrics
|
||||
- status_logs: Optional[List] - Detailed logs of status changes
|
||||
"""
|
||||
response = await self._make_request(
|
||||
"GET",
|
||||
f"/v1/customization/jobs/{job_uuid}/status",
|
||||
params={"job_id": job_uuid},
|
||||
)
|
||||
|
||||
api_status = response.pop("status").lower()
|
||||
mapped_status = STATUS_MAPPING.get(api_status, "unknown")
|
||||
|
||||
return NvidiaPostTrainingJobStatusResponse(
|
||||
status=JobStatus(mapped_status),
|
||||
job_uuid=job_uuid,
|
||||
started_at=datetime.fromisoformat(response.pop("created_at")),
|
||||
updated_at=datetime.fromisoformat(response.pop("updated_at")),
|
||||
**response,
|
||||
)
|
||||
|
||||
async def cancel_training_job(self, job_uuid: str) -> None:
|
||||
async def update_post_training_job(self, job_id: str, status: JobStatus | None = None) -> PostTrainingJob:
|
||||
if status is None:
|
||||
raise ValueError("Status must be provided")
|
||||
if status not in {JobStatus.cancelled}:
|
||||
raise ValueError(f"Unsupported status: {status}")
|
||||
await self._make_request(
|
||||
method="POST", path=f"/v1/customization/jobs/{job_uuid}/cancel", params={"job_id": job_uuid}
|
||||
method="POST", path=f"/v1/customization/jobs/{job_id}/cancel", params={"job_id": job_id}
|
||||
)
|
||||
return await self.get_post_training_job(job_id)
|
||||
|
||||
async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse:
|
||||
raise NotImplementedError("Job artifacts are not implemented yet")
|
||||
|
||||
async def get_post_training_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse:
|
||||
raise NotImplementedError("Job artifacts are not implemented yet")
|
||||
async def delete_post_training_job(self, job_id: str) -> None:
|
||||
raise NotImplementedError("Delete job is not implemented yet")
|
||||
|
||||
async def supervised_fine_tune(
|
||||
self,
|
||||
|
|
@ -206,7 +125,7 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
|||
model: str,
|
||||
checkpoint_dir: Optional[str],
|
||||
algorithm_config: Optional[AlgorithmConfig] = None,
|
||||
) -> NvidiaPostTrainingJob:
|
||||
) -> PostTrainingJob:
|
||||
"""
|
||||
Fine-tunes a model on a dataset.
|
||||
Currently only supports Lora finetuning for standlone docker container.
|
||||
|
|
@ -409,15 +328,12 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
|||
headers={"Accept": "application/json"},
|
||||
json=job_config,
|
||||
)
|
||||
|
||||
job_uuid = response["id"]
|
||||
response.pop("status")
|
||||
created_at = datetime.fromisoformat(response.pop("created_at"))
|
||||
updated_at = datetime.fromisoformat(response.pop("updated_at"))
|
||||
|
||||
return NvidiaPostTrainingJob(
|
||||
job_uuid=job_uuid, status=JobStatus.in_progress, created_at=created_at, updated_at=updated_at, **response
|
||||
)
|
||||
# TODO: expose artifacts
|
||||
job = PostTrainingJob(**response)
|
||||
job.update_status(JobStatus.running)
|
||||
return job
|
||||
|
||||
async def preference_optimize(
|
||||
self,
|
||||
|
|
@ -430,6 +346,3 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
|||
) -> PostTrainingJob:
|
||||
"""Optimize a model based on preference data."""
|
||||
raise NotImplementedError("Preference optimization is not implemented yet")
|
||||
|
||||
async def get_training_job_container_logs(self, job_uuid: str) -> PostTrainingJobStatusResponse:
|
||||
raise NotImplementedError("Job logs are not implemented yet")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue