feat: add support for fine-tuning GPT-NeoX-based models, save optimizer and LR scheduler to checkpoint

This commit is contained in:
11b 2022-12-25 15:42:59 -03:00
parent 186df60691
commit 4f794489ac
1 changed files with 44 additions and 10 deletions

View File

@ -25,6 +25,7 @@ https://huggingface.co/models?filter=text-generation
import datetime
import math
import os
import signal
import time
from itertools import chain
@ -68,8 +69,9 @@ from transformers.utils.versions import require_version
# subclass - even if we don't use it - so ColossalAI properly patches the inner
# modules.
from transformers import (
OPTForCausalLM,
BloomForCausalLM,
OPTForCausalLM,
GPTNeoXForCausalLM,
)
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
@ -414,12 +416,6 @@ def main():
logger.info(f'{model.__class__.__name__} has been created', ranks=[0])
if args.resume_from_checkpoint is not None:
# FIXME(11b): Implement this properly. Need to save/restore all the other
# state as well (optimizer, LR scheduler, dataloader position via step counter...)
logger.info(f"Resuming from checkpoint {args.resume_from_checkpoint}", ranks=[0])
colossalai.utils.load_checkpoint(args.resume_from_checkpoint, model)
# Preprocessing the datasets.
# First we tokenize all the texts.
column_names = raw_datasets["train"].column_names
@ -533,6 +529,12 @@ def main():
num_training_steps=args.max_train_steps,
)
if args.resume_from_checkpoint is not None:
# FIXME(11b): Implement this properly. Need to save/restore all the other
# state as well (optimizer, LR scheduler, dataloader position via step counter...)
logger.info(f"Resuming from checkpoint {args.resume_from_checkpoint}", ranks=[0])
colossalai.utils.load_checkpoint(args.resume_from_checkpoint, model, optimizer, lr_scheduler)
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if overrode_max_train_steps:
@ -561,6 +563,10 @@ def main():
starting_epoch = 0
global_step = 0
# FIXME(11b): One needs to manually update this when resuming from a
# checkpoint. Not ideal.
step_from_checkpoint = 0
for epoch in range(starting_epoch, args.num_train_epochs):
if completed_steps >= args.max_train_steps:
@ -568,6 +574,11 @@ def main():
model.train()
for step, batch in enumerate(train_dataloader):
if step < step_from_checkpoint:
completed_steps += 1
global_step += 1
continue
batch = {k: v.cuda() for k, v in batch.items()}
outputs = model(use_cache=False, **batch) # Caching is incompatible with gradient checkpointing.
loss = outputs['loss']
@ -594,12 +605,35 @@ def main():
if args.checkpointing_steps != "epoch" and completed_steps % int(args.checkpointing_steps) == 0:
checkpoint_path = f'{args.output_dir}/epoch_{epoch}_step_{completed_steps}.pt'
logger.info(f" Saving iter checkpoint...", ranks=[0])
save_checkpoint(checkpoint_path, epoch, model)
save_checkpoint(checkpoint_path, epoch, model, optimizer, lr_scheduler)
logger.info(f" Saved checkpoint to {checkpoint_path}!", ranks=[0])
if completed_steps % (int(args.checkpointing_steps) * 8) == 0:
# Evaluate every X checkpoints.
model.eval()
losses = []
for step, batch in enumerate(eval_dataloader):
with torch.no_grad():
batch = {k: v.cuda() for k, v in batch.items()}
outputs = model(**batch)
loss = outputs['loss'].unsqueeze(0)
losses.append(loss)
losses = torch.cat(losses)
losses = losses[:len(eval_dataset)]
try:
eval_loss = torch.mean(losses)
perplexity = math.exp(eval_loss)
except OverflowError:
perplexity = float("inf")
logger.info(f"Step {global_step}: perplexity: {perplexity} eval_loss: {eval_loss}", ranks=[0])
model.train()
if completed_steps >= args.max_train_steps:
break
# Evaluate per epoch.
model.eval()
losses = []
for step, batch in enumerate(eval_dataloader):
@ -627,13 +661,13 @@ def main():
if args.output_dir is not None and args.checkpointing_steps == "epoch":
checkpoint_path = f'{args.output_dir}/epoch_{epoch}_step_{completed_steps}.pt'
logger.info(f" Saving epoch checkpoint...", ranks=[0])
save_checkpoint(checkpoint_path, epoch, model)
save_checkpoint(checkpoint_path, epoch, model, optimizer, lr_scheduler)
logger.info(f" Saved checkpoint to {checkpoint_path}!", ranks=[0])
if args.output_dir is not None:
checkpoint_path = f'{args.output_dir}/epoch_{epoch}_step_{completed_steps}.pt'
logger.info(f" Saving final checkpoint...", ranks=[0])
save_checkpoint(checkpoint_path, epoch, model)
save_checkpoint(checkpoint_path, epoch, model, optimizer, lr_scheduler)
logger.info(f" Saved checkpoint to {checkpoint_path}!", ranks=[0])
logger.info("Training finished", ranks=[0])