mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-02 08:44:44 +00:00
refine
This commit is contained in:
parent
5a628d32c8
commit
9c1ae088f9
4 changed files with 7 additions and 16 deletions
|
@ -182,12 +182,8 @@ class PostTraining(Protocol):
|
||||||
async def preference_optimize(
|
async def preference_optimize(
|
||||||
self,
|
self,
|
||||||
job_uuid: str,
|
job_uuid: str,
|
||||||
finetuned_model: URL,
|
finetuned_model: str,
|
||||||
dataset_id: str,
|
|
||||||
validation_dataset_id: str,
|
|
||||||
algorithm: RLHFAlgorithm,
|
|
||||||
algorithm_config: DPOAlignmentConfig,
|
algorithm_config: DPOAlignmentConfig,
|
||||||
optimizer_config: OptimizerConfig,
|
|
||||||
training_config: TrainingConfig,
|
training_config: TrainingConfig,
|
||||||
hyperparam_search_config: Dict[str, Any],
|
hyperparam_search_config: Dict[str, Any],
|
||||||
logger_config: Dict[str, Any],
|
logger_config: Dict[str, Any],
|
||||||
|
|
|
@ -10,6 +10,8 @@ from llama_stack.distribution.datatypes import Api, ProviderSpec
|
||||||
|
|
||||||
from .config import TorchtunePostTrainingConfig
|
from .config import TorchtunePostTrainingConfig
|
||||||
|
|
||||||
|
# post_training api and the torchtune provider is still experimental and under heavy development
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(
|
async def get_provider_impl(
|
||||||
config: TorchtunePostTrainingConfig,
|
config: TorchtunePostTrainingConfig,
|
||||||
|
@ -21,5 +23,4 @@ async def get_provider_impl(
|
||||||
config,
|
config,
|
||||||
deps[Api.datasetio],
|
deps[Api.datasetio],
|
||||||
)
|
)
|
||||||
# await impl.initialize()
|
|
||||||
return impl
|
return impl
|
||||||
|
|
|
@ -51,12 +51,8 @@ class TorchtunePostTrainingImpl:
|
||||||
async def preference_optimize(
|
async def preference_optimize(
|
||||||
self,
|
self,
|
||||||
job_uuid: str,
|
job_uuid: str,
|
||||||
finetuned_model: URL,
|
finetuned_model: str,
|
||||||
dataset_id: str,
|
|
||||||
validation_dataset_id: str,
|
|
||||||
algorithm: RLHFAlgorithm,
|
|
||||||
algorithm_config: DPOAlignmentConfig,
|
algorithm_config: DPOAlignmentConfig,
|
||||||
optimizer_config: OptimizerConfig,
|
|
||||||
training_config: TrainingConfig,
|
training_config: TrainingConfig,
|
||||||
hyperparam_search_config: Dict[str, Any],
|
hyperparam_search_config: Dict[str, Any],
|
||||||
logger_config: Dict[str, Any],
|
logger_config: Dict[str, Any],
|
||||||
|
|
|
@ -57,8 +57,8 @@ class LoraFinetuningSingleDevice:
|
||||||
# Resume from checkpoint hasn't been supported yet
|
# Resume from checkpoint hasn't been supported yet
|
||||||
# Validation hasn't been supported yet
|
# Validation hasn't been supported yet
|
||||||
|
|
||||||
# TODO @markchen1015 figure out the logging for this training recipe
|
# Currently logging only logs limited training metrics to local disk
|
||||||
# and make it work with telemetry
|
# will figure out more loggings and how it works with telemetry in future PRs
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: TorchtunePostTrainingConfig,
|
config: TorchtunePostTrainingConfig,
|
||||||
|
@ -70,7 +70,6 @@ class LoraFinetuningSingleDevice:
|
||||||
algorithm_config: Optional[Union[LoraFinetuningConfig, QATFinetuningConfig]],
|
algorithm_config: Optional[Union[LoraFinetuningConfig, QATFinetuningConfig]],
|
||||||
datasetio_api: DatasetIO,
|
datasetio_api: DatasetIO,
|
||||||
) -> None:
|
) -> None:
|
||||||
# Assume the training only happens on GPU
|
|
||||||
self.training_config = training_config
|
self.training_config = training_config
|
||||||
self.algorithm_config = algorithm_config
|
self.algorithm_config = algorithm_config
|
||||||
self._device = torchtune_utils.get_device(device="cuda")
|
self._device = torchtune_utils.get_device(device="cuda")
|
||||||
|
@ -149,7 +148,6 @@ class LoraFinetuningSingleDevice:
|
||||||
return checkpoint_dict
|
return checkpoint_dict
|
||||||
|
|
||||||
async def setup(self) -> None:
|
async def setup(self) -> None:
|
||||||
# temporily log to local disk, will figure out how to interop with telemetry
|
|
||||||
self._metric_logger = DiskLogger(log_dir=self._output_dir)
|
self._metric_logger = DiskLogger(log_dir=self._output_dir)
|
||||||
|
|
||||||
checkpoint_dict = await self.load_checkpoint()
|
checkpoint_dict = await self.load_checkpoint()
|
||||||
|
@ -316,7 +314,7 @@ class LoraFinetuningSingleDevice:
|
||||||
all_rows = await fetch_rows()
|
all_rows = await fetch_rows()
|
||||||
rows = all_rows.rows
|
rows = all_rows.rows
|
||||||
|
|
||||||
# Curretly only support instruct dataset
|
# Curretly only support alpaca instruct dataset
|
||||||
# TODO @markchen1015 make the message_transform swappable and support more dataset types
|
# TODO @markchen1015 make the message_transform swappable and support more dataset types
|
||||||
ds = SFTDataset(
|
ds = SFTDataset(
|
||||||
rows,
|
rows,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue