mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-05 02:17:31 +00:00
wiprouters
Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
parent
357be98279
commit
8df9340dd3
155 changed files with 61817 additions and 95863 deletions
222
src/llama_stack/apis/post_training/models.py
Normal file
222
src/llama_stack/apis/post_training/models.py
Normal file
|
|
@ -0,0 +1,222 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Annotated, Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.common.job_types import JobStatus
|
||||
from llama_stack.apis.common.training_types import Checkpoint
|
||||
from llama_stack.schema_utils import json_schema_type, register_schema
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OptimizerType(Enum):
|
||||
"""Available optimizer algorithms for training."""
|
||||
|
||||
adam = "adam"
|
||||
adamw = "adamw"
|
||||
sgd = "sgd"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class DatasetFormat(Enum):
|
||||
"""Format of the training dataset."""
|
||||
|
||||
instruct = "instruct"
|
||||
dialog = "dialog"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class DataConfig(BaseModel):
|
||||
"""Configuration for training data and data loading."""
|
||||
|
||||
dataset_id: str
|
||||
batch_size: int
|
||||
shuffle: bool
|
||||
data_format: DatasetFormat
|
||||
validation_dataset_id: str | None = None
|
||||
packed: bool | None = False
|
||||
train_on_input: bool | None = False
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OptimizerConfig(BaseModel):
|
||||
"""Configuration parameters for the optimization algorithm."""
|
||||
|
||||
optimizer_type: OptimizerType
|
||||
lr: float
|
||||
weight_decay: float
|
||||
num_warmup_steps: int
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class EfficiencyConfig(BaseModel):
|
||||
"""Configuration for memory and compute efficiency optimizations."""
|
||||
|
||||
enable_activation_checkpointing: bool | None = False
|
||||
enable_activation_offloading: bool | None = False
|
||||
memory_efficient_fsdp_wrap: bool | None = False
|
||||
fsdp_cpu_offload: bool | None = False
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class TrainingConfig(BaseModel):
|
||||
"""Comprehensive configuration for the training process."""
|
||||
|
||||
n_epochs: int
|
||||
max_steps_per_epoch: int = 1
|
||||
gradient_accumulation_steps: int = 1
|
||||
max_validation_steps: int | None = 1
|
||||
data_config: DataConfig | None = None
|
||||
optimizer_config: OptimizerConfig | None = None
|
||||
efficiency_config: EfficiencyConfig | None = None
|
||||
dtype: str | None = "bf16"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class LoraFinetuningConfig(BaseModel):
|
||||
"""Configuration for Low-Rank Adaptation (LoRA) fine-tuning."""
|
||||
|
||||
type: Literal["LoRA"] = "LoRA"
|
||||
lora_attn_modules: list[str]
|
||||
apply_lora_to_mlp: bool
|
||||
apply_lora_to_output: bool
|
||||
rank: int
|
||||
alpha: int
|
||||
use_dora: bool | None = False
|
||||
quantize_base: bool | None = False
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class QATFinetuningConfig(BaseModel):
|
||||
"""Configuration for Quantization-Aware Training (QAT) fine-tuning."""
|
||||
|
||||
type: Literal["QAT"] = "QAT"
|
||||
quantizer_name: str
|
||||
group_size: int
|
||||
|
||||
|
||||
AlgorithmConfig = Annotated[LoraFinetuningConfig | QATFinetuningConfig, Field(discriminator="type")]
|
||||
register_schema(AlgorithmConfig, name="AlgorithmConfig")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class PostTrainingJobLogStream(BaseModel):
|
||||
"""Stream of logs from a finetuning job."""
|
||||
|
||||
job_uuid: str
|
||||
log_lines: list[str]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RLHFAlgorithm(Enum):
|
||||
"""Available reinforcement learning from human feedback algorithms."""
|
||||
|
||||
dpo = "dpo"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class DPOLossType(Enum):
|
||||
sigmoid = "sigmoid"
|
||||
hinge = "hinge"
|
||||
ipo = "ipo"
|
||||
kto_pair = "kto_pair"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class DPOAlignmentConfig(BaseModel):
|
||||
"""Configuration for Direct Preference Optimization (DPO) alignment."""
|
||||
|
||||
beta: float
|
||||
loss_type: DPOLossType = DPOLossType.sigmoid
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class PostTrainingRLHFRequest(BaseModel):
|
||||
"""Request to finetune a model using reinforcement learning from human feedback."""
|
||||
|
||||
job_uuid: str
|
||||
|
||||
finetuned_model: URL
|
||||
|
||||
dataset_id: str
|
||||
validation_dataset_id: str
|
||||
|
||||
algorithm: RLHFAlgorithm
|
||||
algorithm_config: DPOAlignmentConfig
|
||||
|
||||
optimizer_config: OptimizerConfig
|
||||
training_config: TrainingConfig
|
||||
|
||||
# TODO: define these
|
||||
hyperparam_search_config: dict[str, Any]
|
||||
logger_config: dict[str, Any]
|
||||
|
||||
|
||||
class PostTrainingJob(BaseModel):
|
||||
job_uuid: str = Field(..., description="The UUID of the job")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class PostTrainingJobStatusResponse(BaseModel):
|
||||
"""Status of a finetuning job."""
|
||||
|
||||
job_uuid: str
|
||||
status: JobStatus
|
||||
|
||||
scheduled_at: datetime | None = None
|
||||
started_at: datetime | None = None
|
||||
completed_at: datetime | None = None
|
||||
|
||||
resources_allocated: dict[str, Any] | None = None
|
||||
|
||||
checkpoints: list[Checkpoint] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ListPostTrainingJobsResponse(BaseModel):
|
||||
data: list[PostTrainingJob] = Field(..., description="The list of training jobs")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class PostTrainingJobArtifactsResponse(BaseModel):
|
||||
"""Artifacts of a finetuning job."""
|
||||
|
||||
job_uuid: str = Field(..., description="The UUID of the job")
|
||||
checkpoints: list[Checkpoint] = Field(default_factory=list)
|
||||
|
||||
# TODO(ashwin): metrics, evals
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class SupervisedFineTuneRequest(BaseModel):
|
||||
"""Request to run supervised fine-tuning of a model."""
|
||||
|
||||
job_uuid: str = Field(..., description="The UUID of the job to create")
|
||||
training_config: TrainingConfig = Field(..., description="The training configuration")
|
||||
hyperparam_search_config: dict[str, Any] = Field(..., description="The hyperparam search configuration")
|
||||
logger_config: dict[str, Any] = Field(..., description="The logger configuration")
|
||||
model: str | None = Field(
|
||||
default=None,
|
||||
description="Model descriptor for training if not in provider config`",
|
||||
)
|
||||
checkpoint_dir: str | None = Field(default=None, description="The directory to save checkpoint(s) to")
|
||||
algorithm_config: AlgorithmConfig | None = Field(default=None, description="The algorithm configuration")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class PreferenceOptimizeRequest(BaseModel):
|
||||
"""Request to run preference optimization of a model."""
|
||||
|
||||
job_uuid: str = Field(..., description="The UUID of the job to create")
|
||||
finetuned_model: str = Field(..., description="The model to fine-tune")
|
||||
algorithm_config: DPOAlignmentConfig = Field(..., description="The algorithm configuration")
|
||||
training_config: TrainingConfig = Field(..., description="The training configuration")
|
||||
hyperparam_search_config: dict[str, Any] = Field(..., description="The hyperparam search configuration")
|
||||
logger_config: dict[str, Any] = Field(..., description="The logger configuration")
|
||||
Loading…
Add table
Add a link
Reference in a new issue