temp commit

This commit is contained in:
Botao Chen 2024-12-04 13:59:40 -08:00
parent 41cf2bb0a7
commit 2a15a8a005
12 changed files with 142 additions and 77 deletions

View file

@ -7,7 +7,7 @@
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Optional, Protocol, Union
from typing import Any, Dict, List, Optional, Protocol
from llama_models.schema_utils import json_schema_type, webmethod
@ -115,21 +115,6 @@ class DPOAlignmentConfig(BaseModel):
gamma: float
@json_schema_type
class PostTrainingSFTRequest(BaseModel):
"""Request to finetune a model."""
job_uuid: str
model: str
algorithm: FinetuningAlgorithm
algorithm_config: Optional[Union[LoraFinetuningConfig, QATFinetuningConfig]] = None
training_config: TrainingConfig
# TODO: define these
hyperparam_search_config: Dict[str, Any]
logger_config: Dict[str, Any]
@json_schema_type
class PostTrainingRLHFRequest(BaseModel):
"""Request to finetune a model."""

View file

@ -8,16 +8,16 @@ from typing import Dict
from llama_stack.distribution.datatypes import Api, ProviderSpec
from .config import MetaReferencePostTrainingConfig
from .config import TorchtunePostTrainingConfig
async def get_provider_impl(
config: MetaReferencePostTrainingConfig,
config: TorchtunePostTrainingConfig,
deps: Dict[Api, ProviderSpec],
):
from .post_training import MetaReferencePostTrainingImpl
from .post_training import TorchtunePostTrainingImpl
impl = MetaReferencePostTrainingImpl(
impl = TorchtunePostTrainingImpl(
config,
deps[Api.datasetio],
)

View file

@ -9,7 +9,7 @@ from typing import Optional
from pydantic import BaseModel, Field
class MetaReferencePostTrainingConfig(BaseModel):
class TorchtunePostTrainingConfig(BaseModel):
model: str = Field(
default="Llama3.2-3B-Instruct",
description="Model descriptor from `llama model list`",

View file

@ -4,18 +4,30 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.providers.inline.post_training.meta_reference.config import (
MetaReferencePostTrainingConfig,
from llama_stack.providers.inline.post_training.torchtune.config import (
TorchtunePostTrainingConfig,
)
from llama_stack.apis.post_training import * # noqa
from llama_stack.providers.inline.post_training.meta_reference.recipes.lora_finetuning_single_device import (
from llama_stack.providers.inline.post_training.torchtune.recipes.lora_finetuning_single_device import (
LoraFinetuningSingleDevice,
)
class MetaReferencePostTrainingImpl:
class PostTrainingSFTRequest(BaseModel):
job_uuid: str
model: str
algorithm: FinetuningAlgorithm
algorithm_config: Optional[Union[LoraFinetuningConfig, QATFinetuningConfig]] = None
training_config: TrainingConfig
# TODO: define these
hyperparam_search_config: Dict[str, Any]
logger_config: Dict[str, Any]
class TorchtunePostTrainingImpl:
def __init__(
self, config: MetaReferencePostTrainingConfig, datasetio_api: DatasetIO
self, config: TorchtunePostTrainingConfig, datasetio_api: DatasetIO
) -> None:
self.config = config
self.datasetio_api = datasetio_api

View file

@ -18,15 +18,15 @@ from torch import nn
from torchtune import utils as torchtune_utils
from torchtune.training.metric_logging import DiskLogger
from llama_stack.apis.post_training import * # noqa
from llama_stack.apis.post_training import PostTrainingSFTRequest
from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.providers.inline.post_training.meta_reference import utils
from llama_stack.providers.inline.post_training.meta_reference.config import (
from llama_stack.providers.inline.post_training.torchtune import utils
from llama_stack.providers.inline.post_training.torchtune.config import (
MetaReferencePostTrainingConfig,
)
from llama_stack.providers.inline.post_training.meta_reference.datasets.sft import (
SFTDataset,
from llama_stack.providers.inline.post_training.torchtune.datasets.sft import SFTDataset
from llama_stack.providers.inline.post_training.torchtune.post_training import (
PostTrainingSFTRequest,
)
from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler

View file

@ -14,6 +14,7 @@ from typing import Any, Callable, Dict
import torch
from llama_models.sku_list import resolve_model
from torchtune.models.llama3 import llama3_tokenizer, lora_llama3_8b
from torchtune.models.llama3._tokenizer import Llama3Tokenizer
from torchtune.models.llama3_2 import lora_llama3_2_3b

View file

@ -9,22 +9,14 @@ from typing import List
from llama_stack.distribution.datatypes import * # noqa: F403
META_REFERENCE_DEPS = [
"torch",
"torchtune",
"torchao",
"numpy",
]
def available_providers() -> List[ProviderSpec]:
return [
InlineProviderSpec(
api=Api.post_training,
provider_type="inline::meta-reference",
pip_packages=META_REFERENCE_DEPS,
module="llama_stack.providers.inline.post_training.meta_reference",
config_class="llama_stack.providers.inline.post_training.meta_reference.MetaReferencePostTrainingConfig",
provider_type="inline::torchtune",
pip_packages=["torch", "torchtune", "torchao", "numpy"],
module="llama_stack.providers.inline.post_training.torchtune",
config_class="llama_stack.providers.inline.post_training.torchtune.torchtunePostTrainingConfig",
api_dependencies=[
Api.datasetio,
],

View file

@ -0,0 +1,23 @@
version: '2'
name: experimental-post-training
distribution_spec:
description: Experimental template for post training
docker_image: null
providers:
post_training:
- inline::torchtune
datasetio:
- remote::huggingface
inference:
- inline::meta-reference
memory:
- inline::faiss
- remote::chromadb
- remote::pgvector
safety:
- inline::llama-guard
agents:
- inline::meta-reference
telemetry:
- inline::meta-reference
image_type: conda

View file

@ -0,0 +1,86 @@
version: '2'
image_name: experimental-post-training
docker_image: null
conda_env: experimental-post-training
apis:
- agents
- inference
- memory
- safety
- telemetry
- datasetio
- post_training
providers:
inference:
- provider_id: meta-reference-inference
provider_type: inline::meta-reference
config:
model: ${env.INFERENCE_MODEL}
max_seq_len: 4096
checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null}
datasetio:
- provider_id: huggingface-0
provider_type: remote::huggingface
config: {}
memory:
- provider_id: faiss
provider_type: inline::faiss
config:
kvstore:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/faiss_store.db
safety:
- provider_id: llama-guard
provider_type: inline::llama-guard
config: {}
agents:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
persistence_store:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/agents_store.db
telemetry:
- provider_id: meta-reference
provider_type: inline::meta-reference
config: {}
post_training:
- provider_id: meta-reference-post-training
provider_type: inline::torchtune
config:
model: ${env.POST_TRAINING_MODEL}
checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null}
metadata_store:
namespace: null
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/registry.db
models:
- metadata: {}
model_id: ${env.INFERENCE_MODEL}
provider_id: meta-reference-inference
provider_model_id: null
shields: []
memory_banks: []
datasets:
- dataset_id: alpaca
provider_id: huggingface-0
url:
uri: https://huggingface.co/datasets/tatsu-lab/alpaca
metadata:
path: tatsu-lab/alpaca
name:
split: train
dataset_schema:
instruction:
type: string
input:
type: string
output:
type: string
text:
type: string
scoring_fns: []
eval_tasks: []

View file

@ -4,10 +4,6 @@ distribution_spec:
description: Use Meta Reference for running LLM inference
docker_image: null
providers:
post_training:
- inline::meta-reference
datasetio:
- remote::huggingface
inference:
- inline::meta-reference
memory:

View file

@ -8,8 +8,6 @@ apis:
- memory
- safety
- telemetry
- datasetio
- post_training
providers:
inference:
- provider_id: meta-reference-inference
@ -18,10 +16,6 @@ providers:
model: ${env.INFERENCE_MODEL}
max_seq_len: 4096
checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null}
datasetio:
- provider_id: huggingface-0
provider_type: remote::huggingface
config: {}
memory:
- provider_id: faiss
provider_type: inline::faiss
@ -46,12 +40,6 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config: {}
post_training:
- provider_id: meta-reference-post-training
provider_type: inline::meta-reference
config:
model: ${env.INFERENCE_MODEL}
checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null}
metadata_store:
namespace: null
@ -64,23 +52,5 @@ models:
provider_model_id: null
shields: []
memory_banks: []
datasets:
- dataset_id: alpaca
provider_id: huggingface-0
url:
uri: https://huggingface.co/datasets/tatsu-lab/alpaca
metadata:
path: tatsu-lab/alpaca
name:
split: train
dataset_schema:
instruction:
type: string
input:
type: string
output:
type: string
text:
type: string
scoring_fns: []
eval_tasks: []