forked from phoenix-oss/llama-stack-mirror
[2/n][torchtune integration] implement job management and return training artifacts (#593)
### Context In this PR, we - Implement the post training job management and get training artifacts apis - get_training_jobs - get_training_job_status - get_training_job_artifacts - get_training_job_logstream is deleted since the trace can be directly accessed by UI with Jaeger https://llama-stack.readthedocs.io/en/latest/building_applications/telemetry.html#jaeger-to-visualize-traces - Refactor the post training and training types definition to make them more intuitive. - Rewrite the checkpointer to make it compatible with llama-stack file system and can be recognized during inference ### Test Unit test `pytest llama_stack/providers/tests/post_training/test_post_training.py -m "torchtune_post_training_huggingface_datasetio" -v -s --tb=short --disable-warnings` <img width="1506" alt="Screenshot 2024-12-10 at 4 06 17 PM" src="https://github.com/user-attachments/assets/16225029-bdb7-48c4-9d13-e580cc769c0a"> e2e test with client side call <img width="888" alt="Screenshot 2024-12-10 at 4 09 44 PM" src="https://github.com/user-attachments/assets/de375e4c-ef67-4dcc-a045-4037d9489191">
This commit is contained in:
parent
5764a95912
commit
c294a01c4b
8 changed files with 331 additions and 67 deletions
|
@ -18,3 +18,5 @@ class Job(BaseModel):
|
||||||
class JobStatus(Enum):
|
class JobStatus(Enum):
|
||||||
completed = "completed"
|
completed = "completed"
|
||||||
in_progress = "in_progress"
|
in_progress = "in_progress"
|
||||||
|
failed = "failed"
|
||||||
|
scheduled = "scheduled"
|
||||||
|
|
|
@ -4,13 +4,26 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import URL
|
from datetime import datetime
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type
|
from llama_models.schema_utils import json_schema_type
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class PostTrainingMetric(BaseModel):
|
||||||
|
epoch: int
|
||||||
|
train_loss: float
|
||||||
|
validation_loss: float
|
||||||
|
perplexity: float
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type(schema={"description": "Checkpoint created during training runs"})
|
@json_schema_type(schema={"description": "Checkpoint created during training runs"})
|
||||||
class Checkpoint(BaseModel):
|
class Checkpoint(BaseModel):
|
||||||
iters: int
|
identifier: str
|
||||||
path: URL
|
created_at: datetime
|
||||||
epoch: int
|
epoch: int
|
||||||
|
post_training_job_id: str
|
||||||
|
path: str
|
||||||
|
training_metrics: Optional[PostTrainingMetric] = None
|
||||||
|
|
|
@ -6,6 +6,7 @@
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional, Protocol, Union
|
from typing import Any, Dict, List, Optional, Protocol, Union
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, webmethod
|
from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
|
@ -14,6 +15,7 @@ from pydantic import BaseModel, Field
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
|
from llama_stack.apis.common.job_types import JobStatus
|
||||||
from llama_stack.apis.datasets import * # noqa: F403
|
from llama_stack.apis.datasets import * # noqa: F403
|
||||||
from llama_stack.apis.common.training_types import * # noqa: F403
|
from llama_stack.apis.common.training_types import * # noqa: F403
|
||||||
|
|
||||||
|
@ -64,6 +66,7 @@ class TrainingConfig(BaseModel):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class LoraFinetuningConfig(BaseModel):
|
class LoraFinetuningConfig(BaseModel):
|
||||||
|
type: Literal["LoRA"] = "LoRA"
|
||||||
lora_attn_modules: List[str]
|
lora_attn_modules: List[str]
|
||||||
apply_lora_to_mlp: bool
|
apply_lora_to_mlp: bool
|
||||||
apply_lora_to_output: bool
|
apply_lora_to_output: bool
|
||||||
|
@ -75,12 +78,13 @@ class LoraFinetuningConfig(BaseModel):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class QATFinetuningConfig(BaseModel):
|
class QATFinetuningConfig(BaseModel):
|
||||||
|
type: Literal["QAT"] = "QAT"
|
||||||
quantizer_name: str
|
quantizer_name: str
|
||||||
group_size: int
|
group_size: int
|
||||||
|
|
||||||
|
|
||||||
AlgorithmConfig = Annotated[
|
AlgorithmConfig = Annotated[
|
||||||
Union[LoraFinetuningConfig, LoraFinetuningConfig], Field(discriminator="type")
|
Union[LoraFinetuningConfig, QATFinetuningConfig], Field(discriminator="type")
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -92,14 +96,6 @@ class PostTrainingJobLogStream(BaseModel):
|
||||||
log_lines: List[str]
|
log_lines: List[str]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class PostTrainingJobStatus(Enum):
|
|
||||||
running = "running"
|
|
||||||
completed = "completed"
|
|
||||||
failed = "failed"
|
|
||||||
scheduled = "scheduled"
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class RLHFAlgorithm(Enum):
|
class RLHFAlgorithm(Enum):
|
||||||
dpo = "dpo"
|
dpo = "dpo"
|
||||||
|
@ -144,7 +140,7 @@ class PostTrainingJobStatusResponse(BaseModel):
|
||||||
"""Status of a finetuning job."""
|
"""Status of a finetuning job."""
|
||||||
|
|
||||||
job_uuid: str
|
job_uuid: str
|
||||||
status: PostTrainingJobStatus
|
status: JobStatus
|
||||||
|
|
||||||
scheduled_at: Optional[datetime] = None
|
scheduled_at: Optional[datetime] = None
|
||||||
started_at: Optional[datetime] = None
|
started_at: Optional[datetime] = None
|
||||||
|
@ -166,7 +162,7 @@ class PostTrainingJobArtifactsResponse(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class PostTraining(Protocol):
|
class PostTraining(Protocol):
|
||||||
@webmethod(route="/post-training/supervised-fine-tune")
|
@webmethod(route="/post-training/supervised-fine-tune", method="POST")
|
||||||
async def supervised_fine_tune(
|
async def supervised_fine_tune(
|
||||||
self,
|
self,
|
||||||
job_uuid: str,
|
job_uuid: str,
|
||||||
|
@ -181,7 +177,7 @@ class PostTraining(Protocol):
|
||||||
algorithm_config: Optional[AlgorithmConfig] = None,
|
algorithm_config: Optional[AlgorithmConfig] = None,
|
||||||
) -> PostTrainingJob: ...
|
) -> PostTrainingJob: ...
|
||||||
|
|
||||||
@webmethod(route="/post-training/preference-optimize")
|
@webmethod(route="/post-training/preference-optimize", method="POST")
|
||||||
async def preference_optimize(
|
async def preference_optimize(
|
||||||
self,
|
self,
|
||||||
job_uuid: str,
|
job_uuid: str,
|
||||||
|
@ -192,24 +188,18 @@ class PostTraining(Protocol):
|
||||||
logger_config: Dict[str, Any],
|
logger_config: Dict[str, Any],
|
||||||
) -> PostTrainingJob: ...
|
) -> PostTrainingJob: ...
|
||||||
|
|
||||||
@webmethod(route="/post-training/jobs")
|
@webmethod(route="/post-training/jobs", method="GET")
|
||||||
async def get_training_jobs(self) -> List[PostTrainingJob]: ...
|
async def get_training_jobs(self) -> List[PostTrainingJob]: ...
|
||||||
|
|
||||||
# sends SSE stream of logs
|
@webmethod(route="/post-training/job/status", method="GET")
|
||||||
@webmethod(route="/post-training/job/logs")
|
|
||||||
async def get_training_job_logstream(
|
|
||||||
self, job_uuid: str
|
|
||||||
) -> PostTrainingJobLogStream: ...
|
|
||||||
|
|
||||||
@webmethod(route="/post-training/job/status")
|
|
||||||
async def get_training_job_status(
|
async def get_training_job_status(
|
||||||
self, job_uuid: str
|
self, job_uuid: str
|
||||||
) -> PostTrainingJobStatusResponse: ...
|
) -> Optional[PostTrainingJobStatusResponse]: ...
|
||||||
|
|
||||||
@webmethod(route="/post-training/job/cancel")
|
@webmethod(route="/post-training/job/cancel", method="POST")
|
||||||
async def cancel_training_job(self, job_uuid: str) -> None: ...
|
async def cancel_training_job(self, job_uuid: str) -> None: ...
|
||||||
|
|
||||||
@webmethod(route="/post-training/job/artifacts")
|
@webmethod(route="/post-training/job/artifacts", method="GET")
|
||||||
async def get_training_job_artifacts(
|
async def get_training_job_artifacts(
|
||||||
self, job_uuid: str
|
self, job_uuid: str
|
||||||
) -> PostTrainingJobArtifactsResponse: ...
|
) -> Optional[PostTrainingJobArtifactsResponse]: ...
|
||||||
|
|
|
@ -0,0 +1,157 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torchtune import training
|
||||||
|
from torchtune.models import convert_weights
|
||||||
|
from torchtune.training.checkpointing._utils import ModelType, safe_torch_load
|
||||||
|
from torchtune.utils._logging import get_logger
|
||||||
|
|
||||||
|
logger = get_logger("DEBUG")
|
||||||
|
|
||||||
|
|
||||||
|
class TorchtuneCheckpointer:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
training_algorithm: str,
|
||||||
|
checkpoint_dir: str,
|
||||||
|
checkpoint_files: List[str],
|
||||||
|
output_dir: str,
|
||||||
|
model_type: str,
|
||||||
|
) -> None:
|
||||||
|
# Fail fast if ``checkpoint_files`` is invalid
|
||||||
|
# TODO: support loading more than one file
|
||||||
|
if len(checkpoint_files) != 1:
|
||||||
|
raise ValueError(
|
||||||
|
"Currently we only support reading from a single torchtune checkpoint file. "
|
||||||
|
f"Got {len(checkpoint_files)} files instead."
|
||||||
|
)
|
||||||
|
self._checkpoint_file = checkpoint_files[0]
|
||||||
|
self._model_id = model_id
|
||||||
|
self._training_algorithm = training_algorithm
|
||||||
|
self._checkpoint_dir = Path(checkpoint_dir)
|
||||||
|
self._model_type = ModelType[model_type]
|
||||||
|
self._output_dir = output_dir
|
||||||
|
# get ckpt paths
|
||||||
|
self._checkpoint_path = Path.joinpath(
|
||||||
|
self._checkpoint_dir, self._checkpoint_file
|
||||||
|
)
|
||||||
|
|
||||||
|
def load_checkpoint(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Load Meta checkpoint from file. Currently only loading from a single file is supported.
|
||||||
|
"""
|
||||||
|
state_dict: Dict[str:Any] = {}
|
||||||
|
model_state_dict = safe_torch_load(self._checkpoint_path)
|
||||||
|
if self._model_type == ModelType.LLAMA3_VISION:
|
||||||
|
from torchtune.models.llama3_2_vision._convert_weights import (
|
||||||
|
llama3_vision_meta_to_tune,
|
||||||
|
)
|
||||||
|
|
||||||
|
state_dict[training.MODEL_KEY] = llama3_vision_meta_to_tune(
|
||||||
|
model_state_dict
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
state_dict[training.MODEL_KEY] = convert_weights.meta_to_tune(
|
||||||
|
model_state_dict
|
||||||
|
)
|
||||||
|
|
||||||
|
# llama3_2 has tied weights, so we need to remove the output.weight key
|
||||||
|
if self._model_type == ModelType.LLAMA3_2:
|
||||||
|
logger.info(
|
||||||
|
"Identified model_type = Llama3_2. Ignoring output.weight in"
|
||||||
|
" checkpoint in favor of the tok_embedding.weight"
|
||||||
|
" tied weights."
|
||||||
|
)
|
||||||
|
state_dict[training.MODEL_KEY].pop("output.weight")
|
||||||
|
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
def save_checkpoint(
|
||||||
|
self,
|
||||||
|
state_dict: Dict[str, Any],
|
||||||
|
epoch: int,
|
||||||
|
adapter_only: bool = False,
|
||||||
|
) -> str:
|
||||||
|
model_file_path = (
|
||||||
|
Path(self._output_dir)
|
||||||
|
/ f"{self._model_id}-{self._training_algorithm}-{epoch}"
|
||||||
|
)
|
||||||
|
|
||||||
|
model_file_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# copy the related files for inference
|
||||||
|
shutil.copy(
|
||||||
|
Path.joinpath(self._checkpoint_dir, "params.json"),
|
||||||
|
Path.joinpath(model_file_path, "params.json"),
|
||||||
|
)
|
||||||
|
shutil.copy(
|
||||||
|
Path.joinpath(self._checkpoint_dir, "tokenizer.model"),
|
||||||
|
Path.joinpath(model_file_path, "tokenizer.model"),
|
||||||
|
)
|
||||||
|
shutil.copy(
|
||||||
|
Path.joinpath(self._checkpoint_dir, "orig_params.json"),
|
||||||
|
Path.joinpath(model_file_path, "orig_params.json"),
|
||||||
|
)
|
||||||
|
|
||||||
|
if not adapter_only:
|
||||||
|
model_state_dict = state_dict[training.MODEL_KEY]
|
||||||
|
if self._model_type == ModelType.LLAMA3_VISION:
|
||||||
|
from torchtune.models.llama3_2_vision._convert_weights import (
|
||||||
|
llama3_vision_tune_to_meta,
|
||||||
|
)
|
||||||
|
|
||||||
|
state_dict[training.MODEL_KEY] = llama3_vision_tune_to_meta(
|
||||||
|
model_state_dict
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# llama3_2 has tied weights, so we need to add the output.weight key
|
||||||
|
if (
|
||||||
|
self._model_type == ModelType.LLAMA3_2
|
||||||
|
and "output.weight" not in model_state_dict
|
||||||
|
):
|
||||||
|
model_state_dict["output.weight"] = model_state_dict[
|
||||||
|
"tok_embeddings.weight"
|
||||||
|
]
|
||||||
|
|
||||||
|
state_dict[training.MODEL_KEY] = convert_weights.tune_to_meta(
|
||||||
|
model_state_dict
|
||||||
|
)
|
||||||
|
|
||||||
|
model_file_name = Path.joinpath(model_file_path, "consolidated.00.pth")
|
||||||
|
|
||||||
|
torch.save(state_dict[training.MODEL_KEY], model_file_name)
|
||||||
|
logger.info(
|
||||||
|
"Model checkpoint of size "
|
||||||
|
f"{os.path.getsize(model_file_name) / 1000**3:.2f} GB "
|
||||||
|
f"saved to {model_file_name}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if training.ADAPTER_KEY in state_dict:
|
||||||
|
adapter_file_path = model_file_path / "adapter"
|
||||||
|
adapter_file_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
adapter_file_name = Path.joinpath(adapter_file_path, "adapter.pth")
|
||||||
|
torch.save(state_dict[training.ADAPTER_KEY], adapter_file_name)
|
||||||
|
logger.info(
|
||||||
|
"Adapter checkpoint of size "
|
||||||
|
f"{os.path.getsize(adapter_file_name) / 1000**3:.2f} GB "
|
||||||
|
f"saved to {adapter_file_name}"
|
||||||
|
)
|
||||||
|
|
||||||
|
elif adapter_only:
|
||||||
|
raise ValueError(
|
||||||
|
"Adapter checkpoint not found in state_dict. Please ensure that the state_dict contains adapter weights."
|
||||||
|
)
|
||||||
|
|
||||||
|
print("model_file_path", str(model_file_path))
|
||||||
|
|
||||||
|
return str(model_file_path)
|
|
@ -24,6 +24,11 @@ class TorchtunePostTrainingImpl:
|
||||||
self.datasetio_api = datasetio_api
|
self.datasetio_api = datasetio_api
|
||||||
self.datasets_api = datasets
|
self.datasets_api = datasets
|
||||||
|
|
||||||
|
# TODO: assume sync job, will need jobs API for async scheduling
|
||||||
|
self.jobs_status = {}
|
||||||
|
self.jobs_list = []
|
||||||
|
self.checkpoints_dict = {}
|
||||||
|
|
||||||
async def supervised_fine_tune(
|
async def supervised_fine_tune(
|
||||||
self,
|
self,
|
||||||
job_uuid: str,
|
job_uuid: str,
|
||||||
|
@ -32,11 +37,26 @@ class TorchtunePostTrainingImpl:
|
||||||
logger_config: Dict[str, Any],
|
logger_config: Dict[str, Any],
|
||||||
model: str,
|
model: str,
|
||||||
checkpoint_dir: Optional[str],
|
checkpoint_dir: Optional[str],
|
||||||
algorithm_config: Optional[Union[LoraFinetuningConfig, QATFinetuningConfig]],
|
algorithm_config: Optional[AlgorithmConfig],
|
||||||
) -> PostTrainingJob:
|
) -> PostTrainingJob:
|
||||||
|
for job in self.jobs_list:
|
||||||
|
if job_uuid == job.job_uuid:
|
||||||
|
raise ValueError(f"Job {job_uuid} already exists")
|
||||||
|
|
||||||
|
post_training_job = PostTrainingJob(job_uuid=job_uuid)
|
||||||
|
|
||||||
|
job_status_response = PostTrainingJobStatusResponse(
|
||||||
|
job_uuid=job_uuid,
|
||||||
|
status=JobStatus.scheduled,
|
||||||
|
scheduled_at=datetime.now(),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.jobs_list.append(post_training_job)
|
||||||
if isinstance(algorithm_config, LoraFinetuningConfig):
|
if isinstance(algorithm_config, LoraFinetuningConfig):
|
||||||
|
try:
|
||||||
recipe = LoraFinetuningSingleDevice(
|
recipe = LoraFinetuningSingleDevice(
|
||||||
self.config,
|
self.config,
|
||||||
|
job_uuid,
|
||||||
training_config,
|
training_config,
|
||||||
hyperparam_search_config,
|
hyperparam_search_config,
|
||||||
logger_config,
|
logger_config,
|
||||||
|
@ -46,12 +66,28 @@ class TorchtunePostTrainingImpl:
|
||||||
self.datasetio_api,
|
self.datasetio_api,
|
||||||
self.datasets_api,
|
self.datasets_api,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
job_status_response.status = JobStatus.in_progress
|
||||||
|
job_status_response.started_at = datetime.now()
|
||||||
|
|
||||||
await recipe.setup()
|
await recipe.setup()
|
||||||
await recipe.train()
|
resources_allocated, checkpoints = await recipe.train()
|
||||||
|
|
||||||
|
self.checkpoints_dict[job_uuid] = checkpoints
|
||||||
|
job_status_response.resources_allocated = resources_allocated
|
||||||
|
job_status_response.checkpoints = checkpoints
|
||||||
|
job_status_response.status = JobStatus.completed
|
||||||
|
job_status_response.completed_at = datetime.now()
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
job_status_response.status = JobStatus.failed
|
||||||
|
raise
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
return PostTrainingJob(job_uuid=job_uuid)
|
self.jobs_status[job_uuid] = job_status_response
|
||||||
|
|
||||||
|
return post_training_job
|
||||||
|
|
||||||
async def preference_optimize(
|
async def preference_optimize(
|
||||||
self,
|
self,
|
||||||
|
@ -63,24 +99,28 @@ class TorchtunePostTrainingImpl:
|
||||||
logger_config: Dict[str, Any],
|
logger_config: Dict[str, Any],
|
||||||
) -> PostTrainingJob: ...
|
) -> PostTrainingJob: ...
|
||||||
|
|
||||||
# TODO @SLR722 impelment below APIs
|
async def get_training_jobs(self) -> List[PostTrainingJob]:
|
||||||
async def get_training_jobs(self) -> List[PostTrainingJob]: ...
|
return self.jobs_list
|
||||||
|
|
||||||
# sends SSE stream of logs
|
|
||||||
@webmethod(route="/post-training/job/logs")
|
|
||||||
async def get_training_job_logstream(
|
|
||||||
self, job_uuid: str
|
|
||||||
) -> PostTrainingJobLogStream: ...
|
|
||||||
|
|
||||||
@webmethod(route="/post-training/job/status")
|
@webmethod(route="/post-training/job/status")
|
||||||
async def get_training_job_status(
|
async def get_training_job_status(
|
||||||
self, job_uuid: str
|
self, job_uuid: str
|
||||||
) -> PostTrainingJobStatusResponse: ...
|
) -> Optional[PostTrainingJobStatusResponse]:
|
||||||
|
if job_uuid in self.jobs_status:
|
||||||
|
return self.jobs_status[job_uuid]
|
||||||
|
return None
|
||||||
|
|
||||||
@webmethod(route="/post-training/job/cancel")
|
@webmethod(route="/post-training/job/cancel")
|
||||||
async def cancel_training_job(self, job_uuid: str) -> None: ...
|
async def cancel_training_job(self, job_uuid: str) -> None:
|
||||||
|
raise NotImplementedError("Job cancel is not implemented yet")
|
||||||
|
|
||||||
@webmethod(route="/post-training/job/artifacts")
|
@webmethod(route="/post-training/job/artifacts")
|
||||||
async def get_training_job_artifacts(
|
async def get_training_job_artifacts(
|
||||||
self, job_uuid: str
|
self, job_uuid: str
|
||||||
) -> PostTrainingJobArtifactsResponse: ...
|
) -> Optional[PostTrainingJobArtifactsResponse]:
|
||||||
|
if job_uuid in self.checkpoints_dict:
|
||||||
|
checkpoints = self.checkpoints_dict.get(job_uuid, [])
|
||||||
|
return PostTrainingJobArtifactsResponse(
|
||||||
|
job_uuid=job_uuid, checkpoints=checkpoints
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
|
@ -13,14 +13,20 @@ from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from llama_models.sku_list import resolve_model
|
from llama_models.sku_list import resolve_model
|
||||||
|
|
||||||
from llama_stack.apis.datasetio import DatasetIO
|
from llama_stack.apis.datasetio import DatasetIO
|
||||||
|
|
||||||
|
from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
|
||||||
|
from llama_stack.providers.inline.post_training.torchtune.common.checkpointer import (
|
||||||
|
TorchtuneCheckpointer,
|
||||||
|
)
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torchtune import utils as torchtune_utils
|
from torchtune import utils as torchtune_utils
|
||||||
from torchtune.training.metric_logging import DiskLogger
|
from torchtune.training.metric_logging import DiskLogger
|
||||||
from llama_stack.apis.post_training import * # noqa
|
from llama_stack.apis.post_training import * # noqa
|
||||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||||
|
|
||||||
from llama_stack.providers.inline.post_training.torchtune import utils
|
from llama_stack.providers.inline.post_training.torchtune.common import utils
|
||||||
from llama_stack.providers.inline.post_training.torchtune.config import (
|
from llama_stack.providers.inline.post_training.torchtune.config import (
|
||||||
TorchtunePostTrainingConfig,
|
TorchtunePostTrainingConfig,
|
||||||
)
|
)
|
||||||
|
@ -62,16 +68,22 @@ class LoraFinetuningSingleDevice:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: TorchtunePostTrainingConfig,
|
config: TorchtunePostTrainingConfig,
|
||||||
|
job_uuid: str,
|
||||||
training_config: TrainingConfig,
|
training_config: TrainingConfig,
|
||||||
hyperparam_search_config: Dict[str, Any],
|
hyperparam_search_config: Dict[str, Any],
|
||||||
logger_config: Dict[str, Any],
|
logger_config: Dict[str, Any],
|
||||||
model: str,
|
model: str,
|
||||||
checkpoint_dir: Optional[str],
|
checkpoint_dir: Optional[str],
|
||||||
algorithm_config: Optional[Union[LoraFinetuningConfig, QATFinetuningConfig]],
|
algorithm_config: Optional[AlgorithmConfig],
|
||||||
datasetio_api: DatasetIO,
|
datasetio_api: DatasetIO,
|
||||||
datasets_api: Datasets,
|
datasets_api: Datasets,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
self.job_uuid = job_uuid
|
||||||
self.training_config = training_config
|
self.training_config = training_config
|
||||||
|
if not isinstance(algorithm_config, LoraFinetuningConfig):
|
||||||
|
raise ValueError(
|
||||||
|
"You need to speicifc LoraFinetuningConfig for LoRA finetuning"
|
||||||
|
)
|
||||||
self.algorithm_config = algorithm_config
|
self.algorithm_config = algorithm_config
|
||||||
self._device = torchtune_utils.get_device(device="cuda")
|
self._device = torchtune_utils.get_device(device="cuda")
|
||||||
self._dtype = training.get_dtype(training_config.dtype, device=self._device)
|
self._dtype = training.get_dtype(training_config.dtype, device=self._device)
|
||||||
|
@ -99,8 +111,7 @@ class LoraFinetuningSingleDevice:
|
||||||
model = resolve_model(self.model_id)
|
model = resolve_model(self.model_id)
|
||||||
self.checkpoint_dir = model_checkpoint_dir(model)
|
self.checkpoint_dir = model_checkpoint_dir(model)
|
||||||
|
|
||||||
# TODO @SLR722 make it work with get_training_job_artifacts
|
self._output_dir = str(DEFAULT_CHECKPOINT_DIR)
|
||||||
self._output_dir = self.checkpoint_dir + "/posting_training/"
|
|
||||||
|
|
||||||
self.seed = training.set_seed(seed=config.torch_seed)
|
self.seed = training.set_seed(seed=config.torch_seed)
|
||||||
self.epochs_run = 0
|
self.epochs_run = 0
|
||||||
|
@ -140,7 +151,9 @@ class LoraFinetuningSingleDevice:
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
return [f"Error: The directory '{checkpoint_dir}' does not exist."]
|
return [f"Error: The directory '{checkpoint_dir}' does not exist."]
|
||||||
|
|
||||||
self._checkpointer = training.FullModelMetaCheckpointer(
|
self._checkpointer = TorchtuneCheckpointer(
|
||||||
|
model_id=self.model_id,
|
||||||
|
training_algorithm="sft",
|
||||||
checkpoint_dir=self.checkpoint_dir,
|
checkpoint_dir=self.checkpoint_dir,
|
||||||
checkpoint_files=get_checkpoint_files(self.checkpoint_dir),
|
checkpoint_files=get_checkpoint_files(self.checkpoint_dir),
|
||||||
output_dir=self._output_dir,
|
output_dir=self._output_dir,
|
||||||
|
@ -150,8 +163,6 @@ class LoraFinetuningSingleDevice:
|
||||||
return checkpoint_dict
|
return checkpoint_dict
|
||||||
|
|
||||||
async def setup(self) -> None:
|
async def setup(self) -> None:
|
||||||
self._metric_logger = DiskLogger(log_dir=self._output_dir)
|
|
||||||
|
|
||||||
checkpoint_dict = await self.load_checkpoint()
|
checkpoint_dict = await self.load_checkpoint()
|
||||||
|
|
||||||
self._model = await self._setup_model(
|
self._model = await self._setup_model(
|
||||||
|
@ -370,7 +381,7 @@ class LoraFinetuningSingleDevice:
|
||||||
)
|
)
|
||||||
return lr_scheduler
|
return lr_scheduler
|
||||||
|
|
||||||
async def save_checkpoint(self, epoch: int) -> None:
|
async def save_checkpoint(self, epoch: int) -> str:
|
||||||
ckpt_dict = {}
|
ckpt_dict = {}
|
||||||
|
|
||||||
adapter_state_dict = get_adapter_state_dict(self._model.state_dict())
|
adapter_state_dict = get_adapter_state_dict(self._model.state_dict())
|
||||||
|
@ -400,7 +411,7 @@ class LoraFinetuningSingleDevice:
|
||||||
}
|
}
|
||||||
ckpt_dict.update({training.ADAPTER_CONFIG: adapter_config})
|
ckpt_dict.update({training.ADAPTER_CONFIG: adapter_config})
|
||||||
|
|
||||||
self._checkpointer.save_checkpoint(
|
return self._checkpointer.save_checkpoint(
|
||||||
ckpt_dict,
|
ckpt_dict,
|
||||||
epoch=epoch,
|
epoch=epoch,
|
||||||
)
|
)
|
||||||
|
@ -429,20 +440,26 @@ class LoraFinetuningSingleDevice:
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
async def train(self) -> None:
|
async def train(self) -> Tuple[Dict[str, Any], List[Checkpoint]]:
|
||||||
"""
|
"""
|
||||||
The core training loop.
|
The core training loop.
|
||||||
"""
|
"""
|
||||||
# Initialize tokens count and running loss (for grad accumulation)
|
# Initialize tokens count and running loss (for grad accumulation)
|
||||||
# t0 = time.perf_counter()
|
|
||||||
t0 = time.perf_counter()
|
t0 = time.perf_counter()
|
||||||
running_loss = 0
|
running_loss = 0
|
||||||
num_tokens = 0
|
num_tokens = 0
|
||||||
|
|
||||||
|
# training artifacts
|
||||||
|
checkpoints = []
|
||||||
|
memory_stats = {}
|
||||||
|
|
||||||
# self.epochs_run should be non-zero when we're resuming from a checkpoint
|
# self.epochs_run should be non-zero when we're resuming from a checkpoint
|
||||||
for curr_epoch in range(self.epochs_run, self.total_epochs):
|
for curr_epoch in range(self.epochs_run, self.total_epochs):
|
||||||
# Update the sampler to ensure data is correctly shuffled across epochs
|
# Update the sampler to ensure data is correctly shuffled across epochs
|
||||||
# in case shuffle is True
|
# in case shuffle is True
|
||||||
|
metric_logger = DiskLogger(
|
||||||
|
log_dir=self._output_dir + f"/{self.model_id}-sft-{curr_epoch}"
|
||||||
|
)
|
||||||
self._sampler.set_epoch(curr_epoch)
|
self._sampler.set_epoch(curr_epoch)
|
||||||
|
|
||||||
for idx, batch in enumerate(self._dataloader):
|
for idx, batch in enumerate(self._dataloader):
|
||||||
|
@ -488,10 +505,14 @@ class LoraFinetuningSingleDevice:
|
||||||
"lr": self._optimizer.param_groups[0]["lr"],
|
"lr": self._optimizer.param_groups[0]["lr"],
|
||||||
"tokens_per_second_per_gpu": num_tokens / time_per_step,
|
"tokens_per_second_per_gpu": num_tokens / time_per_step,
|
||||||
}
|
}
|
||||||
log_dict.update(training.get_memory_stats(device=self._device))
|
|
||||||
|
memory_stats = training.get_memory_stats(device=self._device)
|
||||||
|
log_dict.update(memory_stats)
|
||||||
|
|
||||||
if self._clip_grad_norm is not None:
|
if self._clip_grad_norm is not None:
|
||||||
log_dict.update({"grad_norm": grad_norm})
|
log_dict.update({"grad_norm": grad_norm})
|
||||||
self._metric_logger.log_dict(
|
|
||||||
|
metric_logger.log_dict(
|
||||||
log_dict,
|
log_dict,
|
||||||
step=self.global_step,
|
step=self.global_step,
|
||||||
)
|
)
|
||||||
|
@ -503,4 +524,14 @@ class LoraFinetuningSingleDevice:
|
||||||
|
|
||||||
self.epochs_run += 1
|
self.epochs_run += 1
|
||||||
log.info("Starting checkpoint save...")
|
log.info("Starting checkpoint save...")
|
||||||
await self.save_checkpoint(epoch=curr_epoch)
|
checkpoint_path = await self.save_checkpoint(epoch=curr_epoch)
|
||||||
|
checkpoint = Checkpoint(
|
||||||
|
identifier=f"{self.model_id}-sft-{curr_epoch}",
|
||||||
|
created_at=datetime.now(),
|
||||||
|
epoch=curr_epoch,
|
||||||
|
post_training_job_id=self.job_uuid,
|
||||||
|
path=checkpoint_path,
|
||||||
|
)
|
||||||
|
checkpoints.append(checkpoint)
|
||||||
|
|
||||||
|
return (memory_stats, checkpoints)
|
||||||
|
|
|
@ -19,6 +19,7 @@ class TestPostTraining:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_supervised_fine_tune(self, post_training_stack):
|
async def test_supervised_fine_tune(self, post_training_stack):
|
||||||
algorithm_config = LoraFinetuningConfig(
|
algorithm_config = LoraFinetuningConfig(
|
||||||
|
type="LoRA",
|
||||||
lora_attn_modules=["q_proj", "v_proj", "output_proj"],
|
lora_attn_modules=["q_proj", "v_proj", "output_proj"],
|
||||||
apply_lora_to_mlp=True,
|
apply_lora_to_mlp=True,
|
||||||
apply_lora_to_output=False,
|
apply_lora_to_output=False,
|
||||||
|
@ -59,3 +60,33 @@ class TestPostTraining:
|
||||||
)
|
)
|
||||||
assert isinstance(response, PostTrainingJob)
|
assert isinstance(response, PostTrainingJob)
|
||||||
assert response.job_uuid == "1234"
|
assert response.job_uuid == "1234"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_training_jobs(self, post_training_stack):
|
||||||
|
post_training_impl = post_training_stack
|
||||||
|
jobs_list = await post_training_impl.get_training_jobs()
|
||||||
|
assert isinstance(jobs_list, List)
|
||||||
|
assert jobs_list[0].job_uuid == "1234"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_training_job_status(self, post_training_stack):
|
||||||
|
post_training_impl = post_training_stack
|
||||||
|
job_status = await post_training_impl.get_training_job_status("1234")
|
||||||
|
assert isinstance(job_status, PostTrainingJobStatusResponse)
|
||||||
|
assert job_status.job_uuid == "1234"
|
||||||
|
assert job_status.status == JobStatus.completed
|
||||||
|
assert isinstance(job_status.checkpoints[0], Checkpoint)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_training_job_artifacts(self, post_training_stack):
|
||||||
|
post_training_impl = post_training_stack
|
||||||
|
job_artifacts = await post_training_impl.get_training_job_artifacts("1234")
|
||||||
|
assert isinstance(job_artifacts, PostTrainingJobArtifactsResponse)
|
||||||
|
assert job_artifacts.job_uuid == "1234"
|
||||||
|
assert isinstance(job_artifacts.checkpoints[0], Checkpoint)
|
||||||
|
assert job_artifacts.checkpoints[0].identifier == "Llama3.2-3B-Instruct-sft-0"
|
||||||
|
assert job_artifacts.checkpoints[0].epoch == 0
|
||||||
|
assert (
|
||||||
|
"/.llama/checkpoints/Llama3.2-3B-Instruct-sft-0"
|
||||||
|
in job_artifacts.checkpoints[0].path
|
||||||
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue