mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-03 17:29:01 +00:00
address comments
This commit is contained in:
parent
9c1ae088f9
commit
68ebf8a8da
6 changed files with 97 additions and 29 deletions
|
@ -6,12 +6,12 @@
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional, Protocol, Union
|
from typing import Any, Dict, List, Optional, Protocol, Union
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, webmethod
|
from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
from llama_stack.apis.datasets import * # noqa: F403
|
from llama_stack.apis.datasets import * # noqa: F403
|
||||||
|
@ -79,6 +79,11 @@ class QATFinetuningConfig(BaseModel):
|
||||||
group_size: int
|
group_size: int
|
||||||
|
|
||||||
|
|
||||||
|
AlgorithmConfig = Annotated[
|
||||||
|
Union[LoraFinetuningConfig, LoraFinetuningConfig], Field(discriminator="type")
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class PostTrainingJobLogStream(BaseModel):
|
class PostTrainingJobLogStream(BaseModel):
|
||||||
"""Stream of logs from a finetuning job."""
|
"""Stream of logs from a finetuning job."""
|
||||||
|
@ -173,9 +178,7 @@ class PostTraining(Protocol):
|
||||||
description="Model descriptor from `llama model list`",
|
description="Model descriptor from `llama model list`",
|
||||||
),
|
),
|
||||||
checkpoint_dir: Optional[str] = None,
|
checkpoint_dir: Optional[str] = None,
|
||||||
algorithm_config: Optional[
|
algorithm_config: Optional[AlgorithmConfig] = None,
|
||||||
Union[LoraFinetuningConfig, QATFinetuningConfig]
|
|
||||||
] = None,
|
|
||||||
) -> PostTrainingJob: ...
|
) -> PostTrainingJob: ...
|
||||||
|
|
||||||
@webmethod(route="/post-training/preference-optimize")
|
@webmethod(route="/post-training/preference-optimize")
|
||||||
|
|
|
@ -22,5 +22,6 @@ async def get_provider_impl(
|
||||||
impl = TorchtunePostTrainingImpl(
|
impl = TorchtunePostTrainingImpl(
|
||||||
config,
|
config,
|
||||||
deps[Api.datasetio],
|
deps[Api.datasetio],
|
||||||
|
deps[Api.datasets],
|
||||||
)
|
)
|
||||||
return impl
|
return impl
|
||||||
|
|
|
@ -15,10 +15,14 @@ from llama_stack.providers.inline.post_training.torchtune.recipes.lora_finetunin
|
||||||
|
|
||||||
class TorchtunePostTrainingImpl:
|
class TorchtunePostTrainingImpl:
|
||||||
def __init__(
|
def __init__(
|
||||||
self, config: TorchtunePostTrainingConfig, datasetio_api: DatasetIO
|
self,
|
||||||
|
config: TorchtunePostTrainingConfig,
|
||||||
|
datasetio_api: DatasetIO,
|
||||||
|
datasets: Datasets,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
self.datasetio_api = datasetio_api
|
self.datasetio_api = datasetio_api
|
||||||
|
self.datasets_api = datasets
|
||||||
|
|
||||||
async def supervised_fine_tune(
|
async def supervised_fine_tune(
|
||||||
self,
|
self,
|
||||||
|
@ -40,6 +44,7 @@ class TorchtunePostTrainingImpl:
|
||||||
checkpoint_dir,
|
checkpoint_dir,
|
||||||
algorithm_config,
|
algorithm_config,
|
||||||
self.datasetio_api,
|
self.datasetio_api,
|
||||||
|
self.datasets_api,
|
||||||
)
|
)
|
||||||
await recipe.setup()
|
await recipe.setup()
|
||||||
await recipe.train()
|
await recipe.train()
|
||||||
|
@ -58,7 +63,7 @@ class TorchtunePostTrainingImpl:
|
||||||
logger_config: Dict[str, Any],
|
logger_config: Dict[str, Any],
|
||||||
) -> PostTrainingJob: ...
|
) -> PostTrainingJob: ...
|
||||||
|
|
||||||
# TODO @markchen1015 impelment below APIs
|
# TODO @SLR722 impelment below APIs
|
||||||
async def get_training_jobs(self) -> List[PostTrainingJob]: ...
|
async def get_training_jobs(self) -> List[PostTrainingJob]: ...
|
||||||
|
|
||||||
# sends SSE stream of logs
|
# sends SSE stream of logs
|
||||||
|
|
|
@ -69,6 +69,7 @@ class LoraFinetuningSingleDevice:
|
||||||
checkpoint_dir: Optional[str],
|
checkpoint_dir: Optional[str],
|
||||||
algorithm_config: Optional[Union[LoraFinetuningConfig, QATFinetuningConfig]],
|
algorithm_config: Optional[Union[LoraFinetuningConfig, QATFinetuningConfig]],
|
||||||
datasetio_api: DatasetIO,
|
datasetio_api: DatasetIO,
|
||||||
|
datasets_api: Datasets,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.training_config = training_config
|
self.training_config = training_config
|
||||||
self.algorithm_config = algorithm_config
|
self.algorithm_config = algorithm_config
|
||||||
|
@ -98,7 +99,7 @@ class LoraFinetuningSingleDevice:
|
||||||
model = resolve_model(self.model_id)
|
model = resolve_model(self.model_id)
|
||||||
self.checkpoint_dir = model_checkpoint_dir(model)
|
self.checkpoint_dir = model_checkpoint_dir(model)
|
||||||
|
|
||||||
# TODO @markchen1015 make it work with get_training_job_artifacts
|
# TODO @SLR722 make it work with get_training_job_artifacts
|
||||||
self._output_dir = self.checkpoint_dir + "/posting_training/"
|
self._output_dir = self.checkpoint_dir + "/posting_training/"
|
||||||
|
|
||||||
self.seed = training.set_seed(seed=config.torch_seed)
|
self.seed = training.set_seed(seed=config.torch_seed)
|
||||||
|
@ -126,6 +127,7 @@ class LoraFinetuningSingleDevice:
|
||||||
)
|
)
|
||||||
|
|
||||||
self.datasetio_api = datasetio_api
|
self.datasetio_api = datasetio_api
|
||||||
|
self.datasets_api = datasets_api
|
||||||
|
|
||||||
async def load_checkpoint(self):
|
async def load_checkpoint(self):
|
||||||
def get_checkpoint_files(checkpoint_dir: str) -> List[str]:
|
def get_checkpoint_files(checkpoint_dir: str) -> List[str]:
|
||||||
|
@ -142,7 +144,7 @@ class LoraFinetuningSingleDevice:
|
||||||
checkpoint_dir=self.checkpoint_dir,
|
checkpoint_dir=self.checkpoint_dir,
|
||||||
checkpoint_files=get_checkpoint_files(self.checkpoint_dir),
|
checkpoint_files=get_checkpoint_files(self.checkpoint_dir),
|
||||||
output_dir=self._output_dir,
|
output_dir=self._output_dir,
|
||||||
model_type=utils.get_checkpointer_model_type(self.model_id),
|
model_type=await utils.get_checkpointer_model_type(self.model_id),
|
||||||
)
|
)
|
||||||
checkpoint_dict = self._checkpointer.load_checkpoint()
|
checkpoint_dict = self._checkpointer.load_checkpoint()
|
||||||
return checkpoint_dict
|
return checkpoint_dict
|
||||||
|
@ -222,7 +224,7 @@ class LoraFinetuningSingleDevice:
|
||||||
self._use_dora = self.algorithm_config.use_dora or False
|
self._use_dora = self.algorithm_config.use_dora or False
|
||||||
|
|
||||||
with training.set_default_dtype(self._dtype), self._device:
|
with training.set_default_dtype(self._dtype), self._device:
|
||||||
model_type = utils.get_model_type(self.model_id)
|
model_type = await utils.get_model_definition(self.model_id)
|
||||||
model = model_type(
|
model = model_type(
|
||||||
lora_attn_modules=self._lora_attn_modules,
|
lora_attn_modules=self._lora_attn_modules,
|
||||||
apply_lora_to_mlp=self._apply_lora_to_mlp,
|
apply_lora_to_mlp=self._apply_lora_to_mlp,
|
||||||
|
@ -289,7 +291,7 @@ class LoraFinetuningSingleDevice:
|
||||||
self,
|
self,
|
||||||
) -> Llama3Tokenizer:
|
) -> Llama3Tokenizer:
|
||||||
tokenizer_path = self.checkpoint_dir + "/tokenizer.model"
|
tokenizer_path = self.checkpoint_dir + "/tokenizer.model"
|
||||||
tokenizer_type = utils.get_tokenizer_type(self.model_id)
|
tokenizer_type = await utils.get_tokenizer_type(self.model_id)
|
||||||
return tokenizer_type(path=tokenizer_path)
|
return tokenizer_type(path=tokenizer_path)
|
||||||
|
|
||||||
async def _setup_optimizer(self, optimizer_config: OptimizerConfig) -> Optimizer:
|
async def _setup_optimizer(self, optimizer_config: OptimizerConfig) -> Optimizer:
|
||||||
|
@ -305,9 +307,11 @@ class LoraFinetuningSingleDevice:
|
||||||
async def _setup_data(
|
async def _setup_data(
|
||||||
self, tokenizer: Llama3Tokenizer, shuffle: bool, batch_size: int
|
self, tokenizer: Llama3Tokenizer, shuffle: bool, batch_size: int
|
||||||
) -> Tuple[DistributedSampler, DataLoader]:
|
) -> Tuple[DistributedSampler, DataLoader]:
|
||||||
|
dataset_id = self.training_config.data_config.dataset_id
|
||||||
|
|
||||||
async def fetch_rows():
|
async def fetch_rows():
|
||||||
return await self.datasetio_api.get_rows_paginated(
|
return await self.datasetio_api.get_rows_paginated(
|
||||||
dataset_id=self.training_config.data_config.dataset_id,
|
dataset_id=dataset_id,
|
||||||
rows_in_page=-1,
|
rows_in_page=-1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -315,7 +319,13 @@ class LoraFinetuningSingleDevice:
|
||||||
rows = all_rows.rows
|
rows = all_rows.rows
|
||||||
|
|
||||||
# Curretly only support alpaca instruct dataset
|
# Curretly only support alpaca instruct dataset
|
||||||
# TODO @markchen1015 make the message_transform swappable and support more dataset types
|
# TODO @SLR722 make the message_transform swappable and support more dataset types
|
||||||
|
# TODO @SLR722 make the input dataset schema more flexible by exposing column_map
|
||||||
|
await utils.validate_input_dataset_schema(
|
||||||
|
datasets_api=self.datasets_api,
|
||||||
|
dataset_id=dataset_id,
|
||||||
|
dataset_type="alpaca",
|
||||||
|
)
|
||||||
ds = SFTDataset(
|
ds = SFTDataset(
|
||||||
rows,
|
rows,
|
||||||
message_transform=AlpacaToMessages(train_on_input=False),
|
message_transform=AlpacaToMessages(train_on_input=False),
|
||||||
|
|
|
@ -10,49 +10,97 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Any, Callable, Dict
|
from enum import Enum
|
||||||
|
from typing import Any, Callable, Dict, List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from llama_stack.apis.datasets import Datasets
|
||||||
|
from llama_stack.apis.common.type_system import * # noqa
|
||||||
from llama_models.sku_list import resolve_model
|
from llama_models.sku_list import resolve_model
|
||||||
|
from llama_stack.apis.common.type_system import ParamType
|
||||||
|
|
||||||
from torchtune.models.llama3 import llama3_tokenizer, lora_llama3_8b
|
from torchtune.models.llama3 import llama3_tokenizer, lora_llama3_8b
|
||||||
from torchtune.models.llama3._tokenizer import Llama3Tokenizer
|
from torchtune.models.llama3._tokenizer import Llama3Tokenizer
|
||||||
from torchtune.models.llama3_2 import lora_llama3_2_3b
|
from torchtune.models.llama3_2 import lora_llama3_2_3b
|
||||||
|
|
||||||
LORA_MODEL_TYPES: Dict[str, Any] = {
|
|
||||||
"Llama3.2-3B-Instruct": lora_llama3_2_3b,
|
class ColumnName(Enum):
|
||||||
"Llama-3-8B-Instruct": lora_llama3_8b,
|
instruction = "instruction"
|
||||||
|
input = "input"
|
||||||
|
output = "output"
|
||||||
|
text = "text"
|
||||||
|
|
||||||
|
|
||||||
|
MODEL_CONFIGS: Dict[str, Dict[str, Any]] = {
|
||||||
|
"Llama3.2-3B-Instruct": {
|
||||||
|
"model_definition": lora_llama3_2_3b,
|
||||||
|
"tokenizer_type": llama3_tokenizer,
|
||||||
|
"checkpoint_type": "LLAMA3_2",
|
||||||
|
},
|
||||||
|
"Llama-3-8B-Instruct": {
|
||||||
|
"model_definition": lora_llama3_8b,
|
||||||
|
"tokenizer_type": llama3_tokenizer,
|
||||||
|
"checkpoint_type": "LLAMA3",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
TOKENIZER_TYPES: Dict[str, Any] = {
|
EXPECTED_DATASET_SCHEMA: Dict[str, List[Dict[str, ParamType]]] = {
|
||||||
"Llama3.2-3B-Instruct": llama3_tokenizer,
|
"alpaca": [
|
||||||
"Llama-3-8B-Instruct": llama3_tokenizer,
|
{
|
||||||
}
|
ColumnName.instruction.value: StringType(),
|
||||||
|
ColumnName.input.value: StringType(),
|
||||||
CHECKPOINT_MODEL_TYPES: Dict[str, str] = {
|
ColumnName.output.value: StringType(),
|
||||||
"Llama3.2-3B-Instruct": "LLAMA3_2",
|
ColumnName.text.value: StringType(),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ColumnName.instruction.value: StringType(),
|
||||||
|
ColumnName.input.value: StringType(),
|
||||||
|
ColumnName.output.value: StringType(),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ColumnName.instruction.value: StringType(),
|
||||||
|
ColumnName.output.value: StringType(),
|
||||||
|
},
|
||||||
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
BuildLoraModelCallable = Callable[..., torch.nn.Module]
|
BuildLoraModelCallable = Callable[..., torch.nn.Module]
|
||||||
BuildTokenizerCallable = Callable[..., Llama3Tokenizer]
|
BuildTokenizerCallable = Callable[..., Llama3Tokenizer]
|
||||||
|
|
||||||
|
|
||||||
def get_model_type(
|
async def get_model_definition(
|
||||||
model_id: str,
|
model_id: str,
|
||||||
) -> BuildLoraModelCallable:
|
) -> BuildLoraModelCallable:
|
||||||
model = resolve_model(model_id)
|
model = resolve_model(model_id)
|
||||||
return LORA_MODEL_TYPES[model.core_model_id.value]
|
if model is None or model.core_model_id.value not in MODEL_CONFIGS:
|
||||||
|
raise ValueError(f"Model {model_id} is not supported.")
|
||||||
|
return MODEL_CONFIGS[model.core_model_id.value]["model_definition"]
|
||||||
|
|
||||||
|
|
||||||
def get_tokenizer_type(
|
async def get_tokenizer_type(
|
||||||
model_id: str,
|
model_id: str,
|
||||||
) -> BuildTokenizerCallable:
|
) -> BuildTokenizerCallable:
|
||||||
model = resolve_model(model_id)
|
model = resolve_model(model_id)
|
||||||
return TOKENIZER_TYPES[model.core_model_id.value]
|
return MODEL_CONFIGS[model.core_model_id.value]["tokenizer_type"]
|
||||||
|
|
||||||
|
|
||||||
def get_checkpointer_model_type(
|
async def get_checkpointer_model_type(
|
||||||
model_id: str,
|
model_id: str,
|
||||||
) -> str:
|
) -> str:
|
||||||
model = resolve_model(model_id)
|
model = resolve_model(model_id)
|
||||||
return CHECKPOINT_MODEL_TYPES[model.core_model_id.value]
|
return MODEL_CONFIGS[model.core_model_id.value]["checkpoint_type"]
|
||||||
|
|
||||||
|
|
||||||
|
async def validate_input_dataset_schema(
|
||||||
|
datasets_api: Datasets,
|
||||||
|
dataset_id: str,
|
||||||
|
dataset_type: str,
|
||||||
|
) -> None:
|
||||||
|
dataset_def = await datasets_api.get_dataset(dataset_id=dataset_id)
|
||||||
|
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
|
||||||
|
raise ValueError(f"Dataset {dataset_id} does not have a schema defined.")
|
||||||
|
|
||||||
|
if dataset_def.dataset_schema not in EXPECTED_DATASET_SCHEMA[dataset_type]:
|
||||||
|
raise ValueError(
|
||||||
|
f"Dataset {dataset_id} does not have a correct input schema in {EXPECTED_DATASET_SCHEMA[dataset_type]}"
|
||||||
|
)
|
||||||
|
|
|
@ -19,6 +19,7 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
config_class="llama_stack.providers.inline.post_training.torchtune.TorchtunePostTrainingConfig",
|
config_class="llama_stack.providers.inline.post_training.torchtune.TorchtunePostTrainingConfig",
|
||||||
api_dependencies=[
|
api_dependencies=[
|
||||||
Api.datasetio,
|
Api.datasetio,
|
||||||
|
Api.datasets,
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue