fix: Restore discriminator for AlgorithmConfig (#1706)

This commit is contained in:
Ihar Hrachyshka 2025-03-20 10:33:26 -04:00 committed by GitHub
parent af8b4484a3
commit 5403582582
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 31 additions and 14 deletions

View file

@ -6,7 +6,7 @@
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Protocol
from typing import Any, Dict, List, Literal, Optional, Protocol, Union
from pydantic import BaseModel, Field
from typing_extensions import Annotated
@ -88,7 +88,7 @@ class QATFinetuningConfig(BaseModel):
group_size: int
AlgorithmConfig = Annotated[LoraFinetuningConfig | QATFinetuningConfig, Field(discriminator="type")]
AlgorithmConfig = Annotated[Union[LoraFinetuningConfig, QATFinetuningConfig], Field(discriminator="type")]
register_schema(AlgorithmConfig, name="AlgorithmConfig")
@ -182,7 +182,7 @@ class PostTraining(Protocol):
description="Model descriptor from `llama model list`",
),
checkpoint_dir: Optional[str] = None,
algorithm_config: Optional[LoraFinetuningConfig | QATFinetuningConfig] = None,
algorithm_config: Optional[AlgorithmConfig] = None,
) -> PostTrainingJob: ...
@webmethod(route="/post-training/preference-optimize", method="POST")