feat: add support for fine-tuning GPT-NeoX-based models, save optimizer and LR scheduler to checkpoint
This commit is contained in:
parent
186df60691
commit
4f794489ac
|
@ -25,6 +25,7 @@ https://huggingface.co/models?filter=text-generation
|
|||
import datetime
|
||||
import math
|
||||
import os
|
||||
import signal
|
||||
import time
|
||||
from itertools import chain
|
||||
|
||||
|
@ -68,8 +69,9 @@ from transformers.utils.versions import require_version
|
|||
# subclass - even if we don't use it - so ColossalAI properly patches the inner
|
||||
# modules.
|
||||
from transformers import (
|
||||
OPTForCausalLM,
|
||||
BloomForCausalLM,
|
||||
OPTForCausalLM,
|
||||
GPTNeoXForCausalLM,
|
||||
)
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
||||
|
@ -414,12 +416,6 @@ def main():
|
|||
|
||||
logger.info(f'{model.__class__.__name__} has been created', ranks=[0])
|
||||
|
||||
if args.resume_from_checkpoint is not None:
|
||||
# FIXME(11b): Implement this properly. Need to save/restore all the other
|
||||
# state as well (optimizer, LR scheduler, dataloader position via step counter...)
|
||||
logger.info(f"Resuming from checkpoint {args.resume_from_checkpoint}", ranks=[0])
|
||||
colossalai.utils.load_checkpoint(args.resume_from_checkpoint, model)
|
||||
|
||||
# Preprocessing the datasets.
|
||||
# First we tokenize all the texts.
|
||||
column_names = raw_datasets["train"].column_names
|
||||
|
@ -533,6 +529,12 @@ def main():
|
|||
num_training_steps=args.max_train_steps,
|
||||
)
|
||||
|
||||
if args.resume_from_checkpoint is not None:
|
||||
# FIXME(11b): Implement this properly. Need to save/restore all the other
|
||||
# state as well (optimizer, LR scheduler, dataloader position via step counter...)
|
||||
logger.info(f"Resuming from checkpoint {args.resume_from_checkpoint}", ranks=[0])
|
||||
colossalai.utils.load_checkpoint(args.resume_from_checkpoint, model, optimizer, lr_scheduler)
|
||||
|
||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
if overrode_max_train_steps:
|
||||
|
@ -561,6 +563,10 @@ def main():
|
|||
starting_epoch = 0
|
||||
global_step = 0
|
||||
|
||||
# FIXME(11b): One needs to manually update this when resuming from a
|
||||
# checkpoint. Not ideal.
|
||||
step_from_checkpoint = 0
|
||||
|
||||
for epoch in range(starting_epoch, args.num_train_epochs):
|
||||
|
||||
if completed_steps >= args.max_train_steps:
|
||||
|
@ -568,6 +574,11 @@ def main():
|
|||
|
||||
model.train()
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
if step < step_from_checkpoint:
|
||||
completed_steps += 1
|
||||
global_step += 1
|
||||
continue
|
||||
|
||||
batch = {k: v.cuda() for k, v in batch.items()}
|
||||
outputs = model(use_cache=False, **batch) # Caching is incompatible with gradient checkpointing.
|
||||
loss = outputs['loss']
|
||||
|
@ -594,12 +605,35 @@ def main():
|
|||
if args.checkpointing_steps != "epoch" and completed_steps % int(args.checkpointing_steps) == 0:
|
||||
checkpoint_path = f'{args.output_dir}/epoch_{epoch}_step_{completed_steps}.pt'
|
||||
logger.info(f" Saving iter checkpoint...", ranks=[0])
|
||||
save_checkpoint(checkpoint_path, epoch, model)
|
||||
save_checkpoint(checkpoint_path, epoch, model, optimizer, lr_scheduler)
|
||||
logger.info(f" Saved checkpoint to {checkpoint_path}!", ranks=[0])
|
||||
|
||||
if completed_steps % (int(args.checkpointing_steps) * 8) == 0:
|
||||
# Evaluate every X checkpoints.
|
||||
model.eval()
|
||||
losses = []
|
||||
for step, batch in enumerate(eval_dataloader):
|
||||
with torch.no_grad():
|
||||
batch = {k: v.cuda() for k, v in batch.items()}
|
||||
outputs = model(**batch)
|
||||
|
||||
loss = outputs['loss'].unsqueeze(0)
|
||||
losses.append(loss)
|
||||
|
||||
losses = torch.cat(losses)
|
||||
losses = losses[:len(eval_dataset)]
|
||||
try:
|
||||
eval_loss = torch.mean(losses)
|
||||
perplexity = math.exp(eval_loss)
|
||||
except OverflowError:
|
||||
perplexity = float("inf")
|
||||
logger.info(f"Step {global_step}: perplexity: {perplexity} eval_loss: {eval_loss}", ranks=[0])
|
||||
model.train()
|
||||
|
||||
if completed_steps >= args.max_train_steps:
|
||||
break
|
||||
|
||||
# Evaluate per epoch.
|
||||
model.eval()
|
||||
losses = []
|
||||
for step, batch in enumerate(eval_dataloader):
|
||||
|
@ -627,13 +661,13 @@ def main():
|
|||
if args.output_dir is not None and args.checkpointing_steps == "epoch":
|
||||
checkpoint_path = f'{args.output_dir}/epoch_{epoch}_step_{completed_steps}.pt'
|
||||
logger.info(f" Saving epoch checkpoint...", ranks=[0])
|
||||
save_checkpoint(checkpoint_path, epoch, model)
|
||||
save_checkpoint(checkpoint_path, epoch, model, optimizer, lr_scheduler)
|
||||
logger.info(f" Saved checkpoint to {checkpoint_path}!", ranks=[0])
|
||||
|
||||
if args.output_dir is not None:
|
||||
checkpoint_path = f'{args.output_dir}/epoch_{epoch}_step_{completed_steps}.pt'
|
||||
logger.info(f" Saving final checkpoint...", ranks=[0])
|
||||
save_checkpoint(checkpoint_path, epoch, model)
|
||||
save_checkpoint(checkpoint_path, epoch, model, optimizer, lr_scheduler)
|
||||
logger.info(f" Saved checkpoint to {checkpoint_path}!", ranks=[0])
|
||||
|
||||
logger.info("Training finished", ranks=[0])
|
||||
|
|
Loading…
Reference in New Issue