feat: minor improvements to the fine-tune script

This commit is contained in:
11b 2022-12-18 17:25:36 -03:00
parent 925f5767ec
commit b0d2d80ac3
1 changed files with 11 additions and 3 deletions

View File

@ -22,6 +22,7 @@ https://huggingface.co/models?filter=text-generation
"""
# You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments.
import datetime
import math
import os
import time
@ -404,6 +405,8 @@ def main():
logger.info(f'{model.__class__.__name__} has been created', ranks=[0])
if args.resume_from_checkpoint is not None:
# FIXME(11b): Implement this properly. Need to save/restore all the other
# state as well (optimizer, LR scheduler, dataloader position via step counter...)
logger.info(f"Resuming from checkpoint {args.resume_from_checkpoint}", ranks=[0])
colossalai.utils.load_checkpoint(args.resume_from_checkpoint, model)
@ -538,8 +541,11 @@ def main():
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}", ranks=[0])
logger.info(f" Total optimization steps = {args.max_train_steps}", ranks=[0])
now = datetime.datetime.now()
run_name = now.strftime("%Y-%m-%dT_%H-%M-%S%z")
writer = torch.utils.tensorboard.SummaryWriter(log_dir=f"{args.output_dir}/runs/{run_name}", comment=args.comment)
# Only show the progress bar once on each machine.
writer = torch.utils.tensorboard.SummaryWriter(log_dir=f"{args.output_dir}/runs", comment=args.comment)
progress_bar = tqdm(range(args.max_train_steps), disable=not is_main_process)
completed_steps = 0
starting_epoch = 0
@ -603,8 +609,10 @@ def main():
perplexity = float("inf")
logger.info(f"Epoch {epoch}: perplexity: {perplexity} eval_loss: {eval_loss}", ranks=[0])
writer.add_scalar("Eval/Loss (Epoch)", eval_loss, epoch)
writer.add_scalar("Eval/Perplexity (Epoch)", perplexity, epoch)
# TODO(11b): This messes up the intra-epoch graphs. Apparently I need to
# read up on the Tensorboard docs to do this properly. Ignoring for now.
# writer.add_scalar("Eval/Loss (Global Step)", eval_loss, completed_steps)
# writer.add_scalar("Eval/Perplexity (Global Step)", perplexity, completed_steps)
if args.output_dir is not None and args.checkpointing_steps == "epoch":
checkpoint_path = f'{args.output_dir}/epoch_{epoch}_step_{completed_steps}.pt'