fix: Restore discriminator for AlgorithmConfig

It was lost during refactoring in #1548.

Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
Ihar Hrachyshka 2025-03-19 16:41:58 -04:00
parent af8b4484a3
commit bc85d16ff0
3 changed files with 31 additions and 14 deletions

View file

@ -9863,6 +9863,23 @@
], ],
"title": "ScoreBatchResponse" "title": "ScoreBatchResponse"
}, },
"AlgorithmConfig": {
"oneOf": [
{
"$ref": "#/components/schemas/LoraFinetuningConfig"
},
{
"$ref": "#/components/schemas/QATFinetuningConfig"
}
],
"discriminator": {
"propertyName": "type",
"mapping": {
"LoRA": "#/components/schemas/LoraFinetuningConfig",
"QAT": "#/components/schemas/QATFinetuningConfig"
}
}
},
"LoraFinetuningConfig": { "LoraFinetuningConfig": {
"type": "object", "type": "object",
"properties": { "properties": {
@ -9998,14 +10015,7 @@
"type": "string" "type": "string"
}, },
"algorithm_config": { "algorithm_config": {
"oneOf": [ "$ref": "#/components/schemas/AlgorithmConfig"
{
"$ref": "#/components/schemas/LoraFinetuningConfig"
},
{
"$ref": "#/components/schemas/QATFinetuningConfig"
}
]
} }
}, },
"additionalProperties": false, "additionalProperties": false,

View file

@ -6689,6 +6689,15 @@ components:
required: required:
- results - results
title: ScoreBatchResponse title: ScoreBatchResponse
AlgorithmConfig:
oneOf:
- $ref: '#/components/schemas/LoraFinetuningConfig'
- $ref: '#/components/schemas/QATFinetuningConfig'
discriminator:
propertyName: type
mapping:
LoRA: '#/components/schemas/LoraFinetuningConfig'
QAT: '#/components/schemas/QATFinetuningConfig'
LoraFinetuningConfig: LoraFinetuningConfig:
type: object type: object
properties: properties:
@ -6772,9 +6781,7 @@ components:
checkpoint_dir: checkpoint_dir:
type: string type: string
algorithm_config: algorithm_config:
oneOf: $ref: '#/components/schemas/AlgorithmConfig'
- $ref: '#/components/schemas/LoraFinetuningConfig'
- $ref: '#/components/schemas/QATFinetuningConfig'
additionalProperties: false additionalProperties: false
required: required:
- job_uuid - job_uuid

View file

@ -6,7 +6,7 @@
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Protocol from typing import Any, Dict, List, Literal, Optional, Protocol, Union
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing_extensions import Annotated from typing_extensions import Annotated
@ -88,7 +88,7 @@ class QATFinetuningConfig(BaseModel):
group_size: int group_size: int
AlgorithmConfig = Annotated[LoraFinetuningConfig | QATFinetuningConfig, Field(discriminator="type")] AlgorithmConfig = Annotated[Union[LoraFinetuningConfig, QATFinetuningConfig], Field(discriminator="type")]
register_schema(AlgorithmConfig, name="AlgorithmConfig") register_schema(AlgorithmConfig, name="AlgorithmConfig")
@ -182,7 +182,7 @@ class PostTraining(Protocol):
description="Model descriptor from `llama model list`", description="Model descriptor from `llama model list`",
), ),
checkpoint_dir: Optional[str] = None, checkpoint_dir: Optional[str] = None,
algorithm_config: Optional[LoraFinetuningConfig | QATFinetuningConfig] = None, algorithm_config: Optional[AlgorithmConfig] = None,
) -> PostTrainingJob: ... ) -> PostTrainingJob: ...
@webmethod(route="/post-training/preference-optimize", method="POST") @webmethod(route="/post-training/preference-optimize", method="POST")