llama-stack-mirror/src/llama_stack/apis/post_training/models.py
Sébastien Han 8df9340dd3
wiprouters
Signed-off-by: Sébastien Han <seb@redhat.com>
2025-11-04 18:09:11 +01:00

222 lines
6.3 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from datetime import datetime
from enum import Enum
from typing import Annotated, Any, Literal
from pydantic import BaseModel, Field
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.common.job_types import JobStatus
from llama_stack.apis.common.training_types import Checkpoint
from llama_stack.schema_utils import json_schema_type, register_schema
@json_schema_type
class OptimizerType(Enum):
"""Available optimizer algorithms for training."""
adam = "adam"
adamw = "adamw"
sgd = "sgd"
@json_schema_type
class DatasetFormat(Enum):
"""Format of the training dataset."""
instruct = "instruct"
dialog = "dialog"
@json_schema_type
class DataConfig(BaseModel):
"""Configuration for training data and data loading."""
dataset_id: str
batch_size: int
shuffle: bool
data_format: DatasetFormat
validation_dataset_id: str | None = None
packed: bool | None = False
train_on_input: bool | None = False
@json_schema_type
class OptimizerConfig(BaseModel):
"""Configuration parameters for the optimization algorithm."""
optimizer_type: OptimizerType
lr: float
weight_decay: float
num_warmup_steps: int
@json_schema_type
class EfficiencyConfig(BaseModel):
"""Configuration for memory and compute efficiency optimizations."""
enable_activation_checkpointing: bool | None = False
enable_activation_offloading: bool | None = False
memory_efficient_fsdp_wrap: bool | None = False
fsdp_cpu_offload: bool | None = False
@json_schema_type
class TrainingConfig(BaseModel):
"""Comprehensive configuration for the training process."""
n_epochs: int
max_steps_per_epoch: int = 1
gradient_accumulation_steps: int = 1
max_validation_steps: int | None = 1
data_config: DataConfig | None = None
optimizer_config: OptimizerConfig | None = None
efficiency_config: EfficiencyConfig | None = None
dtype: str | None = "bf16"
@json_schema_type
class LoraFinetuningConfig(BaseModel):
"""Configuration for Low-Rank Adaptation (LoRA) fine-tuning."""
type: Literal["LoRA"] = "LoRA"
lora_attn_modules: list[str]
apply_lora_to_mlp: bool
apply_lora_to_output: bool
rank: int
alpha: int
use_dora: bool | None = False
quantize_base: bool | None = False
@json_schema_type
class QATFinetuningConfig(BaseModel):
"""Configuration for Quantization-Aware Training (QAT) fine-tuning."""
type: Literal["QAT"] = "QAT"
quantizer_name: str
group_size: int
AlgorithmConfig = Annotated[LoraFinetuningConfig | QATFinetuningConfig, Field(discriminator="type")]
register_schema(AlgorithmConfig, name="AlgorithmConfig")
@json_schema_type
class PostTrainingJobLogStream(BaseModel):
"""Stream of logs from a finetuning job."""
job_uuid: str
log_lines: list[str]
@json_schema_type
class RLHFAlgorithm(Enum):
"""Available reinforcement learning from human feedback algorithms."""
dpo = "dpo"
@json_schema_type
class DPOLossType(Enum):
sigmoid = "sigmoid"
hinge = "hinge"
ipo = "ipo"
kto_pair = "kto_pair"
@json_schema_type
class DPOAlignmentConfig(BaseModel):
"""Configuration for Direct Preference Optimization (DPO) alignment."""
beta: float
loss_type: DPOLossType = DPOLossType.sigmoid
@json_schema_type
class PostTrainingRLHFRequest(BaseModel):
"""Request to finetune a model using reinforcement learning from human feedback."""
job_uuid: str
finetuned_model: URL
dataset_id: str
validation_dataset_id: str
algorithm: RLHFAlgorithm
algorithm_config: DPOAlignmentConfig
optimizer_config: OptimizerConfig
training_config: TrainingConfig
# TODO: define these
hyperparam_search_config: dict[str, Any]
logger_config: dict[str, Any]
class PostTrainingJob(BaseModel):
job_uuid: str = Field(..., description="The UUID of the job")
@json_schema_type
class PostTrainingJobStatusResponse(BaseModel):
"""Status of a finetuning job."""
job_uuid: str
status: JobStatus
scheduled_at: datetime | None = None
started_at: datetime | None = None
completed_at: datetime | None = None
resources_allocated: dict[str, Any] | None = None
checkpoints: list[Checkpoint] = Field(default_factory=list)
class ListPostTrainingJobsResponse(BaseModel):
data: list[PostTrainingJob] = Field(..., description="The list of training jobs")
@json_schema_type
class PostTrainingJobArtifactsResponse(BaseModel):
"""Artifacts of a finetuning job."""
job_uuid: str = Field(..., description="The UUID of the job")
checkpoints: list[Checkpoint] = Field(default_factory=list)
# TODO(ashwin): metrics, evals
@json_schema_type
class SupervisedFineTuneRequest(BaseModel):
"""Request to run supervised fine-tuning of a model."""
job_uuid: str = Field(..., description="The UUID of the job to create")
training_config: TrainingConfig = Field(..., description="The training configuration")
hyperparam_search_config: dict[str, Any] = Field(..., description="The hyperparam search configuration")
logger_config: dict[str, Any] = Field(..., description="The logger configuration")
model: str | None = Field(
default=None,
description="Model descriptor for training if not in provider config`",
)
checkpoint_dir: str | None = Field(default=None, description="The directory to save checkpoint(s) to")
algorithm_config: AlgorithmConfig | None = Field(default=None, description="The algorithm configuration")
@json_schema_type
class PreferenceOptimizeRequest(BaseModel):
"""Request to run preference optimization of a model."""
job_uuid: str = Field(..., description="The UUID of the job to create")
finetuned_model: str = Field(..., description="The model to fine-tune")
algorithm_config: DPOAlignmentConfig = Field(..., description="The algorithm configuration")
training_config: TrainingConfig = Field(..., description="The training configuration")
hyperparam_search_config: dict[str, Any] = Field(..., description="The hyperparam search configuration")
logger_config: dict[str, Any] = Field(..., description="The logger configuration")