feat: allow fine-tuning of non-OPT LMs
This commit is contained in:
parent
25ae9da046
commit
30ff3751de
|
@ -55,13 +55,23 @@ from transformers import (
|
||||||
AutoConfig,
|
AutoConfig,
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
GPT2Tokenizer,
|
GPT2Tokenizer,
|
||||||
OPTForCausalLM,
|
AutoModelForCausalLM,
|
||||||
SchedulerType,
|
SchedulerType,
|
||||||
default_data_collator,
|
default_data_collator,
|
||||||
get_scheduler,
|
get_scheduler,
|
||||||
)
|
)
|
||||||
from transformers.utils.versions import require_version
|
from transformers.utils.versions import require_version
|
||||||
|
|
||||||
|
# Explanation: "AutoModelForCausalLM" will instantiate the proper subclass after
|
||||||
|
# ColossalAI has attempted to do a bunch of meta-programming trickery, so it
|
||||||
|
# crashes due to missing attributes. To work around that, we need to import the
|
||||||
|
# subclass - even if we don't use it - so ColossalAI properly patches the inner
|
||||||
|
# modules.
|
||||||
|
from transformers import (
|
||||||
|
OPTForCausalLM,
|
||||||
|
BloomForCausalLM,
|
||||||
|
)
|
||||||
|
|
||||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
||||||
|
|
||||||
MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys())
|
MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys())
|
||||||
|
@ -373,11 +383,11 @@ def main():
|
||||||
# we can not import it until huggingface fix it
|
# we can not import it until huggingface fix it
|
||||||
logger.info("Train a new model from scratch", ranks=[0])
|
logger.info("Train a new model from scratch", ranks=[0])
|
||||||
with ColoInitContext(device=init_dev):
|
with ColoInitContext(device=init_dev):
|
||||||
model = OPTForCausalLM(config)
|
model = AutoModelForCausalLM(config)
|
||||||
else:
|
else:
|
||||||
logger.info("Finetune a pre-trained model", ranks=[0])
|
logger.info("Finetune a pre-trained model", ranks=[0])
|
||||||
with ColoInitContext(device=init_dev):
|
with ColoInitContext(device=init_dev):
|
||||||
model = OPTForCausalLM.from_pretrained(args.model_name_or_path,
|
model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path,
|
||||||
from_tf=bool(".ckpt" in args.model_name_or_path),
|
from_tf=bool(".ckpt" in args.model_name_or_path),
|
||||||
config=config,
|
config=config,
|
||||||
local_files_only=False)
|
local_files_only=False)
|
||||||
|
|
Loading…
Reference in New Issue