feat: proper checkpoint resume in CLM fine-tune script

This commit is contained in:
11b 2022-12-27 13:21:20 -03:00
parent e99277ec52
commit aebd405bbd
1 changed files with 11 additions and 3 deletions

View File

@ -25,6 +25,7 @@ https://huggingface.co/models?filter=text-generation
import datetime import datetime
import math import math
import os import os
import re
import signal import signal
import time import time
from itertools import chain from itertools import chain
@ -219,7 +220,7 @@ def parse_args():
default=None, default=None,
help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.", help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.",
) )
parser.add_argument( parser.add_argument("-r",
"--resume_from_checkpoint", "--resume_from_checkpoint",
type=str, type=str,
default=None, default=None,
@ -563,9 +564,9 @@ def main():
starting_epoch = 0 starting_epoch = 0
global_step = 0 global_step = 0
# FIXME(11b): One needs to manually update this when resuming from a
# checkpoint. Not ideal.
step_from_checkpoint = 0 step_from_checkpoint = 0
if args.resume_from_checkpoint is not None:
step_from_checkpoint = int(re.findall(r"epoch_\d+_step_(\d+).pt", args.resume_from_checkpoint)[0])
for epoch in range(starting_epoch, args.num_train_epochs): for epoch in range(starting_epoch, args.num_train_epochs):
@ -577,6 +578,13 @@ def main():
if step < step_from_checkpoint: if step < step_from_checkpoint:
completed_steps += 1 completed_steps += 1
global_step += 1 global_step += 1
progress_bar.update(1)
progress_bar.refresh()
# Apparently ColossalAI's checkpoint utilities don't work
# correctly for saving/restore the LR scheduler? So we "step" it
# manually here.
lr_scheduler.step()
continue continue
batch = {k: v.cuda() for k, v in batch.items()} batch = {k: v.cuda() for k, v in batch.items()}