mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-02 08:44:44 +00:00
feat: add post_training RuntimeConfig
certain APIs require a bunch of runtime arguments per-provider. The best way currently to pass these arguments in is via the provider config. This is tricky because it requires a provider to be pre-configured with certain arguments that a client side user should be able to pass in at runtime Especially with the advent of out-of-tree providers, it would be great for a generic RuntimeConfig class to allow for providers to add and validate their own runtime arguments for things like supervised_fine_tune For example: https://github.com/opendatahub-io/llama-stack-provider-kft has things like `input-pvc`, `model-path`, etc in the Provider Config. This is not sustainable nor is adding each and every field needed to the post_training API spec. RuntimeConfig has a sub-class called Config which allows for extra fields to arbitrarily be specified. It is the providers job to create its own class based on this one and add valid options, parse them, etc Signed-off-by: Charlie Doern <cdoern@redhat.com>
This commit is contained in:
parent
bb1a85c9a0
commit
0ec5151ab5
4 changed files with 37 additions and 0 deletions
11
docs/_static/llama-stack-spec.html
vendored
11
docs/_static/llama-stack-spec.html
vendored
|
@ -10170,6 +10170,11 @@
|
|||
],
|
||||
"title": "OptimizerType"
|
||||
},
|
||||
"RuntimeConfig": {
|
||||
"type": "object",
|
||||
"title": "RuntimeConfig",
|
||||
"description": "Provider-specific runtime configuration. Providers should document and parse their own expected fields. This model allows arbitrary extra fields for maximum flexibility."
|
||||
},
|
||||
"TrainingConfig": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
@ -10274,6 +10279,9 @@
|
|||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"runtime_config": {
|
||||
"$ref": "#/components/schemas/RuntimeConfig"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
|
@ -11375,6 +11383,9 @@
|
|||
},
|
||||
"algorithm_config": {
|
||||
"$ref": "#/components/schemas/AlgorithmConfig"
|
||||
},
|
||||
"runtime_config": {
|
||||
"$ref": "#/components/schemas/RuntimeConfig"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
|
|
11
docs/_static/llama-stack-spec.yaml
vendored
11
docs/_static/llama-stack-spec.yaml
vendored
|
@ -7000,6 +7000,13 @@ components:
|
|||
- adamw
|
||||
- sgd
|
||||
title: OptimizerType
|
||||
RuntimeConfig:
|
||||
type: object
|
||||
title: RuntimeConfig
|
||||
description: >-
|
||||
Provider-specific runtime configuration. Providers should document and parse
|
||||
their own expected fields. This model allows arbitrary extra fields for maximum
|
||||
flexibility.
|
||||
TrainingConfig:
|
||||
type: object
|
||||
properties:
|
||||
|
@ -7060,6 +7067,8 @@ components:
|
|||
- type: string
|
||||
- type: array
|
||||
- type: object
|
||||
runtime_config:
|
||||
$ref: '#/components/schemas/RuntimeConfig'
|
||||
additionalProperties: false
|
||||
required:
|
||||
- job_uuid
|
||||
|
@ -7755,6 +7764,8 @@ components:
|
|||
type: string
|
||||
algorithm_config:
|
||||
$ref: '#/components/schemas/AlgorithmConfig'
|
||||
runtime_config:
|
||||
$ref: '#/components/schemas/RuntimeConfig'
|
||||
additionalProperties: false
|
||||
required:
|
||||
- job_uuid
|
||||
|
|
|
@ -169,6 +169,17 @@ class PostTrainingJobArtifactsResponse(BaseModel):
|
|||
# TODO(ashwin): metrics, evals
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RuntimeConfig(BaseModel):
|
||||
"""
|
||||
Provider-specific runtime configuration. Providers should document and parse their own expected fields.
|
||||
This model allows arbitrary extra fields for maximum flexibility.
|
||||
"""
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
|
||||
class PostTraining(Protocol):
|
||||
@webmethod(route="/post-training/supervised-fine-tune", method="POST")
|
||||
async def supervised_fine_tune(
|
||||
|
@ -183,6 +194,7 @@ class PostTraining(Protocol):
|
|||
),
|
||||
checkpoint_dir: Optional[str] = None,
|
||||
algorithm_config: Optional[AlgorithmConfig] = None,
|
||||
runtime_config: Optional[RuntimeConfig] = None,
|
||||
) -> PostTrainingJob: ...
|
||||
|
||||
@webmethod(route="/post-training/preference-optimize", method="POST")
|
||||
|
@ -194,6 +206,7 @@ class PostTraining(Protocol):
|
|||
training_config: TrainingConfig,
|
||||
hyperparam_search_config: Dict[str, Any],
|
||||
logger_config: Dict[str, Any],
|
||||
runtime_config: Optional[RuntimeConfig] = None,
|
||||
) -> PostTrainingJob: ...
|
||||
|
||||
@webmethod(route="/post-training/jobs", method="GET")
|
||||
|
|
|
@ -18,6 +18,7 @@ from llama_stack.apis.post_training import (
|
|||
PostTrainingJob,
|
||||
PostTrainingJobArtifactsResponse,
|
||||
PostTrainingJobStatusResponse,
|
||||
RuntimeConfig,
|
||||
TrainingConfig,
|
||||
)
|
||||
from llama_stack.providers.inline.post_training.torchtune.config import (
|
||||
|
@ -80,6 +81,7 @@ class TorchtunePostTrainingImpl:
|
|||
model: str,
|
||||
checkpoint_dir: Optional[str],
|
||||
algorithm_config: Optional[AlgorithmConfig],
|
||||
runtime_config: Optional[RuntimeConfig] = None,
|
||||
) -> PostTrainingJob:
|
||||
if isinstance(algorithm_config, LoraFinetuningConfig):
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue