[2/n][torchtune integration] implement job management and return training artifacts (#593)

### Context 
In this PR, we 
- Implement the post training job management and get training artifacts
apis
  - get_training_jobs
  - get_training_job_status
  - get_training_job_artifacts
- get_training_job_logstream is deleted since the trace can be directly
accessed by UI with Jaeger
https://llama-stack.readthedocs.io/en/latest/building_applications/telemetry.html#jaeger-to-visualize-traces
- Refactor the post training and training types definition to make them
more intuitive.
- Rewrite the checkpointer to make it compatible with llama-stack file
system and can be recognized during inference


### Test
Unit test
`pytest llama_stack/providers/tests/post_training/test_post_training.py
-m "torchtune_post_training_huggingface_datasetio" -v -s --tb=short
--disable-warnings`

<img width="1506" alt="Screenshot 2024-12-10 at 4 06 17 PM"
src="https://github.com/user-attachments/assets/16225029-bdb7-48c4-9d13-e580cc769c0a">


e2e test with client side call

<img width="888" alt="Screenshot 2024-12-10 at 4 09 44 PM"
src="https://github.com/user-attachments/assets/de375e4c-ef67-4dcc-a045-4037d9489191">
This commit is contained in:
Botao Chen 2024-12-13 15:00:04 -08:00 committed by GitHub
parent 5764a95912
commit c294a01c4b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 331 additions and 67 deletions

View file

@ -18,3 +18,5 @@ class Job(BaseModel):
class JobStatus(Enum): class JobStatus(Enum):
completed = "completed" completed = "completed"
in_progress = "in_progress" in_progress = "in_progress"
failed = "failed"
scheduled = "scheduled"

View file

@ -4,13 +4,26 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from llama_models.llama3.api.datatypes import URL from datetime import datetime
from typing import Optional
from llama_models.schema_utils import json_schema_type from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel from pydantic import BaseModel
@json_schema_type
class PostTrainingMetric(BaseModel):
epoch: int
train_loss: float
validation_loss: float
perplexity: float
@json_schema_type(schema={"description": "Checkpoint created during training runs"}) @json_schema_type(schema={"description": "Checkpoint created during training runs"})
class Checkpoint(BaseModel): class Checkpoint(BaseModel):
iters: int identifier: str
path: URL created_at: datetime
epoch: int epoch: int
post_training_job_id: str
path: str
training_metrics: Optional[PostTrainingMetric] = None

View file

@ -6,6 +6,7 @@
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum
from typing import Any, Dict, List, Optional, Protocol, Union from typing import Any, Dict, List, Optional, Protocol, Union
from llama_models.schema_utils import json_schema_type, webmethod from llama_models.schema_utils import json_schema_type, webmethod
@ -14,6 +15,7 @@ from pydantic import BaseModel, Field
from typing_extensions import Annotated from typing_extensions import Annotated
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.common.job_types import JobStatus
from llama_stack.apis.datasets import * # noqa: F403 from llama_stack.apis.datasets import * # noqa: F403
from llama_stack.apis.common.training_types import * # noqa: F403 from llama_stack.apis.common.training_types import * # noqa: F403
@ -64,6 +66,7 @@ class TrainingConfig(BaseModel):
@json_schema_type @json_schema_type
class LoraFinetuningConfig(BaseModel): class LoraFinetuningConfig(BaseModel):
type: Literal["LoRA"] = "LoRA"
lora_attn_modules: List[str] lora_attn_modules: List[str]
apply_lora_to_mlp: bool apply_lora_to_mlp: bool
apply_lora_to_output: bool apply_lora_to_output: bool
@ -75,12 +78,13 @@ class LoraFinetuningConfig(BaseModel):
@json_schema_type @json_schema_type
class QATFinetuningConfig(BaseModel): class QATFinetuningConfig(BaseModel):
type: Literal["QAT"] = "QAT"
quantizer_name: str quantizer_name: str
group_size: int group_size: int
AlgorithmConfig = Annotated[ AlgorithmConfig = Annotated[
Union[LoraFinetuningConfig, LoraFinetuningConfig], Field(discriminator="type") Union[LoraFinetuningConfig, QATFinetuningConfig], Field(discriminator="type")
] ]
@ -92,14 +96,6 @@ class PostTrainingJobLogStream(BaseModel):
log_lines: List[str] log_lines: List[str]
@json_schema_type
class PostTrainingJobStatus(Enum):
running = "running"
completed = "completed"
failed = "failed"
scheduled = "scheduled"
@json_schema_type @json_schema_type
class RLHFAlgorithm(Enum): class RLHFAlgorithm(Enum):
dpo = "dpo" dpo = "dpo"
@ -144,7 +140,7 @@ class PostTrainingJobStatusResponse(BaseModel):
"""Status of a finetuning job.""" """Status of a finetuning job."""
job_uuid: str job_uuid: str
status: PostTrainingJobStatus status: JobStatus
scheduled_at: Optional[datetime] = None scheduled_at: Optional[datetime] = None
started_at: Optional[datetime] = None started_at: Optional[datetime] = None
@ -166,7 +162,7 @@ class PostTrainingJobArtifactsResponse(BaseModel):
class PostTraining(Protocol): class PostTraining(Protocol):
@webmethod(route="/post-training/supervised-fine-tune") @webmethod(route="/post-training/supervised-fine-tune", method="POST")
async def supervised_fine_tune( async def supervised_fine_tune(
self, self,
job_uuid: str, job_uuid: str,
@ -181,7 +177,7 @@ class PostTraining(Protocol):
algorithm_config: Optional[AlgorithmConfig] = None, algorithm_config: Optional[AlgorithmConfig] = None,
) -> PostTrainingJob: ... ) -> PostTrainingJob: ...
@webmethod(route="/post-training/preference-optimize") @webmethod(route="/post-training/preference-optimize", method="POST")
async def preference_optimize( async def preference_optimize(
self, self,
job_uuid: str, job_uuid: str,
@ -192,24 +188,18 @@ class PostTraining(Protocol):
logger_config: Dict[str, Any], logger_config: Dict[str, Any],
) -> PostTrainingJob: ... ) -> PostTrainingJob: ...
@webmethod(route="/post-training/jobs") @webmethod(route="/post-training/jobs", method="GET")
async def get_training_jobs(self) -> List[PostTrainingJob]: ... async def get_training_jobs(self) -> List[PostTrainingJob]: ...
# sends SSE stream of logs @webmethod(route="/post-training/job/status", method="GET")
@webmethod(route="/post-training/job/logs")
async def get_training_job_logstream(
self, job_uuid: str
) -> PostTrainingJobLogStream: ...
@webmethod(route="/post-training/job/status")
async def get_training_job_status( async def get_training_job_status(
self, job_uuid: str self, job_uuid: str
) -> PostTrainingJobStatusResponse: ... ) -> Optional[PostTrainingJobStatusResponse]: ...
@webmethod(route="/post-training/job/cancel") @webmethod(route="/post-training/job/cancel", method="POST")
async def cancel_training_job(self, job_uuid: str) -> None: ... async def cancel_training_job(self, job_uuid: str) -> None: ...
@webmethod(route="/post-training/job/artifacts") @webmethod(route="/post-training/job/artifacts", method="GET")
async def get_training_job_artifacts( async def get_training_job_artifacts(
self, job_uuid: str self, job_uuid: str
) -> PostTrainingJobArtifactsResponse: ... ) -> Optional[PostTrainingJobArtifactsResponse]: ...

View file

@ -0,0 +1,157 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
import shutil
from pathlib import Path
from typing import Any, Dict, List
import torch
from torchtune import training
from torchtune.models import convert_weights
from torchtune.training.checkpointing._utils import ModelType, safe_torch_load
from torchtune.utils._logging import get_logger
logger = get_logger("DEBUG")
class TorchtuneCheckpointer:
def __init__(
self,
model_id: str,
training_algorithm: str,
checkpoint_dir: str,
checkpoint_files: List[str],
output_dir: str,
model_type: str,
) -> None:
# Fail fast if ``checkpoint_files`` is invalid
# TODO: support loading more than one file
if len(checkpoint_files) != 1:
raise ValueError(
"Currently we only support reading from a single torchtune checkpoint file. "
f"Got {len(checkpoint_files)} files instead."
)
self._checkpoint_file = checkpoint_files[0]
self._model_id = model_id
self._training_algorithm = training_algorithm
self._checkpoint_dir = Path(checkpoint_dir)
self._model_type = ModelType[model_type]
self._output_dir = output_dir
# get ckpt paths
self._checkpoint_path = Path.joinpath(
self._checkpoint_dir, self._checkpoint_file
)
def load_checkpoint(self) -> Dict[str, Any]:
"""
Load Meta checkpoint from file. Currently only loading from a single file is supported.
"""
state_dict: Dict[str:Any] = {}
model_state_dict = safe_torch_load(self._checkpoint_path)
if self._model_type == ModelType.LLAMA3_VISION:
from torchtune.models.llama3_2_vision._convert_weights import (
llama3_vision_meta_to_tune,
)
state_dict[training.MODEL_KEY] = llama3_vision_meta_to_tune(
model_state_dict
)
else:
state_dict[training.MODEL_KEY] = convert_weights.meta_to_tune(
model_state_dict
)
# llama3_2 has tied weights, so we need to remove the output.weight key
if self._model_type == ModelType.LLAMA3_2:
logger.info(
"Identified model_type = Llama3_2. Ignoring output.weight in"
" checkpoint in favor of the tok_embedding.weight"
" tied weights."
)
state_dict[training.MODEL_KEY].pop("output.weight")
return state_dict
def save_checkpoint(
self,
state_dict: Dict[str, Any],
epoch: int,
adapter_only: bool = False,
) -> str:
model_file_path = (
Path(self._output_dir)
/ f"{self._model_id}-{self._training_algorithm}-{epoch}"
)
model_file_path.mkdir(parents=True, exist_ok=True)
# copy the related files for inference
shutil.copy(
Path.joinpath(self._checkpoint_dir, "params.json"),
Path.joinpath(model_file_path, "params.json"),
)
shutil.copy(
Path.joinpath(self._checkpoint_dir, "tokenizer.model"),
Path.joinpath(model_file_path, "tokenizer.model"),
)
shutil.copy(
Path.joinpath(self._checkpoint_dir, "orig_params.json"),
Path.joinpath(model_file_path, "orig_params.json"),
)
if not adapter_only:
model_state_dict = state_dict[training.MODEL_KEY]
if self._model_type == ModelType.LLAMA3_VISION:
from torchtune.models.llama3_2_vision._convert_weights import (
llama3_vision_tune_to_meta,
)
state_dict[training.MODEL_KEY] = llama3_vision_tune_to_meta(
model_state_dict
)
else:
# llama3_2 has tied weights, so we need to add the output.weight key
if (
self._model_type == ModelType.LLAMA3_2
and "output.weight" not in model_state_dict
):
model_state_dict["output.weight"] = model_state_dict[
"tok_embeddings.weight"
]
state_dict[training.MODEL_KEY] = convert_weights.tune_to_meta(
model_state_dict
)
model_file_name = Path.joinpath(model_file_path, "consolidated.00.pth")
torch.save(state_dict[training.MODEL_KEY], model_file_name)
logger.info(
"Model checkpoint of size "
f"{os.path.getsize(model_file_name) / 1000**3:.2f} GB "
f"saved to {model_file_name}"
)
if training.ADAPTER_KEY in state_dict:
adapter_file_path = model_file_path / "adapter"
adapter_file_path.mkdir(parents=True, exist_ok=True)
adapter_file_name = Path.joinpath(adapter_file_path, "adapter.pth")
torch.save(state_dict[training.ADAPTER_KEY], adapter_file_name)
logger.info(
"Adapter checkpoint of size "
f"{os.path.getsize(adapter_file_name) / 1000**3:.2f} GB "
f"saved to {adapter_file_name}"
)
elif adapter_only:
raise ValueError(
"Adapter checkpoint not found in state_dict. Please ensure that the state_dict contains adapter weights."
)
print("model_file_path", str(model_file_path))
return str(model_file_path)

View file

@ -24,6 +24,11 @@ class TorchtunePostTrainingImpl:
self.datasetio_api = datasetio_api self.datasetio_api = datasetio_api
self.datasets_api = datasets self.datasets_api = datasets
# TODO: assume sync job, will need jobs API for async scheduling
self.jobs_status = {}
self.jobs_list = []
self.checkpoints_dict = {}
async def supervised_fine_tune( async def supervised_fine_tune(
self, self,
job_uuid: str, job_uuid: str,
@ -32,11 +37,26 @@ class TorchtunePostTrainingImpl:
logger_config: Dict[str, Any], logger_config: Dict[str, Any],
model: str, model: str,
checkpoint_dir: Optional[str], checkpoint_dir: Optional[str],
algorithm_config: Optional[Union[LoraFinetuningConfig, QATFinetuningConfig]], algorithm_config: Optional[AlgorithmConfig],
) -> PostTrainingJob: ) -> PostTrainingJob:
for job in self.jobs_list:
if job_uuid == job.job_uuid:
raise ValueError(f"Job {job_uuid} already exists")
post_training_job = PostTrainingJob(job_uuid=job_uuid)
job_status_response = PostTrainingJobStatusResponse(
job_uuid=job_uuid,
status=JobStatus.scheduled,
scheduled_at=datetime.now(),
)
self.jobs_list.append(post_training_job)
if isinstance(algorithm_config, LoraFinetuningConfig): if isinstance(algorithm_config, LoraFinetuningConfig):
try:
recipe = LoraFinetuningSingleDevice( recipe = LoraFinetuningSingleDevice(
self.config, self.config,
job_uuid,
training_config, training_config,
hyperparam_search_config, hyperparam_search_config,
logger_config, logger_config,
@ -46,12 +66,28 @@ class TorchtunePostTrainingImpl:
self.datasetio_api, self.datasetio_api,
self.datasets_api, self.datasets_api,
) )
job_status_response.status = JobStatus.in_progress
job_status_response.started_at = datetime.now()
await recipe.setup() await recipe.setup()
await recipe.train() resources_allocated, checkpoints = await recipe.train()
self.checkpoints_dict[job_uuid] = checkpoints
job_status_response.resources_allocated = resources_allocated
job_status_response.checkpoints = checkpoints
job_status_response.status = JobStatus.completed
job_status_response.completed_at = datetime.now()
except Exception:
job_status_response.status = JobStatus.failed
raise
else: else:
raise NotImplementedError() raise NotImplementedError()
return PostTrainingJob(job_uuid=job_uuid) self.jobs_status[job_uuid] = job_status_response
return post_training_job
async def preference_optimize( async def preference_optimize(
self, self,
@ -63,24 +99,28 @@ class TorchtunePostTrainingImpl:
logger_config: Dict[str, Any], logger_config: Dict[str, Any],
) -> PostTrainingJob: ... ) -> PostTrainingJob: ...
# TODO @SLR722 impelment below APIs async def get_training_jobs(self) -> List[PostTrainingJob]:
async def get_training_jobs(self) -> List[PostTrainingJob]: ... return self.jobs_list
# sends SSE stream of logs
@webmethod(route="/post-training/job/logs")
async def get_training_job_logstream(
self, job_uuid: str
) -> PostTrainingJobLogStream: ...
@webmethod(route="/post-training/job/status") @webmethod(route="/post-training/job/status")
async def get_training_job_status( async def get_training_job_status(
self, job_uuid: str self, job_uuid: str
) -> PostTrainingJobStatusResponse: ... ) -> Optional[PostTrainingJobStatusResponse]:
if job_uuid in self.jobs_status:
return self.jobs_status[job_uuid]
return None
@webmethod(route="/post-training/job/cancel") @webmethod(route="/post-training/job/cancel")
async def cancel_training_job(self, job_uuid: str) -> None: ... async def cancel_training_job(self, job_uuid: str) -> None:
raise NotImplementedError("Job cancel is not implemented yet")
@webmethod(route="/post-training/job/artifacts") @webmethod(route="/post-training/job/artifacts")
async def get_training_job_artifacts( async def get_training_job_artifacts(
self, job_uuid: str self, job_uuid: str
) -> PostTrainingJobArtifactsResponse: ... ) -> Optional[PostTrainingJobArtifactsResponse]:
if job_uuid in self.checkpoints_dict:
checkpoints = self.checkpoints_dict.get(job_uuid, [])
return PostTrainingJobArtifactsResponse(
job_uuid=job_uuid, checkpoints=checkpoints
)
return None

View file

@ -13,14 +13,20 @@ from typing import Any, Dict, List, Optional, Tuple
import torch import torch
from llama_models.sku_list import resolve_model from llama_models.sku_list import resolve_model
from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasetio import DatasetIO
from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
from llama_stack.providers.inline.post_training.torchtune.common.checkpointer import (
TorchtuneCheckpointer,
)
from torch import nn from torch import nn
from torchtune import utils as torchtune_utils from torchtune import utils as torchtune_utils
from torchtune.training.metric_logging import DiskLogger from torchtune.training.metric_logging import DiskLogger
from llama_stack.apis.post_training import * # noqa from llama_stack.apis.post_training import * # noqa
from llama_stack.distribution.utils.model_utils import model_local_dir from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.providers.inline.post_training.torchtune import utils from llama_stack.providers.inline.post_training.torchtune.common import utils
from llama_stack.providers.inline.post_training.torchtune.config import ( from llama_stack.providers.inline.post_training.torchtune.config import (
TorchtunePostTrainingConfig, TorchtunePostTrainingConfig,
) )
@ -62,16 +68,22 @@ class LoraFinetuningSingleDevice:
def __init__( def __init__(
self, self,
config: TorchtunePostTrainingConfig, config: TorchtunePostTrainingConfig,
job_uuid: str,
training_config: TrainingConfig, training_config: TrainingConfig,
hyperparam_search_config: Dict[str, Any], hyperparam_search_config: Dict[str, Any],
logger_config: Dict[str, Any], logger_config: Dict[str, Any],
model: str, model: str,
checkpoint_dir: Optional[str], checkpoint_dir: Optional[str],
algorithm_config: Optional[Union[LoraFinetuningConfig, QATFinetuningConfig]], algorithm_config: Optional[AlgorithmConfig],
datasetio_api: DatasetIO, datasetio_api: DatasetIO,
datasets_api: Datasets, datasets_api: Datasets,
) -> None: ) -> None:
self.job_uuid = job_uuid
self.training_config = training_config self.training_config = training_config
if not isinstance(algorithm_config, LoraFinetuningConfig):
raise ValueError(
"You need to speicifc LoraFinetuningConfig for LoRA finetuning"
)
self.algorithm_config = algorithm_config self.algorithm_config = algorithm_config
self._device = torchtune_utils.get_device(device="cuda") self._device = torchtune_utils.get_device(device="cuda")
self._dtype = training.get_dtype(training_config.dtype, device=self._device) self._dtype = training.get_dtype(training_config.dtype, device=self._device)
@ -99,8 +111,7 @@ class LoraFinetuningSingleDevice:
model = resolve_model(self.model_id) model = resolve_model(self.model_id)
self.checkpoint_dir = model_checkpoint_dir(model) self.checkpoint_dir = model_checkpoint_dir(model)
# TODO @SLR722 make it work with get_training_job_artifacts self._output_dir = str(DEFAULT_CHECKPOINT_DIR)
self._output_dir = self.checkpoint_dir + "/posting_training/"
self.seed = training.set_seed(seed=config.torch_seed) self.seed = training.set_seed(seed=config.torch_seed)
self.epochs_run = 0 self.epochs_run = 0
@ -140,7 +151,9 @@ class LoraFinetuningSingleDevice:
except FileNotFoundError: except FileNotFoundError:
return [f"Error: The directory '{checkpoint_dir}' does not exist."] return [f"Error: The directory '{checkpoint_dir}' does not exist."]
self._checkpointer = training.FullModelMetaCheckpointer( self._checkpointer = TorchtuneCheckpointer(
model_id=self.model_id,
training_algorithm="sft",
checkpoint_dir=self.checkpoint_dir, checkpoint_dir=self.checkpoint_dir,
checkpoint_files=get_checkpoint_files(self.checkpoint_dir), checkpoint_files=get_checkpoint_files(self.checkpoint_dir),
output_dir=self._output_dir, output_dir=self._output_dir,
@ -150,8 +163,6 @@ class LoraFinetuningSingleDevice:
return checkpoint_dict return checkpoint_dict
async def setup(self) -> None: async def setup(self) -> None:
self._metric_logger = DiskLogger(log_dir=self._output_dir)
checkpoint_dict = await self.load_checkpoint() checkpoint_dict = await self.load_checkpoint()
self._model = await self._setup_model( self._model = await self._setup_model(
@ -370,7 +381,7 @@ class LoraFinetuningSingleDevice:
) )
return lr_scheduler return lr_scheduler
async def save_checkpoint(self, epoch: int) -> None: async def save_checkpoint(self, epoch: int) -> str:
ckpt_dict = {} ckpt_dict = {}
adapter_state_dict = get_adapter_state_dict(self._model.state_dict()) adapter_state_dict = get_adapter_state_dict(self._model.state_dict())
@ -400,7 +411,7 @@ class LoraFinetuningSingleDevice:
} }
ckpt_dict.update({training.ADAPTER_CONFIG: adapter_config}) ckpt_dict.update({training.ADAPTER_CONFIG: adapter_config})
self._checkpointer.save_checkpoint( return self._checkpointer.save_checkpoint(
ckpt_dict, ckpt_dict,
epoch=epoch, epoch=epoch,
) )
@ -429,20 +440,26 @@ class LoraFinetuningSingleDevice:
return loss return loss
async def train(self) -> None: async def train(self) -> Tuple[Dict[str, Any], List[Checkpoint]]:
""" """
The core training loop. The core training loop.
""" """
# Initialize tokens count and running loss (for grad accumulation) # Initialize tokens count and running loss (for grad accumulation)
# t0 = time.perf_counter()
t0 = time.perf_counter() t0 = time.perf_counter()
running_loss = 0 running_loss = 0
num_tokens = 0 num_tokens = 0
# training artifacts
checkpoints = []
memory_stats = {}
# self.epochs_run should be non-zero when we're resuming from a checkpoint # self.epochs_run should be non-zero when we're resuming from a checkpoint
for curr_epoch in range(self.epochs_run, self.total_epochs): for curr_epoch in range(self.epochs_run, self.total_epochs):
# Update the sampler to ensure data is correctly shuffled across epochs # Update the sampler to ensure data is correctly shuffled across epochs
# in case shuffle is True # in case shuffle is True
metric_logger = DiskLogger(
log_dir=self._output_dir + f"/{self.model_id}-sft-{curr_epoch}"
)
self._sampler.set_epoch(curr_epoch) self._sampler.set_epoch(curr_epoch)
for idx, batch in enumerate(self._dataloader): for idx, batch in enumerate(self._dataloader):
@ -488,10 +505,14 @@ class LoraFinetuningSingleDevice:
"lr": self._optimizer.param_groups[0]["lr"], "lr": self._optimizer.param_groups[0]["lr"],
"tokens_per_second_per_gpu": num_tokens / time_per_step, "tokens_per_second_per_gpu": num_tokens / time_per_step,
} }
log_dict.update(training.get_memory_stats(device=self._device))
memory_stats = training.get_memory_stats(device=self._device)
log_dict.update(memory_stats)
if self._clip_grad_norm is not None: if self._clip_grad_norm is not None:
log_dict.update({"grad_norm": grad_norm}) log_dict.update({"grad_norm": grad_norm})
self._metric_logger.log_dict(
metric_logger.log_dict(
log_dict, log_dict,
step=self.global_step, step=self.global_step,
) )
@ -503,4 +524,14 @@ class LoraFinetuningSingleDevice:
self.epochs_run += 1 self.epochs_run += 1
log.info("Starting checkpoint save...") log.info("Starting checkpoint save...")
await self.save_checkpoint(epoch=curr_epoch) checkpoint_path = await self.save_checkpoint(epoch=curr_epoch)
checkpoint = Checkpoint(
identifier=f"{self.model_id}-sft-{curr_epoch}",
created_at=datetime.now(),
epoch=curr_epoch,
post_training_job_id=self.job_uuid,
path=checkpoint_path,
)
checkpoints.append(checkpoint)
return (memory_stats, checkpoints)

View file

@ -19,6 +19,7 @@ class TestPostTraining:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_supervised_fine_tune(self, post_training_stack): async def test_supervised_fine_tune(self, post_training_stack):
algorithm_config = LoraFinetuningConfig( algorithm_config = LoraFinetuningConfig(
type="LoRA",
lora_attn_modules=["q_proj", "v_proj", "output_proj"], lora_attn_modules=["q_proj", "v_proj", "output_proj"],
apply_lora_to_mlp=True, apply_lora_to_mlp=True,
apply_lora_to_output=False, apply_lora_to_output=False,
@ -59,3 +60,33 @@ class TestPostTraining:
) )
assert isinstance(response, PostTrainingJob) assert isinstance(response, PostTrainingJob)
assert response.job_uuid == "1234" assert response.job_uuid == "1234"
@pytest.mark.asyncio
async def test_get_training_jobs(self, post_training_stack):
post_training_impl = post_training_stack
jobs_list = await post_training_impl.get_training_jobs()
assert isinstance(jobs_list, List)
assert jobs_list[0].job_uuid == "1234"
@pytest.mark.asyncio
async def test_get_training_job_status(self, post_training_stack):
post_training_impl = post_training_stack
job_status = await post_training_impl.get_training_job_status("1234")
assert isinstance(job_status, PostTrainingJobStatusResponse)
assert job_status.job_uuid == "1234"
assert job_status.status == JobStatus.completed
assert isinstance(job_status.checkpoints[0], Checkpoint)
@pytest.mark.asyncio
async def test_get_training_job_artifacts(self, post_training_stack):
post_training_impl = post_training_stack
job_artifacts = await post_training_impl.get_training_job_artifacts("1234")
assert isinstance(job_artifacts, PostTrainingJobArtifactsResponse)
assert job_artifacts.job_uuid == "1234"
assert isinstance(job_artifacts.checkpoints[0], Checkpoint)
assert job_artifacts.checkpoints[0].identifier == "Llama3.2-3B-Instruct-sft-0"
assert job_artifacts.checkpoints[0].epoch == 0
assert (
"/.llama/checkpoints/Llama3.2-3B-Instruct-sft-0"
in job_artifacts.checkpoints[0].path
)