diff --git a/training/colossalai/run_clm.py b/training/colossalai/run_clm.py index bfa774b..72f5c8a 100644 --- a/training/colossalai/run_clm.py +++ b/training/colossalai/run_clm.py @@ -25,6 +25,7 @@ https://huggingface.co/models?filter=text-generation import datetime import math import os +import signal import time from itertools import chain @@ -68,8 +69,9 @@ from transformers.utils.versions import require_version # subclass - even if we don't use it - so ColossalAI properly patches the inner # modules. from transformers import ( - OPTForCausalLM, BloomForCausalLM, + OPTForCausalLM, + GPTNeoXForCausalLM, ) require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") @@ -414,12 +416,6 @@ def main(): logger.info(f'{model.__class__.__name__} has been created', ranks=[0]) - if args.resume_from_checkpoint is not None: - # FIXME(11b): Implement this properly. Need to save/restore all the other - # state as well (optimizer, LR scheduler, dataloader position via step counter...) - logger.info(f"Resuming from checkpoint {args.resume_from_checkpoint}", ranks=[0]) - colossalai.utils.load_checkpoint(args.resume_from_checkpoint, model) - # Preprocessing the datasets. # First we tokenize all the texts. column_names = raw_datasets["train"].column_names @@ -533,6 +529,12 @@ def main(): num_training_steps=args.max_train_steps, ) + if args.resume_from_checkpoint is not None: + # FIXME(11b): Implement this properly. Need to save/restore all the other + # state as well (optimizer, LR scheduler, dataloader position via step counter...) + logger.info(f"Resuming from checkpoint {args.resume_from_checkpoint}", ranks=[0]) + colossalai.utils.load_checkpoint(args.resume_from_checkpoint, model, optimizer, lr_scheduler) + # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) if overrode_max_train_steps: @@ -561,6 +563,10 @@ def main(): starting_epoch = 0 global_step = 0 + # FIXME(11b): One needs to manually update this when resuming from a + # checkpoint. Not ideal. + step_from_checkpoint = 0 + for epoch in range(starting_epoch, args.num_train_epochs): if completed_steps >= args.max_train_steps: @@ -568,6 +574,11 @@ def main(): model.train() for step, batch in enumerate(train_dataloader): + if step < step_from_checkpoint: + completed_steps += 1 + global_step += 1 + continue + batch = {k: v.cuda() for k, v in batch.items()} outputs = model(use_cache=False, **batch) # Caching is incompatible with gradient checkpointing. loss = outputs['loss'] @@ -594,12 +605,35 @@ def main(): if args.checkpointing_steps != "epoch" and completed_steps % int(args.checkpointing_steps) == 0: checkpoint_path = f'{args.output_dir}/epoch_{epoch}_step_{completed_steps}.pt' logger.info(f" Saving iter checkpoint...", ranks=[0]) - save_checkpoint(checkpoint_path, epoch, model) + save_checkpoint(checkpoint_path, epoch, model, optimizer, lr_scheduler) logger.info(f" Saved checkpoint to {checkpoint_path}!", ranks=[0]) + if completed_steps % (int(args.checkpointing_steps) * 8) == 0: + # Evaluate every X checkpoints. + model.eval() + losses = [] + for step, batch in enumerate(eval_dataloader): + with torch.no_grad(): + batch = {k: v.cuda() for k, v in batch.items()} + outputs = model(**batch) + + loss = outputs['loss'].unsqueeze(0) + losses.append(loss) + + losses = torch.cat(losses) + losses = losses[:len(eval_dataset)] + try: + eval_loss = torch.mean(losses) + perplexity = math.exp(eval_loss) + except OverflowError: + perplexity = float("inf") + logger.info(f"Step {global_step}: perplexity: {perplexity} eval_loss: {eval_loss}", ranks=[0]) + model.train() + if completed_steps >= args.max_train_steps: break + # Evaluate per epoch. model.eval() losses = [] for step, batch in enumerate(eval_dataloader): @@ -627,13 +661,13 @@ def main(): if args.output_dir is not None and args.checkpointing_steps == "epoch": checkpoint_path = f'{args.output_dir}/epoch_{epoch}_step_{completed_steps}.pt' logger.info(f" Saving epoch checkpoint...", ranks=[0]) - save_checkpoint(checkpoint_path, epoch, model) + save_checkpoint(checkpoint_path, epoch, model, optimizer, lr_scheduler) logger.info(f" Saved checkpoint to {checkpoint_path}!", ranks=[0]) if args.output_dir is not None: checkpoint_path = f'{args.output_dir}/epoch_{epoch}_step_{completed_steps}.pt' logger.info(f" Saving final checkpoint...", ranks=[0]) - save_checkpoint(checkpoint_path, epoch, model) + save_checkpoint(checkpoint_path, epoch, model, optimizer, lr_scheduler) logger.info(f" Saved checkpoint to {checkpoint_path}!", ranks=[0]) logger.info("Training finished", ranks=[0])