temp commit

This commit is contained in:
Botao Chen 2024-12-02 17:24:25 -08:00
parent 6c709abc4d
commit 79c525be94
5 changed files with 115 additions and 163 deletions

View file

@ -41,7 +41,7 @@ class TrainingConfig(BaseModel):
gradient_accumulation_steps: int
batch_size: int
shuffle: bool
# n_iters: int
optimizer_config: OptimizerConfig
enable_activation_checkpointing: bool
memory_efficient_fsdp_wrap: Optional[bool]
@ -63,6 +63,7 @@ class LoraFinetuningConfig(BaseModel):
apply_lora_to_output: bool
rank: int
alpha: int
use_dora: bool
@json_schema_type
@ -116,7 +117,6 @@ class PostTrainingSFTRequest(BaseModel):
algorithm: FinetuningAlgorithm
algorithm_config: LoraFinetuningConfig
optimizer_config: OptimizerConfig
training_config: TrainingConfig
# TODO: define these
@ -178,7 +178,7 @@ class PostTrainingJobArtifactsResponse(BaseModel):
class PostTraining(Protocol):
@webmethod(route="/post-training/supervised-fine-tune")
def supervised_fine_tune(
async def supervised_fine_tune(
self,
job_uuid: str,
model: str,
@ -186,14 +186,14 @@ class PostTraining(Protocol):
validation_dataset_id: str,
algorithm: FinetuningAlgorithm,
algorithm_config: LoraFinetuningConfig,
optimizer_config: OptimizerConfig,
# optimizer_config: OptimizerConfig,
training_config: TrainingConfig,
hyperparam_search_config: Dict[str, Any],
logger_config: Dict[str, Any],
) -> PostTrainingJob: ...
@webmethod(route="/post-training/preference-optimize")
def preference_optimize(
async def preference_optimize(
self,
job_uuid: str,
finetuned_model: URL,
@ -208,21 +208,23 @@ class PostTraining(Protocol):
) -> PostTrainingJob: ...
@webmethod(route="/post-training/jobs")
def get_training_jobs(self) -> List[PostTrainingJob]: ...
async def get_training_jobs(self) -> List[PostTrainingJob]: ...
# sends SSE stream of logs
@webmethod(route="/post-training/job/logs")
def get_training_job_logstream(self, job_uuid: str) -> PostTrainingJobLogStream: ...
async def get_training_job_logstream(
self, job_uuid: str
) -> PostTrainingJobLogStream: ...
@webmethod(route="/post-training/job/status")
def get_training_job_status(
async def get_training_job_status(
self, job_uuid: str
) -> PostTrainingJobStatusResponse: ...
@webmethod(route="/post-training/job/cancel")
def cancel_training_job(self, job_uuid: str) -> None: ...
async def cancel_training_job(self, job_uuid: str) -> None: ...
@webmethod(route="/post-training/job/artifacts")
def get_training_job_artifacts(
async def get_training_job_artifacts(
self, job_uuid: str
) -> PostTrainingJobArtifactsResponse: ...