DylanJHJ commited on
Commit
0b289a7
1 Parent(s): d1f77b7

add modifying codes

Browse files
Files changed (1) hide show
  1. remake_fidt5.py +29 -0
remake_fidt5.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ # from models import FiDT5_meta
3
+ from transformers import T5ForConditionalGeneration
4
+
5
+ # load fid model checkppints
6
+ model_old = torch.load('models/ckpt/fidt5-base-nq/pytorch_model.bin', map_location='cpu')
7
+
8
+ model_new = T5ForConditionalGeneration.from_pretrained('t5-base')
9
+
10
+ # compare state dict
11
+ model_new_keys = sorted(list(model_new.state_dict().keys()))
12
+ model_old_keys = sorted(list(model_old.keys()))
13
+
14
+ # change key map
15
+ for k in model_old_keys:
16
+ k_prime = k.replace('encoder.encoder', 'encoder')
17
+ k_prime = k_prime.replace('module.layer', 'layer')
18
+ model_old[k_prime] = model_old.pop(k)
19
+
20
+ # validate if the old keys align the new one
21
+ # model_old_keys = sorted(list(model_old.keys()))
22
+ #
23
+ # for i, k in enumerate(model_new_keys):
24
+ # if k not in model_old_keys:
25
+ # print(model_old_keys[i])
26
+ # print(k)
27
+
28
+ # save as the new checkpoint
29
+ torch.save(model_old, '/home/jhju/models/fidt5-base-nq/pytorch_model.bin')