From 93e283daeeb65ef0b64763ad13e6644bc019287c Mon Sep 17 00:00:00 2001 From: 0x000011b <0x000011b@waifu.club> Date: Mon, 26 Dec 2022 20:43:01 -0300 Subject: [PATCH] feat: implement utility to convert ColossalAI checkpoints to HF pre-trained model --- training/convert_to_hf.py | 61 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 61 insertions(+) create mode 100644 training/convert_to_hf.py diff --git a/training/convert_to_hf.py b/training/convert_to_hf.py new file mode 100644 index 0000000..c9984a0 --- /dev/null +++ b/training/convert_to_hf.py @@ -0,0 +1,61 @@ +#!/usr/bin/env python3 + +# Utility to convert ColossalAI checkpoints to a HuggingFace pre-trained model. + +import argparse +import logging + +import transformers +import torch + +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO) + + +def main() -> None: + args = _parse_args_from_argv() + model = _build_model(args) + + output_dir = args.output_dir + logger.info("Saving pre-trained HF model to `%s`...", output_dir) + model.save_pretrained(output_dir) + + +def _parse_args_from_argv() -> argparse.Namespace: + parser = argparse.ArgumentParser() + parser.add_argument( + "-m", + "--model-name", + default="EleutherAI/pythia-1.3b-deduped", + help="HuggingFace Transformers base model name.", + ) + parser.add_argument( + "-c", + "--checkpoint", + help="Fine-tune checkpoint to load into the base model.", + required=True, + ) + parser.add_argument( + "-o", + "--output-dir", + help="Name of the output folder to save the pre-trained HF model to.", + required=True, + ) + + return parser.parse_args() + + +def _build_model(args: argparse.Namespace) -> transformers.AutoModelForCausalLM: + logger.info(f"Loading checkpoint from `{args.checkpoint}`") + state_dict = torch.load(args.checkpoint, map_location="cuda").pop("model") + + logger.info(f"Loading the `{args.model_name}` model") + model = transformers.AutoModelForCausalLM.from_pretrained( + args.model_name, state_dict=state_dict) + model.eval().half() # .to("cuda") + + return model + + +if __name__ == "__main__": + main()