llama-stack/llama_stack/apis/post_training/post_training.py
Botao Chen aeb76390fc
[1/n] torchtune <> llama-stack integration skeleton (#540)
### Context 
This is the 1st of series PRs that integrate torchtune with llama-stack
as meta reference post-training implementation. For MVP, we will focus
on single device LoRA SFT.

Though this PR is still WIP, we want to get early feedback on the high
level design of this skeleton while still working on several details

### Scope
To limit the scope of this PR, we focus on the skeleton of the
implementation.

**What are included?**
- refine the post-training SFT apis
- skeleton of supervised_fine_tune implementation. We verified that we
can call the supervised_fine_tune API successfully from llama stack
client SDK (client side PR:
https://github.com/meta-llama/llama-stack-client-python/pull/51)
- a very basic single device LoRA training recipe based on torchtune
core components
- parity check with torchtune library and post training api unit test

**What are not includes?**
- implementation of other job management, get training artifacts apis
(separate PR)
- refactor the meta reference inference logic to support eval on
finetuned model (separate PR)
- several necessary functionality in the training recipe such as
logging, validation etc (separate PR)
- interop with telemetry for tracing and metrics logging, currently
temporarily log to local disk (separate PR)

### Testing
**e2e test**
Although we haven't added detailed testing and numerical parity check
with torchtune yet, we did a simple E2E test from client to server
1. setup server with` llama stack build --template
experimental-post-training --image-type conda` and `llama stack run
experimental-post-training `
2. On client, run `llama-stack-client --endpoint
http://devgpu018.nha2.facebook.com:5000 post_training
supervised_fine_tune`
3. Training finishes successfully. On server side, get the finetune
checkpoints under output dir. On client side, get the job uuid

server 
<img width="1110" alt="Screenshot 2024-12-02 at 5 52 32 PM"
src="https://github.com/user-attachments/assets/b548eb90-7a9b-4edc-a858-ee237cc4361d">

client 
<img width="807" alt="Screenshot 2024-12-02 at 5 52 37 PM"
src="https://github.com/user-attachments/assets/1138ffa8-4698-40fa-b190-3d7b99646838">

**parity check**
torchtune dataloader output and llama-stack post training dataloader
output are same
<img width="1116" alt="Screenshot 2024-12-04 at 8 18 46 PM"
src="https://github.com/user-attachments/assets/5e295cdc-4c24-4ea6-82c0-ca96ef1bd6ee">

torchtune LoRA SFT and llama-stack post training LoRA SFT on alpaca
dataset with llama3.2 3B instruct model are numerical match

<img width="860" alt="Screenshot 2024-12-04 at 8 17 01 PM"
src="https://github.com/user-attachments/assets/c05cf0a8-c674-4d2e-9f0a-c5d01b2dca99">

<img width="1049" alt="Screenshot 2024-12-04 at 8 17 06 PM"
src="https://github.com/user-attachments/assets/b911d4e2-e7b1-41a9-b62c-d75529b6d443">

**unit test ** 
![Uploading Screenshot 2024-12-09 at 1.35.10 PM.png…]()
2024-12-13 11:05:35 -08:00

215 lines
5.4 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Optional, Protocol, Union
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
from typing_extensions import Annotated
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.datasets import * # noqa: F403
from llama_stack.apis.common.training_types import * # noqa: F403
@json_schema_type
class OptimizerType(Enum):
adam = "adam"
adamw = "adamw"
sgd = "sgd"
@json_schema_type
class DataConfig(BaseModel):
dataset_id: str
batch_size: int
shuffle: bool
validation_dataset_id: Optional[str] = None
packed: Optional[bool] = False
train_on_input: Optional[bool] = False
@json_schema_type
class OptimizerConfig(BaseModel):
optimizer_type: OptimizerType
lr: float
weight_decay: float
num_warmup_steps: int
@json_schema_type
class EfficiencyConfig(BaseModel):
enable_activation_checkpointing: Optional[bool] = False
enable_activation_offloading: Optional[bool] = False
memory_efficient_fsdp_wrap: Optional[bool] = False
fsdp_cpu_offload: Optional[bool] = False
@json_schema_type
class TrainingConfig(BaseModel):
n_epochs: int
max_steps_per_epoch: int
gradient_accumulation_steps: int
data_config: DataConfig
optimizer_config: OptimizerConfig
efficiency_config: Optional[EfficiencyConfig] = None
dtype: Optional[str] = "bf16"
@json_schema_type
class LoraFinetuningConfig(BaseModel):
lora_attn_modules: List[str]
apply_lora_to_mlp: bool
apply_lora_to_output: bool
rank: int
alpha: int
use_dora: Optional[bool] = False
quantize_base: Optional[bool] = False
@json_schema_type
class QATFinetuningConfig(BaseModel):
quantizer_name: str
group_size: int
AlgorithmConfig = Annotated[
Union[LoraFinetuningConfig, LoraFinetuningConfig], Field(discriminator="type")
]
@json_schema_type
class PostTrainingJobLogStream(BaseModel):
"""Stream of logs from a finetuning job."""
job_uuid: str
log_lines: List[str]
@json_schema_type
class PostTrainingJobStatus(Enum):
running = "running"
completed = "completed"
failed = "failed"
scheduled = "scheduled"
@json_schema_type
class RLHFAlgorithm(Enum):
dpo = "dpo"
@json_schema_type
class DPOAlignmentConfig(BaseModel):
reward_scale: float
reward_clip: float
epsilon: float
gamma: float
@json_schema_type
class PostTrainingRLHFRequest(BaseModel):
"""Request to finetune a model."""
job_uuid: str
finetuned_model: URL
dataset_id: str
validation_dataset_id: str
algorithm: RLHFAlgorithm
algorithm_config: DPOAlignmentConfig
optimizer_config: OptimizerConfig
training_config: TrainingConfig
# TODO: define these
hyperparam_search_config: Dict[str, Any]
logger_config: Dict[str, Any]
class PostTrainingJob(BaseModel):
job_uuid: str
@json_schema_type
class PostTrainingJobStatusResponse(BaseModel):
"""Status of a finetuning job."""
job_uuid: str
status: PostTrainingJobStatus
scheduled_at: Optional[datetime] = None
started_at: Optional[datetime] = None
completed_at: Optional[datetime] = None
resources_allocated: Optional[Dict[str, Any]] = None
checkpoints: List[Checkpoint] = Field(default_factory=list)
@json_schema_type
class PostTrainingJobArtifactsResponse(BaseModel):
"""Artifacts of a finetuning job."""
job_uuid: str
checkpoints: List[Checkpoint] = Field(default_factory=list)
# TODO(ashwin): metrics, evals
class PostTraining(Protocol):
@webmethod(route="/post-training/supervised-fine-tune")
async def supervised_fine_tune(
self,
job_uuid: str,
training_config: TrainingConfig,
hyperparam_search_config: Dict[str, Any],
logger_config: Dict[str, Any],
model: str = Field(
default="Llama3.2-3B-Instruct",
description="Model descriptor from `llama model list`",
),
checkpoint_dir: Optional[str] = None,
algorithm_config: Optional[AlgorithmConfig] = None,
) -> PostTrainingJob: ...
@webmethod(route="/post-training/preference-optimize")
async def preference_optimize(
self,
job_uuid: str,
finetuned_model: str,
algorithm_config: DPOAlignmentConfig,
training_config: TrainingConfig,
hyperparam_search_config: Dict[str, Any],
logger_config: Dict[str, Any],
) -> PostTrainingJob: ...
@webmethod(route="/post-training/jobs")
async def get_training_jobs(self) -> List[PostTrainingJob]: ...
# sends SSE stream of logs
@webmethod(route="/post-training/job/logs")
async def get_training_job_logstream(
self, job_uuid: str
) -> PostTrainingJobLogStream: ...
@webmethod(route="/post-training/job/status")
async def get_training_job_status(
self, job_uuid: str
) -> PostTrainingJobStatusResponse: ...
@webmethod(route="/post-training/job/cancel")
async def cancel_training_job(self, job_uuid: str) -> None: ...
@webmethod(route="/post-training/job/artifacts")
async def get_training_job_artifacts(
self, job_uuid: str
) -> PostTrainingJobArtifactsResponse: ...