|
''' |
|
|
|
Converts a transformers model to a format compatible with flexgen. |
|
|
|
''' |
|
|
|
import argparse |
|
import os |
|
from pathlib import Path |
|
|
|
import numpy as np |
|
import torch |
|
from tqdm import tqdm |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog, max_help_position=54)) |
|
parser.add_argument('MODEL', type=str, default=None, nargs='?', help="Path to the input model.") |
|
args = parser.parse_args() |
|
|
|
|
|
def disable_torch_init(): |
|
""" |
|
Disable the redundant torch default initialization to accelerate model creation. |
|
""" |
|
import torch |
|
global torch_linear_init_backup |
|
global torch_layer_norm_init_backup |
|
|
|
torch_linear_init_backup = torch.nn.Linear.reset_parameters |
|
setattr(torch.nn.Linear, "reset_parameters", lambda self: None) |
|
|
|
torch_layer_norm_init_backup = torch.nn.LayerNorm.reset_parameters |
|
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) |
|
|
|
|
|
def restore_torch_init(): |
|
"""Rollback the change made by disable_torch_init.""" |
|
import torch |
|
setattr(torch.nn.Linear, "reset_parameters", torch_linear_init_backup) |
|
setattr(torch.nn.LayerNorm, "reset_parameters", torch_layer_norm_init_backup) |
|
|
|
|
|
if __name__ == '__main__': |
|
path = Path(args.MODEL) |
|
model_name = path.name |
|
|
|
print(f"Loading {model_name}...") |
|
|
|
model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch.float16, low_cpu_mem_usage=True) |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(path) |
|
|
|
out_folder = Path(f"models/{model_name}-np") |
|
if not Path(out_folder).exists(): |
|
os.mkdir(out_folder) |
|
|
|
print(f"Saving the converted model to {out_folder}...") |
|
for name, param in tqdm(list(model.model.named_parameters())): |
|
name = name.replace("decoder.final_layer_norm", "decoder.layer_norm") |
|
param_path = os.path.join(out_folder, name) |
|
with open(param_path, "wb") as f: |
|
np.save(f, param.cpu().detach().numpy()) |
|
|