fix: Restore discriminator for AlgorithmConfig

It was lost during refactoring in #1548.

Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
Ihar Hrachyshka 2025-03-19 16:41:58 -04:00
parent af8b4484a3
commit bc85d16ff0
3 changed files with 31 additions and 14 deletions

View file

@ -9863,6 +9863,23 @@
],
"title": "ScoreBatchResponse"
},
"AlgorithmConfig": {
"oneOf": [
{
"$ref": "#/components/schemas/LoraFinetuningConfig"
},
{
"$ref": "#/components/schemas/QATFinetuningConfig"
}
],
"discriminator": {
"propertyName": "type",
"mapping": {
"LoRA": "#/components/schemas/LoraFinetuningConfig",
"QAT": "#/components/schemas/QATFinetuningConfig"
}
}
},
"LoraFinetuningConfig": {
"type": "object",
"properties": {
@ -9998,14 +10015,7 @@
"type": "string"
},
"algorithm_config": {
"oneOf": [
{
"$ref": "#/components/schemas/LoraFinetuningConfig"
},
{
"$ref": "#/components/schemas/QATFinetuningConfig"
}
]
"$ref": "#/components/schemas/AlgorithmConfig"
}
},
"additionalProperties": false,

View file

@ -6689,6 +6689,15 @@ components:
required:
- results
title: ScoreBatchResponse
AlgorithmConfig:
oneOf:
- $ref: '#/components/schemas/LoraFinetuningConfig'
- $ref: '#/components/schemas/QATFinetuningConfig'
discriminator:
propertyName: type
mapping:
LoRA: '#/components/schemas/LoraFinetuningConfig'
QAT: '#/components/schemas/QATFinetuningConfig'
LoraFinetuningConfig:
type: object
properties:
@ -6772,9 +6781,7 @@ components:
checkpoint_dir:
type: string
algorithm_config:
oneOf:
- $ref: '#/components/schemas/LoraFinetuningConfig'
- $ref: '#/components/schemas/QATFinetuningConfig'
$ref: '#/components/schemas/AlgorithmConfig'
additionalProperties: false
required:
- job_uuid

View file

@ -6,7 +6,7 @@
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Protocol
from typing import Any, Dict, List, Literal, Optional, Protocol, Union
from pydantic import BaseModel, Field
from typing_extensions import Annotated
@ -88,7 +88,7 @@ class QATFinetuningConfig(BaseModel):
group_size: int
AlgorithmConfig = Annotated[LoraFinetuningConfig | QATFinetuningConfig, Field(discriminator="type")]
AlgorithmConfig = Annotated[Union[LoraFinetuningConfig, QATFinetuningConfig], Field(discriminator="type")]
register_schema(AlgorithmConfig, name="AlgorithmConfig")
@ -182,7 +182,7 @@ class PostTraining(Protocol):
description="Model descriptor from `llama model list`",
),
checkpoint_dir: Optional[str] = None,
algorithm_config: Optional[LoraFinetuningConfig | QATFinetuningConfig] = None,
algorithm_config: Optional[AlgorithmConfig] = None,
) -> PostTrainingJob: ...
@webmethod(route="/post-training/preference-optimize", method="POST")