chore: fix mypy violations in post_training modules (#1548)

# What does this PR do?

Fixes a bunch of violations.

Note: this patch touches all files but post_training.py that will be
significantly changed by #1437, hence leaving it out of the picture for
now.

[//]: # (If resolving an issue, uncomment and update the line below)
[//]: # (Closes #[issue-number])

## Test Plan

Testing with https://github.com/meta-llama/llama-stack/pull/1543

Also checked that GPU training works with the change:

```
INFO:     ::1:53316 - "POST /v1/post-training/supervised-fine-tune HTTP/1.1" 200 OK
INFO:     ::1:53316 - "GET /v1/post-training/job/status?job_uuid=test-jobb5ca2d84-d541-42f8-883b-762828b4c0e7 HTTP/1.1" 200 OK
INFO:     ::1:53316 - "GET /v1/post-training/job/artifacts?job_uuid=test-jobb5ca2d84-d541-42f8-883b-762828b4c0e7 HTTP/1.1" 200 OK
21:24:01.161 [END] /v1/post-training/supervised-fine-tune [StatusCode.OK] (32526.75ms)
 21:23:28.769 [DEBUG] Setting manual seed to local seed 3918872849. Local seed is seed + rank = 3918872849 + 0
 21:23:28.996 [INFO] Identified model_type = Llama3_2. Ignoring output.weight in checkpoint in favor of the tok_embedding.weight tied weights.
 21:23:29.933 [INFO] Memory stats after model init:
        GPU peak memory allocation: 6.05 GiB
        GPU peak memory reserved: 6.10 GiB
        GPU peak memory active: 6.05 GiB
 21:23:29.934 [INFO] Model is initialized with precision torch.bfloat16.
 21:23:30.115 [INFO] Tokenizer is initialized.
 21:23:30.118 [INFO] Optimizer is initialized.
 21:23:30.119 [INFO] Loss is initialized.
 21:23:30.896 [INFO] Dataset and Sampler are initialized.
 21:23:30.898 [INFO] Learning rate scheduler is initialized.
 21:23:31.618 [INFO] Memory stats after model init:
        GPU peak memory allocation: 6.24 GiB
        GPU peak memory reserved: 6.30 GiB
        GPU peak memory active: 6.24 GiB
 21:23:31.620 [INFO] Starting checkpoint save...
 21:23:59.428 [INFO] Model checkpoint of size 6.43 GB saved to /home/ec2-user/.llama/checkpoints/meta-llama/Llama-3.2-3B-Instruct-sft-0/consolidated.00.pth
 21:23:59.445 [INFO] Adapter checkpoint of size 0.00 GB saved to /home/ec2-user/.llama/checkpoints/meta-llama/Llama-3.2-3B-Instruct-sft-0/adapter/adapter.pth

```

[//]: # (## Documentation)

Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
Ihar Hrachyshka 2025-03-18 17:58:16 -04:00 committed by GitHub
parent f86f3cf878
commit 0cbb7f7f21
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 56 additions and 69 deletions

View file

@ -37,7 +37,7 @@ class TorchtuneCheckpointer:
checkpoint_files: List[str],
output_dir: str,
model_type: str,
) -> None:
):
# Fail fast if ``checkpoint_files`` is invalid
# TODO: support loading more than one file
if len(checkpoint_files) != 1:
@ -58,7 +58,7 @@ class TorchtuneCheckpointer:
"""
Load Meta checkpoint from file. Currently only loading from a single file is supported.
"""
state_dict: Dict[str:Any] = {}
state_dict: Dict[str, Any] = {}
model_state_dict = safe_torch_load(self._checkpoint_path)
if self._model_type == ModelType.LLAMA3_VISION:
from torchtune.models.llama3_2_vision._convert_weights import (
@ -85,10 +85,10 @@ class TorchtuneCheckpointer:
state_dict: Dict[str, Any],
epoch: int,
adapter_only: bool = False,
checkpoint_format: str = "meta",
checkpoint_format: str | None = None,
) -> str:
model_file_path = Path(self._output_dir) / f"{self._model_id}-{self._training_algorithm}-{epoch}"
if checkpoint_format == "meta":
if checkpoint_format == "meta" or checkpoint_format is None:
self._save_meta_format_checkpoint(model_file_path, state_dict, adapter_only)
elif checkpoint_format == "huggingface":
# Note: for saving hugging face format checkpoints, we only suppport saving adapter weights now

View file

@ -10,7 +10,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Callable, Dict
from typing import Callable, Dict
import torch
from pydantic import BaseModel
@ -25,10 +25,13 @@ from llama_stack.apis.post_training import DatasetFormat
from llama_stack.models.llama.datatypes import Model
from llama_stack.models.llama.sku_list import resolve_model
BuildLoraModelCallable = Callable[..., torch.nn.Module]
BuildTokenizerCallable = Callable[..., Llama3Tokenizer]
class ModelConfig(BaseModel):
model_definition: Any
tokenizer_type: Any
model_definition: BuildLoraModelCallable
tokenizer_type: BuildTokenizerCallable
checkpoint_type: str
@ -51,10 +54,6 @@ DATA_FORMATS: Dict[str, Transform] = {
}
BuildLoraModelCallable = Callable[..., torch.nn.Module]
BuildTokenizerCallable = Callable[..., Llama3Tokenizer]
def _validate_model_id(model_id: str) -> Model:
model = resolve_model(model_id)
if model is None or model.core_model_id.value not in MODEL_CONFIGS: