mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-02 08:44:44 +00:00
temp commit
This commit is contained in:
parent
41cf2bb0a7
commit
2a15a8a005
12 changed files with 142 additions and 77 deletions
|
@ -7,7 +7,7 @@
|
|||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
from typing import Any, Dict, List, Optional, Protocol, Union
|
||||
from typing import Any, Dict, List, Optional, Protocol
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
|
||||
|
@ -115,21 +115,6 @@ class DPOAlignmentConfig(BaseModel):
|
|||
gamma: float
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class PostTrainingSFTRequest(BaseModel):
|
||||
"""Request to finetune a model."""
|
||||
|
||||
job_uuid: str
|
||||
model: str
|
||||
algorithm: FinetuningAlgorithm
|
||||
algorithm_config: Optional[Union[LoraFinetuningConfig, QATFinetuningConfig]] = None
|
||||
training_config: TrainingConfig
|
||||
|
||||
# TODO: define these
|
||||
hyperparam_search_config: Dict[str, Any]
|
||||
logger_config: Dict[str, Any]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class PostTrainingRLHFRequest(BaseModel):
|
||||
"""Request to finetune a model."""
|
||||
|
|
|
@ -8,16 +8,16 @@ from typing import Dict
|
|||
|
||||
from llama_stack.distribution.datatypes import Api, ProviderSpec
|
||||
|
||||
from .config import MetaReferencePostTrainingConfig
|
||||
from .config import TorchtunePostTrainingConfig
|
||||
|
||||
|
||||
async def get_provider_impl(
|
||||
config: MetaReferencePostTrainingConfig,
|
||||
config: TorchtunePostTrainingConfig,
|
||||
deps: Dict[Api, ProviderSpec],
|
||||
):
|
||||
from .post_training import MetaReferencePostTrainingImpl
|
||||
from .post_training import TorchtunePostTrainingImpl
|
||||
|
||||
impl = MetaReferencePostTrainingImpl(
|
||||
impl = TorchtunePostTrainingImpl(
|
||||
config,
|
||||
deps[Api.datasetio],
|
||||
)
|
|
@ -9,7 +9,7 @@ from typing import Optional
|
|||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class MetaReferencePostTrainingConfig(BaseModel):
|
||||
class TorchtunePostTrainingConfig(BaseModel):
|
||||
model: str = Field(
|
||||
default="Llama3.2-3B-Instruct",
|
||||
description="Model descriptor from `llama model list`",
|
|
@ -4,18 +4,30 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.providers.inline.post_training.meta_reference.config import (
|
||||
MetaReferencePostTrainingConfig,
|
||||
from llama_stack.providers.inline.post_training.torchtune.config import (
|
||||
TorchtunePostTrainingConfig,
|
||||
)
|
||||
from llama_stack.apis.post_training import * # noqa
|
||||
from llama_stack.providers.inline.post_training.meta_reference.recipes.lora_finetuning_single_device import (
|
||||
from llama_stack.providers.inline.post_training.torchtune.recipes.lora_finetuning_single_device import (
|
||||
LoraFinetuningSingleDevice,
|
||||
)
|
||||
|
||||
|
||||
class MetaReferencePostTrainingImpl:
|
||||
class PostTrainingSFTRequest(BaseModel):
|
||||
job_uuid: str
|
||||
model: str
|
||||
algorithm: FinetuningAlgorithm
|
||||
algorithm_config: Optional[Union[LoraFinetuningConfig, QATFinetuningConfig]] = None
|
||||
training_config: TrainingConfig
|
||||
|
||||
# TODO: define these
|
||||
hyperparam_search_config: Dict[str, Any]
|
||||
logger_config: Dict[str, Any]
|
||||
|
||||
|
||||
class TorchtunePostTrainingImpl:
|
||||
def __init__(
|
||||
self, config: MetaReferencePostTrainingConfig, datasetio_api: DatasetIO
|
||||
self, config: TorchtunePostTrainingConfig, datasetio_api: DatasetIO
|
||||
) -> None:
|
||||
self.config = config
|
||||
self.datasetio_api = datasetio_api
|
|
@ -18,15 +18,15 @@ from torch import nn
|
|||
from torchtune import utils as torchtune_utils
|
||||
from torchtune.training.metric_logging import DiskLogger
|
||||
from llama_stack.apis.post_training import * # noqa
|
||||
from llama_stack.apis.post_training import PostTrainingSFTRequest
|
||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||
|
||||
from llama_stack.providers.inline.post_training.meta_reference import utils
|
||||
from llama_stack.providers.inline.post_training.meta_reference.config import (
|
||||
from llama_stack.providers.inline.post_training.torchtune import utils
|
||||
from llama_stack.providers.inline.post_training.torchtune.config import (
|
||||
MetaReferencePostTrainingConfig,
|
||||
)
|
||||
from llama_stack.providers.inline.post_training.meta_reference.datasets.sft import (
|
||||
SFTDataset,
|
||||
from llama_stack.providers.inline.post_training.torchtune.datasets.sft import SFTDataset
|
||||
from llama_stack.providers.inline.post_training.torchtune.post_training import (
|
||||
PostTrainingSFTRequest,
|
||||
)
|
||||
from torch.optim import Optimizer
|
||||
from torch.utils.data import DataLoader, DistributedSampler
|
|
@ -14,6 +14,7 @@ from typing import Any, Callable, Dict
|
|||
|
||||
import torch
|
||||
from llama_models.sku_list import resolve_model
|
||||
|
||||
from torchtune.models.llama3 import llama3_tokenizer, lora_llama3_8b
|
||||
from torchtune.models.llama3._tokenizer import Llama3Tokenizer
|
||||
from torchtune.models.llama3_2 import lora_llama3_2_3b
|
|
@ -9,22 +9,14 @@ from typing import List
|
|||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
|
||||
|
||||
META_REFERENCE_DEPS = [
|
||||
"torch",
|
||||
"torchtune",
|
||||
"torchao",
|
||||
"numpy",
|
||||
]
|
||||
|
||||
|
||||
def available_providers() -> List[ProviderSpec]:
|
||||
return [
|
||||
InlineProviderSpec(
|
||||
api=Api.post_training,
|
||||
provider_type="inline::meta-reference",
|
||||
pip_packages=META_REFERENCE_DEPS,
|
||||
module="llama_stack.providers.inline.post_training.meta_reference",
|
||||
config_class="llama_stack.providers.inline.post_training.meta_reference.MetaReferencePostTrainingConfig",
|
||||
provider_type="inline::torchtune",
|
||||
pip_packages=["torch", "torchtune", "torchao", "numpy"],
|
||||
module="llama_stack.providers.inline.post_training.torchtune",
|
||||
config_class="llama_stack.providers.inline.post_training.torchtune.torchtunePostTrainingConfig",
|
||||
api_dependencies=[
|
||||
Api.datasetio,
|
||||
],
|
||||
|
|
23
llama_stack/templates/experimental-post-training/build.yaml
Normal file
23
llama_stack/templates/experimental-post-training/build.yaml
Normal file
|
@ -0,0 +1,23 @@
|
|||
version: '2'
|
||||
name: experimental-post-training
|
||||
distribution_spec:
|
||||
description: Experimental template for post training
|
||||
docker_image: null
|
||||
providers:
|
||||
post_training:
|
||||
- inline::torchtune
|
||||
datasetio:
|
||||
- remote::huggingface
|
||||
inference:
|
||||
- inline::meta-reference
|
||||
memory:
|
||||
- inline::faiss
|
||||
- remote::chromadb
|
||||
- remote::pgvector
|
||||
safety:
|
||||
- inline::llama-guard
|
||||
agents:
|
||||
- inline::meta-reference
|
||||
telemetry:
|
||||
- inline::meta-reference
|
||||
image_type: conda
|
86
llama_stack/templates/experimental-post-training/run.yaml
Normal file
86
llama_stack/templates/experimental-post-training/run.yaml
Normal file
|
@ -0,0 +1,86 @@
|
|||
version: '2'
|
||||
image_name: experimental-post-training
|
||||
docker_image: null
|
||||
conda_env: experimental-post-training
|
||||
apis:
|
||||
- agents
|
||||
- inference
|
||||
- memory
|
||||
- safety
|
||||
- telemetry
|
||||
- datasetio
|
||||
- post_training
|
||||
providers:
|
||||
inference:
|
||||
- provider_id: meta-reference-inference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
model: ${env.INFERENCE_MODEL}
|
||||
max_seq_len: 4096
|
||||
checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null}
|
||||
datasetio:
|
||||
- provider_id: huggingface-0
|
||||
provider_type: remote::huggingface
|
||||
config: {}
|
||||
memory:
|
||||
- provider_id: faiss
|
||||
provider_type: inline::faiss
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/faiss_store.db
|
||||
safety:
|
||||
- provider_id: llama-guard
|
||||
provider_type: inline::llama-guard
|
||||
config: {}
|
||||
agents:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
persistence_store:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/agents_store.db
|
||||
telemetry:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config: {}
|
||||
post_training:
|
||||
- provider_id: meta-reference-post-training
|
||||
provider_type: inline::torchtune
|
||||
config:
|
||||
model: ${env.POST_TRAINING_MODEL}
|
||||
checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null}
|
||||
|
||||
metadata_store:
|
||||
namespace: null
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/registry.db
|
||||
models:
|
||||
- metadata: {}
|
||||
model_id: ${env.INFERENCE_MODEL}
|
||||
provider_id: meta-reference-inference
|
||||
provider_model_id: null
|
||||
shields: []
|
||||
memory_banks: []
|
||||
datasets:
|
||||
- dataset_id: alpaca
|
||||
provider_id: huggingface-0
|
||||
url:
|
||||
uri: https://huggingface.co/datasets/tatsu-lab/alpaca
|
||||
metadata:
|
||||
path: tatsu-lab/alpaca
|
||||
name:
|
||||
split: train
|
||||
dataset_schema:
|
||||
instruction:
|
||||
type: string
|
||||
input:
|
||||
type: string
|
||||
output:
|
||||
type: string
|
||||
text:
|
||||
type: string
|
||||
scoring_fns: []
|
||||
eval_tasks: []
|
|
@ -4,10 +4,6 @@ distribution_spec:
|
|||
description: Use Meta Reference for running LLM inference
|
||||
docker_image: null
|
||||
providers:
|
||||
post_training:
|
||||
- inline::meta-reference
|
||||
datasetio:
|
||||
- remote::huggingface
|
||||
inference:
|
||||
- inline::meta-reference
|
||||
memory:
|
||||
|
|
|
@ -8,8 +8,6 @@ apis:
|
|||
- memory
|
||||
- safety
|
||||
- telemetry
|
||||
- datasetio
|
||||
- post_training
|
||||
providers:
|
||||
inference:
|
||||
- provider_id: meta-reference-inference
|
||||
|
@ -18,10 +16,6 @@ providers:
|
|||
model: ${env.INFERENCE_MODEL}
|
||||
max_seq_len: 4096
|
||||
checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null}
|
||||
datasetio:
|
||||
- provider_id: huggingface-0
|
||||
provider_type: remote::huggingface
|
||||
config: {}
|
||||
memory:
|
||||
- provider_id: faiss
|
||||
provider_type: inline::faiss
|
||||
|
@ -46,12 +40,6 @@ providers:
|
|||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config: {}
|
||||
post_training:
|
||||
- provider_id: meta-reference-post-training
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
model: ${env.INFERENCE_MODEL}
|
||||
checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null}
|
||||
|
||||
metadata_store:
|
||||
namespace: null
|
||||
|
@ -64,23 +52,5 @@ models:
|
|||
provider_model_id: null
|
||||
shields: []
|
||||
memory_banks: []
|
||||
datasets:
|
||||
- dataset_id: alpaca
|
||||
provider_id: huggingface-0
|
||||
url:
|
||||
uri: https://huggingface.co/datasets/tatsu-lab/alpaca
|
||||
metadata:
|
||||
path: tatsu-lab/alpaca
|
||||
name:
|
||||
split: train
|
||||
dataset_schema:
|
||||
instruction:
|
||||
type: string
|
||||
input:
|
||||
type: string
|
||||
output:
|
||||
type: string
|
||||
text:
|
||||
type: string
|
||||
scoring_fns: []
|
||||
eval_tasks: []
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue