diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html
index 72b2e6b17..2362dfa53 100644
--- a/docs/_static/llama-stack-spec.html
+++ b/docs/_static/llama-stack-spec.html
@@ -9847,23 +9847,6 @@
],
"title": "ScoreBatchResponse"
},
- "AlgorithmConfig": {
- "oneOf": [
- {
- "$ref": "#/components/schemas/LoraFinetuningConfig"
- },
- {
- "$ref": "#/components/schemas/QATFinetuningConfig"
- }
- ],
- "discriminator": {
- "propertyName": "type",
- "mapping": {
- "LoRA": "#/components/schemas/LoraFinetuningConfig",
- "QAT": "#/components/schemas/QATFinetuningConfig"
- }
- }
- },
"LoraFinetuningConfig": {
"type": "object",
"properties": {
@@ -9999,7 +9982,14 @@
"type": "string"
},
"algorithm_config": {
- "$ref": "#/components/schemas/AlgorithmConfig"
+ "oneOf": [
+ {
+ "$ref": "#/components/schemas/LoraFinetuningConfig"
+ },
+ {
+ "$ref": "#/components/schemas/QATFinetuningConfig"
+ }
+ ]
}
},
"additionalProperties": false,
diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml
index 6f4a9528b..38e08e41c 100644
--- a/docs/_static/llama-stack-spec.yaml
+++ b/docs/_static/llama-stack-spec.yaml
@@ -6678,15 +6678,6 @@ components:
required:
- results
title: ScoreBatchResponse
- AlgorithmConfig:
- oneOf:
- - $ref: '#/components/schemas/LoraFinetuningConfig'
- - $ref: '#/components/schemas/QATFinetuningConfig'
- discriminator:
- propertyName: type
- mapping:
- LoRA: '#/components/schemas/LoraFinetuningConfig'
- QAT: '#/components/schemas/QATFinetuningConfig'
LoraFinetuningConfig:
type: object
properties:
@@ -6770,7 +6761,9 @@ components:
checkpoint_dir:
type: string
algorithm_config:
- $ref: '#/components/schemas/AlgorithmConfig'
+ oneOf:
+ - $ref: '#/components/schemas/LoraFinetuningConfig'
+ - $ref: '#/components/schemas/QATFinetuningConfig'
additionalProperties: false
required:
- job_uuid
diff --git a/llama_stack/apis/post_training/post_training.py b/llama_stack/apis/post_training/post_training.py
index 636eb7e7b..362f87a26 100644
--- a/llama_stack/apis/post_training/post_training.py
+++ b/llama_stack/apis/post_training/post_training.py
@@ -6,7 +6,7 @@
from datetime import datetime
from enum import Enum
-from typing import Any, Dict, List, Literal, Optional, Protocol, Union
+from typing import Any, Dict, List, Literal, Optional, Protocol
from pydantic import BaseModel, Field
from typing_extensions import Annotated
@@ -89,7 +89,7 @@ class QATFinetuningConfig(BaseModel):
AlgorithmConfig = register_schema(
- Annotated[Union[LoraFinetuningConfig, QATFinetuningConfig], Field(discriminator="type")],
+ Annotated[LoraFinetuningConfig | QATFinetuningConfig, Field(discriminator="type")],
name="AlgorithmConfig",
)
@@ -184,7 +184,7 @@ class PostTraining(Protocol):
description="Model descriptor from `llama model list`",
),
checkpoint_dir: Optional[str] = None,
- algorithm_config: Optional[AlgorithmConfig] = None,
+ algorithm_config: Optional[LoraFinetuningConfig | QATFinetuningConfig] = None,
) -> PostTrainingJob: ...
@webmethod(route="/post-training/preference-optimize", method="POST")
diff --git a/llama_stack/providers/inline/post_training/common/validator.py b/llama_stack/providers/inline/post_training/common/validator.py
index e76edf3a0..b0aec6187 100644
--- a/llama_stack/providers/inline/post_training/common/validator.py
+++ b/llama_stack/providers/inline/post_training/common/validator.py
@@ -9,6 +9,9 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
+
+from typing import Any
+
from llama_stack.apis.common.type_system import (
ChatCompletionInputType,
DialogType,
@@ -20,7 +23,7 @@ from llama_stack.providers.utils.common.data_schema_validator import (
validate_dataset_schema,
)
-EXPECTED_DATASET_SCHEMA = {
+EXPECTED_DATASET_SCHEMA: dict[str, list[dict[str, Any]]] = {
"instruct": [
{
ColumnName.chat_completion_input.value: ChatCompletionInputType(),
@@ -41,6 +44,9 @@ async def validate_input_dataset_schema(
dataset_type: str,
) -> None:
dataset_def = await datasets_api.get_dataset(dataset_id=dataset_id)
+ if not dataset_def:
+ raise ValueError(f"Dataset {dataset_id} does not exist.")
+
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
raise ValueError(f"Dataset {dataset_id} does not have a schema defined.")
diff --git a/llama_stack/providers/inline/post_training/torchtune/common/checkpointer.py b/llama_stack/providers/inline/post_training/torchtune/common/checkpointer.py
index 64d61b053..fcadd0884 100644
--- a/llama_stack/providers/inline/post_training/torchtune/common/checkpointer.py
+++ b/llama_stack/providers/inline/post_training/torchtune/common/checkpointer.py
@@ -37,7 +37,7 @@ class TorchtuneCheckpointer:
checkpoint_files: List[str],
output_dir: str,
model_type: str,
- ) -> None:
+ ):
# Fail fast if ``checkpoint_files`` is invalid
# TODO: support loading more than one file
if len(checkpoint_files) != 1:
@@ -58,7 +58,7 @@ class TorchtuneCheckpointer:
"""
Load Meta checkpoint from file. Currently only loading from a single file is supported.
"""
- state_dict: Dict[str:Any] = {}
+ state_dict: Dict[str, Any] = {}
model_state_dict = safe_torch_load(self._checkpoint_path)
if self._model_type == ModelType.LLAMA3_VISION:
from torchtune.models.llama3_2_vision._convert_weights import (
@@ -85,10 +85,10 @@ class TorchtuneCheckpointer:
state_dict: Dict[str, Any],
epoch: int,
adapter_only: bool = False,
- checkpoint_format: str = "meta",
+ checkpoint_format: str | None = None,
) -> str:
model_file_path = Path(self._output_dir) / f"{self._model_id}-{self._training_algorithm}-{epoch}"
- if checkpoint_format == "meta":
+ if checkpoint_format == "meta" or checkpoint_format is None:
self._save_meta_format_checkpoint(model_file_path, state_dict, adapter_only)
elif checkpoint_format == "huggingface":
# Note: for saving hugging face format checkpoints, we only suppport saving adapter weights now
diff --git a/llama_stack/providers/inline/post_training/torchtune/common/utils.py b/llama_stack/providers/inline/post_training/torchtune/common/utils.py
index 98e16f9d7..f8a1c0436 100644
--- a/llama_stack/providers/inline/post_training/torchtune/common/utils.py
+++ b/llama_stack/providers/inline/post_training/torchtune/common/utils.py
@@ -10,7 +10,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
-from typing import Any, Callable, Dict
+from typing import Callable, Dict
import torch
from pydantic import BaseModel
@@ -25,10 +25,13 @@ from llama_stack.apis.post_training import DatasetFormat
from llama_stack.models.llama.datatypes import Model
from llama_stack.models.llama.sku_list import resolve_model
+BuildLoraModelCallable = Callable[..., torch.nn.Module]
+BuildTokenizerCallable = Callable[..., Llama3Tokenizer]
+
class ModelConfig(BaseModel):
- model_definition: Any
- tokenizer_type: Any
+ model_definition: BuildLoraModelCallable
+ tokenizer_type: BuildTokenizerCallable
checkpoint_type: str
@@ -51,10 +54,6 @@ DATA_FORMATS: Dict[str, Transform] = {
}
-BuildLoraModelCallable = Callable[..., torch.nn.Module]
-BuildTokenizerCallable = Callable[..., Llama3Tokenizer]
-
-
def _validate_model_id(model_id: str) -> Model:
model = resolve_model(model_id)
if model is None or model.core_model_id.value not in MODEL_CONFIGS:
diff --git a/llama_stack/providers/inline/post_training/torchtune/datasets/sft.py b/llama_stack/providers/inline/post_training/torchtune/datasets/sft.py
index b556b59a6..050996860 100644
--- a/llama_stack/providers/inline/post_training/torchtune/datasets/sft.py
+++ b/llama_stack/providers/inline/post_training/torchtune/datasets/sft.py
@@ -55,7 +55,7 @@ class SFTDataset(Dataset):
if "messages" in transformed_sample:
validate_messages(transformed_sample["messages"])
- tokenized_dict = self._model_transform(transformed_sample)
+ tokenized_dict: dict[str, Any] = self._model_transform(transformed_sample)
if not ("tokens" in tokenized_dict and "mask" in tokenized_dict):
keys_str = ", ".join(tokenized_dict.keys())
diff --git a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py
index 0f89b4064..edc1ceb90 100644
--- a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py
+++ b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py
@@ -37,10 +37,10 @@ from llama_stack.apis.common.training_types import PostTrainingMetric
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.post_training import (
- AlgorithmConfig,
Checkpoint,
LoraFinetuningConfig,
OptimizerConfig,
+ QATFinetuningConfig,
TrainingConfig,
)
from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
@@ -73,6 +73,9 @@ class LoraFinetuningSingleDevice:
# Currently logging only logs limited training metrics to local disk
# will figure out more loggings and how it works with telemetry in future PRs
+
+ _checkpointer: TorchtuneCheckpointer
+
def __init__(
self,
config: TorchtunePostTrainingConfig,
@@ -82,7 +85,7 @@ class LoraFinetuningSingleDevice:
logger_config: Dict[str, Any],
model: str,
checkpoint_dir: Optional[str],
- algorithm_config: Optional[AlgorithmConfig],
+ algorithm_config: LoraFinetuningConfig | QATFinetuningConfig | None,
datasetio_api: DatasetIO,
datasets_api: Datasets,
) -> None:
@@ -109,12 +112,12 @@ class LoraFinetuningSingleDevice:
return str(checkpoint_dir)
if checkpoint_dir and checkpoint_dir != "null":
- self.checkpoint_dir = config.checkpoint_dir
+ self.checkpoint_dir = checkpoint_dir
else:
- model = resolve_model(self.model_id)
- if model is None:
+ model_obj = resolve_model(self.model_id)
+ if model_obj is None:
raise ValueError(f"{self.model_id} not found. Your model id should be in the llama models SKU list")
- self.checkpoint_dir = model_checkpoint_dir(model)
+ self.checkpoint_dir = model_checkpoint_dir(model_obj)
self._output_dir = str(DEFAULT_CHECKPOINT_DIR)
self._checkpoint_format = config.checkpoint_format
@@ -135,16 +138,16 @@ class LoraFinetuningSingleDevice:
self.max_validation_steps = training_config.max_validation_steps
self._clip_grad_norm = 1.0
- self._enable_activation_checkpointing = (
- (training_config.efficiency_config.enable_activation_checkpointing)
- if training_config.efficiency_config
- else False
- )
- self._enable_activation_offloading = (
- (training_config.efficiency_config.enable_activation_offloading)
- if training_config.efficiency_config
- else False
- )
+
+ self._enable_activation_checkpointing = False
+ self._enable_activation_offloading = False
+ if training_config.efficiency_config:
+ if training_config.efficiency_config.enable_activation_checkpointing:
+ self._enable_activation_checkpointing = (
+ training_config.efficiency_config.enable_activation_checkpointing
+ )
+ if training_config.efficiency_config.enable_activation_offloading:
+ self._enable_activation_offloading = training_config.efficiency_config.enable_activation_offloading
self.datasetio_api = datasetio_api
self.datasets_api = datasets_api
@@ -451,12 +454,12 @@ class LoraFinetuningSingleDevice:
"""
# Initialize tokens count and running loss (for grad accumulation)
t0 = time.perf_counter()
- running_loss = 0
+ running_loss: float = 0.0
num_tokens = 0
# training artifacts
checkpoints = []
- memory_stats = {}
+ memory_stats: Dict[str, Any] = {}
# self.epochs_run should be non-zero when we're resuming from a checkpoint
for curr_epoch in range(self.epochs_run, self.total_epochs):
@@ -484,7 +487,7 @@ class LoraFinetuningSingleDevice:
# Loss is normalized by default so we multiply by the number of tokens
# This way we can normalize by the total number of tokens if we're accumulating gradients
current_loss = await self._loss_step(batch) * current_num_tokens
- running_loss += current_loss
+ running_loss += current_loss.detach().item()
current_loss.backward()
# Step with optimizer
@@ -500,7 +503,7 @@ class LoraFinetuningSingleDevice:
# Update the number of steps when the weights are updated
self.global_step += 1
- loss_to_log = running_loss.item() / num_tokens
+ loss_to_log = running_loss / num_tokens
pbar.update(1)
pbar.set_description(f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}")
@@ -523,7 +526,7 @@ class LoraFinetuningSingleDevice:
)
# Reset running stats for the next step
- running_loss = 0
+ running_loss = 0.0
num_tokens = 0
t0 = time.perf_counter()
diff --git a/pyproject.toml b/pyproject.toml
index f57b91462..107150cee 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -228,10 +228,6 @@ exclude = [
"^llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers\\.py$",
"^llama_stack/providers/inline/inference/vllm/",
"^llama_stack/providers/inline/post_training/common/validator\\.py$",
- "^llama_stack/providers/inline/post_training/torchtune/common/checkpointer\\.py$",
- "^llama_stack/providers/inline/post_training/torchtune/common/utils\\.py$",
- "^llama_stack/providers/inline/post_training/torchtune/datasets/sft\\.py$",
- "^llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device\\.py$",
"^llama_stack/providers/inline/post_training/torchtune/post_training\\.py$",
"^llama_stack/providers/inline/safety/code_scanner/",
"^llama_stack/providers/inline/safety/llama_guard/",