feat(api): define a more coherent jobs api across different flows

Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
Ihar Hrachyshka 2025-03-24 20:54:04 -04:00
parent 71ed47ea76
commit 0f50cfa561
15 changed files with 1864 additions and 1670 deletions

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -4,9 +4,11 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List, Optional, Protocol, runtime_checkable
from typing import List, Literal, Optional, Protocol, runtime_checkable
from llama_stack.apis.common.job_types import Job
from pydantic import BaseModel
from llama_stack.apis.common.job_types import BaseJob
from llama_stack.apis.inference import (
InterleavedContent,
LogProbConfig,
@ -20,6 +22,14 @@ from llama_stack.apis.inference import (
from llama_stack.schema_utils import webmethod
class BatchInferenceJob(BaseJob, BaseModel):
type: Literal["batch_inference"] = "batch_inference"
class ListBatchInferenceJobsResponse(BaseModel):
data: list[BatchInferenceJob]
@runtime_checkable
class BatchInference(Protocol):
"""Batch inference API for generating completions and chat completions.
@ -38,7 +48,7 @@ class BatchInference(Protocol):
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None,
) -> Job: ...
) -> BatchInferenceJob: ...
@webmethod(route="/batch-inference/chat-completion", method="POST")
async def chat_completion(
@ -52,4 +62,4 @@ class BatchInference(Protocol):
tool_prompt_format: Optional[ToolPromptFormat] = None,
response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None,
) -> Job: ...
) -> BatchInferenceJob: ...

View file

@ -3,22 +3,68 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from datetime import datetime, timezone
from enum import Enum, unique
from typing import Any
from pydantic import BaseModel
from pydantic import BaseModel, Field, computed_field
from llama_stack.schema_utils import json_schema_type
@unique
class JobStatus(Enum):
completed = "completed"
in_progress = "in_progress"
failed = "failed"
unknown = "unknown"
new = "new"
scheduled = "scheduled"
running = "running"
paused = "paused"
resuming = "resuming"
cancelled = "cancelled"
failed = "failed"
completed = "completed"
@json_schema_type
class Job(BaseModel):
job_id: str
class JobStatusDetails(BaseModel):
status: JobStatus
message: str | None = None
timestamp: datetime
@json_schema_type
class JobArtifact(BaseModel):
name: str
# TODO: should it be a Literal / Enum?
type: str
# Any additional metadata the artifact may have
# TODO: is Any the right type here? What happens when the caller passes a value without a __repr__?
metadata: dict[str, Any] | None = None
# TODO: enforce type to be a URI
uri: str | None = None # points to /files
def _get_job_status_details(status: JobStatus) -> JobStatusDetails:
return JobStatusDetails(status=status, timestamp=datetime.now(timezone.utc))
class BaseJob(BaseModel):
id: str # TODO: make it a UUID?
artifacts: list[JobArtifact] = Field(default_factory=list)
events: list[JobStatusDetails] = Field(default_factory=lambda: [_get_job_status_details(JobStatus.new)])
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
if "type" not in cls.__annotations__:
raise ValueError(f"Class {cls.__name__} must have a type field")
@computed_field
def status(self) -> JobStatus:
return self.events[-1].status
def update_status(self, value: JobStatus):
self.events.append(_get_job_status_details(value))

View file

@ -4,15 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List, Literal, Optional, Protocol, Union
from typing import Dict, Literal, Optional, Protocol, Union
from pydantic import BaseModel, Field
from typing_extensions import Annotated
from llama_stack.apis.agents import AgentConfig
from llama_stack.apis.common.job_types import Job
from llama_stack.apis.common.job_types import BaseJob
from llama_stack.apis.inference import SamplingParams, SystemMessage
from llama_stack.apis.scoring import ScoringResult
from llama_stack.apis.scoring_functions import ScoringFnParams
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
@ -47,6 +46,14 @@ EvalCandidate = Annotated[Union[ModelCandidate, AgentCandidate], Field(discrimin
register_schema(EvalCandidate, name="EvalCandidate")
class EvaluateJob(BaseJob, BaseModel):
type: Literal["eval"] = "eval"
class ListEvaluateJobsResponse(BaseModel):
data: list[EvaluateJob]
@json_schema_type
class BenchmarkConfig(BaseModel):
"""A benchmark configuration for evaluation.
@ -68,76 +75,30 @@ class BenchmarkConfig(BaseModel):
# we could optinally add any specific dataset config here
@json_schema_type
class EvaluateResponse(BaseModel):
"""The response from an evaluation.
:param generations: The generations from the evaluation.
:param scores: The scores from the evaluation.
"""
generations: List[Dict[str, Any]]
# each key in the dict is a scoring function name
scores: Dict[str, ScoringResult]
class Eval(Protocol):
"""Llama Stack Evaluation API for running evaluations on model and agent candidates."""
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs", method="POST")
async def run_eval(
@webmethod(route="/eval/benchmarks/{benchmark_id}/evaluate", method="POST")
async def evaluate(
self,
benchmark_id: str,
benchmark_config: BenchmarkConfig,
) -> Job:
"""Run an evaluation on a benchmark.
) -> EvaluateJob: ...
:param benchmark_id: The ID of the benchmark to run the evaluation on.
:param benchmark_config: The configuration for the benchmark.
:return: The job that was created to run the evaluation.
"""
# CRUD operations on running jobs
@webmethod(route="/evaluate/jobs/{job_id:path}", method="GET")
async def get_evaluate_job(self, job_id: str) -> EvaluateJob: ...
@webmethod(route="/eval/benchmarks/{benchmark_id}/evaluations", method="POST")
async def evaluate_rows(
self,
benchmark_id: str,
input_rows: List[Dict[str, Any]],
scoring_functions: List[str],
benchmark_config: BenchmarkConfig,
) -> EvaluateResponse:
"""Evaluate a list of rows on a benchmark.
@webmethod(route="/evaluate/jobs", method="GET")
async def list_evaluate_jobs(self) -> ListEvaluateJobsResponse: ...
:param benchmark_id: The ID of the benchmark to run the evaluation on.
:param input_rows: The rows to evaluate.
:param scoring_functions: The scoring functions to use for the evaluation.
:param benchmark_config: The configuration for the benchmark.
:return: EvaluateResponse object containing generations and scores
"""
@webmethod(route="/evaluate/jobs/{job_id:path}", method="POST")
async def update_evaluate_job(self, job: EvaluateJob) -> EvaluateJob: ...
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="GET")
async def job_status(self, benchmark_id: str, job_id: str) -> Job:
"""Get the status of a job.
@webmethod(route="/evaluate/job/{job_id:path}", method="DELETE")
async def delete_evaluate_job(self, job_id: str) -> None: ...
:param benchmark_id: The ID of the benchmark to run the evaluation on.
:param job_id: The ID of the job to get the status of.
:return: The status of the evaluationjob.
"""
...
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="DELETE")
async def job_cancel(self, benchmark_id: str, job_id: str) -> None:
"""Cancel a job.
:param benchmark_id: The ID of the benchmark to run the evaluation on.
:param job_id: The ID of the job to cancel.
"""
...
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}/result", method="GET")
async def job_result(self, benchmark_id: str, job_id: str) -> EvaluateResponse:
"""Get the result of a job.
:param benchmark_id: The ID of the benchmark to run the evaluation on.
:param job_id: The ID of the job to get the result of.
:return: The result of the job.
"""
# Note: pause/resume/cancel are achieved as follows:
# - POST with status=paused
# - POST with status=resuming
# - POST with status=cancelled

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Protocol, Union
@ -12,8 +11,7 @@ from pydantic import BaseModel, Field
from typing_extensions import Annotated
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.common.job_types import JobStatus
from llama_stack.apis.common.training_types import Checkpoint
from llama_stack.apis.common.job_types import BaseJob
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
@ -92,14 +90,6 @@ AlgorithmConfig = Annotated[Union[LoraFinetuningConfig, QATFinetuningConfig], Fi
register_schema(AlgorithmConfig, name="AlgorithmConfig")
@json_schema_type
class PostTrainingJobLogStream(BaseModel):
"""Stream of logs from a finetuning job."""
job_uuid: str
log_lines: List[str]
@json_schema_type
class RLHFAlgorithm(Enum):
dpo = "dpo"
@ -135,41 +125,17 @@ class PostTrainingRLHFRequest(BaseModel):
logger_config: Dict[str, Any]
class PostTrainingJob(BaseModel):
job_uuid: str
@json_schema_type
class PostTrainingJobStatusResponse(BaseModel):
"""Status of a finetuning job."""
job_uuid: str
status: JobStatus
scheduled_at: Optional[datetime] = None
started_at: Optional[datetime] = None
completed_at: Optional[datetime] = None
resources_allocated: Optional[Dict[str, Any]] = None
checkpoints: List[Checkpoint] = Field(default_factory=list)
class PostTrainingJob(BaseJob, BaseModel):
type: Literal["post-training"] = "post-training"
class ListPostTrainingJobsResponse(BaseModel):
data: List[PostTrainingJob]
@json_schema_type
class PostTrainingJobArtifactsResponse(BaseModel):
"""Artifacts of a finetuning job."""
job_uuid: str
checkpoints: List[Checkpoint] = Field(default_factory=list)
# TODO(ashwin): metrics, evals
data: list[PostTrainingJob]
class PostTraining(Protocol):
# This is how you create a new job - POST against the root endpoint
@webmethod(route="/post-training/supervised-fine-tune", method="POST")
async def supervised_fine_tune(
self,
@ -196,14 +162,20 @@ class PostTraining(Protocol):
logger_config: Dict[str, Any],
) -> PostTrainingJob: ...
# CRUD operations on running jobs
@webmethod(route="/post-training/jobs/{job_id:path}", method="GET")
async def get_post_training_job(self, job_id: str) -> PostTrainingJob: ...
@webmethod(route="/post-training/jobs", method="GET")
async def get_training_jobs(self) -> ListPostTrainingJobsResponse: ...
async def list_post_training_jobs(self) -> ListPostTrainingJobsResponse: ...
@webmethod(route="/post-training/job/status", method="GET")
async def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse: ...
@webmethod(route="/post-training/jobs/{job_id:path}", method="POST")
async def update_post_training_job(self, job: PostTrainingJob) -> PostTrainingJob: ...
@webmethod(route="/post-training/job/cancel", method="POST")
async def cancel_training_job(self, job_uuid: str) -> None: ...
@webmethod(route="/post-training/job/{job_id:path}", method="DELETE")
async def delete_post_training_job(self, job_id: str) -> None: ...
@webmethod(route="/post-training/job/artifacts", method="GET")
async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse: ...
# Note: pause/resume/cancel are achieved as follows:
# - POST with status=paused
# - POST with status=resuming
# - POST with status=cancelled

View file

@ -5,10 +5,11 @@
# the root directory of this source tree.
from enum import Enum
from typing import Any, Dict, List, Optional, Protocol, Union
from typing import List, Literal, Optional, Protocol
from pydantic import BaseModel
from llama_stack.apis.common.job_types import BaseJob
from llama_stack.apis.inference import Message
from llama_stack.schema_utils import json_schema_type, webmethod
@ -34,11 +35,13 @@ class SyntheticDataGenerationRequest(BaseModel):
@json_schema_type
class SyntheticDataGenerationResponse(BaseModel):
"""Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold."""
class SyntheticDataGenerationJob(BaseJob, BaseModel):
type: Literal["synthetic-data-generation"] = "synthetic-data-generation"
synthetic_data: List[Dict[str, Any]]
statistics: Optional[Dict[str, Any]] = None
@json_schema_type
class ListSyntheticDataGenerationJobsResponse(BaseModel):
items: list[SyntheticDataGenerationJob]
class SyntheticDataGeneration(Protocol):
@ -48,4 +51,24 @@ class SyntheticDataGeneration(Protocol):
dialogs: List[Message],
filtering_function: FilteringFunction = FilteringFunction.none,
model: Optional[str] = None,
) -> Union[SyntheticDataGenerationResponse]: ...
) -> SyntheticDataGenerationJob: ...
# CRUD operations on running jobs
@webmethod(route="/synthetic-data-generation/jobs/{job_id:path}", method="GET")
async def get_synthetic_data_generation_job(self) -> SyntheticDataGenerationJob: ...
@webmethod(route="/synthetic-data-generation/jobs", method="GET")
async def list_synthetic_data_generation_jobs(self) -> ListSyntheticDataGenerationJobsResponse: ...
@webmethod(route="/synthetic-data-generation/jobs/{job_id:path}", method="POST")
async def update_synthetic_data_generation_job(
self, job: SyntheticDataGenerationJob
) -> SyntheticDataGenerationJob: ...
@webmethod(route="/synthetic-data-generation/job/{job_id:path}", method="DELETE")
async def delete_synthetic_data_generation_job(self, job_id: str) -> None: ...
# Note: pause/resume/cancel are achieved as follows:
# - POST with status=paused
# - POST with status=resuming
# - POST with status=cancelled

View file

@ -16,7 +16,7 @@ from llama_stack.apis.common.content_types import (
from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import DatasetPurpose, DataSource
from llama_stack.apis.eval import BenchmarkConfig, Eval, EvaluateResponse, Job
from llama_stack.apis.eval import BenchmarkConfig, Eval, EvaluateJob, ListEvaluateJobsResponse
from llama_stack.apis.inference import (
BatchChatCompletionResponse,
BatchCompletionResponse,
@ -779,61 +779,32 @@ class EvalRouter(Eval):
logger.debug("EvalRouter.shutdown")
pass
async def run_eval(
async def evaluate(
self,
benchmark_id: str,
benchmark_config: BenchmarkConfig,
) -> Job:
logger.debug(f"EvalRouter.run_eval: {benchmark_id}")
return await self.routing_table.get_provider_impl(benchmark_id).run_eval(
) -> EvaluateJob:
logger.debug(f"EvalRouter.evaluate: {benchmark_id}")
return await self.routing_table.get_provider_impl(benchmark_id).evaluate(
benchmark_id=benchmark_id,
benchmark_config=benchmark_config,
)
async def evaluate_rows(
self,
benchmark_id: str,
input_rows: List[Dict[str, Any]],
scoring_functions: List[str],
benchmark_config: BenchmarkConfig,
) -> EvaluateResponse:
logger.debug(f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows")
return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows(
benchmark_id=benchmark_id,
input_rows=input_rows,
scoring_functions=scoring_functions,
benchmark_config=benchmark_config,
)
async def get_evaluate_job(self, job_id: str) -> EvaluateJob:
logger.debug(f"EvalRouter.get_evaluate_job: {job_id}")
return await self.routing_table.get_provider_impl("eval").get_evaluate_job(job_id)
async def job_status(
self,
benchmark_id: str,
job_id: str,
) -> Job:
logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}")
return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id)
async def list_evaluate_jobs(self) -> ListEvaluateJobsResponse:
logger.debug("EvalRouter.list_evaluate_jobs")
return await self.routing_table.get_provider_impl("eval").list_evaluate_jobs()
async def job_cancel(
self,
benchmark_id: str,
job_id: str,
) -> None:
logger.debug(f"EvalRouter.job_cancel: {benchmark_id}, {job_id}")
await self.routing_table.get_provider_impl(benchmark_id).job_cancel(
benchmark_id,
job_id,
)
async def update_evaluate_job(self, job: EvaluateJob) -> EvaluateJob:
logger.debug(f"EvalRouter.update_evaluate_job: {job.id}")
return await self.routing_table.get_provider_impl("eval").update_evaluate_job(job)
async def job_result(
self,
benchmark_id: str,
job_id: str,
) -> EvaluateResponse:
logger.debug(f"EvalRouter.job_result: {benchmark_id}, {job_id}")
return await self.routing_table.get_provider_impl(benchmark_id).job_result(
benchmark_id,
job_id,
)
async def delete_evaluate_job(self, job_id: str) -> None:
logger.debug(f"EvalRouter.delete_evaluate_job: {job_id}")
return await self.routing_table.get_provider_impl("eval").delete_evaluate_job(job_id)
class ToolRuntimeRouter(ToolRuntime):

View file

@ -20,9 +20,10 @@ from llama_stack.providers.inline.agents.meta_reference.agent_instance import (
)
from llama_stack.providers.utils.common.data_schema_validator import ColumnName
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.schema_utils import webmethod
from .....apis.common.job_types import Job, JobStatus
from .....apis.eval.eval import BenchmarkConfig, Eval, EvaluateResponse
from .....apis.common.job_types import JobArtifact, JobStatus
from .....apis.eval.eval import BenchmarkConfig, Eval, EvaluateJob, ListEvaluateJobsResponse
from .config import MetaReferenceEvalConfig
EVAL_TASKS_PREFIX = "benchmarks:"
@ -75,11 +76,11 @@ class MetaReferenceEvalImpl(
)
self.benchmarks[task_def.identifier] = task_def
async def run_eval(
async def evaluate(
self,
benchmark_id: str,
benchmark_config: BenchmarkConfig,
) -> Job:
) -> EvaluateJob:
task_def = self.benchmarks[benchmark_id]
dataset_id = task_def.dataset_id
scoring_functions = task_def.scoring_functions
@ -91,18 +92,35 @@ class MetaReferenceEvalImpl(
dataset_id=dataset_id,
limit=(-1 if benchmark_config.num_examples is None else benchmark_config.num_examples),
)
res = await self.evaluate_rows(
generations, scoring_results = await self._evaluate_rows(
benchmark_id=benchmark_id,
input_rows=all_rows.data,
scoring_functions=scoring_functions,
benchmark_config=benchmark_config,
)
artifacts = [
JobArtifact(
type="generation",
name=f"generation-{i}",
metadata=generation,
)
for i, generation in enumerate(generations)
] + [
JobArtifact(
type="scoring_results",
name="scoring_results",
metadata=scoring_results,
)
]
# TODO: currently needs to wait for generation before returning
# need job scheduler queue (ray/celery) w/ jobs api
job_id = str(len(self.jobs))
self.jobs[job_id] = res
return Job(job_id=job_id, status=JobStatus.completed)
job = EvaluateJob(id=job_id, artifacts=artifacts)
job.update_status(JobStatus.completed)
self.jobs[job_id] = job
return job
async def _run_agent_generation(
self, input_rows: List[Dict[str, Any]], benchmark_config: BenchmarkConfig
@ -182,13 +200,13 @@ class MetaReferenceEvalImpl(
return generations
async def evaluate_rows(
async def _evaluate_rows(
self,
benchmark_id: str,
input_rows: List[Dict[str, Any]],
scoring_functions: List[str],
benchmark_config: BenchmarkConfig,
) -> EvaluateResponse:
) -> tuple[list[dict[str, Any]], dict[str, Any]]:
candidate = benchmark_config.eval_candidate
if candidate.type == "agent":
generations = await self._run_agent_generation(input_rows, benchmark_config)
@ -214,21 +232,26 @@ class MetaReferenceEvalImpl(
input_rows=score_input_rows, scoring_functions=scoring_functions_dict
)
return EvaluateResponse(generations=generations, scores=score_response.results)
async def job_status(self, benchmark_id: str, job_id: str) -> Job:
if job_id in self.jobs:
return Job(job_id=job_id, status=JobStatus.completed)
raise ValueError(f"Job {job_id} not found")
async def job_cancel(self, benchmark_id: str, job_id: str) -> None:
raise NotImplementedError("Job cancel is not implemented yet")
async def job_result(self, benchmark_id: str, job_id: str) -> EvaluateResponse:
job = await self.job_status(benchmark_id, job_id)
status = job.status
if not status or status != JobStatus.completed:
raise ValueError(f"Job is not completed, Status: {status.value}")
return generations, score_response.results
# CRUD operations on running jobs
@webmethod(route="/evaluate/jobs/{job_id:path}", method="GET")
async def get_evaluate_job(self, job_id: str) -> EvaluateJob:
return self.jobs[job_id]
@webmethod(route="/evaluate/jobs", method="GET")
async def list_evaluate_jobs(self) -> ListEvaluateJobsResponse:
return ListEvaluateJobsResponse(data=list(self.jobs.values()))
@webmethod(route="/evaluate/jobs/{job_id:path}", method="POST")
async def update_evaluate_job(self, job: EvaluateJob) -> EvaluateJob:
raise NotImplementedError
@webmethod(route="/evaluate/job/{job_id:path}", method="DELETE")
async def delete_evaluate_job(self, job_id: str) -> None:
raise NotImplementedError
# Note: pause/resume/cancel are achieved as follows:
# - POST with status=paused
# - POST with status=resuming
# - POST with status=cancelled

View file

@ -10,14 +10,10 @@ from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.post_training import (
AlgorithmConfig,
Checkpoint,
DPOAlignmentConfig,
JobStatus,
ListPostTrainingJobsResponse,
LoraFinetuningConfig,
PostTrainingJob,
PostTrainingJobArtifactsResponse,
PostTrainingJobStatusResponse,
TrainingConfig,
)
from llama_stack.providers.inline.post_training.torchtune.config import (
@ -54,15 +50,6 @@ class TorchtunePostTrainingImpl:
async def shutdown(self) -> None:
await self._scheduler.shutdown()
@staticmethod
def _checkpoint_to_artifact(checkpoint: Checkpoint) -> JobArtifact:
return JobArtifact(
type=TrainingArtifactType.CHECKPOINT.value,
name=checkpoint.identifier,
uri=checkpoint.path,
metadata=dict(checkpoint),
)
@staticmethod
def _resources_stats_to_artifact(resources_stats: Dict[str, Any]) -> JobArtifact:
return JobArtifact(
@ -98,14 +85,14 @@ class TorchtunePostTrainingImpl:
self.datasetio_api,
self.datasets_api,
)
await recipe.setup()
resources_allocated, checkpoints = await recipe.train()
on_artifact_collected_cb(self._resources_stats_to_artifact(resources_allocated))
for checkpoint in checkpoints:
artifact = self._checkpoint_to_artifact(checkpoint)
on_artifact_collected_cb(artifact)
on_artifact_collected_cb(checkpoint)
on_status_change_cb(SchedulerJobStatus.completed)
on_log_message_cb("Lora finetuning completed")
@ -113,6 +100,8 @@ class TorchtunePostTrainingImpl:
raise NotImplementedError()
job_uuid = self._scheduler.schedule(_JOB_TYPE_SUPERVISED_FINE_TUNE, job_uuid, handler)
# TODO: initialize with more data from scheduler
return PostTrainingJob(job_uuid=job_uuid)
async def preference_optimize(
@ -125,56 +114,31 @@ class TorchtunePostTrainingImpl:
logger_config: Dict[str, Any],
) -> PostTrainingJob: ...
async def get_training_jobs(self) -> ListPostTrainingJobsResponse:
# TODO: should these be under post-training/supervised-fine-tune/?
# CRUD operations on running jobs
@webmethod(route="/post-training/jobs/{job_id:path}", method="GET")
async def get_post_training_job(self, job_id: str) -> PostTrainingJob:
# TODO: implement
raise NotImplementedError
@webmethod(route="/post-training/jobs", method="GET")
async def list_post_training_jobs(self) -> ListPostTrainingJobsResponse:
# TODO: populate other data
return ListPostTrainingJobsResponse(
data=[PostTrainingJob(job_uuid=job.id) for job in self._scheduler.get_jobs()]
)
@staticmethod
def _get_artifacts_metadata_by_type(job, artifact_type):
return [artifact.metadata for artifact in job.artifacts if artifact.type == artifact_type]
@webmethod(route="/post-training/jobs/{job_id:path}", method="POST")
async def update_post_training_job(self, job: PostTrainingJob) -> PostTrainingJob:
# TODO: implement
raise NotImplementedError
@classmethod
def _get_checkpoints(cls, job):
return cls._get_artifacts_metadata_by_type(job, TrainingArtifactType.CHECKPOINT.value)
@webmethod(route="/post-training/job/{job_id:path}", method="DELETE")
async def delete_post_training_job(self, job_id: str) -> None:
# TODO: implement
raise NotImplementedError
@classmethod
def _get_resources_allocated(cls, job):
data = cls._get_artifacts_metadata_by_type(job, TrainingArtifactType.RESOURCES_STATS.value)
return data[0] if data else None
@webmethod(route="/post-training/job/status")
async def get_training_job_status(self, job_uuid: str) -> Optional[PostTrainingJobStatusResponse]:
job = self._scheduler.get_job(job_uuid)
match job.status:
# TODO: Add support for other statuses to API
case SchedulerJobStatus.new | SchedulerJobStatus.scheduled:
status = JobStatus.scheduled
case SchedulerJobStatus.running:
status = JobStatus.in_progress
case SchedulerJobStatus.completed:
status = JobStatus.completed
case SchedulerJobStatus.failed:
status = JobStatus.failed
case _:
raise NotImplementedError()
return PostTrainingJobStatusResponse(
job_uuid=job_uuid,
status=status,
scheduled_at=job.scheduled_at,
started_at=job.started_at,
completed_at=job.completed_at,
checkpoints=self._get_checkpoints(job),
resources_allocated=self._get_resources_allocated(job),
)
@webmethod(route="/post-training/job/cancel")
async def cancel_training_job(self, job_uuid: str) -> None:
self._scheduler.cancel(job_uuid)
@webmethod(route="/post-training/job/artifacts")
async def get_training_job_artifacts(self, job_uuid: str) -> Optional[PostTrainingJobArtifactsResponse]:
job = self._scheduler.get_job(job_uuid)
return PostTrainingJobArtifactsResponse(job_uuid=job_uuid, checkpoints=self._get_checkpoints(job))
# Note: pause/resume/cancel are achieved as follows:
# - POST with status=paused
# - POST with status=resuming
# - POST with status=cancelled

View file

@ -33,11 +33,11 @@ from torchtune.training.lr_schedulers import get_cosine_schedule_with_warmup
from torchtune.training.metric_logging import DiskLogger
from tqdm import tqdm
from llama_stack.apis.common.job_types import JobArtifact
from llama_stack.apis.common.training_types import PostTrainingMetric
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.post_training import (
Checkpoint,
DataConfig,
EfficiencyConfig,
LoraFinetuningConfig,
@ -457,7 +457,7 @@ class LoraFinetuningSingleDevice:
return loss
async def train(self) -> Tuple[Dict[str, Any], List[Checkpoint]]:
async def train(self) -> Tuple[Dict[str, Any], List[JobArtifact]]:
"""
The core training loop.
"""
@ -543,13 +543,18 @@ class LoraFinetuningSingleDevice:
self.epochs_run += 1
log.info("Starting checkpoint save...")
checkpoint_path = await self.save_checkpoint(epoch=curr_epoch)
checkpoint = Checkpoint(
identifier=f"{self.model_id}-sft-{curr_epoch}",
created_at=datetime.now(timezone.utc),
epoch=curr_epoch,
post_training_job_id=self.job_uuid,
path=checkpoint_path,
checkpoint = JobArtifact(
name=f"{self.model_id}-sft-{curr_epoch}",
type="checkpoint",
# TODO: this should be exposed via /files instead
uri=checkpoint_path,
)
metadata = {
"created_at": datetime.now(timezone.utc),
"epoch": curr_epoch,
}
if self.training_config.data_config.validation_dataset_id:
validation_loss, perplexity = await self.validation()
training_metrics = PostTrainingMetric(
@ -558,7 +563,9 @@ class LoraFinetuningSingleDevice:
validation_loss=validation_loss,
perplexity=perplexity,
)
checkpoint.training_metrics = training_metrics
metadata["training_metrics"] = training_metrics
checkpoint.metadata = metadata
checkpoints.append(checkpoint)
# clean up the memory after training finishes

View file

@ -4,19 +4,16 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import warnings
from datetime import datetime
from typing import Any, Dict, List, Literal, Optional
from typing import Any, Dict, Optional
import aiohttp
from pydantic import BaseModel, ConfigDict
from llama_stack.apis.common.job_types import JobStatus
from llama_stack.apis.post_training import (
AlgorithmConfig,
DPOAlignmentConfig,
JobStatus,
ListPostTrainingJobsResponse,
PostTrainingJob,
PostTrainingJobArtifactsResponse,
PostTrainingJobStatusResponse,
TrainingConfig,
)
from llama_stack.providers.remote.post_training.nvidia.config import NvidiaPostTrainingConfig
@ -25,36 +22,6 @@ from llama_stack.providers.utils.inference.model_registry import ModelRegistryHe
from .models import _MODEL_ENTRIES
# Map API status to JobStatus enum
STATUS_MAPPING = {
"running": "in_progress",
"completed": "completed",
"failed": "failed",
"cancelled": "cancelled",
"pending": "scheduled",
}
class NvidiaPostTrainingJob(PostTrainingJob):
"""Parse the response from the Customizer API.
Inherits job_uuid from PostTrainingJob.
Adds status, created_at, updated_at parameters.
Passes through all other parameters from data field in the response.
"""
model_config = ConfigDict(extra="allow")
status: JobStatus
created_at: datetime
updated_at: datetime
class ListNvidiaPostTrainingJobs(BaseModel):
data: List[NvidiaPostTrainingJob]
class NvidiaPostTrainingJobStatusResponse(PostTrainingJobStatusResponse):
model_config = ConfigDict(extra="allow")
class NvidiaPostTrainingAdapter(ModelRegistryHelper):
def __init__(self, config: NvidiaPostTrainingConfig):
@ -100,102 +67,54 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
raise Exception(f"API request failed: {error_data}")
return await response.json()
async def get_training_jobs(
self,
page: Optional[int] = 1,
page_size: Optional[int] = 10,
sort: Optional[Literal["created_at", "-created_at"]] = "created_at",
) -> ListNvidiaPostTrainingJobs:
"""Get all customization jobs.
Updated the base class return type from ListPostTrainingJobsResponse to ListNvidiaPostTrainingJobs.
raise Exception(f"API request failed after {self.config.max_retries} retries")
Returns a ListNvidiaPostTrainingJobs object with the following fields:
- data: List[NvidiaPostTrainingJob] - List of NvidiaPostTrainingJob objects
@staticmethod
def _get_job_status(job: Dict[str, Any]) -> JobStatus:
job_status = job.get("status", "unknown").lower()
try:
return JobStatus(job_status)
except ValueError:
return JobStatus.unknown
# TODO: fetch just the necessary job from remote
async def get_post_training_job(self, job_id: str) -> PostTrainingJob:
jobs = await self.list_post_training_jobs()
for job in jobs.data:
if job.id == job_id:
return job
raise ValueError(f"Job with ID {job_id} not found")
async def list_post_training_jobs(self) -> ListPostTrainingJobsResponse:
"""Get all customization jobs.
ToDo: Support for schema input for filtering.
"""
params = {"page": page, "page_size": page_size, "sort": sort}
# TODO: don't hardcode pagination params
params = {"page": 1, "page_size": 10, "sort": "created_at"}
response = await self._make_request("GET", "/v1/customization/jobs", params=params)
jobs = []
for job in response.get("data", []):
job_id = job.pop("id")
job_status = job.pop("status", "unknown").lower()
mapped_status = STATUS_MAPPING.get(job_status, "unknown")
for job_dict in response.get("data", []):
# TODO: expose artifacts
job = PostTrainingJob(**job_dict)
job.update_status(self._get_job_status(job_dict))
jobs.append(job)
# Convert string timestamps to datetime objects
created_at = (
datetime.fromisoformat(job.pop("created_at"))
if "created_at" in job
else datetime.now(tz=datetime.timezone.utc)
)
updated_at = (
datetime.fromisoformat(job.pop("updated_at"))
if "updated_at" in job
else datetime.now(tz=datetime.timezone.utc)
)
return ListPostTrainingJobsResponse(data=jobs)
# Create NvidiaPostTrainingJob instance
jobs.append(
NvidiaPostTrainingJob(
job_uuid=job_id,
status=JobStatus(mapped_status),
created_at=created_at,
updated_at=updated_at,
**job,
)
)
return ListNvidiaPostTrainingJobs(data=jobs)
async def get_training_job_status(self, job_uuid: str) -> NvidiaPostTrainingJobStatusResponse:
"""Get the status of a customization job.
Updated the base class return type from PostTrainingJobResponse to NvidiaPostTrainingJob.
Returns a NvidiaPostTrainingJob object with the following fields:
- job_uuid: str - Unique identifier for the job
- status: JobStatus - Current status of the job (in_progress, completed, failed, cancelled, scheduled)
- created_at: datetime - The time when the job was created
- updated_at: datetime - The last time the job status was updated
Additional fields that may be included:
- steps_completed: Optional[int] - Number of training steps completed
- epochs_completed: Optional[int] - Number of epochs completed
- percentage_done: Optional[float] - Percentage of training completed (0-100)
- best_epoch: Optional[int] - The epoch with the best performance
- train_loss: Optional[float] - Training loss of the best checkpoint
- val_loss: Optional[float] - Validation loss of the best checkpoint
- metrics: Optional[Dict] - Additional training metrics
- status_logs: Optional[List] - Detailed logs of status changes
"""
response = await self._make_request(
"GET",
f"/v1/customization/jobs/{job_uuid}/status",
params={"job_id": job_uuid},
)
api_status = response.pop("status").lower()
mapped_status = STATUS_MAPPING.get(api_status, "unknown")
return NvidiaPostTrainingJobStatusResponse(
status=JobStatus(mapped_status),
job_uuid=job_uuid,
started_at=datetime.fromisoformat(response.pop("created_at")),
updated_at=datetime.fromisoformat(response.pop("updated_at")),
**response,
)
async def cancel_training_job(self, job_uuid: str) -> None:
async def update_post_training_job(self, job_id: str, status: JobStatus | None = None) -> PostTrainingJob:
if status is None:
raise ValueError("Status must be provided")
if status not in {JobStatus.cancelled}:
raise ValueError(f"Unsupported status: {status}")
await self._make_request(
method="POST", path=f"/v1/customization/jobs/{job_uuid}/cancel", params={"job_id": job_uuid}
method="POST", path=f"/v1/customization/jobs/{job_id}/cancel", params={"job_id": job_id}
)
return await self.get_post_training_job(job_id)
async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse:
raise NotImplementedError("Job artifacts are not implemented yet")
async def get_post_training_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse:
raise NotImplementedError("Job artifacts are not implemented yet")
async def delete_post_training_job(self, job_id: str) -> None:
raise NotImplementedError("Delete job is not implemented yet")
async def supervised_fine_tune(
self,
@ -206,7 +125,7 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
model: str,
checkpoint_dir: Optional[str],
algorithm_config: Optional[AlgorithmConfig] = None,
) -> NvidiaPostTrainingJob:
) -> PostTrainingJob:
"""
Fine-tunes a model on a dataset.
Currently only supports Lora finetuning for standlone docker container.
@ -409,15 +328,12 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
headers={"Accept": "application/json"},
json=job_config,
)
job_uuid = response["id"]
response.pop("status")
created_at = datetime.fromisoformat(response.pop("created_at"))
updated_at = datetime.fromisoformat(response.pop("updated_at"))
return NvidiaPostTrainingJob(
job_uuid=job_uuid, status=JobStatus.in_progress, created_at=created_at, updated_at=updated_at, **response
)
# TODO: expose artifacts
job = PostTrainingJob(**response)
job.update_status(JobStatus.running)
return job
async def preference_optimize(
self,
@ -430,6 +346,3 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
) -> PostTrainingJob:
"""Optimize a model based on preference data."""
raise NotImplementedError("Preference optimization is not implemented yet")
async def get_training_job_container_logs(self, job_uuid: str) -> PostTrainingJobStatusResponse:
raise NotImplementedError("Job logs are not implemented yet")

View file

@ -562,6 +562,15 @@ else:
return typing.get_type_hints(typ)
def get_computed_fields(typ: type) -> dict[str, type]:
"Returns all computed fields of a class."
pydantic_decorators = getattr(typ, "__pydantic_decorators__", None)
if not pydantic_decorators:
return {}
computed_fields = pydantic_decorators.computed_fields
return {field_name: decorator.info.return_type for field_name, decorator in computed_fields.items()}
def get_class_properties(typ: type) -> Iterable[Tuple[str, type | str]]:
"Returns all properties of a class."
@ -569,7 +578,8 @@ def get_class_properties(typ: type) -> Iterable[Tuple[str, type | str]]:
return ((field.name, field.type) for field in dataclasses.fields(typ))
else:
resolved_hints = get_resolved_hints(typ)
return resolved_hints.items()
computed_fields = get_computed_fields(typ)
return (resolved_hints | computed_fields).items()
def get_class_property(typ: type, name: str) -> Optional[type | str]:

View file

@ -8,14 +8,12 @@ from typing import List
import pytest
from llama_stack.apis.common.job_types import JobStatus
from llama_stack.apis.common.training_types import Checkpoint
from llama_stack.apis.post_training import (
Checkpoint,
DataConfig,
LoraFinetuningConfig,
OptimizerConfig,
PostTrainingJob,
PostTrainingJobArtifactsResponse,
PostTrainingJobStatusResponse,
TrainingConfig,
)
@ -84,7 +82,6 @@ class TestPostTraining:
async def test_get_training_job_status(self, post_training_stack):
post_training_impl = post_training_stack
job_status = await post_training_impl.get_training_job_status("1234")
assert isinstance(job_status, PostTrainingJobStatusResponse)
assert job_status.job_uuid == "1234"
assert job_status.status == JobStatus.completed
assert isinstance(job_status.checkpoints[0], Checkpoint)
@ -93,7 +90,6 @@ class TestPostTraining:
async def test_get_training_job_artifacts(self, post_training_stack):
post_training_impl = post_training_stack
job_artifacts = await post_training_impl.get_training_job_artifacts("1234")
assert isinstance(job_artifacts, PostTrainingJobArtifactsResponse)
assert job_artifacts.job_uuid == "1234"
assert isinstance(job_artifacts.checkpoints[0], Checkpoint)
assert job_artifacts.checkpoints[0].identifier == "Llama3.2-3B-Instruct-sft-0"

View file

@ -17,12 +17,14 @@ from llama_stack_client.types.post_training_supervised_fine_tune_params import (
TrainingConfigOptimizerConfig,
)
from llama_stack.apis.common.job_types import JobStatus
from llama_stack.apis.post_training import (
ListPostTrainingJobsResponse,
PostTrainingJob,
)
from llama_stack.providers.remote.post_training.nvidia.post_training import (
ListNvidiaPostTrainingJobs,
NvidiaPostTrainingAdapter,
NvidiaPostTrainingConfig,
NvidiaPostTrainingJob,
NvidiaPostTrainingJobStatusResponse,
)
@ -49,21 +51,25 @@ class TestNvidiaPostTraining(unittest.TestCase):
def _assert_request(self, mock_call, expected_method, expected_path, expected_params=None, expected_json=None):
"""Helper method to verify request details in mock calls."""
call_args = mock_call.call_args
found = False
for call_args in mock_call.call_args_list:
if expected_method and expected_path:
if isinstance(call_args[0], tuple) and len(call_args[0]) == 2:
if call_args[0] == (expected_method, expected_path):
found = True
else:
if call_args[1]["method"] == expected_method and call_args[1]["path"] == expected_path:
found = True
if expected_method and expected_path:
if isinstance(call_args[0], tuple) and len(call_args[0]) == 2:
assert call_args[0] == (expected_method, expected_path)
else:
assert call_args[1]["method"] == expected_method
assert call_args[1]["path"] == expected_path
if expected_params:
if call_args[1]["params"] == expected_params:
found = True
if expected_params:
assert call_args[1]["params"] == expected_params
if expected_json:
for key, value in expected_json.items():
assert call_args[1]["json"][key] == value
if expected_json:
for key, value in expected_json.items():
if call_args[1]["json"][key] == value:
found = True
assert found
def test_supervised_fine_tune(self):
"""Test the supervised fine-tuning API call."""
@ -151,9 +157,8 @@ class TestNvidiaPostTraining(unittest.TestCase):
)
)
# check the output is a PostTrainingJob
assert isinstance(training_job, NvidiaPostTrainingJob)
assert training_job.job_uuid == "cust-JGTaMbJMdqjJU8WbQdN9Q2"
assert isinstance(training_job, PostTrainingJob)
assert training_job.id == "cust-JGTaMbJMdqjJU8WbQdN9Q2"
self.mock_make_request.assert_called_once()
self._assert_request(
@ -199,38 +204,7 @@ class TestNvidiaPostTraining(unittest.TestCase):
)
)
def test_get_training_job_status(self):
self.mock_make_request.return_value = {
"created_at": "2024-12-09T04:06:28.580220",
"updated_at": "2024-12-09T04:21:19.852832",
"status": "completed",
"steps_completed": 1210,
"epochs_completed": 2,
"percentage_done": 100.0,
"best_epoch": 2,
"train_loss": 1.718016266822815,
"val_loss": 1.8661999702453613,
}
job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2"
status = self.run_async(self.adapter.get_training_job_status(job_uuid=job_id))
assert isinstance(status, NvidiaPostTrainingJobStatusResponse)
assert status.status.value == "completed"
assert status.steps_completed == 1210
assert status.epochs_completed == 2
assert status.percentage_done == 100.0
assert status.best_epoch == 2
assert status.train_loss == 1.718016266822815
assert status.val_loss == 1.8661999702453613
self.mock_make_request.assert_called_once()
self._assert_request(
self.mock_make_request, "GET", f"/v1/customization/jobs/{job_id}/status", expected_params={"job_id": job_id}
)
def test_get_training_jobs(self):
def test_list_post_training_jobs(self):
job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2"
self.mock_make_request.return_value = {
"data": [
@ -258,12 +232,12 @@ class TestNvidiaPostTraining(unittest.TestCase):
]
}
jobs = self.run_async(self.adapter.get_training_jobs())
jobs = self.run_async(self.adapter.list_post_training_jobs())
assert isinstance(jobs, ListNvidiaPostTrainingJobs)
assert isinstance(jobs, ListPostTrainingJobsResponse)
assert len(jobs.data) == 1
job = jobs.data[0]
assert job.job_uuid == job_id
assert job.id == job_id
assert job.status.value == "completed"
self.mock_make_request.assert_called_once()
@ -275,14 +249,36 @@ class TestNvidiaPostTraining(unittest.TestCase):
)
def test_cancel_training_job(self):
self.mock_make_request.return_value = {} # Empty response for successful cancellation
job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2"
self.mock_make_request.return_value = {
"data": [
{
"id": job_id,
"created_at": "2024-12-09T04:06:28.542884",
"updated_at": "2024-12-09T04:21:19.852832",
"config": {
"name": "meta-llama/Llama-3.1-8B-Instruct",
"base_model": "meta-llama/Llama-3.1-8B-Instruct",
},
"dataset": {"name": "default/sample-basic-test"},
"hyperparameters": {
"finetuning_type": "lora",
"training_type": "sft",
"batch_size": 16,
"epochs": 2,
"learning_rate": 0.0001,
"lora": {"adapter_dim": 16, "adapter_dropout": 0.1},
},
"output_model": "default/job-1234",
"status": "completed",
"project": "default",
}
]
}
result = self.run_async(self.adapter.cancel_training_job(job_uuid=job_id))
result = self.run_async(self.adapter.update_post_training_job(job_id=job_id, status=JobStatus.cancelled))
assert result.id == job_id
assert result is None
self.mock_make_request.assert_called_once()
self._assert_request(
self.mock_make_request,
"POST",