fix: Restore discriminator for AlgorithmConfig (#1706)

This commit is contained in:
Ihar Hrachyshka 2025-03-20 10:33:26 -04:00 committed by GitHub
parent af8b4484a3
commit 5403582582
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 31 additions and 14 deletions

View file

@ -9863,6 +9863,23 @@
], ],
"title": "ScoreBatchResponse" "title": "ScoreBatchResponse"
}, },
"AlgorithmConfig": {
"oneOf": [
{
"$ref": "#/components/schemas/LoraFinetuningConfig"
},
{
"$ref": "#/components/schemas/QATFinetuningConfig"
}
],
"discriminator": {
"propertyName": "type",
"mapping": {
"LoRA": "#/components/schemas/LoraFinetuningConfig",
"QAT": "#/components/schemas/QATFinetuningConfig"
}
}
},
"LoraFinetuningConfig": { "LoraFinetuningConfig": {
"type": "object", "type": "object",
"properties": { "properties": {
@ -9998,14 +10015,7 @@
"type": "string" "type": "string"
}, },
"algorithm_config": { "algorithm_config": {
"oneOf": [ "$ref": "#/components/schemas/AlgorithmConfig"
{
"$ref": "#/components/schemas/LoraFinetuningConfig"
},
{
"$ref": "#/components/schemas/QATFinetuningConfig"
}
]
} }
}, },
"additionalProperties": false, "additionalProperties": false,

View file

@ -6689,6 +6689,15 @@ components:
required: required:
- results - results
title: ScoreBatchResponse title: ScoreBatchResponse
AlgorithmConfig:
oneOf:
- $ref: '#/components/schemas/LoraFinetuningConfig'
- $ref: '#/components/schemas/QATFinetuningConfig'
discriminator:
propertyName: type
mapping:
LoRA: '#/components/schemas/LoraFinetuningConfig'
QAT: '#/components/schemas/QATFinetuningConfig'
LoraFinetuningConfig: LoraFinetuningConfig:
type: object type: object
properties: properties:
@ -6772,9 +6781,7 @@ components:
checkpoint_dir: checkpoint_dir:
type: string type: string
algorithm_config: algorithm_config:
oneOf: $ref: '#/components/schemas/AlgorithmConfig'
- $ref: '#/components/schemas/LoraFinetuningConfig'
- $ref: '#/components/schemas/QATFinetuningConfig'
additionalProperties: false additionalProperties: false
required: required:
- job_uuid - job_uuid

View file

@ -6,7 +6,7 @@
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Protocol from typing import Any, Dict, List, Literal, Optional, Protocol, Union
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing_extensions import Annotated from typing_extensions import Annotated
@ -88,7 +88,7 @@ class QATFinetuningConfig(BaseModel):
group_size: int group_size: int
AlgorithmConfig = Annotated[LoraFinetuningConfig | QATFinetuningConfig, Field(discriminator="type")] AlgorithmConfig = Annotated[Union[LoraFinetuningConfig, QATFinetuningConfig], Field(discriminator="type")]
register_schema(AlgorithmConfig, name="AlgorithmConfig") register_schema(AlgorithmConfig, name="AlgorithmConfig")
@ -182,7 +182,7 @@ class PostTraining(Protocol):
description="Model descriptor from `llama model list`", description="Model descriptor from `llama model list`",
), ),
checkpoint_dir: Optional[str] = None, checkpoint_dir: Optional[str] = None,
algorithm_config: Optional[LoraFinetuningConfig | QATFinetuningConfig] = None, algorithm_config: Optional[AlgorithmConfig] = None,
) -> PostTrainingJob: ... ) -> PostTrainingJob: ...
@webmethod(route="/post-training/preference-optimize", method="POST") @webmethod(route="/post-training/preference-optimize", method="POST")