diff --git a/llama_stack/apis/post_training/post_training.py b/llama_stack/apis/post_training/post_training.py index 17b5b4449..8e2d131c0 100644 --- a/llama_stack/apis/post_training/post_training.py +++ b/llama_stack/apis/post_training/post_training.py @@ -16,7 +16,6 @@ from pydantic import BaseModel, Field from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.apis.datasets import * # noqa: F403 from llama_stack.apis.common.training_types import * # noqa: F403 -import torch class OptimizerType(Enum): @@ -36,7 +35,7 @@ class OptimizerConfig(BaseModel): @json_schema_type class TrainingConfig(BaseModel): - dtype: torch.dtype + dtype: str n_epochs: int max_steps_per_epoch: int gradient_accumulation_steps: int @@ -116,10 +115,7 @@ class PostTrainingSFTRequest(BaseModel): validation_dataset_id: str algorithm: FinetuningAlgorithm - algorithm_config: Union[ - LoraFinetuningConfig, QLoraFinetuningConfig, DoraFinetuningConfig - ] - + algorithm_config: LoraFinetuningConfig optimizer_config: OptimizerConfig training_config: TrainingConfig @@ -140,7 +136,7 @@ class PostTrainingRLHFRequest(BaseModel): validation_dataset_id: str algorithm: RLHFAlgorithm - algorithm_config: Union[DPOAlignmentConfig] + algorithm_config: DPOAlignmentConfig optimizer_config: OptimizerConfig training_config: TrainingConfig @@ -184,18 +180,16 @@ class PostTraining(Protocol): @webmethod(route="/post-training/supervised-fine-tune") def supervised_fine_tune( self, - job_uuid: str, - model: str, - dataset_id: str, - validation_dataset_id: str, - algorithm: FinetuningAlgorithm, - algorithm_config: Union[ - LoraFinetuningConfig, QLoraFinetuningConfig, DoraFinetuningConfig - ], - optimizer_config: OptimizerConfig, - training_config: TrainingConfig, - hyperparam_search_config: Dict[str, Any], - logger_config: Dict[str, Any], + job_uuid: Optional[str], + model: Optional[str], + dataset_id: Optional[str], + validation_dataset_id: Optional[str], + algorithm: Optional[FinetuningAlgorithm], + algorithm_config: Optional[LoraFinetuningConfig], + optimizer_config: Optional[OptimizerConfig], + training_config: Optional[TrainingConfig], + hyperparam_search_config: Optional[Dict[str, Any]], + logger_config: Optional[Dict[str, Any]], ) -> PostTrainingJob: ... @webmethod(route="/post-training/preference-optimize") @@ -206,7 +200,7 @@ class PostTraining(Protocol): dataset_id: str, validation_dataset_id: str, algorithm: RLHFAlgorithm, - algorithm_config: Union[DPOAlignmentConfig], + algorithm_config: DPOAlignmentConfig, optimizer_config: OptimizerConfig, training_config: TrainingConfig, hyperparam_search_config: Dict[str, Any], diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index 9b3812e9e..4541b01eb 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -24,6 +24,7 @@ from llama_stack.apis.inspect import Inspect from llama_stack.apis.memory import Memory from llama_stack.apis.memory_banks import MemoryBanks from llama_stack.apis.models import Models +from llama_stack.apis.post_training import PostTraining from llama_stack.apis.safety import Safety from llama_stack.apis.scoring import Scoring from llama_stack.apis.scoring_functions import ScoringFunctions @@ -58,6 +59,7 @@ def api_protocol_map() -> Dict[Api, Any]: Api.scoring_functions: ScoringFunctions, Api.eval: Eval, Api.eval_tasks: EvalTasks, + Api.post_training: PostTraining, } diff --git a/llama_stack/providers/inline/post_training/meta_reference/post_training.py b/llama_stack/providers/inline/post_training/meta_reference/post_training.py index 31ff9786c..5f7a70742 100644 --- a/llama_stack/providers/inline/post_training/meta_reference/post_training.py +++ b/llama_stack/providers/inline/post_training/meta_reference/post_training.py @@ -20,17 +20,46 @@ class MetaReferencePostTrainingImpl: self.config = config self.datasetio_api = datasetio_api + LoraFinetuningConfig( + lora_attn_modules=["q_proj", "v_proj", "output_proj"], + apply_lora_to_mlp=True, + apply_lora_to_output=False, + rank=8, + alpha=16, + ) + + OptimizerConfig( + optimizer_type=OptimizerType.adamw, + lr=3e-4, + lr_min=3e-5, + weight_decay=0.1, + num_warmup_steps=100, + ) + + TrainingConfig( + dtype="bf16", + n_epochs=1, + max_steps_per_epoch=10, + gradient_accumulation_steps=1, + batch_size=1, + shuffle=1, + enable_activation_checkpointing=False, + memory_efficient_fsdp_wrap=False, + fsdp_cpu_offload=False, + ) + def supervised_fine_tune( self, - job_uuid: str, - model: str, - dataset_id: str, - validation_dataset_id: str, - algorithm: FinetuningAlgorithm, - algorithm_config: LoraFinetuningConfig, - optimizer_config: OptimizerConfig, - training_config: TrainingConfig, - logger_config: Dict[str, Any], + job_uuid: Optional[str] = "1234", + model: Optional[str] = " meta-llama/Llama-3.2-3B-Instruct", + dataset_id: Optional[str] = "alpaca", + validation_dataset_id: Optional[str] = "alpaca", + algorithm: Optional[FinetuningAlgorithm] = FinetuningAlgorithm.lora, + algorithm_config: Optional[LoraFinetuningConfig] = LoraFinetuningConfig, + optimizer_config: Optional[OptimizerConfig] = OptimizerConfig, + training_config: Optional[TrainingConfig] = TrainingConfig, + hyperparam_search_config: Optional[Dict[str, Any]] = {}, + logger_config: Optional[Dict[str, Any]] = {}, ) -> PostTrainingJob: # wrapper request to make it easier to pass around (internal only, not exposed to API) request = PostTrainingSFTRequest( @@ -54,3 +83,36 @@ class MetaReferencePostTrainingImpl: raise NotImplementedError() return PostTrainingJob(job_uuid=job_uuid) + + def preference_optimize( + self, + job_uuid: str, + finetuned_model: URL, + dataset_id: str, + validation_dataset_id: str, + algorithm: RLHFAlgorithm, + algorithm_config: DPOAlignmentConfig, + optimizer_config: OptimizerConfig, + training_config: TrainingConfig, + hyperparam_search_config: Dict[str, Any], + logger_config: Dict[str, Any], + ) -> PostTrainingJob: ... + + def get_training_jobs(self) -> List[PostTrainingJob]: ... + + # sends SSE stream of logs + @webmethod(route="/post-training/job/logs") + def get_training_job_logstream(self, job_uuid: str) -> PostTrainingJobLogStream: ... + + @webmethod(route="/post-training/job/status") + def get_training_job_status( + self, job_uuid: str + ) -> PostTrainingJobStatusResponse: ... + + @webmethod(route="/post-training/job/cancel") + def cancel_training_job(self, job_uuid: str) -> None: ... + + @webmethod(route="/post-training/job/artifacts") + def get_training_job_artifacts( + self, job_uuid: str + ) -> PostTrainingJobArtifactsResponse: ... diff --git a/llama_stack/providers/inline/post_training/meta_reference/recipes/lora_finetuning_single_device.py b/llama_stack/providers/inline/post_training/meta_reference/recipes/lora_finetuning_single_device.py index acf302220..7bf99ad21 100644 --- a/llama_stack/providers/inline/post_training/meta_reference/recipes/lora_finetuning_single_device.py +++ b/llama_stack/providers/inline/post_training/meta_reference/recipes/lora_finetuning_single_device.py @@ -38,7 +38,7 @@ from torchtune.modules.peft import ( set_trainable_params, validate_missing_and_unexpected_for_lora, ) -from torchtune.training.lr_scheduler import get_cosine_schedule_with_warmup +from torchtune.training.lr_schedulers import get_cosine_schedule_with_warmup log = logging.getLogger(__name__) diff --git a/llama_stack/providers/registry/post_training.py b/llama_stack/providers/registry/post_training.py index 6c4554bb7..67fcba0ef 100644 --- a/llama_stack/providers/registry/post_training.py +++ b/llama_stack/providers/registry/post_training.py @@ -12,6 +12,7 @@ from llama_stack.distribution.datatypes import * # noqa: F403 META_REFERENCE_DEPS = [ "torch", "torchtune", + "torchao", "numpy", ] @@ -24,5 +25,8 @@ def available_providers() -> List[ProviderSpec]: pip_packages=META_REFERENCE_DEPS, module="llama_stack.providers.inline.post_training.meta_reference", config_class="llama_stack.providers.inline.post_training.meta_reference.MetaReferencePostTrainingConfig", + api_dependencies=[ + Api.datasetio, + ], ), ] diff --git a/llama_stack/templates/meta-reference-gpu/build.yaml b/llama_stack/templates/meta-reference-gpu/build.yaml index 459a1b96c..469d3b71b 100644 --- a/llama_stack/templates/meta-reference-gpu/build.yaml +++ b/llama_stack/templates/meta-reference-gpu/build.yaml @@ -6,6 +6,8 @@ distribution_spec: providers: post_training: - inline::meta-reference + datasetio: + - remote::huggingface inference: - inline::meta-reference memory: diff --git a/llama_stack/templates/meta-reference-gpu/run.yaml b/llama_stack/templates/meta-reference-gpu/run.yaml index d8bc99811..8cd71b7b1 100644 --- a/llama_stack/templates/meta-reference-gpu/run.yaml +++ b/llama_stack/templates/meta-reference-gpu/run.yaml @@ -8,6 +8,8 @@ apis: - memory - safety - telemetry +- datasetio +- post_training providers: inference: - provider_id: meta-reference-inference @@ -16,6 +18,10 @@ providers: model: ${env.INFERENCE_MODEL} max_seq_len: 4096 checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null} + datasetio: + - provider_id: huggingface-0 + provider_type: remote::huggingface + config: {} memory: - provider_id: faiss provider_type: inline::faiss