temp commit

This commit is contained in:
Botao Chen 2024-11-27 15:22:55 -08:00
parent 15e21cb8bd
commit bfc782c054
7 changed files with 100 additions and 30 deletions

View file

@ -16,7 +16,6 @@ from pydantic import BaseModel, Field
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.datasets import * # noqa: F403
from llama_stack.apis.common.training_types import * # noqa: F403
import torch
class OptimizerType(Enum):
@ -36,7 +35,7 @@ class OptimizerConfig(BaseModel):
@json_schema_type
class TrainingConfig(BaseModel):
dtype: torch.dtype
dtype: str
n_epochs: int
max_steps_per_epoch: int
gradient_accumulation_steps: int
@ -116,10 +115,7 @@ class PostTrainingSFTRequest(BaseModel):
validation_dataset_id: str
algorithm: FinetuningAlgorithm
algorithm_config: Union[
LoraFinetuningConfig, QLoraFinetuningConfig, DoraFinetuningConfig
]
algorithm_config: LoraFinetuningConfig
optimizer_config: OptimizerConfig
training_config: TrainingConfig
@ -140,7 +136,7 @@ class PostTrainingRLHFRequest(BaseModel):
validation_dataset_id: str
algorithm: RLHFAlgorithm
algorithm_config: Union[DPOAlignmentConfig]
algorithm_config: DPOAlignmentConfig
optimizer_config: OptimizerConfig
training_config: TrainingConfig
@ -184,18 +180,16 @@ class PostTraining(Protocol):
@webmethod(route="/post-training/supervised-fine-tune")
def supervised_fine_tune(
self,
job_uuid: str,
model: str,
dataset_id: str,
validation_dataset_id: str,
algorithm: FinetuningAlgorithm,
algorithm_config: Union[
LoraFinetuningConfig, QLoraFinetuningConfig, DoraFinetuningConfig
],
optimizer_config: OptimizerConfig,
training_config: TrainingConfig,
hyperparam_search_config: Dict[str, Any],
logger_config: Dict[str, Any],
job_uuid: Optional[str],
model: Optional[str],
dataset_id: Optional[str],
validation_dataset_id: Optional[str],
algorithm: Optional[FinetuningAlgorithm],
algorithm_config: Optional[LoraFinetuningConfig],
optimizer_config: Optional[OptimizerConfig],
training_config: Optional[TrainingConfig],
hyperparam_search_config: Optional[Dict[str, Any]],
logger_config: Optional[Dict[str, Any]],
) -> PostTrainingJob: ...
@webmethod(route="/post-training/preference-optimize")
@ -206,7 +200,7 @@ class PostTraining(Protocol):
dataset_id: str,
validation_dataset_id: str,
algorithm: RLHFAlgorithm,
algorithm_config: Union[DPOAlignmentConfig],
algorithm_config: DPOAlignmentConfig,
optimizer_config: OptimizerConfig,
training_config: TrainingConfig,
hyperparam_search_config: Dict[str, Any],