forked from phoenix-oss/llama-stack-mirror
[2/n][torchtune integration] implement job management and return training artifacts (#593)
### Context In this PR, we - Implement the post training job management and get training artifacts apis - get_training_jobs - get_training_job_status - get_training_job_artifacts - get_training_job_logstream is deleted since the trace can be directly accessed by UI with Jaeger https://llama-stack.readthedocs.io/en/latest/building_applications/telemetry.html#jaeger-to-visualize-traces - Refactor the post training and training types definition to make them more intuitive. - Rewrite the checkpointer to make it compatible with llama-stack file system and can be recognized during inference ### Test Unit test `pytest llama_stack/providers/tests/post_training/test_post_training.py -m "torchtune_post_training_huggingface_datasetio" -v -s --tb=short --disable-warnings` <img width="1506" alt="Screenshot 2024-12-10 at 4 06 17 PM" src="https://github.com/user-attachments/assets/16225029-bdb7-48c4-9d13-e580cc769c0a"> e2e test with client side call <img width="888" alt="Screenshot 2024-12-10 at 4 09 44 PM" src="https://github.com/user-attachments/assets/de375e4c-ef67-4dcc-a045-4037d9489191">
This commit is contained in:
parent
5764a95912
commit
c294a01c4b
8 changed files with 331 additions and 67 deletions
|
@ -6,6 +6,7 @@
|
|||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
from typing import Any, Dict, List, Optional, Protocol, Union
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
|
@ -14,6 +15,7 @@ from pydantic import BaseModel, Field
|
|||
from typing_extensions import Annotated
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.common.job_types import JobStatus
|
||||
from llama_stack.apis.datasets import * # noqa: F403
|
||||
from llama_stack.apis.common.training_types import * # noqa: F403
|
||||
|
||||
|
@ -64,6 +66,7 @@ class TrainingConfig(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class LoraFinetuningConfig(BaseModel):
|
||||
type: Literal["LoRA"] = "LoRA"
|
||||
lora_attn_modules: List[str]
|
||||
apply_lora_to_mlp: bool
|
||||
apply_lora_to_output: bool
|
||||
|
@ -75,12 +78,13 @@ class LoraFinetuningConfig(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class QATFinetuningConfig(BaseModel):
|
||||
type: Literal["QAT"] = "QAT"
|
||||
quantizer_name: str
|
||||
group_size: int
|
||||
|
||||
|
||||
AlgorithmConfig = Annotated[
|
||||
Union[LoraFinetuningConfig, LoraFinetuningConfig], Field(discriminator="type")
|
||||
Union[LoraFinetuningConfig, QATFinetuningConfig], Field(discriminator="type")
|
||||
]
|
||||
|
||||
|
||||
|
@ -92,14 +96,6 @@ class PostTrainingJobLogStream(BaseModel):
|
|||
log_lines: List[str]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class PostTrainingJobStatus(Enum):
|
||||
running = "running"
|
||||
completed = "completed"
|
||||
failed = "failed"
|
||||
scheduled = "scheduled"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RLHFAlgorithm(Enum):
|
||||
dpo = "dpo"
|
||||
|
@ -144,7 +140,7 @@ class PostTrainingJobStatusResponse(BaseModel):
|
|||
"""Status of a finetuning job."""
|
||||
|
||||
job_uuid: str
|
||||
status: PostTrainingJobStatus
|
||||
status: JobStatus
|
||||
|
||||
scheduled_at: Optional[datetime] = None
|
||||
started_at: Optional[datetime] = None
|
||||
|
@ -166,7 +162,7 @@ class PostTrainingJobArtifactsResponse(BaseModel):
|
|||
|
||||
|
||||
class PostTraining(Protocol):
|
||||
@webmethod(route="/post-training/supervised-fine-tune")
|
||||
@webmethod(route="/post-training/supervised-fine-tune", method="POST")
|
||||
async def supervised_fine_tune(
|
||||
self,
|
||||
job_uuid: str,
|
||||
|
@ -181,7 +177,7 @@ class PostTraining(Protocol):
|
|||
algorithm_config: Optional[AlgorithmConfig] = None,
|
||||
) -> PostTrainingJob: ...
|
||||
|
||||
@webmethod(route="/post-training/preference-optimize")
|
||||
@webmethod(route="/post-training/preference-optimize", method="POST")
|
||||
async def preference_optimize(
|
||||
self,
|
||||
job_uuid: str,
|
||||
|
@ -192,24 +188,18 @@ class PostTraining(Protocol):
|
|||
logger_config: Dict[str, Any],
|
||||
) -> PostTrainingJob: ...
|
||||
|
||||
@webmethod(route="/post-training/jobs")
|
||||
@webmethod(route="/post-training/jobs", method="GET")
|
||||
async def get_training_jobs(self) -> List[PostTrainingJob]: ...
|
||||
|
||||
# sends SSE stream of logs
|
||||
@webmethod(route="/post-training/job/logs")
|
||||
async def get_training_job_logstream(
|
||||
self, job_uuid: str
|
||||
) -> PostTrainingJobLogStream: ...
|
||||
|
||||
@webmethod(route="/post-training/job/status")
|
||||
@webmethod(route="/post-training/job/status", method="GET")
|
||||
async def get_training_job_status(
|
||||
self, job_uuid: str
|
||||
) -> PostTrainingJobStatusResponse: ...
|
||||
) -> Optional[PostTrainingJobStatusResponse]: ...
|
||||
|
||||
@webmethod(route="/post-training/job/cancel")
|
||||
@webmethod(route="/post-training/job/cancel", method="POST")
|
||||
async def cancel_training_job(self, job_uuid: str) -> None: ...
|
||||
|
||||
@webmethod(route="/post-training/job/artifacts")
|
||||
@webmethod(route="/post-training/job/artifacts", method="GET")
|
||||
async def get_training_job_artifacts(
|
||||
self, job_uuid: str
|
||||
) -> PostTrainingJobArtifactsResponse: ...
|
||||
) -> Optional[PostTrainingJobArtifactsResponse]: ...
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue