feat: minor improvements to the fine-tune script
This commit is contained in:
parent
925f5767ec
commit
b0d2d80ac3
|
@ -22,6 +22,7 @@ https://huggingface.co/models?filter=text-generation
|
||||||
"""
|
"""
|
||||||
# You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments.
|
# You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments.
|
||||||
|
|
||||||
|
import datetime
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
@ -404,6 +405,8 @@ def main():
|
||||||
logger.info(f'{model.__class__.__name__} has been created', ranks=[0])
|
logger.info(f'{model.__class__.__name__} has been created', ranks=[0])
|
||||||
|
|
||||||
if args.resume_from_checkpoint is not None:
|
if args.resume_from_checkpoint is not None:
|
||||||
|
# FIXME(11b): Implement this properly. Need to save/restore all the other
|
||||||
|
# state as well (optimizer, LR scheduler, dataloader position via step counter...)
|
||||||
logger.info(f"Resuming from checkpoint {args.resume_from_checkpoint}", ranks=[0])
|
logger.info(f"Resuming from checkpoint {args.resume_from_checkpoint}", ranks=[0])
|
||||||
colossalai.utils.load_checkpoint(args.resume_from_checkpoint, model)
|
colossalai.utils.load_checkpoint(args.resume_from_checkpoint, model)
|
||||||
|
|
||||||
|
@ -538,8 +541,11 @@ def main():
|
||||||
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}", ranks=[0])
|
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}", ranks=[0])
|
||||||
logger.info(f" Total optimization steps = {args.max_train_steps}", ranks=[0])
|
logger.info(f" Total optimization steps = {args.max_train_steps}", ranks=[0])
|
||||||
|
|
||||||
|
now = datetime.datetime.now()
|
||||||
|
run_name = now.strftime("%Y-%m-%dT_%H-%M-%S%z")
|
||||||
|
writer = torch.utils.tensorboard.SummaryWriter(log_dir=f"{args.output_dir}/runs/{run_name}", comment=args.comment)
|
||||||
|
|
||||||
# Only show the progress bar once on each machine.
|
# Only show the progress bar once on each machine.
|
||||||
writer = torch.utils.tensorboard.SummaryWriter(log_dir=f"{args.output_dir}/runs", comment=args.comment)
|
|
||||||
progress_bar = tqdm(range(args.max_train_steps), disable=not is_main_process)
|
progress_bar = tqdm(range(args.max_train_steps), disable=not is_main_process)
|
||||||
completed_steps = 0
|
completed_steps = 0
|
||||||
starting_epoch = 0
|
starting_epoch = 0
|
||||||
|
@ -603,8 +609,10 @@ def main():
|
||||||
perplexity = float("inf")
|
perplexity = float("inf")
|
||||||
|
|
||||||
logger.info(f"Epoch {epoch}: perplexity: {perplexity} eval_loss: {eval_loss}", ranks=[0])
|
logger.info(f"Epoch {epoch}: perplexity: {perplexity} eval_loss: {eval_loss}", ranks=[0])
|
||||||
writer.add_scalar("Eval/Loss (Epoch)", eval_loss, epoch)
|
# TODO(11b): This messes up the intra-epoch graphs. Apparently I need to
|
||||||
writer.add_scalar("Eval/Perplexity (Epoch)", perplexity, epoch)
|
# read up on the Tensorboard docs to do this properly. Ignoring for now.
|
||||||
|
# writer.add_scalar("Eval/Loss (Global Step)", eval_loss, completed_steps)
|
||||||
|
# writer.add_scalar("Eval/Perplexity (Global Step)", perplexity, completed_steps)
|
||||||
|
|
||||||
if args.output_dir is not None and args.checkpointing_steps == "epoch":
|
if args.output_dir is not None and args.checkpointing_steps == "epoch":
|
||||||
checkpoint_path = f'{args.output_dir}/epoch_{epoch}_step_{completed_steps}.pt'
|
checkpoint_path = f'{args.output_dir}/epoch_{epoch}_step_{completed_steps}.pt'
|
||||||
|
|
Loading…
Reference in New Issue