feat: minor improvements to the fine-tune script

This commit is contained in:
11b 2022-12-18 17:25:36 -03:00
parent 925f5767ec
commit b0d2d80ac3
1 changed files with 11 additions and 3 deletions

View File

@ -22,6 +22,7 @@ https://huggingface.co/models?filter=text-generation
""" """
# You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments. # You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments.
import datetime
import math import math
import os import os
import time import time
@ -404,6 +405,8 @@ def main():
logger.info(f'{model.__class__.__name__} has been created', ranks=[0]) logger.info(f'{model.__class__.__name__} has been created', ranks=[0])
if args.resume_from_checkpoint is not None: if args.resume_from_checkpoint is not None:
# FIXME(11b): Implement this properly. Need to save/restore all the other
# state as well (optimizer, LR scheduler, dataloader position via step counter...)
logger.info(f"Resuming from checkpoint {args.resume_from_checkpoint}", ranks=[0]) logger.info(f"Resuming from checkpoint {args.resume_from_checkpoint}", ranks=[0])
colossalai.utils.load_checkpoint(args.resume_from_checkpoint, model) colossalai.utils.load_checkpoint(args.resume_from_checkpoint, model)
@ -538,8 +541,11 @@ def main():
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}", ranks=[0]) logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}", ranks=[0])
logger.info(f" Total optimization steps = {args.max_train_steps}", ranks=[0]) logger.info(f" Total optimization steps = {args.max_train_steps}", ranks=[0])
now = datetime.datetime.now()
run_name = now.strftime("%Y-%m-%dT_%H-%M-%S%z")
writer = torch.utils.tensorboard.SummaryWriter(log_dir=f"{args.output_dir}/runs/{run_name}", comment=args.comment)
# Only show the progress bar once on each machine. # Only show the progress bar once on each machine.
writer = torch.utils.tensorboard.SummaryWriter(log_dir=f"{args.output_dir}/runs", comment=args.comment)
progress_bar = tqdm(range(args.max_train_steps), disable=not is_main_process) progress_bar = tqdm(range(args.max_train_steps), disable=not is_main_process)
completed_steps = 0 completed_steps = 0
starting_epoch = 0 starting_epoch = 0
@ -603,8 +609,10 @@ def main():
perplexity = float("inf") perplexity = float("inf")
logger.info(f"Epoch {epoch}: perplexity: {perplexity} eval_loss: {eval_loss}", ranks=[0]) logger.info(f"Epoch {epoch}: perplexity: {perplexity} eval_loss: {eval_loss}", ranks=[0])
writer.add_scalar("Eval/Loss (Epoch)", eval_loss, epoch) # TODO(11b): This messes up the intra-epoch graphs. Apparently I need to
writer.add_scalar("Eval/Perplexity (Epoch)", perplexity, epoch) # read up on the Tensorboard docs to do this properly. Ignoring for now.
# writer.add_scalar("Eval/Loss (Global Step)", eval_loss, completed_steps)
# writer.add_scalar("Eval/Perplexity (Global Step)", perplexity, completed_steps)
if args.output_dir is not None and args.checkpointing_steps == "epoch": if args.output_dir is not None and args.checkpointing_steps == "epoch":
checkpoint_path = f'{args.output_dir}/epoch_{epoch}_step_{completed_steps}.pt' checkpoint_path = f'{args.output_dir}/epoch_{epoch}_step_{completed_steps}.pt'