[2/n][torchtune integration] implement job management and return training artifacts (#593)

### Context 
In this PR, we 
- Implement the post training job management and get training artifacts
apis
  - get_training_jobs
  - get_training_job_status
  - get_training_job_artifacts
- get_training_job_logstream is deleted since the trace can be directly
accessed by UI with Jaeger
https://llama-stack.readthedocs.io/en/latest/building_applications/telemetry.html#jaeger-to-visualize-traces
- Refactor the post training and training types definition to make them
more intuitive.
- Rewrite the checkpointer to make it compatible with llama-stack file
system and can be recognized during inference


### Test
Unit test
`pytest llama_stack/providers/tests/post_training/test_post_training.py
-m "torchtune_post_training_huggingface_datasetio" -v -s --tb=short
--disable-warnings`

<img width="1506" alt="Screenshot 2024-12-10 at 4 06 17 PM"
src="https://github.com/user-attachments/assets/16225029-bdb7-48c4-9d13-e580cc769c0a">


e2e test with client side call

<img width="888" alt="Screenshot 2024-12-10 at 4 09 44 PM"
src="https://github.com/user-attachments/assets/de375e4c-ef67-4dcc-a045-4037d9489191">
This commit is contained in:
Botao Chen 2024-12-13 15:00:04 -08:00 committed by GitHub
parent 5764a95912
commit c294a01c4b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 331 additions and 67 deletions

View file

@ -6,6 +6,7 @@
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Optional, Protocol, Union
from llama_models.schema_utils import json_schema_type, webmethod
@ -14,6 +15,7 @@ from pydantic import BaseModel, Field
from typing_extensions import Annotated
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.common.job_types import JobStatus
from llama_stack.apis.datasets import * # noqa: F403
from llama_stack.apis.common.training_types import * # noqa: F403
@ -64,6 +66,7 @@ class TrainingConfig(BaseModel):
@json_schema_type
class LoraFinetuningConfig(BaseModel):
type: Literal["LoRA"] = "LoRA"
lora_attn_modules: List[str]
apply_lora_to_mlp: bool
apply_lora_to_output: bool
@ -75,12 +78,13 @@ class LoraFinetuningConfig(BaseModel):
@json_schema_type
class QATFinetuningConfig(BaseModel):
type: Literal["QAT"] = "QAT"
quantizer_name: str
group_size: int
AlgorithmConfig = Annotated[
Union[LoraFinetuningConfig, LoraFinetuningConfig], Field(discriminator="type")
Union[LoraFinetuningConfig, QATFinetuningConfig], Field(discriminator="type")
]
@ -92,14 +96,6 @@ class PostTrainingJobLogStream(BaseModel):
log_lines: List[str]
@json_schema_type
class PostTrainingJobStatus(Enum):
running = "running"
completed = "completed"
failed = "failed"
scheduled = "scheduled"
@json_schema_type
class RLHFAlgorithm(Enum):
dpo = "dpo"
@ -144,7 +140,7 @@ class PostTrainingJobStatusResponse(BaseModel):
"""Status of a finetuning job."""
job_uuid: str
status: PostTrainingJobStatus
status: JobStatus
scheduled_at: Optional[datetime] = None
started_at: Optional[datetime] = None
@ -166,7 +162,7 @@ class PostTrainingJobArtifactsResponse(BaseModel):
class PostTraining(Protocol):
@webmethod(route="/post-training/supervised-fine-tune")
@webmethod(route="/post-training/supervised-fine-tune", method="POST")
async def supervised_fine_tune(
self,
job_uuid: str,
@ -181,7 +177,7 @@ class PostTraining(Protocol):
algorithm_config: Optional[AlgorithmConfig] = None,
) -> PostTrainingJob: ...
@webmethod(route="/post-training/preference-optimize")
@webmethod(route="/post-training/preference-optimize", method="POST")
async def preference_optimize(
self,
job_uuid: str,
@ -192,24 +188,18 @@ class PostTraining(Protocol):
logger_config: Dict[str, Any],
) -> PostTrainingJob: ...
@webmethod(route="/post-training/jobs")
@webmethod(route="/post-training/jobs", method="GET")
async def get_training_jobs(self) -> List[PostTrainingJob]: ...
# sends SSE stream of logs
@webmethod(route="/post-training/job/logs")
async def get_training_job_logstream(
self, job_uuid: str
) -> PostTrainingJobLogStream: ...
@webmethod(route="/post-training/job/status")
@webmethod(route="/post-training/job/status", method="GET")
async def get_training_job_status(
self, job_uuid: str
) -> PostTrainingJobStatusResponse: ...
) -> Optional[PostTrainingJobStatusResponse]: ...
@webmethod(route="/post-training/job/cancel")
@webmethod(route="/post-training/job/cancel", method="POST")
async def cancel_training_job(self, job_uuid: str) -> None: ...
@webmethod(route="/post-training/job/artifacts")
@webmethod(route="/post-training/job/artifacts", method="GET")
async def get_training_job_artifacts(
self, job_uuid: str
) -> PostTrainingJobArtifactsResponse: ...
) -> Optional[PostTrainingJobArtifactsResponse]: ...