mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-17 15:52:39 +00:00
refine api
This commit is contained in:
parent
5838b7211d
commit
41cf2bb0a7
3 changed files with 74 additions and 40 deletions
|
|
@ -7,7 +7,7 @@
|
|||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
from typing import Any, Dict, List, Optional, Protocol
|
||||
from typing import Any, Dict, List, Optional, Protocol, Union
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
|
||||
|
|
@ -18,42 +18,55 @@ from llama_stack.apis.datasets import * # noqa: F403
|
|||
from llama_stack.apis.common.training_types import * # noqa: F403
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OptimizerType(Enum):
|
||||
adam = "adam"
|
||||
adamw = "adamw"
|
||||
sgd = "sgd"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class DataConfig(BaseModel):
|
||||
dataset_id: str
|
||||
batch_size: int
|
||||
shuffle: bool
|
||||
validation_dataset_id: Optional[str] = None
|
||||
packed: Optional[bool] = False
|
||||
train_on_input: Optional[bool] = False
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OptimizerConfig(BaseModel):
|
||||
optimizer_type: OptimizerType
|
||||
lr: float
|
||||
lr_min: float
|
||||
weight_decay: float
|
||||
num_warmup_steps: int
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class EfficiencyConfig(BaseModel):
|
||||
enable_activation_checkpointing: Optional[bool] = False
|
||||
enable_activation_offloading: Optional[bool] = False
|
||||
memory_efficient_fsdp_wrap: Optional[bool] = False
|
||||
fsdp_cpu_offload: Optional[bool] = False
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class TrainingConfig(BaseModel):
|
||||
dtype: str
|
||||
n_epochs: int
|
||||
max_steps_per_epoch: int
|
||||
gradient_accumulation_steps: int
|
||||
batch_size: int
|
||||
shuffle: bool
|
||||
data_config: DataConfig
|
||||
optimizer_config: OptimizerConfig
|
||||
|
||||
enable_activation_checkpointing: bool
|
||||
memory_efficient_fsdp_wrap: Optional[bool]
|
||||
fsdp_cpu_offload: Optional[bool]
|
||||
efficiency_config: Optional[EfficiencyConfig] = None
|
||||
dtype: Optional[str] = "bf16"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class FinetuningAlgorithm(Enum):
|
||||
full = "full"
|
||||
lora = "lora"
|
||||
qlora = "qlora"
|
||||
dora = "dora"
|
||||
qat = "qat"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -63,17 +76,14 @@ class LoraFinetuningConfig(BaseModel):
|
|||
apply_lora_to_output: bool
|
||||
rank: int
|
||||
alpha: int
|
||||
use_dora: bool
|
||||
use_dora: Optional[bool] = False
|
||||
quantize_base: Optional[bool] = False
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class QLoraFinetuningConfig(LoraFinetuningConfig):
|
||||
pass
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class DoraFinetuningConfig(LoraFinetuningConfig):
|
||||
pass
|
||||
class QATFinetuningConfig(BaseModel):
|
||||
quantizer_name: str
|
||||
group_size: int
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -110,13 +120,9 @@ class PostTrainingSFTRequest(BaseModel):
|
|||
"""Request to finetune a model."""
|
||||
|
||||
job_uuid: str
|
||||
|
||||
model: str
|
||||
dataset_id: str
|
||||
validation_dataset_id: str
|
||||
|
||||
algorithm: FinetuningAlgorithm
|
||||
algorithm_config: LoraFinetuningConfig
|
||||
algorithm_config: Optional[Union[LoraFinetuningConfig, QATFinetuningConfig]] = None
|
||||
training_config: TrainingConfig
|
||||
|
||||
# TODO: define these
|
||||
|
|
@ -182,13 +188,11 @@ class PostTraining(Protocol):
|
|||
self,
|
||||
job_uuid: str,
|
||||
model: str,
|
||||
dataset_id: str,
|
||||
validation_dataset_id: str,
|
||||
algorithm: FinetuningAlgorithm,
|
||||
algorithm_config: LoraFinetuningConfig,
|
||||
training_config: TrainingConfig,
|
||||
hyperparam_search_config: Dict[str, Any],
|
||||
logger_config: Dict[str, Any],
|
||||
algorithm_config: Optional[LoraFinetuningConfig] = None,
|
||||
) -> PostTrainingJob: ...
|
||||
|
||||
@webmethod(route="/post-training/preference-optimize")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue