From 0ec5151ab58ee6afea99be7572e0fdfda0282b4b Mon Sep 17 00:00:00 2001 From: Charlie Doern Date: Sat, 26 Apr 2025 10:40:58 -0400 Subject: [PATCH] feat: add post_training RuntimeConfig certain APIs require a bunch of runtime arguments per-provider. The best way currently to pass these arguments in is via the provider config. This is tricky because it requires a provider to be pre-configured with certain arguments that a client side user should be able to pass in at runtime Especially with the advent of out-of-tree providers, it would be great for a generic RuntimeConfig class to allow for providers to add and validate their own runtime arguments for things like supervised_fine_tune For example: https://github.com/opendatahub-io/llama-stack-provider-kft has things like `input-pvc`, `model-path`, etc in the Provider Config. This is not sustainable nor is adding each and every field needed to the post_training API spec. RuntimeConfig has a sub-class called Config which allows for extra fields to arbitrarily be specified. It is the providers job to create its own class based on this one and add valid options, parse them, etc Signed-off-by: Charlie Doern --- docs/_static/llama-stack-spec.html | 11 +++++++++++ docs/_static/llama-stack-spec.yaml | 11 +++++++++++ llama_stack/apis/post_training/post_training.py | 13 +++++++++++++ .../inline/post_training/torchtune/post_training.py | 2 ++ 4 files changed, 37 insertions(+) diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index 4c5393947..41f240da1 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -10170,6 +10170,11 @@ ], "title": "OptimizerType" }, + "RuntimeConfig": { + "type": "object", + "title": "RuntimeConfig", + "description": "Provider-specific runtime configuration. Providers should document and parse their own expected fields. This model allows arbitrary extra fields for maximum flexibility." + }, "TrainingConfig": { "type": "object", "properties": { @@ -10274,6 +10279,9 @@ } ] } + }, + "runtime_config": { + "$ref": "#/components/schemas/RuntimeConfig" } }, "additionalProperties": false, @@ -11375,6 +11383,9 @@ }, "algorithm_config": { "$ref": "#/components/schemas/AlgorithmConfig" + }, + "runtime_config": { + "$ref": "#/components/schemas/RuntimeConfig" } }, "additionalProperties": false, diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index a24f1a9db..5aebfd030 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -7000,6 +7000,13 @@ components: - adamw - sgd title: OptimizerType + RuntimeConfig: + type: object + title: RuntimeConfig + description: >- + Provider-specific runtime configuration. Providers should document and parse + their own expected fields. This model allows arbitrary extra fields for maximum + flexibility. TrainingConfig: type: object properties: @@ -7060,6 +7067,8 @@ components: - type: string - type: array - type: object + runtime_config: + $ref: '#/components/schemas/RuntimeConfig' additionalProperties: false required: - job_uuid @@ -7755,6 +7764,8 @@ components: type: string algorithm_config: $ref: '#/components/schemas/AlgorithmConfig' + runtime_config: + $ref: '#/components/schemas/RuntimeConfig' additionalProperties: false required: - job_uuid diff --git a/llama_stack/apis/post_training/post_training.py b/llama_stack/apis/post_training/post_training.py index e5f1bcb65..29a29dc28 100644 --- a/llama_stack/apis/post_training/post_training.py +++ b/llama_stack/apis/post_training/post_training.py @@ -169,6 +169,17 @@ class PostTrainingJobArtifactsResponse(BaseModel): # TODO(ashwin): metrics, evals +@json_schema_type +class RuntimeConfig(BaseModel): + """ + Provider-specific runtime configuration. Providers should document and parse their own expected fields. + This model allows arbitrary extra fields for maximum flexibility. + """ + + class Config: + extra = "allow" + + class PostTraining(Protocol): @webmethod(route="/post-training/supervised-fine-tune", method="POST") async def supervised_fine_tune( @@ -183,6 +194,7 @@ class PostTraining(Protocol): ), checkpoint_dir: Optional[str] = None, algorithm_config: Optional[AlgorithmConfig] = None, + runtime_config: Optional[RuntimeConfig] = None, ) -> PostTrainingJob: ... @webmethod(route="/post-training/preference-optimize", method="POST") @@ -194,6 +206,7 @@ class PostTraining(Protocol): training_config: TrainingConfig, hyperparam_search_config: Dict[str, Any], logger_config: Dict[str, Any], + runtime_config: Optional[RuntimeConfig] = None, ) -> PostTrainingJob: ... @webmethod(route="/post-training/jobs", method="GET") diff --git a/llama_stack/providers/inline/post_training/torchtune/post_training.py b/llama_stack/providers/inline/post_training/torchtune/post_training.py index cc1a6a5fe..7a17b9c4c 100644 --- a/llama_stack/providers/inline/post_training/torchtune/post_training.py +++ b/llama_stack/providers/inline/post_training/torchtune/post_training.py @@ -18,6 +18,7 @@ from llama_stack.apis.post_training import ( PostTrainingJob, PostTrainingJobArtifactsResponse, PostTrainingJobStatusResponse, + RuntimeConfig, TrainingConfig, ) from llama_stack.providers.inline.post_training.torchtune.config import ( @@ -80,6 +81,7 @@ class TorchtunePostTrainingImpl: model: str, checkpoint_dir: Optional[str], algorithm_config: Optional[AlgorithmConfig], + runtime_config: Optional[RuntimeConfig] = None, ) -> PostTrainingJob: if isinstance(algorithm_config, LoraFinetuningConfig):