feat: proper checkpoint resume in CLM fine-tune script

This commit is contained in:
11b 2022-12-27 13:21:20 -03:00
parent e99277ec52
commit aebd405bbd
1 changed files with 11 additions and 3 deletions

View File

@ -25,6 +25,7 @@ https://huggingface.co/models?filter=text-generation
import datetime
import math
import os
import re
import signal
import time
from itertools import chain
@ -219,7 +220,7 @@ def parse_args():
default=None,
help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.",
)
parser.add_argument(
parser.add_argument("-r",
"--resume_from_checkpoint",
type=str,
default=None,
@ -563,9 +564,9 @@ def main():
starting_epoch = 0
global_step = 0
# FIXME(11b): One needs to manually update this when resuming from a
# checkpoint. Not ideal.
step_from_checkpoint = 0
if args.resume_from_checkpoint is not None:
step_from_checkpoint = int(re.findall(r"epoch_\d+_step_(\d+).pt", args.resume_from_checkpoint)[0])
for epoch in range(starting_epoch, args.num_train_epochs):
@ -577,6 +578,13 @@ def main():
if step < step_from_checkpoint:
completed_steps += 1
global_step += 1
progress_bar.update(1)
progress_bar.refresh()
# Apparently ColossalAI's checkpoint utilities don't work
# correctly for saving/restore the LR scheduler? So we "step" it
# manually here.
lr_scheduler.step()
continue
batch = {k: v.cuda() for k, v in batch.items()}