mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-17 18:02:40 +00:00
temp commit
This commit is contained in:
parent
15e21cb8bd
commit
bfc782c054
7 changed files with 100 additions and 30 deletions
|
|
@ -16,7 +16,6 @@ from pydantic import BaseModel, Field
|
|||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.datasets import * # noqa: F403
|
||||
from llama_stack.apis.common.training_types import * # noqa: F403
|
||||
import torch
|
||||
|
||||
|
||||
class OptimizerType(Enum):
|
||||
|
|
@ -36,7 +35,7 @@ class OptimizerConfig(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class TrainingConfig(BaseModel):
|
||||
dtype: torch.dtype
|
||||
dtype: str
|
||||
n_epochs: int
|
||||
max_steps_per_epoch: int
|
||||
gradient_accumulation_steps: int
|
||||
|
|
@ -116,10 +115,7 @@ class PostTrainingSFTRequest(BaseModel):
|
|||
validation_dataset_id: str
|
||||
|
||||
algorithm: FinetuningAlgorithm
|
||||
algorithm_config: Union[
|
||||
LoraFinetuningConfig, QLoraFinetuningConfig, DoraFinetuningConfig
|
||||
]
|
||||
|
||||
algorithm_config: LoraFinetuningConfig
|
||||
optimizer_config: OptimizerConfig
|
||||
training_config: TrainingConfig
|
||||
|
||||
|
|
@ -140,7 +136,7 @@ class PostTrainingRLHFRequest(BaseModel):
|
|||
validation_dataset_id: str
|
||||
|
||||
algorithm: RLHFAlgorithm
|
||||
algorithm_config: Union[DPOAlignmentConfig]
|
||||
algorithm_config: DPOAlignmentConfig
|
||||
|
||||
optimizer_config: OptimizerConfig
|
||||
training_config: TrainingConfig
|
||||
|
|
@ -184,18 +180,16 @@ class PostTraining(Protocol):
|
|||
@webmethod(route="/post-training/supervised-fine-tune")
|
||||
def supervised_fine_tune(
|
||||
self,
|
||||
job_uuid: str,
|
||||
model: str,
|
||||
dataset_id: str,
|
||||
validation_dataset_id: str,
|
||||
algorithm: FinetuningAlgorithm,
|
||||
algorithm_config: Union[
|
||||
LoraFinetuningConfig, QLoraFinetuningConfig, DoraFinetuningConfig
|
||||
],
|
||||
optimizer_config: OptimizerConfig,
|
||||
training_config: TrainingConfig,
|
||||
hyperparam_search_config: Dict[str, Any],
|
||||
logger_config: Dict[str, Any],
|
||||
job_uuid: Optional[str],
|
||||
model: Optional[str],
|
||||
dataset_id: Optional[str],
|
||||
validation_dataset_id: Optional[str],
|
||||
algorithm: Optional[FinetuningAlgorithm],
|
||||
algorithm_config: Optional[LoraFinetuningConfig],
|
||||
optimizer_config: Optional[OptimizerConfig],
|
||||
training_config: Optional[TrainingConfig],
|
||||
hyperparam_search_config: Optional[Dict[str, Any]],
|
||||
logger_config: Optional[Dict[str, Any]],
|
||||
) -> PostTrainingJob: ...
|
||||
|
||||
@webmethod(route="/post-training/preference-optimize")
|
||||
|
|
@ -206,7 +200,7 @@ class PostTraining(Protocol):
|
|||
dataset_id: str,
|
||||
validation_dataset_id: str,
|
||||
algorithm: RLHFAlgorithm,
|
||||
algorithm_config: Union[DPOAlignmentConfig],
|
||||
algorithm_config: DPOAlignmentConfig,
|
||||
optimizer_config: OptimizerConfig,
|
||||
training_config: TrainingConfig,
|
||||
hyperparam_search_config: Dict[str, Any],
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue