mirror of
https://github.com/meta-llama/llama-stack.git
synced 2026-01-02 21:00:01 +00:00
chore: fix mypy violations in post_training modules
Note: this patch touches all files but post_training.py that will be significantly changed by #1437, hence leaving it out of the picture for now. running_loss is now always Tensor (on-device) and doesn't change its type from int to Tensor (which made mypy unhappy). Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
parent
3b35a39b8b
commit
8c01246344
9 changed files with 56 additions and 69 deletions
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Literal, Optional, Protocol, Union
|
||||
from typing import Any, Dict, List, Literal, Optional, Protocol
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Annotated
|
||||
|
|
@ -89,7 +89,7 @@ class QATFinetuningConfig(BaseModel):
|
|||
|
||||
|
||||
AlgorithmConfig = register_schema(
|
||||
Annotated[Union[LoraFinetuningConfig, QATFinetuningConfig], Field(discriminator="type")],
|
||||
Annotated[LoraFinetuningConfig | QATFinetuningConfig, Field(discriminator="type")],
|
||||
name="AlgorithmConfig",
|
||||
)
|
||||
|
||||
|
|
@ -184,7 +184,7 @@ class PostTraining(Protocol):
|
|||
description="Model descriptor from `llama model list`",
|
||||
),
|
||||
checkpoint_dir: Optional[str] = None,
|
||||
algorithm_config: Optional[AlgorithmConfig] = None,
|
||||
algorithm_config: Optional[LoraFinetuningConfig | QATFinetuningConfig] = None,
|
||||
) -> PostTrainingJob: ...
|
||||
|
||||
@webmethod(route="/post-training/preference-optimize", method="POST")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue