chore: fix mypy violations in post_training modules

Note: this patch touches all files but post_training.py that will be
significantly changed by #1437, hence leaving it out of the picture for
now.

running_loss is now always Tensor (on-device) and doesn't change its
type from int to Tensor (which made mypy unhappy).

Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
Ihar Hrachyshka 2025-03-11 11:19:45 -04:00
parent 3b35a39b8b
commit 8c01246344
9 changed files with 56 additions and 69 deletions

View file

@ -6,7 +6,7 @@
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Protocol, Union
from typing import Any, Dict, List, Literal, Optional, Protocol
from pydantic import BaseModel, Field
from typing_extensions import Annotated
@ -89,7 +89,7 @@ class QATFinetuningConfig(BaseModel):
AlgorithmConfig = register_schema(
Annotated[Union[LoraFinetuningConfig, QATFinetuningConfig], Field(discriminator="type")],
Annotated[LoraFinetuningConfig | QATFinetuningConfig, Field(discriminator="type")],
name="AlgorithmConfig",
)
@ -184,7 +184,7 @@ class PostTraining(Protocol):
description="Model descriptor from `llama model list`",
),
checkpoint_dir: Optional[str] = None,
algorithm_config: Optional[AlgorithmConfig] = None,
algorithm_config: Optional[LoraFinetuningConfig | QATFinetuningConfig] = None,
) -> PostTrainingJob: ...
@webmethod(route="/post-training/preference-optimize", method="POST")