forked from phoenix-oss/llama-stack-mirror
Fix precommit check after moving to ruff (#927)
Lint check in main branch is failing. This fixes the lint check after we moved to ruff in https://github.com/meta-llama/llama-stack/pull/921. We need to move to a `ruff.toml` file as well as fixing and ignoring some additional checks. Signed-off-by: Yuan Tang <terrytangyuan@gmail.com>
This commit is contained in:
parent
4773092dd1
commit
34ab7a3b6c
217 changed files with 981 additions and 2681 deletions
|
@ -47,6 +47,4 @@ async def validate_input_dataset_schema(
|
|||
if dataset_type not in EXPECTED_DATASET_SCHEMA:
|
||||
raise ValueError(f"Dataset type {dataset_type} is not supported.")
|
||||
|
||||
validate_dataset_schema(
|
||||
dataset_def.dataset_schema, EXPECTED_DATASET_SCHEMA[dataset_type]
|
||||
)
|
||||
validate_dataset_schema(dataset_def.dataset_schema, EXPECTED_DATASET_SCHEMA[dataset_type])
|
||||
|
|
|
@ -42,9 +42,7 @@ class TorchtuneCheckpointer:
|
|||
self._model_type = ModelType[model_type]
|
||||
self._output_dir = output_dir
|
||||
# get ckpt paths
|
||||
self._checkpoint_path = Path.joinpath(
|
||||
self._checkpoint_dir, self._checkpoint_file
|
||||
)
|
||||
self._checkpoint_path = Path.joinpath(self._checkpoint_dir, self._checkpoint_file)
|
||||
|
||||
def load_checkpoint(self) -> Dict[str, Any]:
|
||||
"""
|
||||
|
@ -57,13 +55,9 @@ class TorchtuneCheckpointer:
|
|||
llama3_vision_meta_to_tune,
|
||||
)
|
||||
|
||||
state_dict[training.MODEL_KEY] = llama3_vision_meta_to_tune(
|
||||
model_state_dict
|
||||
)
|
||||
state_dict[training.MODEL_KEY] = llama3_vision_meta_to_tune(model_state_dict)
|
||||
else:
|
||||
state_dict[training.MODEL_KEY] = convert_weights.meta_to_tune(
|
||||
model_state_dict
|
||||
)
|
||||
state_dict[training.MODEL_KEY] = convert_weights.meta_to_tune(model_state_dict)
|
||||
|
||||
# llama3_2 has tied weights, so we need to remove the output.weight key
|
||||
if self._model_type == ModelType.LLAMA3_2:
|
||||
|
@ -82,10 +76,7 @@ class TorchtuneCheckpointer:
|
|||
epoch: int,
|
||||
adapter_only: bool = False,
|
||||
) -> str:
|
||||
model_file_path = (
|
||||
Path(self._output_dir)
|
||||
/ f"{self._model_id}-{self._training_algorithm}-{epoch}"
|
||||
)
|
||||
model_file_path = Path(self._output_dir) / f"{self._model_id}-{self._training_algorithm}-{epoch}"
|
||||
|
||||
model_file_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
@ -116,22 +107,13 @@ class TorchtuneCheckpointer:
|
|||
llama3_vision_tune_to_meta,
|
||||
)
|
||||
|
||||
state_dict[training.MODEL_KEY] = llama3_vision_tune_to_meta(
|
||||
model_state_dict
|
||||
)
|
||||
state_dict[training.MODEL_KEY] = llama3_vision_tune_to_meta(model_state_dict)
|
||||
else:
|
||||
# llama3_2 has tied weights, so we need to add the output.weight key
|
||||
if (
|
||||
self._model_type == ModelType.LLAMA3_2
|
||||
and "output.weight" not in model_state_dict
|
||||
):
|
||||
model_state_dict["output.weight"] = model_state_dict[
|
||||
"tok_embeddings.weight"
|
||||
]
|
||||
if self._model_type == ModelType.LLAMA3_2 and "output.weight" not in model_state_dict:
|
||||
model_state_dict["output.weight"] = model_state_dict["tok_embeddings.weight"]
|
||||
|
||||
state_dict[training.MODEL_KEY] = convert_weights.tune_to_meta(
|
||||
model_state_dict
|
||||
)
|
||||
state_dict[training.MODEL_KEY] = convert_weights.tune_to_meta(model_state_dict)
|
||||
|
||||
model_file_name = Path.joinpath(model_file_path, "consolidated.00.pth")
|
||||
|
||||
|
|
|
@ -15,18 +15,13 @@ from typing import Any, Mapping
|
|||
from llama_stack.providers.utils.common.data_schema_validator import ColumnName
|
||||
|
||||
|
||||
def llama_stack_instruct_to_torchtune_instruct(
|
||||
sample: Mapping[str, Any]
|
||||
) -> Mapping[str, Any]:
|
||||
assert (
|
||||
ColumnName.chat_completion_input.value in sample
|
||||
and ColumnName.expected_answer.value in sample
|
||||
), "Invalid input row"
|
||||
def llama_stack_instruct_to_torchtune_instruct(sample: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||
assert ColumnName.chat_completion_input.value in sample and ColumnName.expected_answer.value in sample, (
|
||||
"Invalid input row"
|
||||
)
|
||||
input_messages = eval(str(sample[ColumnName.chat_completion_input.value]))
|
||||
|
||||
assert (
|
||||
len(input_messages) == 1
|
||||
), "llama stack intruct dataset format only supports 1 user message"
|
||||
assert len(input_messages) == 1, "llama stack intruct dataset format only supports 1 user message"
|
||||
input_message = input_messages[0]
|
||||
|
||||
assert "content" in input_message, "content not found in input message"
|
||||
|
@ -48,13 +43,9 @@ def llama_stack_chat_to_torchtune_chat(sample: Mapping[str, Any]) -> Mapping[str
|
|||
roles = []
|
||||
conversations = []
|
||||
for message in dialog:
|
||||
assert (
|
||||
"role" in message and "content" in message
|
||||
), "role and content must in message"
|
||||
assert "role" in message and "content" in message, "role and content must in message"
|
||||
roles.append(message["role"])
|
||||
conversations.append(
|
||||
{"from": role_map[message["role"]], "value": message["content"]}
|
||||
)
|
||||
conversations.append({"from": role_map[message["role"]], "value": message["content"]})
|
||||
|
||||
assert roles[0] == "user", "first message must be from user"
|
||||
assert "assistant" in roles, "at least 1 message should be from assistant"
|
||||
|
|
|
@ -61,8 +61,7 @@ class SFTDataset(Dataset):
|
|||
if not ("tokens" in tokenized_dict and "mask" in tokenized_dict):
|
||||
keys_str = ", ".join(tokenized_dict.keys())
|
||||
error_message = (
|
||||
"model_transform returned the following keys: "
|
||||
f"{keys_str}. Must return 'tokens' and 'mask' as keys."
|
||||
f"model_transform returned the following keys: {keys_str}. Must return 'tokens' and 'mask' as keys."
|
||||
)
|
||||
raise ValueError(error_message)
|
||||
|
||||
|
|
|
@ -119,9 +119,7 @@ class TorchtunePostTrainingImpl:
|
|||
return ListPostTrainingJobsResponse(data=self.jobs_list)
|
||||
|
||||
@webmethod(route="/post-training/job/status")
|
||||
async def get_training_job_status(
|
||||
self, job_uuid: str
|
||||
) -> Optional[PostTrainingJobStatusResponse]:
|
||||
async def get_training_job_status(self, job_uuid: str) -> Optional[PostTrainingJobStatusResponse]:
|
||||
if job_uuid in self.jobs_status:
|
||||
return self.jobs_status[job_uuid]
|
||||
return None
|
||||
|
@ -131,12 +129,8 @@ class TorchtunePostTrainingImpl:
|
|||
raise NotImplementedError("Job cancel is not implemented yet")
|
||||
|
||||
@webmethod(route="/post-training/job/artifacts")
|
||||
async def get_training_job_artifacts(
|
||||
self, job_uuid: str
|
||||
) -> Optional[PostTrainingJobArtifactsResponse]:
|
||||
async def get_training_job_artifacts(self, job_uuid: str) -> Optional[PostTrainingJobArtifactsResponse]:
|
||||
if job_uuid in self.checkpoints_dict:
|
||||
checkpoints = self.checkpoints_dict.get(job_uuid, [])
|
||||
return PostTrainingJobArtifactsResponse(
|
||||
job_uuid=job_uuid, checkpoints=checkpoints
|
||||
)
|
||||
return PostTrainingJobArtifactsResponse(job_uuid=job_uuid, checkpoints=checkpoints)
|
||||
return None
|
||||
|
|
|
@ -94,9 +94,7 @@ class LoraFinetuningSingleDevice:
|
|||
self.job_uuid = job_uuid
|
||||
self.training_config = training_config
|
||||
if not isinstance(algorithm_config, LoraFinetuningConfig):
|
||||
raise ValueError(
|
||||
"You need to speicifc LoraFinetuningConfig for LoRA finetuning"
|
||||
)
|
||||
raise ValueError("You need to speicifc LoraFinetuningConfig for LoRA finetuning")
|
||||
self.algorithm_config = algorithm_config
|
||||
self._device = torchtune_utils.get_device(device="cuda")
|
||||
self._dtype = training.get_dtype(training_config.dtype, device=self._device)
|
||||
|
@ -105,10 +103,7 @@ class LoraFinetuningSingleDevice:
|
|||
def model_checkpoint_dir(model) -> str:
|
||||
checkpoint_dir = Path(model_local_dir(model.descriptor()))
|
||||
|
||||
paths = [
|
||||
Path(checkpoint_dir / f"consolidated.{ext}")
|
||||
for ext in ["pth", "00.pth"]
|
||||
]
|
||||
paths = [Path(checkpoint_dir / f"consolidated.{ext}") for ext in ["pth", "00.pth"]]
|
||||
if not any(p.exists() for p in paths):
|
||||
checkpoint_dir = checkpoint_dir / "original"
|
||||
|
||||
|
@ -123,9 +118,7 @@ class LoraFinetuningSingleDevice:
|
|||
else:
|
||||
model = resolve_model(self.model_id)
|
||||
if model is None:
|
||||
raise ValueError(
|
||||
f"{self.model_id} not found. Your model id should be in the llama models SKU list"
|
||||
)
|
||||
raise ValueError(f"{self.model_id} not found. Your model id should be in the llama models SKU list")
|
||||
self.checkpoint_dir = model_checkpoint_dir(model)
|
||||
|
||||
self._output_dir = str(DEFAULT_CHECKPOINT_DIR)
|
||||
|
@ -196,9 +189,7 @@ class LoraFinetuningSingleDevice:
|
|||
self._tokenizer = await self._setup_tokenizer()
|
||||
log.info("Tokenizer is initialized.")
|
||||
|
||||
self._optimizer = await self._setup_optimizer(
|
||||
optimizer_config=self.training_config.optimizer_config
|
||||
)
|
||||
self._optimizer = await self._setup_optimizer(optimizer_config=self.training_config.optimizer_config)
|
||||
log.info("Optimizer is initialized.")
|
||||
|
||||
self._loss_fn = CEWithChunkedOutputLoss()
|
||||
|
@ -226,13 +217,8 @@ class LoraFinetuningSingleDevice:
|
|||
# by the dataloader and the max_steps_per_epoch param set by the user and is used
|
||||
# for logging and tracking training state. This should be computed after the dataloader
|
||||
# has been setup
|
||||
self._steps_per_epoch = (
|
||||
len(self._training_dataloader) // self._gradient_accumulation_steps
|
||||
)
|
||||
if (
|
||||
self.max_steps_per_epoch is not None
|
||||
and self.max_steps_per_epoch < self._steps_per_epoch
|
||||
):
|
||||
self._steps_per_epoch = len(self._training_dataloader) // self._gradient_accumulation_steps
|
||||
if self.max_steps_per_epoch is not None and self.max_steps_per_epoch < self._steps_per_epoch:
|
||||
self._steps_per_epoch = self.max_steps_per_epoch
|
||||
self.global_step = self.epochs_run * self._steps_per_epoch
|
||||
|
||||
|
@ -246,9 +232,7 @@ class LoraFinetuningSingleDevice:
|
|||
log.info("Learning rate scheduler is initialized.")
|
||||
|
||||
# Used to ignore labels for loss computation
|
||||
self.ignore_labels_cache = torch.full(
|
||||
(self._batch_size, 1), self._loss_fn.ignore_index, device=self._device
|
||||
)
|
||||
self.ignore_labels_cache = torch.full((self._batch_size, 1), self._loss_fn.ignore_index, device=self._device)
|
||||
|
||||
async def _setup_model(
|
||||
self,
|
||||
|
@ -282,13 +266,9 @@ class LoraFinetuningSingleDevice:
|
|||
set_trainable_params(model, self.adapter_params)
|
||||
|
||||
if enable_activation_checkpointing:
|
||||
training.set_activation_checkpointing(
|
||||
model, auto_wrap_policy={modules.TransformerSelfAttentionLayer}
|
||||
)
|
||||
training.set_activation_checkpointing(model, auto_wrap_policy={modules.TransformerSelfAttentionLayer})
|
||||
|
||||
base_missing, base_unexpected = model.load_state_dict(
|
||||
base_model_state_dict, strict=False
|
||||
)
|
||||
base_missing, base_unexpected = model.load_state_dict(base_model_state_dict, strict=False)
|
||||
|
||||
# This is for any adapters that need to be initialized after base weights
|
||||
# have been loaded (e.g. DoRA).
|
||||
|
@ -297,9 +277,7 @@ class LoraFinetuningSingleDevice:
|
|||
if hasattr(m, "initialize_dora_magnitude"):
|
||||
m.initialize_dora_magnitude()
|
||||
if lora_weights_state_dict:
|
||||
lora_missing, lora_unexpected = model.load_state_dict(
|
||||
lora_weights_state_dict, strict=False
|
||||
)
|
||||
lora_missing, lora_unexpected = model.load_state_dict(lora_weights_state_dict, strict=False)
|
||||
else:
|
||||
lora_missing, lora_unexpected = None, None
|
||||
validate_missing_and_unexpected_for_lora(
|
||||
|
@ -313,14 +291,10 @@ class LoraFinetuningSingleDevice:
|
|||
)
|
||||
|
||||
# Validate model adapter params were loaded in with the expected dtype
|
||||
training.validate_expected_param_dtype(
|
||||
self.adapter_params.items(), dtype=self._dtype
|
||||
)
|
||||
training.validate_expected_param_dtype(self.adapter_params.items(), dtype=self._dtype)
|
||||
|
||||
# activation offloading
|
||||
self.activations_handling_ctx = training.get_act_offloading_ctx_manager(
|
||||
model, enable_activation_offloading
|
||||
)
|
||||
self.activations_handling_ctx = training.get_act_offloading_ctx_manager(model, enable_activation_offloading)
|
||||
|
||||
memory_stats = training.get_memory_stats(device=self._device)
|
||||
training.log_memory_stats(memory_stats)
|
||||
|
@ -456,9 +430,7 @@ class LoraFinetuningSingleDevice:
|
|||
# Shift labels to compute loss
|
||||
# equivalent to doing labels[..., 1:] and logits[..., :-1, :]
|
||||
# But this way we dont need to slice the logits. We just add an ignore index to labels.
|
||||
labels = torch.hstack(
|
||||
(labels[..., 1:], self.ignore_labels_cache[: labels.shape[0]])
|
||||
)
|
||||
labels = torch.hstack((labels[..., 1:], self.ignore_labels_cache[: labels.shape[0]]))
|
||||
if not isinstance(logits, list):
|
||||
labels = labels.reshape(-1)
|
||||
logits = logits.reshape(-1, logits.size(-1))
|
||||
|
@ -487,9 +459,7 @@ class LoraFinetuningSingleDevice:
|
|||
for curr_epoch in range(self.epochs_run, self.total_epochs):
|
||||
# Update the sampler to ensure data is correctly shuffled across epochs
|
||||
# in case shuffle is True
|
||||
metric_logger = DiskLogger(
|
||||
log_dir=self._output_dir + f"/{self.model_id}-sft-{curr_epoch}"
|
||||
)
|
||||
metric_logger = DiskLogger(log_dir=self._output_dir + f"/{self.model_id}-sft-{curr_epoch}")
|
||||
self._training_sampler.set_epoch(curr_epoch)
|
||||
loss_to_log = 0.0
|
||||
|
||||
|
@ -497,8 +467,7 @@ class LoraFinetuningSingleDevice:
|
|||
for idx, batch in enumerate(self._training_dataloader):
|
||||
if (
|
||||
self.max_steps_per_epoch is not None
|
||||
and (idx // self._gradient_accumulation_steps)
|
||||
== self.max_steps_per_epoch
|
||||
and (idx // self._gradient_accumulation_steps) == self.max_steps_per_epoch
|
||||
):
|
||||
break
|
||||
|
||||
|
@ -506,9 +475,7 @@ class LoraFinetuningSingleDevice:
|
|||
|
||||
# Calculate the number of unmasked tokens in the current batch
|
||||
# and increment the total number of tokens seen in the step
|
||||
current_num_tokens = (
|
||||
batch["labels"] != self._loss_fn.ignore_index
|
||||
).sum()
|
||||
current_num_tokens = (batch["labels"] != self._loss_fn.ignore_index).sum()
|
||||
num_tokens += current_num_tokens
|
||||
|
||||
# Loss is normalized by default so we multiply by the number of tokens
|
||||
|
@ -533,9 +500,7 @@ class LoraFinetuningSingleDevice:
|
|||
loss_to_log = running_loss.item() / num_tokens
|
||||
|
||||
pbar.update(1)
|
||||
pbar.set_description(
|
||||
f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}"
|
||||
)
|
||||
pbar.set_description(f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}")
|
||||
|
||||
time_per_step = time.perf_counter() - t0
|
||||
log_dict = {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue