diff --git a/training/colossalai/run_clm.py b/training/colossalai/run_clm.py index f633974..4c76c4b 100644 --- a/training/colossalai/run_clm.py +++ b/training/colossalai/run_clm.py @@ -22,6 +22,7 @@ https://huggingface.co/models?filter=text-generation """ # You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments. +import datetime import math import os import time @@ -404,6 +405,8 @@ def main(): logger.info(f'{model.__class__.__name__} has been created', ranks=[0]) if args.resume_from_checkpoint is not None: + # FIXME(11b): Implement this properly. Need to save/restore all the other + # state as well (optimizer, LR scheduler, dataloader position via step counter...) logger.info(f"Resuming from checkpoint {args.resume_from_checkpoint}", ranks=[0]) colossalai.utils.load_checkpoint(args.resume_from_checkpoint, model) @@ -538,8 +541,11 @@ def main(): logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}", ranks=[0]) logger.info(f" Total optimization steps = {args.max_train_steps}", ranks=[0]) + now = datetime.datetime.now() + run_name = now.strftime("%Y-%m-%dT_%H-%M-%S%z") + writer = torch.utils.tensorboard.SummaryWriter(log_dir=f"{args.output_dir}/runs/{run_name}", comment=args.comment) + # Only show the progress bar once on each machine. - writer = torch.utils.tensorboard.SummaryWriter(log_dir=f"{args.output_dir}/runs", comment=args.comment) progress_bar = tqdm(range(args.max_train_steps), disable=not is_main_process) completed_steps = 0 starting_epoch = 0 @@ -603,8 +609,10 @@ def main(): perplexity = float("inf") logger.info(f"Epoch {epoch}: perplexity: {perplexity} eval_loss: {eval_loss}", ranks=[0]) - writer.add_scalar("Eval/Loss (Epoch)", eval_loss, epoch) - writer.add_scalar("Eval/Perplexity (Epoch)", perplexity, epoch) + # TODO(11b): This messes up the intra-epoch graphs. Apparently I need to + # read up on the Tensorboard docs to do this properly. Ignoring for now. + # writer.add_scalar("Eval/Loss (Global Step)", eval_loss, completed_steps) + # writer.add_scalar("Eval/Perplexity (Global Step)", perplexity, completed_steps) if args.output_dir is not None and args.checkpointing_steps == "epoch": checkpoint_path = f'{args.output_dir}/epoch_{epoch}_step_{completed_steps}.pt'