# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. from datetime import datetime from enum import Enum from typing import Any, Dict, List, Literal, Optional, Protocol, Union from pydantic import BaseModel, Field from typing_extensions import Annotated from llama_stack.apis.common.content_types import URL from llama_stack.apis.common.job_types import JobStatus from llama_stack.apis.common.training_types import Checkpoint from llama_stack.schema_utils import json_schema_type, register_schema, webmethod @json_schema_type class DatasetFormat(Enum): instruct = "instruct" dialog = "dialog" @json_schema_type class DataConfig(BaseModel): dataset_id: str batch_size: Optional[int] = 1 shuffle: Optional[bool] = True data_format: Optional[DatasetFormat] = DatasetFormat.instruct validation_dataset_id: Optional[str] = None train_on_input: Optional[bool] = False @json_schema_type class OptimizerConfig(BaseModel): lr: Optional[float] = 2e-5 weight_decay: Optional[float] = 0.1 num_warmup_steps: Optional[int] = 20 @json_schema_type class TrainingConfig(BaseModel): data_config: DataConfig optimizer_config: OptimizerConfig n_epochs: Optional[int] = 1 max_steps_per_epoch: Optional[int] = None gradient_accumulation_steps: Optional[int] = 1 max_validation_steps: Optional[int] = None dtype: Optional[str] = "bf16" @json_schema_type class LoraFinetuningConfig(BaseModel): type: Literal["LoRA"] = "LoRA" lora_attn_modules: Optional[List[str]] = ["q_proj", "v_proj", "output_proj"] apply_lora_to_mlp: Optional[bool] = True apply_lora_to_output: Optional[bool] = False rank: Optional[int] = 8 alpha: Optional[int] = 16 use_dora: Optional[bool] = False quantize_base: Optional[bool] = False @json_schema_type class QATFinetuningConfig(BaseModel): type: Literal["QAT"] = "QAT" quantizer_name: str group_size: int AlgorithmConfig = register_schema( Annotated[Union[LoraFinetuningConfig, QATFinetuningConfig], Field(discriminator="type")], name="AlgorithmConfig", ) @json_schema_type class PostTrainingJobLogStream(BaseModel): """Stream of logs from a finetuning job.""" job_uuid: str log_lines: List[str] @json_schema_type class RLHFAlgorithm(Enum): dpo = "dpo" @json_schema_type class DPOAlignmentConfig(BaseModel): reward_scale: float reward_clip: float epsilon: float gamma: float @json_schema_type class PostTrainingRLHFRequest(BaseModel): """Request to finetune a model.""" job_uuid: str finetuned_model: URL dataset_id: str validation_dataset_id: str algorithm: RLHFAlgorithm algorithm_config: DPOAlignmentConfig optimizer_config: OptimizerConfig training_config: TrainingConfig # TODO: define these hyperparam_search_config: Dict[str, Any] logger_config: Dict[str, Any] class PostTrainingJob(BaseModel): job_uuid: str @json_schema_type class PostTrainingJobStatusResponse(BaseModel): """Status of a finetuning job.""" job_uuid: str status: JobStatus scheduled_at: Optional[datetime] = None started_at: Optional[datetime] = None completed_at: Optional[datetime] = None resources_allocated: Optional[Dict[str, Any]] = None checkpoints: List[Checkpoint] = Field(default_factory=list) class ListPostTrainingJobsResponse(BaseModel): data: List[PostTrainingJob] @json_schema_type class PostTrainingJobArtifactsResponse(BaseModel): """Artifacts of a finetuning job.""" job_uuid: str checkpoints: List[Checkpoint] = Field(default_factory=list) # TODO(ashwin): metrics, evals class PostTraining(Protocol): @webmethod(route="/post-training/supervised-fine-tune", method="POST") async def supervised_fine_tune( self, job_uuid: str, training_config: TrainingConfig, hyperparam_search_config: Dict[str, Any], logger_config: Dict[str, Any], model: str = Field( default="Llama3.2-3B-Instruct", description="Model descriptor from `llama model list`", ), checkpoint_dir: Optional[str] = None, algorithm_config: Optional[AlgorithmConfig] = None, ) -> PostTrainingJob: ... @webmethod(route="/post-training/preference-optimize", method="POST") async def preference_optimize( self, job_uuid: str, finetuned_model: str, algorithm_config: DPOAlignmentConfig, training_config: TrainingConfig, hyperparam_search_config: Dict[str, Any], logger_config: Dict[str, Any], ) -> PostTrainingJob: ... @webmethod(route="/post-training/jobs", method="GET") async def get_training_jobs(self) -> ListPostTrainingJobsResponse: ... @webmethod(route="/post-training/job/status", method="GET") async def get_training_job_status(self, job_uuid: str) -> Optional[PostTrainingJobStatusResponse]: ... @webmethod(route="/post-training/job/cancel", method="POST") async def cancel_training_job(self, job_uuid: str) -> None: ... @webmethod(route="/post-training/job/artifacts", method="GET") async def get_training_job_artifacts(self, job_uuid: str) -> Optional[PostTrainingJobArtifactsResponse]: ...