feat: allow fine-tuning of non-OPT LMs

This commit is contained in:
11b 2022-12-18 22:24:43 -03:00
parent 25ae9da046
commit 30ff3751de
1 changed files with 13 additions and 3 deletions

View File

@ -55,13 +55,23 @@ from transformers import (
AutoConfig,
AutoTokenizer,
GPT2Tokenizer,
OPTForCausalLM,
AutoModelForCausalLM,
SchedulerType,
default_data_collator,
get_scheduler,
)
from transformers.utils.versions import require_version
# Explanation: "AutoModelForCausalLM" will instantiate the proper subclass after
# ColossalAI has attempted to do a bunch of meta-programming trickery, so it
# crashes due to missing attributes. To work around that, we need to import the
# subclass - even if we don't use it - so ColossalAI properly patches the inner
# modules.
from transformers import (
OPTForCausalLM,
BloomForCausalLM,
)
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys())
@ -373,11 +383,11 @@ def main():
# we can not import it until huggingface fix it
logger.info("Train a new model from scratch", ranks=[0])
with ColoInitContext(device=init_dev):
model = OPTForCausalLM(config)
model = AutoModelForCausalLM(config)
else:
logger.info("Finetune a pre-trained model", ranks=[0])
with ColoInitContext(device=init_dev):
model = OPTForCausalLM.from_pretrained(args.model_name_or_path,
model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path,
from_tf=bool(".ckpt" in args.model_name_or_path),
config=config,
local_files_only=False)