diff --git a/llama_stack/apis/post_training/post_training.py b/llama_stack/apis/post_training/post_training.py index 63df97c68..9202e1753 100644 --- a/llama_stack/apis/post_training/post_training.py +++ b/llama_stack/apis/post_training/post_training.py @@ -7,7 +7,7 @@ from datetime import datetime from enum import Enum -from typing import Any, Dict, List, Optional, Protocol, Union +from typing import Any, Dict, List, Optional, Protocol from llama_models.schema_utils import json_schema_type, webmethod @@ -115,21 +115,6 @@ class DPOAlignmentConfig(BaseModel): gamma: float -@json_schema_type -class PostTrainingSFTRequest(BaseModel): - """Request to finetune a model.""" - - job_uuid: str - model: str - algorithm: FinetuningAlgorithm - algorithm_config: Optional[Union[LoraFinetuningConfig, QATFinetuningConfig]] = None - training_config: TrainingConfig - - # TODO: define these - hyperparam_search_config: Dict[str, Any] - logger_config: Dict[str, Any] - - @json_schema_type class PostTrainingRLHFRequest(BaseModel): """Request to finetune a model.""" diff --git a/llama_stack/providers/inline/post_training/meta_reference/__init__.py b/llama_stack/providers/inline/post_training/torchtune/__init__.py similarity index 69% rename from llama_stack/providers/inline/post_training/meta_reference/__init__.py rename to llama_stack/providers/inline/post_training/torchtune/__init__.py index d700fbb0a..a756385de 100644 --- a/llama_stack/providers/inline/post_training/meta_reference/__init__.py +++ b/llama_stack/providers/inline/post_training/torchtune/__init__.py @@ -8,16 +8,16 @@ from typing import Dict from llama_stack.distribution.datatypes import Api, ProviderSpec -from .config import MetaReferencePostTrainingConfig +from .config import TorchtunePostTrainingConfig async def get_provider_impl( - config: MetaReferencePostTrainingConfig, + config: TorchtunePostTrainingConfig, deps: Dict[Api, ProviderSpec], ): - from .post_training import MetaReferencePostTrainingImpl + from .post_training import TorchtunePostTrainingImpl - impl = MetaReferencePostTrainingImpl( + impl = TorchtunePostTrainingImpl( config, deps[Api.datasetio], ) diff --git a/llama_stack/providers/inline/post_training/meta_reference/config.py b/llama_stack/providers/inline/post_training/torchtune/config.py similarity index 92% rename from llama_stack/providers/inline/post_training/meta_reference/config.py rename to llama_stack/providers/inline/post_training/torchtune/config.py index 880fb6070..7bfea0f9d 100644 --- a/llama_stack/providers/inline/post_training/meta_reference/config.py +++ b/llama_stack/providers/inline/post_training/torchtune/config.py @@ -9,7 +9,7 @@ from typing import Optional from pydantic import BaseModel, Field -class MetaReferencePostTrainingConfig(BaseModel): +class TorchtunePostTrainingConfig(BaseModel): model: str = Field( default="Llama3.2-3B-Instruct", description="Model descriptor from `llama model list`", diff --git a/llama_stack/providers/inline/post_training/meta_reference/datasets/sft.py b/llama_stack/providers/inline/post_training/torchtune/datasets/sft.py similarity index 100% rename from llama_stack/providers/inline/post_training/meta_reference/datasets/sft.py rename to llama_stack/providers/inline/post_training/torchtune/datasets/sft.py diff --git a/llama_stack/providers/inline/post_training/meta_reference/post_training.py b/llama_stack/providers/inline/post_training/torchtune/post_training.py similarity index 80% rename from llama_stack/providers/inline/post_training/meta_reference/post_training.py rename to llama_stack/providers/inline/post_training/torchtune/post_training.py index 8ab98f7d4..83a8ef02f 100644 --- a/llama_stack/providers/inline/post_training/meta_reference/post_training.py +++ b/llama_stack/providers/inline/post_training/torchtune/post_training.py @@ -4,18 +4,30 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. from llama_stack.apis.datasetio import DatasetIO -from llama_stack.providers.inline.post_training.meta_reference.config import ( - MetaReferencePostTrainingConfig, +from llama_stack.providers.inline.post_training.torchtune.config import ( + TorchtunePostTrainingConfig, ) from llama_stack.apis.post_training import * # noqa -from llama_stack.providers.inline.post_training.meta_reference.recipes.lora_finetuning_single_device import ( +from llama_stack.providers.inline.post_training.torchtune.recipes.lora_finetuning_single_device import ( LoraFinetuningSingleDevice, ) -class MetaReferencePostTrainingImpl: +class PostTrainingSFTRequest(BaseModel): + job_uuid: str + model: str + algorithm: FinetuningAlgorithm + algorithm_config: Optional[Union[LoraFinetuningConfig, QATFinetuningConfig]] = None + training_config: TrainingConfig + + # TODO: define these + hyperparam_search_config: Dict[str, Any] + logger_config: Dict[str, Any] + + +class TorchtunePostTrainingImpl: def __init__( - self, config: MetaReferencePostTrainingConfig, datasetio_api: DatasetIO + self, config: TorchtunePostTrainingConfig, datasetio_api: DatasetIO ) -> None: self.config = config self.datasetio_api = datasetio_api diff --git a/llama_stack/providers/inline/post_training/meta_reference/recipes/lora_finetuning_single_device.py b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py similarity index 98% rename from llama_stack/providers/inline/post_training/meta_reference/recipes/lora_finetuning_single_device.py rename to llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py index 30b315329..0603347bb 100644 --- a/llama_stack/providers/inline/post_training/meta_reference/recipes/lora_finetuning_single_device.py +++ b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py @@ -18,15 +18,15 @@ from torch import nn from torchtune import utils as torchtune_utils from torchtune.training.metric_logging import DiskLogger from llama_stack.apis.post_training import * # noqa -from llama_stack.apis.post_training import PostTrainingSFTRequest from llama_stack.distribution.utils.model_utils import model_local_dir -from llama_stack.providers.inline.post_training.meta_reference import utils -from llama_stack.providers.inline.post_training.meta_reference.config import ( +from llama_stack.providers.inline.post_training.torchtune import utils +from llama_stack.providers.inline.post_training.torchtune.config import ( MetaReferencePostTrainingConfig, ) -from llama_stack.providers.inline.post_training.meta_reference.datasets.sft import ( - SFTDataset, +from llama_stack.providers.inline.post_training.torchtune.datasets.sft import SFTDataset +from llama_stack.providers.inline.post_training.torchtune.post_training import ( + PostTrainingSFTRequest, ) from torch.optim import Optimizer from torch.utils.data import DataLoader, DistributedSampler diff --git a/llama_stack/providers/inline/post_training/meta_reference/utils.py b/llama_stack/providers/inline/post_training/torchtune/utils.py similarity index 99% rename from llama_stack/providers/inline/post_training/meta_reference/utils.py rename to llama_stack/providers/inline/post_training/torchtune/utils.py index 70280081f..93c7ef189 100644 --- a/llama_stack/providers/inline/post_training/meta_reference/utils.py +++ b/llama_stack/providers/inline/post_training/torchtune/utils.py @@ -14,6 +14,7 @@ from typing import Any, Callable, Dict import torch from llama_models.sku_list import resolve_model + from torchtune.models.llama3 import llama3_tokenizer, lora_llama3_8b from torchtune.models.llama3._tokenizer import Llama3Tokenizer from torchtune.models.llama3_2 import lora_llama3_2_3b diff --git a/llama_stack/providers/registry/post_training.py b/llama_stack/providers/registry/post_training.py index 67fcba0ef..fc4b93c40 100644 --- a/llama_stack/providers/registry/post_training.py +++ b/llama_stack/providers/registry/post_training.py @@ -9,22 +9,14 @@ from typing import List from llama_stack.distribution.datatypes import * # noqa: F403 -META_REFERENCE_DEPS = [ - "torch", - "torchtune", - "torchao", - "numpy", -] - - def available_providers() -> List[ProviderSpec]: return [ InlineProviderSpec( api=Api.post_training, - provider_type="inline::meta-reference", - pip_packages=META_REFERENCE_DEPS, - module="llama_stack.providers.inline.post_training.meta_reference", - config_class="llama_stack.providers.inline.post_training.meta_reference.MetaReferencePostTrainingConfig", + provider_type="inline::torchtune", + pip_packages=["torch", "torchtune", "torchao", "numpy"], + module="llama_stack.providers.inline.post_training.torchtune", + config_class="llama_stack.providers.inline.post_training.torchtune.torchtunePostTrainingConfig", api_dependencies=[ Api.datasetio, ], diff --git a/llama_stack/templates/experimental-post-training/build.yaml b/llama_stack/templates/experimental-post-training/build.yaml new file mode 100644 index 000000000..32afdbc0f --- /dev/null +++ b/llama_stack/templates/experimental-post-training/build.yaml @@ -0,0 +1,23 @@ +version: '2' +name: experimental-post-training +distribution_spec: + description: Experimental template for post training + docker_image: null + providers: + post_training: + - inline::torchtune + datasetio: + - remote::huggingface + inference: + - inline::meta-reference + memory: + - inline::faiss + - remote::chromadb + - remote::pgvector + safety: + - inline::llama-guard + agents: + - inline::meta-reference + telemetry: + - inline::meta-reference +image_type: conda diff --git a/llama_stack/templates/experimental-post-training/run.yaml b/llama_stack/templates/experimental-post-training/run.yaml new file mode 100644 index 000000000..e50280401 --- /dev/null +++ b/llama_stack/templates/experimental-post-training/run.yaml @@ -0,0 +1,86 @@ +version: '2' +image_name: experimental-post-training +docker_image: null +conda_env: experimental-post-training +apis: +- agents +- inference +- memory +- safety +- telemetry +- datasetio +- post_training +providers: + inference: + - provider_id: meta-reference-inference + provider_type: inline::meta-reference + config: + model: ${env.INFERENCE_MODEL} + max_seq_len: 4096 + checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null} + datasetio: + - provider_id: huggingface-0 + provider_type: remote::huggingface + config: {} + memory: + - provider_id: faiss + provider_type: inline::faiss + config: + kvstore: + type: sqlite + namespace: null + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/faiss_store.db + safety: + - provider_id: llama-guard + provider_type: inline::llama-guard + config: {} + agents: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: + persistence_store: + type: sqlite + namespace: null + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/agents_store.db + telemetry: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: {} + post_training: + - provider_id: meta-reference-post-training + provider_type: inline::torchtune + config: + model: ${env.POST_TRAINING_MODEL} + checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null} + +metadata_store: + namespace: null + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/registry.db +models: +- metadata: {} + model_id: ${env.INFERENCE_MODEL} + provider_id: meta-reference-inference + provider_model_id: null +shields: [] +memory_banks: [] +datasets: + - dataset_id: alpaca + provider_id: huggingface-0 + url: + uri: https://huggingface.co/datasets/tatsu-lab/alpaca + metadata: + path: tatsu-lab/alpaca + name: + split: train + dataset_schema: + instruction: + type: string + input: + type: string + output: + type: string + text: + type: string +scoring_fns: [] +eval_tasks: [] diff --git a/llama_stack/templates/meta-reference-gpu/build.yaml b/llama_stack/templates/meta-reference-gpu/build.yaml index 469d3b71b..ef075d098 100644 --- a/llama_stack/templates/meta-reference-gpu/build.yaml +++ b/llama_stack/templates/meta-reference-gpu/build.yaml @@ -4,10 +4,6 @@ distribution_spec: description: Use Meta Reference for running LLM inference docker_image: null providers: - post_training: - - inline::meta-reference - datasetio: - - remote::huggingface inference: - inline::meta-reference memory: diff --git a/llama_stack/templates/meta-reference-gpu/run.yaml b/llama_stack/templates/meta-reference-gpu/run.yaml index f19aa180e..85f5f21f6 100644 --- a/llama_stack/templates/meta-reference-gpu/run.yaml +++ b/llama_stack/templates/meta-reference-gpu/run.yaml @@ -8,8 +8,6 @@ apis: - memory - safety - telemetry -- datasetio -- post_training providers: inference: - provider_id: meta-reference-inference @@ -18,10 +16,6 @@ providers: model: ${env.INFERENCE_MODEL} max_seq_len: 4096 checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null} - datasetio: - - provider_id: huggingface-0 - provider_type: remote::huggingface - config: {} memory: - provider_id: faiss provider_type: inline::faiss @@ -46,12 +40,6 @@ providers: - provider_id: meta-reference provider_type: inline::meta-reference config: {} - post_training: - - provider_id: meta-reference-post-training - provider_type: inline::meta-reference - config: - model: ${env.INFERENCE_MODEL} - checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null} metadata_store: namespace: null @@ -64,23 +52,5 @@ models: provider_model_id: null shields: [] memory_banks: [] -datasets: - - dataset_id: alpaca - provider_id: huggingface-0 - url: - uri: https://huggingface.co/datasets/tatsu-lab/alpaca - metadata: - path: tatsu-lab/alpaca - name: - split: train - dataset_schema: - instruction: - type: string - input: - type: string - output: - type: string - text: - type: string scoring_fns: [] eval_tasks: []