feat: add post_training RuntimeConfig

certain APIs require a bunch of runtime arguments per-provider. The best way currently to pass these arguments in is via the
provider config. This is tricky because it requires a provider to be pre-configured with certain arguments that a client side user should be able to pass in at runtime

Especially with the advent of out-of-tree providers, it would be great for a generic RuntimeConfig class to allow for providers to add and validate their own runtime arguments for things like supervised_fine_tune

For example: https://github.com/opendatahub-io/llama-stack-provider-kft has things like `input-pvc`, `model-path`, etc in the Provider Config.
This is not sustainable nor is adding each and every field needed to the post_training API spec. RuntimeConfig has a sub-class called Config which allows for extra fields to arbitrarily be specified. It is the providers job to create its own class based on this one and add valid options, parse them, etc

Signed-off-by: Charlie Doern <cdoern@redhat.com>
This commit is contained in:
Charlie Doern 2025-04-26 10:40:58 -04:00
parent bb1a85c9a0
commit 0ec5151ab5
4 changed files with 37 additions and 0 deletions

View file

@ -10170,6 +10170,11 @@
],
"title": "OptimizerType"
},
"RuntimeConfig": {
"type": "object",
"title": "RuntimeConfig",
"description": "Provider-specific runtime configuration. Providers should document and parse their own expected fields. This model allows arbitrary extra fields for maximum flexibility."
},
"TrainingConfig": {
"type": "object",
"properties": {
@ -10274,6 +10279,9 @@
}
]
}
},
"runtime_config": {
"$ref": "#/components/schemas/RuntimeConfig"
}
},
"additionalProperties": false,
@ -11375,6 +11383,9 @@
},
"algorithm_config": {
"$ref": "#/components/schemas/AlgorithmConfig"
},
"runtime_config": {
"$ref": "#/components/schemas/RuntimeConfig"
}
},
"additionalProperties": false,

View file

@ -7000,6 +7000,13 @@ components:
- adamw
- sgd
title: OptimizerType
RuntimeConfig:
type: object
title: RuntimeConfig
description: >-
Provider-specific runtime configuration. Providers should document and parse
their own expected fields. This model allows arbitrary extra fields for maximum
flexibility.
TrainingConfig:
type: object
properties:
@ -7060,6 +7067,8 @@ components:
- type: string
- type: array
- type: object
runtime_config:
$ref: '#/components/schemas/RuntimeConfig'
additionalProperties: false
required:
- job_uuid
@ -7755,6 +7764,8 @@ components:
type: string
algorithm_config:
$ref: '#/components/schemas/AlgorithmConfig'
runtime_config:
$ref: '#/components/schemas/RuntimeConfig'
additionalProperties: false
required:
- job_uuid

View file

@ -169,6 +169,17 @@ class PostTrainingJobArtifactsResponse(BaseModel):
# TODO(ashwin): metrics, evals
@json_schema_type
class RuntimeConfig(BaseModel):
"""
Provider-specific runtime configuration. Providers should document and parse their own expected fields.
This model allows arbitrary extra fields for maximum flexibility.
"""
class Config:
extra = "allow"
class PostTraining(Protocol):
@webmethod(route="/post-training/supervised-fine-tune", method="POST")
async def supervised_fine_tune(
@ -183,6 +194,7 @@ class PostTraining(Protocol):
),
checkpoint_dir: Optional[str] = None,
algorithm_config: Optional[AlgorithmConfig] = None,
runtime_config: Optional[RuntimeConfig] = None,
) -> PostTrainingJob: ...
@webmethod(route="/post-training/preference-optimize", method="POST")
@ -194,6 +206,7 @@ class PostTraining(Protocol):
training_config: TrainingConfig,
hyperparam_search_config: Dict[str, Any],
logger_config: Dict[str, Any],
runtime_config: Optional[RuntimeConfig] = None,
) -> PostTrainingJob: ...
@webmethod(route="/post-training/jobs", method="GET")

View file

@ -18,6 +18,7 @@ from llama_stack.apis.post_training import (
PostTrainingJob,
PostTrainingJobArtifactsResponse,
PostTrainingJobStatusResponse,
RuntimeConfig,
TrainingConfig,
)
from llama_stack.providers.inline.post_training.torchtune.config import (
@ -80,6 +81,7 @@ class TorchtunePostTrainingImpl:
model: str,
checkpoint_dir: Optional[str],
algorithm_config: Optional[AlgorithmConfig],
runtime_config: Optional[RuntimeConfig] = None,
) -> PostTrainingJob:
if isinstance(algorithm_config, LoraFinetuningConfig):