diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index 72b2e6b17..2362dfa53 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -9847,23 +9847,6 @@ ], "title": "ScoreBatchResponse" }, - "AlgorithmConfig": { - "oneOf": [ - { - "$ref": "#/components/schemas/LoraFinetuningConfig" - }, - { - "$ref": "#/components/schemas/QATFinetuningConfig" - } - ], - "discriminator": { - "propertyName": "type", - "mapping": { - "LoRA": "#/components/schemas/LoraFinetuningConfig", - "QAT": "#/components/schemas/QATFinetuningConfig" - } - } - }, "LoraFinetuningConfig": { "type": "object", "properties": { @@ -9999,7 +9982,14 @@ "type": "string" }, "algorithm_config": { - "$ref": "#/components/schemas/AlgorithmConfig" + "oneOf": [ + { + "$ref": "#/components/schemas/LoraFinetuningConfig" + }, + { + "$ref": "#/components/schemas/QATFinetuningConfig" + } + ] } }, "additionalProperties": false, diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index 6f4a9528b..38e08e41c 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -6678,15 +6678,6 @@ components: required: - results title: ScoreBatchResponse - AlgorithmConfig: - oneOf: - - $ref: '#/components/schemas/LoraFinetuningConfig' - - $ref: '#/components/schemas/QATFinetuningConfig' - discriminator: - propertyName: type - mapping: - LoRA: '#/components/schemas/LoraFinetuningConfig' - QAT: '#/components/schemas/QATFinetuningConfig' LoraFinetuningConfig: type: object properties: @@ -6770,7 +6761,9 @@ components: checkpoint_dir: type: string algorithm_config: - $ref: '#/components/schemas/AlgorithmConfig' + oneOf: + - $ref: '#/components/schemas/LoraFinetuningConfig' + - $ref: '#/components/schemas/QATFinetuningConfig' additionalProperties: false required: - job_uuid diff --git a/llama_stack/apis/post_training/post_training.py b/llama_stack/apis/post_training/post_training.py index 636eb7e7b..362f87a26 100644 --- a/llama_stack/apis/post_training/post_training.py +++ b/llama_stack/apis/post_training/post_training.py @@ -6,7 +6,7 @@ from datetime import datetime from enum import Enum -from typing import Any, Dict, List, Literal, Optional, Protocol, Union +from typing import Any, Dict, List, Literal, Optional, Protocol from pydantic import BaseModel, Field from typing_extensions import Annotated @@ -89,7 +89,7 @@ class QATFinetuningConfig(BaseModel): AlgorithmConfig = register_schema( - Annotated[Union[LoraFinetuningConfig, QATFinetuningConfig], Field(discriminator="type")], + Annotated[LoraFinetuningConfig | QATFinetuningConfig, Field(discriminator="type")], name="AlgorithmConfig", ) @@ -184,7 +184,7 @@ class PostTraining(Protocol): description="Model descriptor from `llama model list`", ), checkpoint_dir: Optional[str] = None, - algorithm_config: Optional[AlgorithmConfig] = None, + algorithm_config: Optional[LoraFinetuningConfig | QATFinetuningConfig] = None, ) -> PostTrainingJob: ... @webmethod(route="/post-training/preference-optimize", method="POST") diff --git a/llama_stack/providers/inline/post_training/common/validator.py b/llama_stack/providers/inline/post_training/common/validator.py index e76edf3a0..b0aec6187 100644 --- a/llama_stack/providers/inline/post_training/common/validator.py +++ b/llama_stack/providers/inline/post_training/common/validator.py @@ -9,6 +9,9 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. + +from typing import Any + from llama_stack.apis.common.type_system import ( ChatCompletionInputType, DialogType, @@ -20,7 +23,7 @@ from llama_stack.providers.utils.common.data_schema_validator import ( validate_dataset_schema, ) -EXPECTED_DATASET_SCHEMA = { +EXPECTED_DATASET_SCHEMA: dict[str, list[dict[str, Any]]] = { "instruct": [ { ColumnName.chat_completion_input.value: ChatCompletionInputType(), @@ -41,6 +44,9 @@ async def validate_input_dataset_schema( dataset_type: str, ) -> None: dataset_def = await datasets_api.get_dataset(dataset_id=dataset_id) + if not dataset_def: + raise ValueError(f"Dataset {dataset_id} does not exist.") + if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0: raise ValueError(f"Dataset {dataset_id} does not have a schema defined.") diff --git a/llama_stack/providers/inline/post_training/torchtune/common/checkpointer.py b/llama_stack/providers/inline/post_training/torchtune/common/checkpointer.py index 64d61b053..fcadd0884 100644 --- a/llama_stack/providers/inline/post_training/torchtune/common/checkpointer.py +++ b/llama_stack/providers/inline/post_training/torchtune/common/checkpointer.py @@ -37,7 +37,7 @@ class TorchtuneCheckpointer: checkpoint_files: List[str], output_dir: str, model_type: str, - ) -> None: + ): # Fail fast if ``checkpoint_files`` is invalid # TODO: support loading more than one file if len(checkpoint_files) != 1: @@ -58,7 +58,7 @@ class TorchtuneCheckpointer: """ Load Meta checkpoint from file. Currently only loading from a single file is supported. """ - state_dict: Dict[str:Any] = {} + state_dict: Dict[str, Any] = {} model_state_dict = safe_torch_load(self._checkpoint_path) if self._model_type == ModelType.LLAMA3_VISION: from torchtune.models.llama3_2_vision._convert_weights import ( @@ -85,10 +85,10 @@ class TorchtuneCheckpointer: state_dict: Dict[str, Any], epoch: int, adapter_only: bool = False, - checkpoint_format: str = "meta", + checkpoint_format: str | None = None, ) -> str: model_file_path = Path(self._output_dir) / f"{self._model_id}-{self._training_algorithm}-{epoch}" - if checkpoint_format == "meta": + if checkpoint_format == "meta" or checkpoint_format is None: self._save_meta_format_checkpoint(model_file_path, state_dict, adapter_only) elif checkpoint_format == "huggingface": # Note: for saving hugging face format checkpoints, we only suppport saving adapter weights now diff --git a/llama_stack/providers/inline/post_training/torchtune/common/utils.py b/llama_stack/providers/inline/post_training/torchtune/common/utils.py index 98e16f9d7..f8a1c0436 100644 --- a/llama_stack/providers/inline/post_training/torchtune/common/utils.py +++ b/llama_stack/providers/inline/post_training/torchtune/common/utils.py @@ -10,7 +10,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Callable, Dict +from typing import Callable, Dict import torch from pydantic import BaseModel @@ -25,10 +25,13 @@ from llama_stack.apis.post_training import DatasetFormat from llama_stack.models.llama.datatypes import Model from llama_stack.models.llama.sku_list import resolve_model +BuildLoraModelCallable = Callable[..., torch.nn.Module] +BuildTokenizerCallable = Callable[..., Llama3Tokenizer] + class ModelConfig(BaseModel): - model_definition: Any - tokenizer_type: Any + model_definition: BuildLoraModelCallable + tokenizer_type: BuildTokenizerCallable checkpoint_type: str @@ -51,10 +54,6 @@ DATA_FORMATS: Dict[str, Transform] = { } -BuildLoraModelCallable = Callable[..., torch.nn.Module] -BuildTokenizerCallable = Callable[..., Llama3Tokenizer] - - def _validate_model_id(model_id: str) -> Model: model = resolve_model(model_id) if model is None or model.core_model_id.value not in MODEL_CONFIGS: diff --git a/llama_stack/providers/inline/post_training/torchtune/datasets/sft.py b/llama_stack/providers/inline/post_training/torchtune/datasets/sft.py index b556b59a6..050996860 100644 --- a/llama_stack/providers/inline/post_training/torchtune/datasets/sft.py +++ b/llama_stack/providers/inline/post_training/torchtune/datasets/sft.py @@ -55,7 +55,7 @@ class SFTDataset(Dataset): if "messages" in transformed_sample: validate_messages(transformed_sample["messages"]) - tokenized_dict = self._model_transform(transformed_sample) + tokenized_dict: dict[str, Any] = self._model_transform(transformed_sample) if not ("tokens" in tokenized_dict and "mask" in tokenized_dict): keys_str = ", ".join(tokenized_dict.keys()) diff --git a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py index 0f89b4064..edc1ceb90 100644 --- a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +++ b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py @@ -37,10 +37,10 @@ from llama_stack.apis.common.training_types import PostTrainingMetric from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasets import Datasets from llama_stack.apis.post_training import ( - AlgorithmConfig, Checkpoint, LoraFinetuningConfig, OptimizerConfig, + QATFinetuningConfig, TrainingConfig, ) from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR @@ -73,6 +73,9 @@ class LoraFinetuningSingleDevice: # Currently logging only logs limited training metrics to local disk # will figure out more loggings and how it works with telemetry in future PRs + + _checkpointer: TorchtuneCheckpointer + def __init__( self, config: TorchtunePostTrainingConfig, @@ -82,7 +85,7 @@ class LoraFinetuningSingleDevice: logger_config: Dict[str, Any], model: str, checkpoint_dir: Optional[str], - algorithm_config: Optional[AlgorithmConfig], + algorithm_config: LoraFinetuningConfig | QATFinetuningConfig | None, datasetio_api: DatasetIO, datasets_api: Datasets, ) -> None: @@ -109,12 +112,12 @@ class LoraFinetuningSingleDevice: return str(checkpoint_dir) if checkpoint_dir and checkpoint_dir != "null": - self.checkpoint_dir = config.checkpoint_dir + self.checkpoint_dir = checkpoint_dir else: - model = resolve_model(self.model_id) - if model is None: + model_obj = resolve_model(self.model_id) + if model_obj is None: raise ValueError(f"{self.model_id} not found. Your model id should be in the llama models SKU list") - self.checkpoint_dir = model_checkpoint_dir(model) + self.checkpoint_dir = model_checkpoint_dir(model_obj) self._output_dir = str(DEFAULT_CHECKPOINT_DIR) self._checkpoint_format = config.checkpoint_format @@ -135,16 +138,16 @@ class LoraFinetuningSingleDevice: self.max_validation_steps = training_config.max_validation_steps self._clip_grad_norm = 1.0 - self._enable_activation_checkpointing = ( - (training_config.efficiency_config.enable_activation_checkpointing) - if training_config.efficiency_config - else False - ) - self._enable_activation_offloading = ( - (training_config.efficiency_config.enable_activation_offloading) - if training_config.efficiency_config - else False - ) + + self._enable_activation_checkpointing = False + self._enable_activation_offloading = False + if training_config.efficiency_config: + if training_config.efficiency_config.enable_activation_checkpointing: + self._enable_activation_checkpointing = ( + training_config.efficiency_config.enable_activation_checkpointing + ) + if training_config.efficiency_config.enable_activation_offloading: + self._enable_activation_offloading = training_config.efficiency_config.enable_activation_offloading self.datasetio_api = datasetio_api self.datasets_api = datasets_api @@ -451,12 +454,12 @@ class LoraFinetuningSingleDevice: """ # Initialize tokens count and running loss (for grad accumulation) t0 = time.perf_counter() - running_loss = 0 + running_loss: float = 0.0 num_tokens = 0 # training artifacts checkpoints = [] - memory_stats = {} + memory_stats: Dict[str, Any] = {} # self.epochs_run should be non-zero when we're resuming from a checkpoint for curr_epoch in range(self.epochs_run, self.total_epochs): @@ -484,7 +487,7 @@ class LoraFinetuningSingleDevice: # Loss is normalized by default so we multiply by the number of tokens # This way we can normalize by the total number of tokens if we're accumulating gradients current_loss = await self._loss_step(batch) * current_num_tokens - running_loss += current_loss + running_loss += current_loss.detach().item() current_loss.backward() # Step with optimizer @@ -500,7 +503,7 @@ class LoraFinetuningSingleDevice: # Update the number of steps when the weights are updated self.global_step += 1 - loss_to_log = running_loss.item() / num_tokens + loss_to_log = running_loss / num_tokens pbar.update(1) pbar.set_description(f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}") @@ -523,7 +526,7 @@ class LoraFinetuningSingleDevice: ) # Reset running stats for the next step - running_loss = 0 + running_loss = 0.0 num_tokens = 0 t0 = time.perf_counter() diff --git a/pyproject.toml b/pyproject.toml index f57b91462..107150cee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -228,10 +228,6 @@ exclude = [ "^llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers\\.py$", "^llama_stack/providers/inline/inference/vllm/", "^llama_stack/providers/inline/post_training/common/validator\\.py$", - "^llama_stack/providers/inline/post_training/torchtune/common/checkpointer\\.py$", - "^llama_stack/providers/inline/post_training/torchtune/common/utils\\.py$", - "^llama_stack/providers/inline/post_training/torchtune/datasets/sft\\.py$", - "^llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device\\.py$", "^llama_stack/providers/inline/post_training/torchtune/post_training\\.py$", "^llama_stack/providers/inline/safety/code_scanner/", "^llama_stack/providers/inline/safety/llama_guard/",