mirror of
https://github.com/meta-llama/llama-stack.git
synced 2026-01-03 10:12:15 +00:00
debugged impl errors for building container and running data prep
Signed-off-by: James Kunstle <jkunstle@redhat.com>
This commit is contained in:
parent
06465441f2
commit
68000499f7
6 changed files with 66 additions and 19 deletions
|
|
@ -63,8 +63,10 @@ class HFilabPostTrainingImpl:
|
|||
if self.current_job is None:
|
||||
return True
|
||||
|
||||
finalized_job_states = [JobStatus.completed.value, JobStatus.failed.value]
|
||||
if self.current_job.status in finalized_job_states:
|
||||
finalized_job_states = [JobStatus.completed, JobStatus.failed]
|
||||
|
||||
# check most recent status of job.
|
||||
if self.current_job.status[-1] in finalized_job_states:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
|
@ -87,7 +89,8 @@ class HFilabPostTrainingImpl:
|
|||
checkpoint_dir: Optional[str],
|
||||
algorithm_config: Optional[AlgorithmConfig],
|
||||
) -> JSONResponse:
|
||||
if not self.can_schedule_new_job():
|
||||
if not await self.can_schedule_new_job():
|
||||
# TODO: this status code isn't making its way up to the user. User just getting 500 from SDK.
|
||||
raise fastapi.HTTPException(
|
||||
status_code=503, # service unavailable, try again later.
|
||||
detail="A tuning job is currently running; this could take a while.",
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from typing import Callable
|
|||
|
||||
import datasets
|
||||
import transformers
|
||||
from termcolor import cprint
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
|
||||
|
|
@ -72,6 +73,10 @@ class FullPrecisionFineTuning:
|
|||
def logs_dir(self):
|
||||
return self.storage_dir / "logs"
|
||||
|
||||
@property
|
||||
def hfcache_dir(self):
|
||||
return self.storage_dir / "hf_cache"
|
||||
|
||||
@staticmethod
|
||||
def check_model_arch_validated(model_config: PretrainedConfig) -> bool:
|
||||
"""Check whether input model architecture from config is among the pre-validated architectures.
|
||||
|
|
@ -98,7 +103,9 @@ class FullPrecisionFineTuning:
|
|||
PretrainedConfig: model config associated with model.
|
||||
"""
|
||||
try:
|
||||
model_config: PretrainedConfig = transformers.AutoConfig.from_pretrained(self.model_name_or_path)
|
||||
model_config: PretrainedConfig = transformers.AutoConfig.from_pretrained(
|
||||
self.model_name_or_path, cache_dir=self.hfcache_dir
|
||||
)
|
||||
except OSError:
|
||||
print(
|
||||
f"Attempted to load model config for ({self.model_name_or_path}) but failed. Model config will be loaded by `AutoConfig.from_pretrained()`"
|
||||
|
|
@ -115,7 +122,7 @@ class FullPrecisionFineTuning:
|
|||
"""
|
||||
try:
|
||||
tokenizer: SomePretrainedTokenizer = transformers.AutoTokenizer.from_pretrained(
|
||||
self.model_name_or_path, use_fast=True
|
||||
self.model_name_or_path, use_fast=True, cache_dir=self.hfcache_dir
|
||||
)
|
||||
except OSError:
|
||||
print(
|
||||
|
|
@ -150,6 +157,9 @@ class FullPrecisionFineTuning:
|
|||
dataset_id=self.training_config.data_config.dataset_id, rows_in_page=-1
|
||||
)
|
||||
self.loaded_dataset = dataset.rows
|
||||
cprint(
|
||||
f"Dataset loaded! len: ({len(self.loaded_dataset)}), example row: ({self.loaded_dataset[0]})", color="cyan"
|
||||
)
|
||||
|
||||
def preflight(self, set_status_callback: Callable[[JobStatus], None]):
|
||||
"""Set of checks that should run before any heavier-weight preprocessing runs to validate starting state.
|
||||
|
|
@ -175,19 +185,22 @@ class FullPrecisionFineTuning:
|
|||
RuntimeError: If tokenizer doesn't have chat template available.
|
||||
OSError: Can be raised via this function if config or tokenizer not available via model's name.
|
||||
"""
|
||||
|
||||
model_config = self.__try_load_config()
|
||||
cprint("Loaded model config", color="cyan")
|
||||
if not self.check_model_arch_validated(model_config=model_config):
|
||||
# could raise Error if we need a strong check against this.
|
||||
print(
|
||||
f"Input model ({self.model_name_or_path}) architecture ({model_config.architectures}) is not among validated architectures."
|
||||
)
|
||||
cprint("Validated model config", color="cyan")
|
||||
|
||||
model_tokenizer = self.__try_load_tokenizer()
|
||||
cprint("Loaded model tokenizer", color="cyan")
|
||||
if not self.check_tokenizer_has_chat_template(model_tokenizer):
|
||||
raise RuntimeError(
|
||||
f"Input model ({self.model_name_or_path})'s tokenizer ({model_tokenizer.__name__}) has no chat template from associated `tokenizer_config.json`"
|
||||
)
|
||||
cprint("Validated model tokenizer", color="cyan")
|
||||
|
||||
try:
|
||||
_ = model_tokenizer.apply_chat_template(self.loaded_dataset[0]["messages"])
|
||||
|
|
@ -198,6 +211,8 @@ class FullPrecisionFineTuning:
|
|||
)
|
||||
raise
|
||||
|
||||
cprint("Model tokenizer applied template to row.", color="cyan")
|
||||
|
||||
# Success! Preflight checks haven't caught any immediate problems.
|
||||
set_status_callback(JobStatus.scheduled)
|
||||
|
||||
|
|
@ -221,7 +236,6 @@ class FullPrecisionFineTuning:
|
|||
Returns:
|
||||
dict[str, list[int]]: Of shape {input_ids, labels, attention_mask, loss_mask}
|
||||
"""
|
||||
|
||||
input_ids = tokenizer.apply_chat_template(conversation=sample, tokenize=True)
|
||||
input_ids = typing.cast(
|
||||
list[int], input_ids
|
||||
|
|
@ -243,10 +257,17 @@ class FullPrecisionFineTuning:
|
|||
"""
|
||||
|
||||
dataset = datasets.Dataset.from_list(self.loaded_dataset)
|
||||
cprint(f"Dataset loaded. Example row: ({dataset[0]})", color="cyan")
|
||||
model_tok = self.__try_load_tokenizer()
|
||||
cprint("Tokenizer loaded.", color="cyan")
|
||||
|
||||
# NOTE: not implementing as batched for the moment; need to know how batching impacts memory usage on machine.
|
||||
dataset = dataset.map(lambda x: self.__tokenize_and_generate_labels_and_mask(tokenizer=model_tok, sample=x))
|
||||
dataset = dataset.map(
|
||||
lambda x: self.__tokenize_and_generate_labels_and_mask(
|
||||
tokenizer=model_tok, sample=x["messages"]
|
||||
) # TODO: get this key from input dataset schema
|
||||
)
|
||||
dataset = dataset.remove_columns(column_names=["messages"])
|
||||
return dataset
|
||||
|
||||
def setup(self):
|
||||
|
|
@ -267,8 +288,8 @@ class FullPrecisionFineTuning:
|
|||
set_subproc_ref_callback (Callable[[subprocess.Process], None]): Sets subprocess reference in 'Impl' class' ref to this job
|
||||
"""
|
||||
|
||||
training_subproc = await asyncio.create_subprocess_exec(
|
||||
"echo 'yay Im running in a subprocess: $$'; sleep 30; echo 'exiting process $$'"
|
||||
training_subproc = await asyncio.create_subprocess_shell(
|
||||
'echo "yay Im running in a subprocess: $$"; sleep 5; echo "exiting subprocess $$"'
|
||||
)
|
||||
set_subproc_ref_callback(training_subproc)
|
||||
await training_subproc.wait()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue