forked from phoenix-oss/llama-stack-mirror
Compare commits
5 commits
kvant
...
simplify_p
Author | SHA1 | Date | |
---|---|---|---|
|
357141f6de | ||
|
41f86b5ce6 | ||
|
51282456b9 | ||
|
0b92fef3ba | ||
|
840fd353f7 |
1 changed files with 44 additions and 92 deletions
|
@ -16,81 +16,49 @@ from llama_stack.apis.common.job_types import JobStatus
|
||||||
from llama_stack.apis.common.training_types import Checkpoint
|
from llama_stack.apis.common.training_types import Checkpoint
|
||||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class OptimizerType(Enum):
|
class TrainingStrategy(BaseModel):
|
||||||
adam = "adam"
|
# params that control Optimizer
|
||||||
adamw = "adamw"
|
learning_rate: Optional[Union[float, Literal["auto"]]] = "auto"
|
||||||
sgd = "sgd"
|
weight_decay: Optional[float] = 0.1
|
||||||
|
num_warmup_steps: Optional[Union[int, Literal["auto"]]] = "auto"
|
||||||
|
|
||||||
@json_schema_type
|
# paramas that control how data is fed for training
|
||||||
class DatasetFormat(Enum):
|
batch_size: Optional[Union[int, Literal["auto"]]] = "auto"
|
||||||
instruct = "instruct"
|
shuffle: Optional[bool] = True
|
||||||
dialog = "dialog"
|
n_epochs: Optional[int] = 3
|
||||||
|
|
||||||
|
# training loop control params
|
||||||
@json_schema_type
|
max_training_steps: Optional[int] = None
|
||||||
class DataConfig(BaseModel):
|
max_validation_steps: Optional[int] = None
|
||||||
dataset_id: str
|
gradient_accumulation_steps: Optional[Union[int, Literal["auto"]]] = "auto"
|
||||||
batch_size: int
|
|
||||||
shuffle: bool
|
# precision for training
|
||||||
data_format: DatasetFormat
|
|
||||||
validation_dataset_id: Optional[str] = None
|
|
||||||
packed: Optional[bool] = False
|
|
||||||
train_on_input: Optional[bool] = False
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class OptimizerConfig(BaseModel):
|
|
||||||
optimizer_type: OptimizerType
|
|
||||||
lr: float
|
|
||||||
weight_decay: float
|
|
||||||
num_warmup_steps: int
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class EfficiencyConfig(BaseModel):
|
|
||||||
enable_activation_checkpointing: Optional[bool] = False
|
|
||||||
enable_activation_offloading: Optional[bool] = False
|
|
||||||
memory_efficient_fsdp_wrap: Optional[bool] = False
|
|
||||||
fsdp_cpu_offload: Optional[bool] = False
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class TrainingConfig(BaseModel):
|
|
||||||
n_epochs: int
|
|
||||||
max_steps_per_epoch: int
|
|
||||||
gradient_accumulation_steps: int
|
|
||||||
max_validation_steps: int
|
|
||||||
data_config: DataConfig
|
|
||||||
optimizer_config: OptimizerConfig
|
|
||||||
efficiency_config: Optional[EfficiencyConfig] = None
|
|
||||||
dtype: Optional[str] = "bf16"
|
dtype: Optional[str] = "bf16"
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class LoraFinetuningConfig(BaseModel):
|
class LoraFinetuningStrategy(BaseModel):
|
||||||
type: Literal["LoRA"] = "LoRA"
|
type: Literal["LoRA"] = "LoRA"
|
||||||
lora_attn_modules: List[str]
|
lora_attn_modules: Optional[List[str]] = ["q_proj", "v_proj", "output_proj"]
|
||||||
apply_lora_to_mlp: bool
|
apply_lora_to_mlp: Optional[bool] = True
|
||||||
apply_lora_to_output: bool
|
apply_lora_to_output: Optional[bool] = False
|
||||||
rank: int
|
rank: Optional[int] = 8
|
||||||
alpha: int
|
alpha: Optional[int] = 16
|
||||||
use_dora: Optional[bool] = False
|
use_dora: Optional[bool] = False
|
||||||
quantize_base: Optional[bool] = False
|
quantize_base: Optional[bool] = False
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class QATFinetuningConfig(BaseModel):
|
class QATFinetuningStrategy(BaseModel):
|
||||||
type: Literal["QAT"] = "QAT"
|
type: Literal["QAT"] = "QAT"
|
||||||
quantizer_name: str
|
quantizer_name: str
|
||||||
group_size: int
|
group_size: int
|
||||||
|
|
||||||
|
|
||||||
AlgorithmConfig = register_schema(
|
AlgorithmStrategy = register_schema(
|
||||||
Annotated[Union[LoraFinetuningConfig, QATFinetuningConfig], Field(discriminator="type")],
|
Annotated[Union[LoraFinetuningStrategy, QATFinetuningStrategy], Field(discriminator="type")],
|
||||||
name="AlgorithmConfig",
|
name="AlgorithmStrategy",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -108,35 +76,13 @@ class RLHFAlgorithm(Enum):
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class DPOAlignmentConfig(BaseModel):
|
class DPOAlignmentStrategy(BaseModel):
|
||||||
reward_scale: float
|
reward_scale: float
|
||||||
reward_clip: float
|
reward_clip: float
|
||||||
epsilon: float
|
epsilon: float
|
||||||
gamma: float
|
gamma: float
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class PostTrainingRLHFRequest(BaseModel):
|
|
||||||
"""Request to finetune a model."""
|
|
||||||
|
|
||||||
job_uuid: str
|
|
||||||
|
|
||||||
finetuned_model: URL
|
|
||||||
|
|
||||||
dataset_id: str
|
|
||||||
validation_dataset_id: str
|
|
||||||
|
|
||||||
algorithm: RLHFAlgorithm
|
|
||||||
algorithm_config: DPOAlignmentConfig
|
|
||||||
|
|
||||||
optimizer_config: OptimizerConfig
|
|
||||||
training_config: TrainingConfig
|
|
||||||
|
|
||||||
# TODO: define these
|
|
||||||
hyperparam_search_config: Dict[str, Any]
|
|
||||||
logger_config: Dict[str, Any]
|
|
||||||
|
|
||||||
|
|
||||||
class PostTrainingJob(BaseModel):
|
class PostTrainingJob(BaseModel):
|
||||||
job_uuid: str
|
job_uuid: str
|
||||||
|
|
||||||
|
@ -176,26 +122,32 @@ class PostTraining(Protocol):
|
||||||
async def supervised_fine_tune(
|
async def supervised_fine_tune(
|
||||||
self,
|
self,
|
||||||
job_uuid: str,
|
job_uuid: str,
|
||||||
training_config: TrainingConfig,
|
training_dataset_id: str,
|
||||||
hyperparam_search_config: Dict[str, Any],
|
|
||||||
logger_config: Dict[str, Any],
|
|
||||||
model: str = Field(
|
model: str = Field(
|
||||||
default="Llama3.2-3B-Instruct",
|
default="Llama3.2-3B-Instruct",
|
||||||
description="Model descriptor from `llama model list`",
|
description="Model descriptor from `llama model list`",
|
||||||
),
|
),
|
||||||
checkpoint_dir: Optional[str] = None,
|
|
||||||
algorithm_config: Optional[AlgorithmConfig] = None,
|
# Optional
|
||||||
|
validation_dataset_id: Optional[str] = None,
|
||||||
|
training_strategy: Optional[TrainingStrategy] = TrainingStrategy(),
|
||||||
|
althorighm: Optional[AlgorithmStrategy] = LoraFinetuningStrategy(),
|
||||||
) -> PostTrainingJob: ...
|
) -> PostTrainingJob: ...
|
||||||
|
|
||||||
@webmethod(route="/post-training/preference-optimize", method="POST")
|
@webmethod(route="/post-training/preference-optimize", method="POST")
|
||||||
async def preference_optimize(
|
async def preference_optimize(
|
||||||
self,
|
self,
|
||||||
job_uuid: str,
|
job_uuid: str,
|
||||||
finetuned_model: str,
|
training_dataset_id: str,
|
||||||
algorithm_config: DPOAlignmentConfig,
|
model: str = Field(
|
||||||
training_config: TrainingConfig,
|
default="Llama3.2-3B-Instruct",
|
||||||
hyperparam_search_config: Dict[str, Any],
|
description="Model descriptor from `llama model list`",
|
||||||
logger_config: Dict[str, Any],
|
),
|
||||||
|
|
||||||
|
# Optional
|
||||||
|
validation_dataset_id: Optional[str] = None,
|
||||||
|
training_strategy: Optional[TrainingStrategy] = TrainingStrategy(),
|
||||||
|
althorighm: Optional[AlgorithmStrategy] = LoraFinetuningStrategy(),
|
||||||
) -> PostTrainingJob: ...
|
) -> PostTrainingJob: ...
|
||||||
|
|
||||||
@webmethod(route="/post-training/jobs", method="GET")
|
@webmethod(route="/post-training/jobs", method="GET")
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue