feat: minor improvements to the fine-tune script
This commit is contained in:
parent
925f5767ec
commit
b0d2d80ac3
|
@ -22,6 +22,7 @@ https://huggingface.co/models?filter=text-generation
|
|||
"""
|
||||
# You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments.
|
||||
|
||||
import datetime
|
||||
import math
|
||||
import os
|
||||
import time
|
||||
|
@ -404,6 +405,8 @@ def main():
|
|||
logger.info(f'{model.__class__.__name__} has been created', ranks=[0])
|
||||
|
||||
if args.resume_from_checkpoint is not None:
|
||||
# FIXME(11b): Implement this properly. Need to save/restore all the other
|
||||
# state as well (optimizer, LR scheduler, dataloader position via step counter...)
|
||||
logger.info(f"Resuming from checkpoint {args.resume_from_checkpoint}", ranks=[0])
|
||||
colossalai.utils.load_checkpoint(args.resume_from_checkpoint, model)
|
||||
|
||||
|
@ -538,8 +541,11 @@ def main():
|
|||
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}", ranks=[0])
|
||||
logger.info(f" Total optimization steps = {args.max_train_steps}", ranks=[0])
|
||||
|
||||
now = datetime.datetime.now()
|
||||
run_name = now.strftime("%Y-%m-%dT_%H-%M-%S%z")
|
||||
writer = torch.utils.tensorboard.SummaryWriter(log_dir=f"{args.output_dir}/runs/{run_name}", comment=args.comment)
|
||||
|
||||
# Only show the progress bar once on each machine.
|
||||
writer = torch.utils.tensorboard.SummaryWriter(log_dir=f"{args.output_dir}/runs", comment=args.comment)
|
||||
progress_bar = tqdm(range(args.max_train_steps), disable=not is_main_process)
|
||||
completed_steps = 0
|
||||
starting_epoch = 0
|
||||
|
@ -603,8 +609,10 @@ def main():
|
|||
perplexity = float("inf")
|
||||
|
||||
logger.info(f"Epoch {epoch}: perplexity: {perplexity} eval_loss: {eval_loss}", ranks=[0])
|
||||
writer.add_scalar("Eval/Loss (Epoch)", eval_loss, epoch)
|
||||
writer.add_scalar("Eval/Perplexity (Epoch)", perplexity, epoch)
|
||||
# TODO(11b): This messes up the intra-epoch graphs. Apparently I need to
|
||||
# read up on the Tensorboard docs to do this properly. Ignoring for now.
|
||||
# writer.add_scalar("Eval/Loss (Global Step)", eval_loss, completed_steps)
|
||||
# writer.add_scalar("Eval/Perplexity (Global Step)", perplexity, completed_steps)
|
||||
|
||||
if args.output_dir is not None and args.checkpointing_steps == "epoch":
|
||||
checkpoint_path = f'{args.output_dir}/epoch_{epoch}_step_{completed_steps}.pt'
|
||||
|
|
Loading…
Reference in New Issue