mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-02 08:44:44 +00:00
refine
This commit is contained in:
parent
5a628d32c8
commit
9c1ae088f9
4 changed files with 7 additions and 16 deletions
|
@ -182,12 +182,8 @@ class PostTraining(Protocol):
|
|||
async def preference_optimize(
|
||||
self,
|
||||
job_uuid: str,
|
||||
finetuned_model: URL,
|
||||
dataset_id: str,
|
||||
validation_dataset_id: str,
|
||||
algorithm: RLHFAlgorithm,
|
||||
finetuned_model: str,
|
||||
algorithm_config: DPOAlignmentConfig,
|
||||
optimizer_config: OptimizerConfig,
|
||||
training_config: TrainingConfig,
|
||||
hyperparam_search_config: Dict[str, Any],
|
||||
logger_config: Dict[str, Any],
|
||||
|
|
|
@ -10,6 +10,8 @@ from llama_stack.distribution.datatypes import Api, ProviderSpec
|
|||
|
||||
from .config import TorchtunePostTrainingConfig
|
||||
|
||||
# post_training api and the torchtune provider is still experimental and under heavy development
|
||||
|
||||
|
||||
async def get_provider_impl(
|
||||
config: TorchtunePostTrainingConfig,
|
||||
|
@ -21,5 +23,4 @@ async def get_provider_impl(
|
|||
config,
|
||||
deps[Api.datasetio],
|
||||
)
|
||||
# await impl.initialize()
|
||||
return impl
|
||||
|
|
|
@ -51,12 +51,8 @@ class TorchtunePostTrainingImpl:
|
|||
async def preference_optimize(
|
||||
self,
|
||||
job_uuid: str,
|
||||
finetuned_model: URL,
|
||||
dataset_id: str,
|
||||
validation_dataset_id: str,
|
||||
algorithm: RLHFAlgorithm,
|
||||
finetuned_model: str,
|
||||
algorithm_config: DPOAlignmentConfig,
|
||||
optimizer_config: OptimizerConfig,
|
||||
training_config: TrainingConfig,
|
||||
hyperparam_search_config: Dict[str, Any],
|
||||
logger_config: Dict[str, Any],
|
||||
|
|
|
@ -57,8 +57,8 @@ class LoraFinetuningSingleDevice:
|
|||
# Resume from checkpoint hasn't been supported yet
|
||||
# Validation hasn't been supported yet
|
||||
|
||||
# TODO @markchen1015 figure out the logging for this training recipe
|
||||
# and make it work with telemetry
|
||||
# Currently logging only logs limited training metrics to local disk
|
||||
# will figure out more loggings and how it works with telemetry in future PRs
|
||||
def __init__(
|
||||
self,
|
||||
config: TorchtunePostTrainingConfig,
|
||||
|
@ -70,7 +70,6 @@ class LoraFinetuningSingleDevice:
|
|||
algorithm_config: Optional[Union[LoraFinetuningConfig, QATFinetuningConfig]],
|
||||
datasetio_api: DatasetIO,
|
||||
) -> None:
|
||||
# Assume the training only happens on GPU
|
||||
self.training_config = training_config
|
||||
self.algorithm_config = algorithm_config
|
||||
self._device = torchtune_utils.get_device(device="cuda")
|
||||
|
@ -149,7 +148,6 @@ class LoraFinetuningSingleDevice:
|
|||
return checkpoint_dict
|
||||
|
||||
async def setup(self) -> None:
|
||||
# temporily log to local disk, will figure out how to interop with telemetry
|
||||
self._metric_logger = DiskLogger(log_dir=self._output_dir)
|
||||
|
||||
checkpoint_dict = await self.load_checkpoint()
|
||||
|
@ -316,7 +314,7 @@ class LoraFinetuningSingleDevice:
|
|||
all_rows = await fetch_rows()
|
||||
rows = all_rows.rows
|
||||
|
||||
# Curretly only support instruct dataset
|
||||
# Curretly only support alpaca instruct dataset
|
||||
# TODO @markchen1015 make the message_transform swappable and support more dataset types
|
||||
ds = SFTDataset(
|
||||
rows,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue