From 540358258210982c13a6e290343149c4bf89f2ad Mon Sep 17 00:00:00 2001 From: Ihar Hrachyshka Date: Thu, 20 Mar 2025 10:33:26 -0400 Subject: [PATCH] fix: Restore discriminator for AlgorithmConfig (#1706) --- docs/_static/llama-stack-spec.html | 26 +++++++++++++------ docs/_static/llama-stack-spec.yaml | 13 +++++++--- .../apis/post_training/post_training.py | 6 ++--- 3 files changed, 31 insertions(+), 14 deletions(-) diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index c3c18774e..98b495de2 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -9863,6 +9863,23 @@ ], "title": "ScoreBatchResponse" }, + "AlgorithmConfig": { + "oneOf": [ + { + "$ref": "#/components/schemas/LoraFinetuningConfig" + }, + { + "$ref": "#/components/schemas/QATFinetuningConfig" + } + ], + "discriminator": { + "propertyName": "type", + "mapping": { + "LoRA": "#/components/schemas/LoraFinetuningConfig", + "QAT": "#/components/schemas/QATFinetuningConfig" + } + } + }, "LoraFinetuningConfig": { "type": "object", "properties": { @@ -9998,14 +10015,7 @@ "type": "string" }, "algorithm_config": { - "oneOf": [ - { - "$ref": "#/components/schemas/LoraFinetuningConfig" - }, - { - "$ref": "#/components/schemas/QATFinetuningConfig" - } - ] + "$ref": "#/components/schemas/AlgorithmConfig" } }, "additionalProperties": false, diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index 1738788e4..321dfe8e0 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -6689,6 +6689,15 @@ components: required: - results title: ScoreBatchResponse + AlgorithmConfig: + oneOf: + - $ref: '#/components/schemas/LoraFinetuningConfig' + - $ref: '#/components/schemas/QATFinetuningConfig' + discriminator: + propertyName: type + mapping: + LoRA: '#/components/schemas/LoraFinetuningConfig' + QAT: '#/components/schemas/QATFinetuningConfig' LoraFinetuningConfig: type: object properties: @@ -6772,9 +6781,7 @@ components: checkpoint_dir: type: string algorithm_config: - oneOf: - - $ref: '#/components/schemas/LoraFinetuningConfig' - - $ref: '#/components/schemas/QATFinetuningConfig' + $ref: '#/components/schemas/AlgorithmConfig' additionalProperties: false required: - job_uuid diff --git a/llama_stack/apis/post_training/post_training.py b/llama_stack/apis/post_training/post_training.py index e61c0e4e4..d49668e23 100644 --- a/llama_stack/apis/post_training/post_training.py +++ b/llama_stack/apis/post_training/post_training.py @@ -6,7 +6,7 @@ from datetime import datetime from enum import Enum -from typing import Any, Dict, List, Literal, Optional, Protocol +from typing import Any, Dict, List, Literal, Optional, Protocol, Union from pydantic import BaseModel, Field from typing_extensions import Annotated @@ -88,7 +88,7 @@ class QATFinetuningConfig(BaseModel): group_size: int -AlgorithmConfig = Annotated[LoraFinetuningConfig | QATFinetuningConfig, Field(discriminator="type")] +AlgorithmConfig = Annotated[Union[LoraFinetuningConfig, QATFinetuningConfig], Field(discriminator="type")] register_schema(AlgorithmConfig, name="AlgorithmConfig") @@ -182,7 +182,7 @@ class PostTraining(Protocol): description="Model descriptor from `llama model list`", ), checkpoint_dir: Optional[str] = None, - algorithm_config: Optional[LoraFinetuningConfig | QATFinetuningConfig] = None, + algorithm_config: Optional[AlgorithmConfig] = None, ) -> PostTrainingJob: ... @webmethod(route="/post-training/preference-optimize", method="POST")