feat: proper checkpoint resume in CLM fine-tune script
This commit is contained in:
parent
e99277ec52
commit
aebd405bbd
|
@ -25,6 +25,7 @@ https://huggingface.co/models?filter=text-generation
|
||||||
import datetime
|
import datetime
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
import signal
|
import signal
|
||||||
import time
|
import time
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
|
@ -219,7 +220,7 @@ def parse_args():
|
||||||
default=None,
|
default=None,
|
||||||
help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.",
|
help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument("-r",
|
||||||
"--resume_from_checkpoint",
|
"--resume_from_checkpoint",
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
|
@ -563,9 +564,9 @@ def main():
|
||||||
starting_epoch = 0
|
starting_epoch = 0
|
||||||
global_step = 0
|
global_step = 0
|
||||||
|
|
||||||
# FIXME(11b): One needs to manually update this when resuming from a
|
|
||||||
# checkpoint. Not ideal.
|
|
||||||
step_from_checkpoint = 0
|
step_from_checkpoint = 0
|
||||||
|
if args.resume_from_checkpoint is not None:
|
||||||
|
step_from_checkpoint = int(re.findall(r"epoch_\d+_step_(\d+).pt", args.resume_from_checkpoint)[0])
|
||||||
|
|
||||||
for epoch in range(starting_epoch, args.num_train_epochs):
|
for epoch in range(starting_epoch, args.num_train_epochs):
|
||||||
|
|
||||||
|
@ -577,6 +578,13 @@ def main():
|
||||||
if step < step_from_checkpoint:
|
if step < step_from_checkpoint:
|
||||||
completed_steps += 1
|
completed_steps += 1
|
||||||
global_step += 1
|
global_step += 1
|
||||||
|
progress_bar.update(1)
|
||||||
|
progress_bar.refresh()
|
||||||
|
|
||||||
|
# Apparently ColossalAI's checkpoint utilities don't work
|
||||||
|
# correctly for saving/restore the LR scheduler? So we "step" it
|
||||||
|
# manually here.
|
||||||
|
lr_scheduler.step()
|
||||||
continue
|
continue
|
||||||
|
|
||||||
batch = {k: v.cuda() for k, v in batch.items()}
|
batch = {k: v.cuda() for k, v in batch.items()}
|
||||||
|
|
Loading…
Reference in New Issue