diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html
index c3c18774e..98b495de2 100644
--- a/docs/_static/llama-stack-spec.html
+++ b/docs/_static/llama-stack-spec.html
@@ -9863,6 +9863,23 @@
],
"title": "ScoreBatchResponse"
},
+ "AlgorithmConfig": {
+ "oneOf": [
+ {
+ "$ref": "#/components/schemas/LoraFinetuningConfig"
+ },
+ {
+ "$ref": "#/components/schemas/QATFinetuningConfig"
+ }
+ ],
+ "discriminator": {
+ "propertyName": "type",
+ "mapping": {
+ "LoRA": "#/components/schemas/LoraFinetuningConfig",
+ "QAT": "#/components/schemas/QATFinetuningConfig"
+ }
+ }
+ },
"LoraFinetuningConfig": {
"type": "object",
"properties": {
@@ -9998,14 +10015,7 @@
"type": "string"
},
"algorithm_config": {
- "oneOf": [
- {
- "$ref": "#/components/schemas/LoraFinetuningConfig"
- },
- {
- "$ref": "#/components/schemas/QATFinetuningConfig"
- }
- ]
+ "$ref": "#/components/schemas/AlgorithmConfig"
}
},
"additionalProperties": false,
diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml
index 1738788e4..321dfe8e0 100644
--- a/docs/_static/llama-stack-spec.yaml
+++ b/docs/_static/llama-stack-spec.yaml
@@ -6689,6 +6689,15 @@ components:
required:
- results
title: ScoreBatchResponse
+ AlgorithmConfig:
+ oneOf:
+ - $ref: '#/components/schemas/LoraFinetuningConfig'
+ - $ref: '#/components/schemas/QATFinetuningConfig'
+ discriminator:
+ propertyName: type
+ mapping:
+ LoRA: '#/components/schemas/LoraFinetuningConfig'
+ QAT: '#/components/schemas/QATFinetuningConfig'
LoraFinetuningConfig:
type: object
properties:
@@ -6772,9 +6781,7 @@ components:
checkpoint_dir:
type: string
algorithm_config:
- oneOf:
- - $ref: '#/components/schemas/LoraFinetuningConfig'
- - $ref: '#/components/schemas/QATFinetuningConfig'
+ $ref: '#/components/schemas/AlgorithmConfig'
additionalProperties: false
required:
- job_uuid
diff --git a/llama_stack/apis/post_training/post_training.py b/llama_stack/apis/post_training/post_training.py
index e61c0e4e4..d49668e23 100644
--- a/llama_stack/apis/post_training/post_training.py
+++ b/llama_stack/apis/post_training/post_training.py
@@ -6,7 +6,7 @@
from datetime import datetime
from enum import Enum
-from typing import Any, Dict, List, Literal, Optional, Protocol
+from typing import Any, Dict, List, Literal, Optional, Protocol, Union
from pydantic import BaseModel, Field
from typing_extensions import Annotated
@@ -88,7 +88,7 @@ class QATFinetuningConfig(BaseModel):
group_size: int
-AlgorithmConfig = Annotated[LoraFinetuningConfig | QATFinetuningConfig, Field(discriminator="type")]
+AlgorithmConfig = Annotated[Union[LoraFinetuningConfig, QATFinetuningConfig], Field(discriminator="type")]
register_schema(AlgorithmConfig, name="AlgorithmConfig")
@@ -182,7 +182,7 @@ class PostTraining(Protocol):
description="Model descriptor from `llama model list`",
),
checkpoint_dir: Optional[str] = None,
- algorithm_config: Optional[LoraFinetuningConfig | QATFinetuningConfig] = None,
+ algorithm_config: Optional[AlgorithmConfig] = None,
) -> PostTrainingJob: ...
@webmethod(route="/post-training/preference-optimize", method="POST")