refine api

This commit is contained in:
Botao Chen 2024-12-03 20:01:27 -08:00
parent 5838b7211d
commit 41cf2bb0a7
3 changed files with 74 additions and 40 deletions

View file

@ -7,7 +7,7 @@
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Optional, Protocol
from typing import Any, Dict, List, Optional, Protocol, Union
from llama_models.schema_utils import json_schema_type, webmethod
@ -18,42 +18,55 @@ from llama_stack.apis.datasets import * # noqa: F403
from llama_stack.apis.common.training_types import * # noqa: F403
@json_schema_type
class OptimizerType(Enum):
adam = "adam"
adamw = "adamw"
sgd = "sgd"
@json_schema_type
class DataConfig(BaseModel):
dataset_id: str
batch_size: int
shuffle: bool
validation_dataset_id: Optional[str] = None
packed: Optional[bool] = False
train_on_input: Optional[bool] = False
@json_schema_type
class OptimizerConfig(BaseModel):
optimizer_type: OptimizerType
lr: float
lr_min: float
weight_decay: float
num_warmup_steps: int
@json_schema_type
class EfficiencyConfig(BaseModel):
enable_activation_checkpointing: Optional[bool] = False
enable_activation_offloading: Optional[bool] = False
memory_efficient_fsdp_wrap: Optional[bool] = False
fsdp_cpu_offload: Optional[bool] = False
@json_schema_type
class TrainingConfig(BaseModel):
dtype: str
n_epochs: int
max_steps_per_epoch: int
gradient_accumulation_steps: int
batch_size: int
shuffle: bool
data_config: DataConfig
optimizer_config: OptimizerConfig
enable_activation_checkpointing: bool
memory_efficient_fsdp_wrap: Optional[bool]
fsdp_cpu_offload: Optional[bool]
efficiency_config: Optional[EfficiencyConfig] = None
dtype: Optional[str] = "bf16"
@json_schema_type
class FinetuningAlgorithm(Enum):
full = "full"
lora = "lora"
qlora = "qlora"
dora = "dora"
qat = "qat"
@json_schema_type
@ -63,17 +76,14 @@ class LoraFinetuningConfig(BaseModel):
apply_lora_to_output: bool
rank: int
alpha: int
use_dora: bool
use_dora: Optional[bool] = False
quantize_base: Optional[bool] = False
@json_schema_type
class QLoraFinetuningConfig(LoraFinetuningConfig):
pass
@json_schema_type
class DoraFinetuningConfig(LoraFinetuningConfig):
pass
class QATFinetuningConfig(BaseModel):
quantizer_name: str
group_size: int
@json_schema_type
@ -110,13 +120,9 @@ class PostTrainingSFTRequest(BaseModel):
"""Request to finetune a model."""
job_uuid: str
model: str
dataset_id: str
validation_dataset_id: str
algorithm: FinetuningAlgorithm
algorithm_config: LoraFinetuningConfig
algorithm_config: Optional[Union[LoraFinetuningConfig, QATFinetuningConfig]] = None
training_config: TrainingConfig
# TODO: define these
@ -182,13 +188,11 @@ class PostTraining(Protocol):
self,
job_uuid: str,
model: str,
dataset_id: str,
validation_dataset_id: str,
algorithm: FinetuningAlgorithm,
algorithm_config: LoraFinetuningConfig,
training_config: TrainingConfig,
hyperparam_search_config: Dict[str, Any],
logger_config: Dict[str, Any],
algorithm_config: Optional[LoraFinetuningConfig] = None,
) -> PostTrainingJob: ...
@webmethod(route="/post-training/preference-optimize")