feat: allow fine-tuning of non-OPT LMs
This commit is contained in:
parent
25ae9da046
commit
30ff3751de
|
@ -55,13 +55,23 @@ from transformers import (
|
|||
AutoConfig,
|
||||
AutoTokenizer,
|
||||
GPT2Tokenizer,
|
||||
OPTForCausalLM,
|
||||
AutoModelForCausalLM,
|
||||
SchedulerType,
|
||||
default_data_collator,
|
||||
get_scheduler,
|
||||
)
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
# Explanation: "AutoModelForCausalLM" will instantiate the proper subclass after
|
||||
# ColossalAI has attempted to do a bunch of meta-programming trickery, so it
|
||||
# crashes due to missing attributes. To work around that, we need to import the
|
||||
# subclass - even if we don't use it - so ColossalAI properly patches the inner
|
||||
# modules.
|
||||
from transformers import (
|
||||
OPTForCausalLM,
|
||||
BloomForCausalLM,
|
||||
)
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
||||
|
||||
MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys())
|
||||
|
@ -373,11 +383,11 @@ def main():
|
|||
# we can not import it until huggingface fix it
|
||||
logger.info("Train a new model from scratch", ranks=[0])
|
||||
with ColoInitContext(device=init_dev):
|
||||
model = OPTForCausalLM(config)
|
||||
model = AutoModelForCausalLM(config)
|
||||
else:
|
||||
logger.info("Finetune a pre-trained model", ranks=[0])
|
||||
with ColoInitContext(device=init_dev):
|
||||
model = OPTForCausalLM.from_pretrained(args.model_name_or_path,
|
||||
model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path,
|
||||
from_tf=bool(".ckpt" in args.model_name_or_path),
|
||||
config=config,
|
||||
local_files_only=False)
|
||||
|
|
Loading…
Reference in New Issue