mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-02 08:44:44 +00:00
temp commit
This commit is contained in:
parent
41cf2bb0a7
commit
2a15a8a005
12 changed files with 142 additions and 77 deletions
|
@ -7,7 +7,7 @@
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional, Protocol, Union
|
from typing import Any, Dict, List, Optional, Protocol
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, webmethod
|
from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
|
@ -115,21 +115,6 @@ class DPOAlignmentConfig(BaseModel):
|
||||||
gamma: float
|
gamma: float
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class PostTrainingSFTRequest(BaseModel):
|
|
||||||
"""Request to finetune a model."""
|
|
||||||
|
|
||||||
job_uuid: str
|
|
||||||
model: str
|
|
||||||
algorithm: FinetuningAlgorithm
|
|
||||||
algorithm_config: Optional[Union[LoraFinetuningConfig, QATFinetuningConfig]] = None
|
|
||||||
training_config: TrainingConfig
|
|
||||||
|
|
||||||
# TODO: define these
|
|
||||||
hyperparam_search_config: Dict[str, Any]
|
|
||||||
logger_config: Dict[str, Any]
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class PostTrainingRLHFRequest(BaseModel):
|
class PostTrainingRLHFRequest(BaseModel):
|
||||||
"""Request to finetune a model."""
|
"""Request to finetune a model."""
|
||||||
|
|
|
@ -8,16 +8,16 @@ from typing import Dict
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import Api, ProviderSpec
|
from llama_stack.distribution.datatypes import Api, ProviderSpec
|
||||||
|
|
||||||
from .config import MetaReferencePostTrainingConfig
|
from .config import TorchtunePostTrainingConfig
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(
|
async def get_provider_impl(
|
||||||
config: MetaReferencePostTrainingConfig,
|
config: TorchtunePostTrainingConfig,
|
||||||
deps: Dict[Api, ProviderSpec],
|
deps: Dict[Api, ProviderSpec],
|
||||||
):
|
):
|
||||||
from .post_training import MetaReferencePostTrainingImpl
|
from .post_training import TorchtunePostTrainingImpl
|
||||||
|
|
||||||
impl = MetaReferencePostTrainingImpl(
|
impl = TorchtunePostTrainingImpl(
|
||||||
config,
|
config,
|
||||||
deps[Api.datasetio],
|
deps[Api.datasetio],
|
||||||
)
|
)
|
|
@ -9,7 +9,7 @@ from typing import Optional
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
class MetaReferencePostTrainingConfig(BaseModel):
|
class TorchtunePostTrainingConfig(BaseModel):
|
||||||
model: str = Field(
|
model: str = Field(
|
||||||
default="Llama3.2-3B-Instruct",
|
default="Llama3.2-3B-Instruct",
|
||||||
description="Model descriptor from `llama model list`",
|
description="Model descriptor from `llama model list`",
|
|
@ -4,18 +4,30 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
from llama_stack.apis.datasetio import DatasetIO
|
from llama_stack.apis.datasetio import DatasetIO
|
||||||
from llama_stack.providers.inline.post_training.meta_reference.config import (
|
from llama_stack.providers.inline.post_training.torchtune.config import (
|
||||||
MetaReferencePostTrainingConfig,
|
TorchtunePostTrainingConfig,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.post_training import * # noqa
|
from llama_stack.apis.post_training import * # noqa
|
||||||
from llama_stack.providers.inline.post_training.meta_reference.recipes.lora_finetuning_single_device import (
|
from llama_stack.providers.inline.post_training.torchtune.recipes.lora_finetuning_single_device import (
|
||||||
LoraFinetuningSingleDevice,
|
LoraFinetuningSingleDevice,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class MetaReferencePostTrainingImpl:
|
class PostTrainingSFTRequest(BaseModel):
|
||||||
|
job_uuid: str
|
||||||
|
model: str
|
||||||
|
algorithm: FinetuningAlgorithm
|
||||||
|
algorithm_config: Optional[Union[LoraFinetuningConfig, QATFinetuningConfig]] = None
|
||||||
|
training_config: TrainingConfig
|
||||||
|
|
||||||
|
# TODO: define these
|
||||||
|
hyperparam_search_config: Dict[str, Any]
|
||||||
|
logger_config: Dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
class TorchtunePostTrainingImpl:
|
||||||
def __init__(
|
def __init__(
|
||||||
self, config: MetaReferencePostTrainingConfig, datasetio_api: DatasetIO
|
self, config: TorchtunePostTrainingConfig, datasetio_api: DatasetIO
|
||||||
) -> None:
|
) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
self.datasetio_api = datasetio_api
|
self.datasetio_api = datasetio_api
|
|
@ -18,15 +18,15 @@ from torch import nn
|
||||||
from torchtune import utils as torchtune_utils
|
from torchtune import utils as torchtune_utils
|
||||||
from torchtune.training.metric_logging import DiskLogger
|
from torchtune.training.metric_logging import DiskLogger
|
||||||
from llama_stack.apis.post_training import * # noqa
|
from llama_stack.apis.post_training import * # noqa
|
||||||
from llama_stack.apis.post_training import PostTrainingSFTRequest
|
|
||||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||||
|
|
||||||
from llama_stack.providers.inline.post_training.meta_reference import utils
|
from llama_stack.providers.inline.post_training.torchtune import utils
|
||||||
from llama_stack.providers.inline.post_training.meta_reference.config import (
|
from llama_stack.providers.inline.post_training.torchtune.config import (
|
||||||
MetaReferencePostTrainingConfig,
|
MetaReferencePostTrainingConfig,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.inline.post_training.meta_reference.datasets.sft import (
|
from llama_stack.providers.inline.post_training.torchtune.datasets.sft import SFTDataset
|
||||||
SFTDataset,
|
from llama_stack.providers.inline.post_training.torchtune.post_training import (
|
||||||
|
PostTrainingSFTRequest,
|
||||||
)
|
)
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
from torch.utils.data import DataLoader, DistributedSampler
|
from torch.utils.data import DataLoader, DistributedSampler
|
|
@ -14,6 +14,7 @@ from typing import Any, Callable, Dict
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from llama_models.sku_list import resolve_model
|
from llama_models.sku_list import resolve_model
|
||||||
|
|
||||||
from torchtune.models.llama3 import llama3_tokenizer, lora_llama3_8b
|
from torchtune.models.llama3 import llama3_tokenizer, lora_llama3_8b
|
||||||
from torchtune.models.llama3._tokenizer import Llama3Tokenizer
|
from torchtune.models.llama3._tokenizer import Llama3Tokenizer
|
||||||
from torchtune.models.llama3_2 import lora_llama3_2_3b
|
from torchtune.models.llama3_2 import lora_llama3_2_3b
|
|
@ -9,22 +9,14 @@ from typing import List
|
||||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
META_REFERENCE_DEPS = [
|
|
||||||
"torch",
|
|
||||||
"torchtune",
|
|
||||||
"torchao",
|
|
||||||
"numpy",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def available_providers() -> List[ProviderSpec]:
|
def available_providers() -> List[ProviderSpec]:
|
||||||
return [
|
return [
|
||||||
InlineProviderSpec(
|
InlineProviderSpec(
|
||||||
api=Api.post_training,
|
api=Api.post_training,
|
||||||
provider_type="inline::meta-reference",
|
provider_type="inline::torchtune",
|
||||||
pip_packages=META_REFERENCE_DEPS,
|
pip_packages=["torch", "torchtune", "torchao", "numpy"],
|
||||||
module="llama_stack.providers.inline.post_training.meta_reference",
|
module="llama_stack.providers.inline.post_training.torchtune",
|
||||||
config_class="llama_stack.providers.inline.post_training.meta_reference.MetaReferencePostTrainingConfig",
|
config_class="llama_stack.providers.inline.post_training.torchtune.torchtunePostTrainingConfig",
|
||||||
api_dependencies=[
|
api_dependencies=[
|
||||||
Api.datasetio,
|
Api.datasetio,
|
||||||
],
|
],
|
||||||
|
|
23
llama_stack/templates/experimental-post-training/build.yaml
Normal file
23
llama_stack/templates/experimental-post-training/build.yaml
Normal file
|
@ -0,0 +1,23 @@
|
||||||
|
version: '2'
|
||||||
|
name: experimental-post-training
|
||||||
|
distribution_spec:
|
||||||
|
description: Experimental template for post training
|
||||||
|
docker_image: null
|
||||||
|
providers:
|
||||||
|
post_training:
|
||||||
|
- inline::torchtune
|
||||||
|
datasetio:
|
||||||
|
- remote::huggingface
|
||||||
|
inference:
|
||||||
|
- inline::meta-reference
|
||||||
|
memory:
|
||||||
|
- inline::faiss
|
||||||
|
- remote::chromadb
|
||||||
|
- remote::pgvector
|
||||||
|
safety:
|
||||||
|
- inline::llama-guard
|
||||||
|
agents:
|
||||||
|
- inline::meta-reference
|
||||||
|
telemetry:
|
||||||
|
- inline::meta-reference
|
||||||
|
image_type: conda
|
86
llama_stack/templates/experimental-post-training/run.yaml
Normal file
86
llama_stack/templates/experimental-post-training/run.yaml
Normal file
|
@ -0,0 +1,86 @@
|
||||||
|
version: '2'
|
||||||
|
image_name: experimental-post-training
|
||||||
|
docker_image: null
|
||||||
|
conda_env: experimental-post-training
|
||||||
|
apis:
|
||||||
|
- agents
|
||||||
|
- inference
|
||||||
|
- memory
|
||||||
|
- safety
|
||||||
|
- telemetry
|
||||||
|
- datasetio
|
||||||
|
- post_training
|
||||||
|
providers:
|
||||||
|
inference:
|
||||||
|
- provider_id: meta-reference-inference
|
||||||
|
provider_type: inline::meta-reference
|
||||||
|
config:
|
||||||
|
model: ${env.INFERENCE_MODEL}
|
||||||
|
max_seq_len: 4096
|
||||||
|
checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null}
|
||||||
|
datasetio:
|
||||||
|
- provider_id: huggingface-0
|
||||||
|
provider_type: remote::huggingface
|
||||||
|
config: {}
|
||||||
|
memory:
|
||||||
|
- provider_id: faiss
|
||||||
|
provider_type: inline::faiss
|
||||||
|
config:
|
||||||
|
kvstore:
|
||||||
|
type: sqlite
|
||||||
|
namespace: null
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/faiss_store.db
|
||||||
|
safety:
|
||||||
|
- provider_id: llama-guard
|
||||||
|
provider_type: inline::llama-guard
|
||||||
|
config: {}
|
||||||
|
agents:
|
||||||
|
- provider_id: meta-reference
|
||||||
|
provider_type: inline::meta-reference
|
||||||
|
config:
|
||||||
|
persistence_store:
|
||||||
|
type: sqlite
|
||||||
|
namespace: null
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/agents_store.db
|
||||||
|
telemetry:
|
||||||
|
- provider_id: meta-reference
|
||||||
|
provider_type: inline::meta-reference
|
||||||
|
config: {}
|
||||||
|
post_training:
|
||||||
|
- provider_id: meta-reference-post-training
|
||||||
|
provider_type: inline::torchtune
|
||||||
|
config:
|
||||||
|
model: ${env.POST_TRAINING_MODEL}
|
||||||
|
checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null}
|
||||||
|
|
||||||
|
metadata_store:
|
||||||
|
namespace: null
|
||||||
|
type: sqlite
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/registry.db
|
||||||
|
models:
|
||||||
|
- metadata: {}
|
||||||
|
model_id: ${env.INFERENCE_MODEL}
|
||||||
|
provider_id: meta-reference-inference
|
||||||
|
provider_model_id: null
|
||||||
|
shields: []
|
||||||
|
memory_banks: []
|
||||||
|
datasets:
|
||||||
|
- dataset_id: alpaca
|
||||||
|
provider_id: huggingface-0
|
||||||
|
url:
|
||||||
|
uri: https://huggingface.co/datasets/tatsu-lab/alpaca
|
||||||
|
metadata:
|
||||||
|
path: tatsu-lab/alpaca
|
||||||
|
name:
|
||||||
|
split: train
|
||||||
|
dataset_schema:
|
||||||
|
instruction:
|
||||||
|
type: string
|
||||||
|
input:
|
||||||
|
type: string
|
||||||
|
output:
|
||||||
|
type: string
|
||||||
|
text:
|
||||||
|
type: string
|
||||||
|
scoring_fns: []
|
||||||
|
eval_tasks: []
|
|
@ -4,10 +4,6 @@ distribution_spec:
|
||||||
description: Use Meta Reference for running LLM inference
|
description: Use Meta Reference for running LLM inference
|
||||||
docker_image: null
|
docker_image: null
|
||||||
providers:
|
providers:
|
||||||
post_training:
|
|
||||||
- inline::meta-reference
|
|
||||||
datasetio:
|
|
||||||
- remote::huggingface
|
|
||||||
inference:
|
inference:
|
||||||
- inline::meta-reference
|
- inline::meta-reference
|
||||||
memory:
|
memory:
|
||||||
|
|
|
@ -8,8 +8,6 @@ apis:
|
||||||
- memory
|
- memory
|
||||||
- safety
|
- safety
|
||||||
- telemetry
|
- telemetry
|
||||||
- datasetio
|
|
||||||
- post_training
|
|
||||||
providers:
|
providers:
|
||||||
inference:
|
inference:
|
||||||
- provider_id: meta-reference-inference
|
- provider_id: meta-reference-inference
|
||||||
|
@ -18,10 +16,6 @@ providers:
|
||||||
model: ${env.INFERENCE_MODEL}
|
model: ${env.INFERENCE_MODEL}
|
||||||
max_seq_len: 4096
|
max_seq_len: 4096
|
||||||
checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null}
|
checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null}
|
||||||
datasetio:
|
|
||||||
- provider_id: huggingface-0
|
|
||||||
provider_type: remote::huggingface
|
|
||||||
config: {}
|
|
||||||
memory:
|
memory:
|
||||||
- provider_id: faiss
|
- provider_id: faiss
|
||||||
provider_type: inline::faiss
|
provider_type: inline::faiss
|
||||||
|
@ -46,12 +40,6 @@ providers:
|
||||||
- provider_id: meta-reference
|
- provider_id: meta-reference
|
||||||
provider_type: inline::meta-reference
|
provider_type: inline::meta-reference
|
||||||
config: {}
|
config: {}
|
||||||
post_training:
|
|
||||||
- provider_id: meta-reference-post-training
|
|
||||||
provider_type: inline::meta-reference
|
|
||||||
config:
|
|
||||||
model: ${env.INFERENCE_MODEL}
|
|
||||||
checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null}
|
|
||||||
|
|
||||||
metadata_store:
|
metadata_store:
|
||||||
namespace: null
|
namespace: null
|
||||||
|
@ -64,23 +52,5 @@ models:
|
||||||
provider_model_id: null
|
provider_model_id: null
|
||||||
shields: []
|
shields: []
|
||||||
memory_banks: []
|
memory_banks: []
|
||||||
datasets:
|
|
||||||
- dataset_id: alpaca
|
|
||||||
provider_id: huggingface-0
|
|
||||||
url:
|
|
||||||
uri: https://huggingface.co/datasets/tatsu-lab/alpaca
|
|
||||||
metadata:
|
|
||||||
path: tatsu-lab/alpaca
|
|
||||||
name:
|
|
||||||
split: train
|
|
||||||
dataset_schema:
|
|
||||||
instruction:
|
|
||||||
type: string
|
|
||||||
input:
|
|
||||||
type: string
|
|
||||||
output:
|
|
||||||
type: string
|
|
||||||
text:
|
|
||||||
type: string
|
|
||||||
scoring_fns: []
|
scoring_fns: []
|
||||||
eval_tasks: []
|
eval_tasks: []
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue