From 30ff3751de4e9a2d0c582745c10ab61c78c090f6 Mon Sep 17 00:00:00 2001 From: 0x000011b <0x000011b@waifu.club> Date: Sun, 18 Dec 2022 22:24:43 -0300 Subject: [PATCH] feat: allow fine-tuning of non-OPT LMs --- training/colossalai/run_clm.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/training/colossalai/run_clm.py b/training/colossalai/run_clm.py index b2e1bfe..bfa774b 100644 --- a/training/colossalai/run_clm.py +++ b/training/colossalai/run_clm.py @@ -55,13 +55,23 @@ from transformers import ( AutoConfig, AutoTokenizer, GPT2Tokenizer, - OPTForCausalLM, + AutoModelForCausalLM, SchedulerType, default_data_collator, get_scheduler, ) from transformers.utils.versions import require_version +# Explanation: "AutoModelForCausalLM" will instantiate the proper subclass after +# ColossalAI has attempted to do a bunch of meta-programming trickery, so it +# crashes due to missing attributes. To work around that, we need to import the +# subclass - even if we don't use it - so ColossalAI properly patches the inner +# modules. +from transformers import ( + OPTForCausalLM, + BloomForCausalLM, +) + require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys()) @@ -373,11 +383,11 @@ def main(): # we can not import it until huggingface fix it logger.info("Train a new model from scratch", ranks=[0]) with ColoInitContext(device=init_dev): - model = OPTForCausalLM(config) + model = AutoModelForCausalLM(config) else: logger.info("Finetune a pre-trained model", ranks=[0]) with ColoInitContext(device=init_dev): - model = OPTForCausalLM.from_pretrained(args.model_name_or_path, + model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, from_tf=bool(".ckpt" in args.model_name_or_path), config=config, local_files_only=False)