forked from phoenix-oss/llama-stack-mirror
# What does this PR do? Fixes a bunch of violations. Note: this patch touches all files but post_training.py that will be significantly changed by #1437, hence leaving it out of the picture for now. [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) ## Test Plan Testing with https://github.com/meta-llama/llama-stack/pull/1543 Also checked that GPU training works with the change: ``` INFO: ::1:53316 - "POST /v1/post-training/supervised-fine-tune HTTP/1.1" 200 OK INFO: ::1:53316 - "GET /v1/post-training/job/status?job_uuid=test-jobb5ca2d84-d541-42f8-883b-762828b4c0e7 HTTP/1.1" 200 OK INFO: ::1:53316 - "GET /v1/post-training/job/artifacts?job_uuid=test-jobb5ca2d84-d541-42f8-883b-762828b4c0e7 HTTP/1.1" 200 OK 21:24:01.161 [END] /v1/post-training/supervised-fine-tune [StatusCode.OK] (32526.75ms) 21:23:28.769 [DEBUG] Setting manual seed to local seed 3918872849. Local seed is seed + rank = 3918872849 + 0 21:23:28.996 [INFO] Identified model_type = Llama3_2. Ignoring output.weight in checkpoint in favor of the tok_embedding.weight tied weights. 21:23:29.933 [INFO] Memory stats after model init: GPU peak memory allocation: 6.05 GiB GPU peak memory reserved: 6.10 GiB GPU peak memory active: 6.05 GiB 21:23:29.934 [INFO] Model is initialized with precision torch.bfloat16. 21:23:30.115 [INFO] Tokenizer is initialized. 21:23:30.118 [INFO] Optimizer is initialized. 21:23:30.119 [INFO] Loss is initialized. 21:23:30.896 [INFO] Dataset and Sampler are initialized. 21:23:30.898 [INFO] Learning rate scheduler is initialized. 21:23:31.618 [INFO] Memory stats after model init: GPU peak memory allocation: 6.24 GiB GPU peak memory reserved: 6.30 GiB GPU peak memory active: 6.24 GiB 21:23:31.620 [INFO] Starting checkpoint save... 21:23:59.428 [INFO] Model checkpoint of size 6.43 GB saved to /home/ec2-user/.llama/checkpoints/meta-llama/Llama-3.2-3B-Instruct-sft-0/consolidated.00.pth 21:23:59.445 [INFO] Adapter checkpoint of size 0.00 GB saved to /home/ec2-user/.llama/checkpoints/meta-llama/Llama-3.2-3B-Instruct-sft-0/adapter/adapter.pth ``` [//]: # (## Documentation) Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
211 lines
5.5 KiB
Python
211 lines
5.5 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
from datetime import datetime
|
|
from enum import Enum
|
|
from typing import Any, Dict, List, Literal, Optional, Protocol
|
|
|
|
from pydantic import BaseModel, Field
|
|
from typing_extensions import Annotated
|
|
|
|
from llama_stack.apis.common.content_types import URL
|
|
from llama_stack.apis.common.job_types import JobStatus
|
|
from llama_stack.apis.common.training_types import Checkpoint
|
|
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
|
|
|
|
|
@json_schema_type
|
|
class OptimizerType(Enum):
|
|
adam = "adam"
|
|
adamw = "adamw"
|
|
sgd = "sgd"
|
|
|
|
|
|
@json_schema_type
|
|
class DatasetFormat(Enum):
|
|
instruct = "instruct"
|
|
dialog = "dialog"
|
|
|
|
|
|
@json_schema_type
|
|
class DataConfig(BaseModel):
|
|
dataset_id: str
|
|
batch_size: int
|
|
shuffle: bool
|
|
data_format: DatasetFormat
|
|
validation_dataset_id: Optional[str] = None
|
|
packed: Optional[bool] = False
|
|
train_on_input: Optional[bool] = False
|
|
|
|
|
|
@json_schema_type
|
|
class OptimizerConfig(BaseModel):
|
|
optimizer_type: OptimizerType
|
|
lr: float
|
|
weight_decay: float
|
|
num_warmup_steps: int
|
|
|
|
|
|
@json_schema_type
|
|
class EfficiencyConfig(BaseModel):
|
|
enable_activation_checkpointing: Optional[bool] = False
|
|
enable_activation_offloading: Optional[bool] = False
|
|
memory_efficient_fsdp_wrap: Optional[bool] = False
|
|
fsdp_cpu_offload: Optional[bool] = False
|
|
|
|
|
|
@json_schema_type
|
|
class TrainingConfig(BaseModel):
|
|
n_epochs: int
|
|
max_steps_per_epoch: int
|
|
gradient_accumulation_steps: int
|
|
max_validation_steps: int
|
|
data_config: DataConfig
|
|
optimizer_config: OptimizerConfig
|
|
efficiency_config: Optional[EfficiencyConfig] = None
|
|
dtype: Optional[str] = "bf16"
|
|
|
|
|
|
@json_schema_type
|
|
class LoraFinetuningConfig(BaseModel):
|
|
type: Literal["LoRA"] = "LoRA"
|
|
lora_attn_modules: List[str]
|
|
apply_lora_to_mlp: bool
|
|
apply_lora_to_output: bool
|
|
rank: int
|
|
alpha: int
|
|
use_dora: Optional[bool] = False
|
|
quantize_base: Optional[bool] = False
|
|
|
|
|
|
@json_schema_type
|
|
class QATFinetuningConfig(BaseModel):
|
|
type: Literal["QAT"] = "QAT"
|
|
quantizer_name: str
|
|
group_size: int
|
|
|
|
|
|
AlgorithmConfig = register_schema(
|
|
Annotated[LoraFinetuningConfig | QATFinetuningConfig, Field(discriminator="type")],
|
|
name="AlgorithmConfig",
|
|
)
|
|
|
|
|
|
@json_schema_type
|
|
class PostTrainingJobLogStream(BaseModel):
|
|
"""Stream of logs from a finetuning job."""
|
|
|
|
job_uuid: str
|
|
log_lines: List[str]
|
|
|
|
|
|
@json_schema_type
|
|
class RLHFAlgorithm(Enum):
|
|
dpo = "dpo"
|
|
|
|
|
|
@json_schema_type
|
|
class DPOAlignmentConfig(BaseModel):
|
|
reward_scale: float
|
|
reward_clip: float
|
|
epsilon: float
|
|
gamma: float
|
|
|
|
|
|
@json_schema_type
|
|
class PostTrainingRLHFRequest(BaseModel):
|
|
"""Request to finetune a model."""
|
|
|
|
job_uuid: str
|
|
|
|
finetuned_model: URL
|
|
|
|
dataset_id: str
|
|
validation_dataset_id: str
|
|
|
|
algorithm: RLHFAlgorithm
|
|
algorithm_config: DPOAlignmentConfig
|
|
|
|
optimizer_config: OptimizerConfig
|
|
training_config: TrainingConfig
|
|
|
|
# TODO: define these
|
|
hyperparam_search_config: Dict[str, Any]
|
|
logger_config: Dict[str, Any]
|
|
|
|
|
|
class PostTrainingJob(BaseModel):
|
|
job_uuid: str
|
|
|
|
|
|
@json_schema_type
|
|
class PostTrainingJobStatusResponse(BaseModel):
|
|
"""Status of a finetuning job."""
|
|
|
|
job_uuid: str
|
|
status: JobStatus
|
|
|
|
scheduled_at: Optional[datetime] = None
|
|
started_at: Optional[datetime] = None
|
|
completed_at: Optional[datetime] = None
|
|
|
|
resources_allocated: Optional[Dict[str, Any]] = None
|
|
|
|
checkpoints: List[Checkpoint] = Field(default_factory=list)
|
|
|
|
|
|
class ListPostTrainingJobsResponse(BaseModel):
|
|
data: List[PostTrainingJob]
|
|
|
|
|
|
@json_schema_type
|
|
class PostTrainingJobArtifactsResponse(BaseModel):
|
|
"""Artifacts of a finetuning job."""
|
|
|
|
job_uuid: str
|
|
checkpoints: List[Checkpoint] = Field(default_factory=list)
|
|
|
|
# TODO(ashwin): metrics, evals
|
|
|
|
|
|
class PostTraining(Protocol):
|
|
@webmethod(route="/post-training/supervised-fine-tune", method="POST")
|
|
async def supervised_fine_tune(
|
|
self,
|
|
job_uuid: str,
|
|
training_config: TrainingConfig,
|
|
hyperparam_search_config: Dict[str, Any],
|
|
logger_config: Dict[str, Any],
|
|
model: str = Field(
|
|
default="Llama3.2-3B-Instruct",
|
|
description="Model descriptor from `llama model list`",
|
|
),
|
|
checkpoint_dir: Optional[str] = None,
|
|
algorithm_config: Optional[LoraFinetuningConfig | QATFinetuningConfig] = None,
|
|
) -> PostTrainingJob: ...
|
|
|
|
@webmethod(route="/post-training/preference-optimize", method="POST")
|
|
async def preference_optimize(
|
|
self,
|
|
job_uuid: str,
|
|
finetuned_model: str,
|
|
algorithm_config: DPOAlignmentConfig,
|
|
training_config: TrainingConfig,
|
|
hyperparam_search_config: Dict[str, Any],
|
|
logger_config: Dict[str, Any],
|
|
) -> PostTrainingJob: ...
|
|
|
|
@webmethod(route="/post-training/jobs", method="GET")
|
|
async def get_training_jobs(self) -> ListPostTrainingJobsResponse: ...
|
|
|
|
@webmethod(route="/post-training/job/status", method="GET")
|
|
async def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse: ...
|
|
|
|
@webmethod(route="/post-training/job/cancel", method="POST")
|
|
async def cancel_training_job(self, job_uuid: str) -> None: ...
|
|
|
|
@webmethod(route="/post-training/job/artifacts", method="GET")
|
|
async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse: ...
|