feat: implement utility to convert ColossalAI checkpoints to HF pre-trained model
This commit is contained in:
parent
b79ac657a4
commit
93e283daee
|
@ -0,0 +1,61 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
# Utility to convert ColossalAI checkpoints to a HuggingFace pre-trained model.
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
|
||||
import transformers
|
||||
import torch
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = _parse_args_from_argv()
|
||||
model = _build_model(args)
|
||||
|
||||
output_dir = args.output_dir
|
||||
logger.info("Saving pre-trained HF model to `%s`...", output_dir)
|
||||
model.save_pretrained(output_dir)
|
||||
|
||||
|
||||
def _parse_args_from_argv() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"-m",
|
||||
"--model-name",
|
||||
default="EleutherAI/pythia-1.3b-deduped",
|
||||
help="HuggingFace Transformers base model name.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-c",
|
||||
"--checkpoint",
|
||||
help="Fine-tune checkpoint to load into the base model.",
|
||||
required=True,
|
||||
)
|
||||
parser.add_argument(
|
||||
"-o",
|
||||
"--output-dir",
|
||||
help="Name of the output folder to save the pre-trained HF model to.",
|
||||
required=True,
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def _build_model(args: argparse.Namespace) -> transformers.AutoModelForCausalLM:
|
||||
logger.info(f"Loading checkpoint from `{args.checkpoint}`")
|
||||
state_dict = torch.load(args.checkpoint, map_location="cuda").pop("model")
|
||||
|
||||
logger.info(f"Loading the `{args.model_name}` model")
|
||||
model = transformers.AutoModelForCausalLM.from_pretrained(
|
||||
args.model_name, state_dict=state_dict)
|
||||
model.eval().half() # .to("cuda")
|
||||
|
||||
return model
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
Reference in New Issue