diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index 4c5393947..41f240da1 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -10170,6 +10170,11 @@ ], "title": "OptimizerType" }, + "RuntimeConfig": { + "type": "object", + "title": "RuntimeConfig", + "description": "Provider-specific runtime configuration. Providers should document and parse their own expected fields. This model allows arbitrary extra fields for maximum flexibility." + }, "TrainingConfig": { "type": "object", "properties": { @@ -10274,6 +10279,9 @@ } ] } + }, + "runtime_config": { + "$ref": "#/components/schemas/RuntimeConfig" } }, "additionalProperties": false, @@ -11375,6 +11383,9 @@ }, "algorithm_config": { "$ref": "#/components/schemas/AlgorithmConfig" + }, + "runtime_config": { + "$ref": "#/components/schemas/RuntimeConfig" } }, "additionalProperties": false, diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index a24f1a9db..5aebfd030 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -7000,6 +7000,13 @@ components: - adamw - sgd title: OptimizerType + RuntimeConfig: + type: object + title: RuntimeConfig + description: >- + Provider-specific runtime configuration. Providers should document and parse + their own expected fields. This model allows arbitrary extra fields for maximum + flexibility. TrainingConfig: type: object properties: @@ -7060,6 +7067,8 @@ components: - type: string - type: array - type: object + runtime_config: + $ref: '#/components/schemas/RuntimeConfig' additionalProperties: false required: - job_uuid @@ -7755,6 +7764,8 @@ components: type: string algorithm_config: $ref: '#/components/schemas/AlgorithmConfig' + runtime_config: + $ref: '#/components/schemas/RuntimeConfig' additionalProperties: false required: - job_uuid diff --git a/llama_stack/apis/post_training/post_training.py b/llama_stack/apis/post_training/post_training.py index e5f1bcb65..29a29dc28 100644 --- a/llama_stack/apis/post_training/post_training.py +++ b/llama_stack/apis/post_training/post_training.py @@ -169,6 +169,17 @@ class PostTrainingJobArtifactsResponse(BaseModel): # TODO(ashwin): metrics, evals +@json_schema_type +class RuntimeConfig(BaseModel): + """ + Provider-specific runtime configuration. Providers should document and parse their own expected fields. + This model allows arbitrary extra fields for maximum flexibility. + """ + + class Config: + extra = "allow" + + class PostTraining(Protocol): @webmethod(route="/post-training/supervised-fine-tune", method="POST") async def supervised_fine_tune( @@ -183,6 +194,7 @@ class PostTraining(Protocol): ), checkpoint_dir: Optional[str] = None, algorithm_config: Optional[AlgorithmConfig] = None, + runtime_config: Optional[RuntimeConfig] = None, ) -> PostTrainingJob: ... @webmethod(route="/post-training/preference-optimize", method="POST") @@ -194,6 +206,7 @@ class PostTraining(Protocol): training_config: TrainingConfig, hyperparam_search_config: Dict[str, Any], logger_config: Dict[str, Any], + runtime_config: Optional[RuntimeConfig] = None, ) -> PostTrainingJob: ... @webmethod(route="/post-training/jobs", method="GET") diff --git a/llama_stack/providers/inline/post_training/torchtune/post_training.py b/llama_stack/providers/inline/post_training/torchtune/post_training.py index cc1a6a5fe..7a17b9c4c 100644 --- a/llama_stack/providers/inline/post_training/torchtune/post_training.py +++ b/llama_stack/providers/inline/post_training/torchtune/post_training.py @@ -18,6 +18,7 @@ from llama_stack.apis.post_training import ( PostTrainingJob, PostTrainingJobArtifactsResponse, PostTrainingJobStatusResponse, + RuntimeConfig, TrainingConfig, ) from llama_stack.providers.inline.post_training.torchtune.config import ( @@ -80,6 +81,7 @@ class TorchtunePostTrainingImpl: model: str, checkpoint_dir: Optional[str], algorithm_config: Optional[AlgorithmConfig], + runtime_config: Optional[RuntimeConfig] = None, ) -> PostTrainingJob: if isinstance(algorithm_config, LoraFinetuningConfig):