From 961528eae168ba875e1e7ffb173a3c5db6702a7c Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Wed, 12 Mar 2025 01:24:00 -0700 Subject: [PATCH] post training job api --- .../apis/post_training/post_training.py | 47 +++++-------------- 1 file changed, 11 insertions(+), 36 deletions(-) diff --git a/llama_stack/apis/post_training/post_training.py b/llama_stack/apis/post_training/post_training.py index ed15c6de4..079093a5d 100644 --- a/llama_stack/apis/post_training/post_training.py +++ b/llama_stack/apis/post_training/post_training.py @@ -12,7 +12,7 @@ from pydantic import BaseModel, Field from typing_extensions import Annotated from llama_stack.apis.common.content_types import URL -from llama_stack.apis.common.job_types import JobStatus +from llama_stack.apis.common.job_types import JobCommonFields, JobStatus from llama_stack.apis.common.training_types import Checkpoint from llama_stack.schema_utils import json_schema_type, register_schema, webmethod @@ -89,7 +89,9 @@ class QATFinetuningConfig(BaseModel): AlgorithmConfig = register_schema( - Annotated[Union[LoraFinetuningConfig, QATFinetuningConfig], Field(discriminator="type")], + Annotated[ + Union[LoraFinetuningConfig, QATFinetuningConfig], Field(discriminator="type") + ], name="AlgorithmConfig", ) @@ -137,23 +139,9 @@ class PostTrainingRLHFRequest(BaseModel): logger_config: Dict[str, Any] -class PostTrainingJob(BaseModel): - job_uuid: str - - @json_schema_type -class PostTrainingJobStatusResponse(BaseModel): - """Status of a finetuning job.""" - - job_uuid: str - status: JobStatus - - scheduled_at: Optional[datetime] = None - started_at: Optional[datetime] = None - completed_at: Optional[datetime] = None - +class PostTrainingJob(JobCommonFields): resources_allocated: Optional[Dict[str, Any]] = None - checkpoints: List[Checkpoint] = Field(default_factory=list) @@ -161,18 +149,8 @@ class ListPostTrainingJobsResponse(BaseModel): data: List[PostTrainingJob] -@json_schema_type -class PostTrainingJobArtifactsResponse(BaseModel): - """Artifacts of a finetuning job.""" - - job_uuid: str - checkpoints: List[Checkpoint] = Field(default_factory=list) - - # TODO(ashwin): metrics, evals - - class PostTraining(Protocol): - @webmethod(route="/post-training/supervised-fine-tune", method="POST") + @webmethod(route="/post-training/supervised-fine-tune/jobs", method="POST") async def supervised_fine_tune( self, job_uuid: str, @@ -187,7 +165,7 @@ class PostTraining(Protocol): algorithm_config: Optional[AlgorithmConfig] = None, ) -> PostTrainingJob: ... - @webmethod(route="/post-training/preference-optimize", method="POST") + @webmethod(route="/post-training/preference-optimize/jobs", method="POST") async def preference_optimize( self, job_uuid: str, @@ -201,11 +179,8 @@ class PostTraining(Protocol): @webmethod(route="/post-training/jobs", method="GET") async def get_training_jobs(self) -> ListPostTrainingJobsResponse: ... - @webmethod(route="/post-training/job/status", method="GET") - async def get_training_job_status(self, job_uuid: str) -> Optional[PostTrainingJobStatusResponse]: ... + @webmethod(route="/post-training/jobs/{job_id}", method="GET") + async def get_training_job(self, job_id: str) -> Optional[PostTrainingJob]: ... - @webmethod(route="/post-training/job/cancel", method="POST") - async def cancel_training_job(self, job_uuid: str) -> None: ... - - @webmethod(route="/post-training/job/artifacts", method="GET") - async def get_training_job_artifacts(self, job_uuid: str) -> Optional[PostTrainingJobArtifactsResponse]: ... + @webmethod(route="/post-training/jobs/{job_id}", method="DELETE") + async def cancel_training_job(self, job_id: str) -> None: ...