From aebd405bbd8728e12cb8e16eabc0a8dcbfc42a63 Mon Sep 17 00:00:00 2001 From: 0x000011b <0x000011b@waifu.club> Date: Tue, 27 Dec 2022 13:21:20 -0300 Subject: [PATCH] feat: proper checkpoint resume in CLM fine-tune script --- training/colossalai/run_clm.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/training/colossalai/run_clm.py b/training/colossalai/run_clm.py index 609cccc..735ed87 100644 --- a/training/colossalai/run_clm.py +++ b/training/colossalai/run_clm.py @@ -25,6 +25,7 @@ https://huggingface.co/models?filter=text-generation import datetime import math import os +import re import signal import time from itertools import chain @@ -219,7 +220,7 @@ def parse_args(): default=None, help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.", ) - parser.add_argument( + parser.add_argument("-r", "--resume_from_checkpoint", type=str, default=None, @@ -563,9 +564,9 @@ def main(): starting_epoch = 0 global_step = 0 - # FIXME(11b): One needs to manually update this when resuming from a - # checkpoint. Not ideal. step_from_checkpoint = 0 + if args.resume_from_checkpoint is not None: + step_from_checkpoint = int(re.findall(r"epoch_\d+_step_(\d+).pt", args.resume_from_checkpoint)[0]) for epoch in range(starting_epoch, args.num_train_epochs): @@ -577,6 +578,13 @@ def main(): if step < step_from_checkpoint: completed_steps += 1 global_step += 1 + progress_bar.update(1) + progress_bar.refresh() + + # Apparently ColossalAI's checkpoint utilities don't work + # correctly for saving/restore the LR scheduler? So we "step" it + # manually here. + lr_scheduler.step() continue batch = {k: v.cuda() for k, v in batch.items()}