feat: proper checkpoint resume in CLM fine-tune script
This commit is contained in:
parent
e99277ec52
commit
aebd405bbd
|
@ -25,6 +25,7 @@ https://huggingface.co/models?filter=text-generation
|
|||
import datetime
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
import signal
|
||||
import time
|
||||
from itertools import chain
|
||||
|
@ -219,7 +220,7 @@ def parse_args():
|
|||
default=None,
|
||||
help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.",
|
||||
)
|
||||
parser.add_argument(
|
||||
parser.add_argument("-r",
|
||||
"--resume_from_checkpoint",
|
||||
type=str,
|
||||
default=None,
|
||||
|
@ -563,9 +564,9 @@ def main():
|
|||
starting_epoch = 0
|
||||
global_step = 0
|
||||
|
||||
# FIXME(11b): One needs to manually update this when resuming from a
|
||||
# checkpoint. Not ideal.
|
||||
step_from_checkpoint = 0
|
||||
if args.resume_from_checkpoint is not None:
|
||||
step_from_checkpoint = int(re.findall(r"epoch_\d+_step_(\d+).pt", args.resume_from_checkpoint)[0])
|
||||
|
||||
for epoch in range(starting_epoch, args.num_train_epochs):
|
||||
|
||||
|
@ -577,6 +578,13 @@ def main():
|
|||
if step < step_from_checkpoint:
|
||||
completed_steps += 1
|
||||
global_step += 1
|
||||
progress_bar.update(1)
|
||||
progress_bar.refresh()
|
||||
|
||||
# Apparently ColossalAI's checkpoint utilities don't work
|
||||
# correctly for saving/restore the LR scheduler? So we "step" it
|
||||
# manually here.
|
||||
lr_scheduler.step()
|
||||
continue
|
||||
|
||||
batch = {k: v.cuda() for k, v in batch.items()}
|
||||
|
|
Loading…
Reference in New Issue