Inference / simplify.py
nekomiro's picture
Duplicate from DIFF-SVCModel/Inference
79f7f06
raw
history blame
824 Bytes
from argparse import ArgumentParser
import torch
def simplify_pth(pth_name, project_name):
model_path = f'./checkpoints/{project_name}'
checkpoint_dict = torch.load(f'{model_path}/{pth_name}')
torch.save({'epoch': checkpoint_dict['epoch'],
'state_dict': checkpoint_dict['state_dict'],
'global_step': None,
'checkpoint_callback_best': None,
'optimizer_states': None,
'lr_schedulers': None
}, f'./clean_{pth_name}')
def main():
parser = ArgumentParser()
parser.add_argument('--proj', type=str)
parser.add_argument('--steps', type=str)
args = parser.parse_args()
model_name = f"model_ckpt_steps_{args.steps}.ckpt"
simplify_pth(model_name, args.proj)
if __name__ == '__main__':
main()