mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-17 18:52:38 +00:00
temp commit
This commit is contained in:
parent
41cf2bb0a7
commit
2a15a8a005
12 changed files with 142 additions and 77 deletions
|
|
@ -8,16 +8,16 @@ from typing import Dict
|
|||
|
||||
from llama_stack.distribution.datatypes import Api, ProviderSpec
|
||||
|
||||
from .config import MetaReferencePostTrainingConfig
|
||||
from .config import TorchtunePostTrainingConfig
|
||||
|
||||
|
||||
async def get_provider_impl(
|
||||
config: MetaReferencePostTrainingConfig,
|
||||
config: TorchtunePostTrainingConfig,
|
||||
deps: Dict[Api, ProviderSpec],
|
||||
):
|
||||
from .post_training import MetaReferencePostTrainingImpl
|
||||
from .post_training import TorchtunePostTrainingImpl
|
||||
|
||||
impl = MetaReferencePostTrainingImpl(
|
||||
impl = TorchtunePostTrainingImpl(
|
||||
config,
|
||||
deps[Api.datasetio],
|
||||
)
|
||||
|
|
@ -9,7 +9,7 @@ from typing import Optional
|
|||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class MetaReferencePostTrainingConfig(BaseModel):
|
||||
class TorchtunePostTrainingConfig(BaseModel):
|
||||
model: str = Field(
|
||||
default="Llama3.2-3B-Instruct",
|
||||
description="Model descriptor from `llama model list`",
|
||||
|
|
@ -4,18 +4,30 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.providers.inline.post_training.meta_reference.config import (
|
||||
MetaReferencePostTrainingConfig,
|
||||
from llama_stack.providers.inline.post_training.torchtune.config import (
|
||||
TorchtunePostTrainingConfig,
|
||||
)
|
||||
from llama_stack.apis.post_training import * # noqa
|
||||
from llama_stack.providers.inline.post_training.meta_reference.recipes.lora_finetuning_single_device import (
|
||||
from llama_stack.providers.inline.post_training.torchtune.recipes.lora_finetuning_single_device import (
|
||||
LoraFinetuningSingleDevice,
|
||||
)
|
||||
|
||||
|
||||
class MetaReferencePostTrainingImpl:
|
||||
class PostTrainingSFTRequest(BaseModel):
|
||||
job_uuid: str
|
||||
model: str
|
||||
algorithm: FinetuningAlgorithm
|
||||
algorithm_config: Optional[Union[LoraFinetuningConfig, QATFinetuningConfig]] = None
|
||||
training_config: TrainingConfig
|
||||
|
||||
# TODO: define these
|
||||
hyperparam_search_config: Dict[str, Any]
|
||||
logger_config: Dict[str, Any]
|
||||
|
||||
|
||||
class TorchtunePostTrainingImpl:
|
||||
def __init__(
|
||||
self, config: MetaReferencePostTrainingConfig, datasetio_api: DatasetIO
|
||||
self, config: TorchtunePostTrainingConfig, datasetio_api: DatasetIO
|
||||
) -> None:
|
||||
self.config = config
|
||||
self.datasetio_api = datasetio_api
|
||||
|
|
@ -18,15 +18,15 @@ from torch import nn
|
|||
from torchtune import utils as torchtune_utils
|
||||
from torchtune.training.metric_logging import DiskLogger
|
||||
from llama_stack.apis.post_training import * # noqa
|
||||
from llama_stack.apis.post_training import PostTrainingSFTRequest
|
||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||
|
||||
from llama_stack.providers.inline.post_training.meta_reference import utils
|
||||
from llama_stack.providers.inline.post_training.meta_reference.config import (
|
||||
from llama_stack.providers.inline.post_training.torchtune import utils
|
||||
from llama_stack.providers.inline.post_training.torchtune.config import (
|
||||
MetaReferencePostTrainingConfig,
|
||||
)
|
||||
from llama_stack.providers.inline.post_training.meta_reference.datasets.sft import (
|
||||
SFTDataset,
|
||||
from llama_stack.providers.inline.post_training.torchtune.datasets.sft import SFTDataset
|
||||
from llama_stack.providers.inline.post_training.torchtune.post_training import (
|
||||
PostTrainingSFTRequest,
|
||||
)
|
||||
from torch.optim import Optimizer
|
||||
from torch.utils.data import DataLoader, DistributedSampler
|
||||
|
|
@ -14,6 +14,7 @@ from typing import Any, Callable, Dict
|
|||
|
||||
import torch
|
||||
from llama_models.sku_list import resolve_model
|
||||
|
||||
from torchtune.models.llama3 import llama3_tokenizer, lora_llama3_8b
|
||||
from torchtune.models.llama3._tokenizer import Llama3Tokenizer
|
||||
from torchtune.models.llama3_2 import lora_llama3_2_3b
|
||||
|
|
@ -9,22 +9,14 @@ from typing import List
|
|||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
|
||||
|
||||
META_REFERENCE_DEPS = [
|
||||
"torch",
|
||||
"torchtune",
|
||||
"torchao",
|
||||
"numpy",
|
||||
]
|
||||
|
||||
|
||||
def available_providers() -> List[ProviderSpec]:
|
||||
return [
|
||||
InlineProviderSpec(
|
||||
api=Api.post_training,
|
||||
provider_type="inline::meta-reference",
|
||||
pip_packages=META_REFERENCE_DEPS,
|
||||
module="llama_stack.providers.inline.post_training.meta_reference",
|
||||
config_class="llama_stack.providers.inline.post_training.meta_reference.MetaReferencePostTrainingConfig",
|
||||
provider_type="inline::torchtune",
|
||||
pip_packages=["torch", "torchtune", "torchao", "numpy"],
|
||||
module="llama_stack.providers.inline.post_training.torchtune",
|
||||
config_class="llama_stack.providers.inline.post_training.torchtune.torchtunePostTrainingConfig",
|
||||
api_dependencies=[
|
||||
Api.datasetio,
|
||||
],
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue