temp commit

This commit is contained in:
Botao Chen 2024-12-04 13:59:40 -08:00
parent 41cf2bb0a7
commit 2a15a8a005
12 changed files with 142 additions and 77 deletions

View file

@ -7,7 +7,7 @@
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum
from typing import Any, Dict, List, Optional, Protocol, Union from typing import Any, Dict, List, Optional, Protocol
from llama_models.schema_utils import json_schema_type, webmethod from llama_models.schema_utils import json_schema_type, webmethod
@ -115,21 +115,6 @@ class DPOAlignmentConfig(BaseModel):
gamma: float gamma: float
@json_schema_type
class PostTrainingSFTRequest(BaseModel):
"""Request to finetune a model."""
job_uuid: str
model: str
algorithm: FinetuningAlgorithm
algorithm_config: Optional[Union[LoraFinetuningConfig, QATFinetuningConfig]] = None
training_config: TrainingConfig
# TODO: define these
hyperparam_search_config: Dict[str, Any]
logger_config: Dict[str, Any]
@json_schema_type @json_schema_type
class PostTrainingRLHFRequest(BaseModel): class PostTrainingRLHFRequest(BaseModel):
"""Request to finetune a model.""" """Request to finetune a model."""

View file

@ -8,16 +8,16 @@ from typing import Dict
from llama_stack.distribution.datatypes import Api, ProviderSpec from llama_stack.distribution.datatypes import Api, ProviderSpec
from .config import MetaReferencePostTrainingConfig from .config import TorchtunePostTrainingConfig
async def get_provider_impl( async def get_provider_impl(
config: MetaReferencePostTrainingConfig, config: TorchtunePostTrainingConfig,
deps: Dict[Api, ProviderSpec], deps: Dict[Api, ProviderSpec],
): ):
from .post_training import MetaReferencePostTrainingImpl from .post_training import TorchtunePostTrainingImpl
impl = MetaReferencePostTrainingImpl( impl = TorchtunePostTrainingImpl(
config, config,
deps[Api.datasetio], deps[Api.datasetio],
) )

View file

@ -9,7 +9,7 @@ from typing import Optional
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
class MetaReferencePostTrainingConfig(BaseModel): class TorchtunePostTrainingConfig(BaseModel):
model: str = Field( model: str = Field(
default="Llama3.2-3B-Instruct", default="Llama3.2-3B-Instruct",
description="Model descriptor from `llama model list`", description="Model descriptor from `llama model list`",

View file

@ -4,18 +4,30 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasetio import DatasetIO
from llama_stack.providers.inline.post_training.meta_reference.config import ( from llama_stack.providers.inline.post_training.torchtune.config import (
MetaReferencePostTrainingConfig, TorchtunePostTrainingConfig,
) )
from llama_stack.apis.post_training import * # noqa from llama_stack.apis.post_training import * # noqa
from llama_stack.providers.inline.post_training.meta_reference.recipes.lora_finetuning_single_device import ( from llama_stack.providers.inline.post_training.torchtune.recipes.lora_finetuning_single_device import (
LoraFinetuningSingleDevice, LoraFinetuningSingleDevice,
) )
class MetaReferencePostTrainingImpl: class PostTrainingSFTRequest(BaseModel):
job_uuid: str
model: str
algorithm: FinetuningAlgorithm
algorithm_config: Optional[Union[LoraFinetuningConfig, QATFinetuningConfig]] = None
training_config: TrainingConfig
# TODO: define these
hyperparam_search_config: Dict[str, Any]
logger_config: Dict[str, Any]
class TorchtunePostTrainingImpl:
def __init__( def __init__(
self, config: MetaReferencePostTrainingConfig, datasetio_api: DatasetIO self, config: TorchtunePostTrainingConfig, datasetio_api: DatasetIO
) -> None: ) -> None:
self.config = config self.config = config
self.datasetio_api = datasetio_api self.datasetio_api = datasetio_api

View file

@ -18,15 +18,15 @@ from torch import nn
from torchtune import utils as torchtune_utils from torchtune import utils as torchtune_utils
from torchtune.training.metric_logging import DiskLogger from torchtune.training.metric_logging import DiskLogger
from llama_stack.apis.post_training import * # noqa from llama_stack.apis.post_training import * # noqa
from llama_stack.apis.post_training import PostTrainingSFTRequest
from llama_stack.distribution.utils.model_utils import model_local_dir from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.providers.inline.post_training.meta_reference import utils from llama_stack.providers.inline.post_training.torchtune import utils
from llama_stack.providers.inline.post_training.meta_reference.config import ( from llama_stack.providers.inline.post_training.torchtune.config import (
MetaReferencePostTrainingConfig, MetaReferencePostTrainingConfig,
) )
from llama_stack.providers.inline.post_training.meta_reference.datasets.sft import ( from llama_stack.providers.inline.post_training.torchtune.datasets.sft import SFTDataset
SFTDataset, from llama_stack.providers.inline.post_training.torchtune.post_training import (
PostTrainingSFTRequest,
) )
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler from torch.utils.data import DataLoader, DistributedSampler

View file

@ -14,6 +14,7 @@ from typing import Any, Callable, Dict
import torch import torch
from llama_models.sku_list import resolve_model from llama_models.sku_list import resolve_model
from torchtune.models.llama3 import llama3_tokenizer, lora_llama3_8b from torchtune.models.llama3 import llama3_tokenizer, lora_llama3_8b
from torchtune.models.llama3._tokenizer import Llama3Tokenizer from torchtune.models.llama3._tokenizer import Llama3Tokenizer
from torchtune.models.llama3_2 import lora_llama3_2_3b from torchtune.models.llama3_2 import lora_llama3_2_3b

View file

@ -9,22 +9,14 @@ from typing import List
from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.distribution.datatypes import * # noqa: F403
META_REFERENCE_DEPS = [
"torch",
"torchtune",
"torchao",
"numpy",
]
def available_providers() -> List[ProviderSpec]: def available_providers() -> List[ProviderSpec]:
return [ return [
InlineProviderSpec( InlineProviderSpec(
api=Api.post_training, api=Api.post_training,
provider_type="inline::meta-reference", provider_type="inline::torchtune",
pip_packages=META_REFERENCE_DEPS, pip_packages=["torch", "torchtune", "torchao", "numpy"],
module="llama_stack.providers.inline.post_training.meta_reference", module="llama_stack.providers.inline.post_training.torchtune",
config_class="llama_stack.providers.inline.post_training.meta_reference.MetaReferencePostTrainingConfig", config_class="llama_stack.providers.inline.post_training.torchtune.torchtunePostTrainingConfig",
api_dependencies=[ api_dependencies=[
Api.datasetio, Api.datasetio,
], ],

View file

@ -0,0 +1,23 @@
version: '2'
name: experimental-post-training
distribution_spec:
description: Experimental template for post training
docker_image: null
providers:
post_training:
- inline::torchtune
datasetio:
- remote::huggingface
inference:
- inline::meta-reference
memory:
- inline::faiss
- remote::chromadb
- remote::pgvector
safety:
- inline::llama-guard
agents:
- inline::meta-reference
telemetry:
- inline::meta-reference
image_type: conda

View file

@ -0,0 +1,86 @@
version: '2'
image_name: experimental-post-training
docker_image: null
conda_env: experimental-post-training
apis:
- agents
- inference
- memory
- safety
- telemetry
- datasetio
- post_training
providers:
inference:
- provider_id: meta-reference-inference
provider_type: inline::meta-reference
config:
model: ${env.INFERENCE_MODEL}
max_seq_len: 4096
checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null}
datasetio:
- provider_id: huggingface-0
provider_type: remote::huggingface
config: {}
memory:
- provider_id: faiss
provider_type: inline::faiss
config:
kvstore:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/faiss_store.db
safety:
- provider_id: llama-guard
provider_type: inline::llama-guard
config: {}
agents:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
persistence_store:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/agents_store.db
telemetry:
- provider_id: meta-reference
provider_type: inline::meta-reference
config: {}
post_training:
- provider_id: meta-reference-post-training
provider_type: inline::torchtune
config:
model: ${env.POST_TRAINING_MODEL}
checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null}
metadata_store:
namespace: null
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/registry.db
models:
- metadata: {}
model_id: ${env.INFERENCE_MODEL}
provider_id: meta-reference-inference
provider_model_id: null
shields: []
memory_banks: []
datasets:
- dataset_id: alpaca
provider_id: huggingface-0
url:
uri: https://huggingface.co/datasets/tatsu-lab/alpaca
metadata:
path: tatsu-lab/alpaca
name:
split: train
dataset_schema:
instruction:
type: string
input:
type: string
output:
type: string
text:
type: string
scoring_fns: []
eval_tasks: []

View file

@ -4,10 +4,6 @@ distribution_spec:
description: Use Meta Reference for running LLM inference description: Use Meta Reference for running LLM inference
docker_image: null docker_image: null
providers: providers:
post_training:
- inline::meta-reference
datasetio:
- remote::huggingface
inference: inference:
- inline::meta-reference - inline::meta-reference
memory: memory:

View file

@ -8,8 +8,6 @@ apis:
- memory - memory
- safety - safety
- telemetry - telemetry
- datasetio
- post_training
providers: providers:
inference: inference:
- provider_id: meta-reference-inference - provider_id: meta-reference-inference
@ -18,10 +16,6 @@ providers:
model: ${env.INFERENCE_MODEL} model: ${env.INFERENCE_MODEL}
max_seq_len: 4096 max_seq_len: 4096
checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null} checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null}
datasetio:
- provider_id: huggingface-0
provider_type: remote::huggingface
config: {}
memory: memory:
- provider_id: faiss - provider_id: faiss
provider_type: inline::faiss provider_type: inline::faiss
@ -46,12 +40,6 @@ providers:
- provider_id: meta-reference - provider_id: meta-reference
provider_type: inline::meta-reference provider_type: inline::meta-reference
config: {} config: {}
post_training:
- provider_id: meta-reference-post-training
provider_type: inline::meta-reference
config:
model: ${env.INFERENCE_MODEL}
checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null}
metadata_store: metadata_store:
namespace: null namespace: null
@ -64,23 +52,5 @@ models:
provider_model_id: null provider_model_id: null
shields: [] shields: []
memory_banks: [] memory_banks: []
datasets:
- dataset_id: alpaca
provider_id: huggingface-0
url:
uri: https://huggingface.co/datasets/tatsu-lab/alpaca
metadata:
path: tatsu-lab/alpaca
name:
split: train
dataset_schema:
instruction:
type: string
input:
type: string
output:
type: string
text:
type: string
scoring_fns: [] scoring_fns: []
eval_tasks: [] eval_tasks: []