versae commited on
Commit
346a10a
1 Parent(s): 3f4b8d4

Improved version of conversion script Flax → PyTorch

Browse files
Files changed (1) hide show
  1. convert.py +22 -6
convert.py CHANGED
@@ -1,13 +1,29 @@
 
 
 
1
  import jax
2
  from jax import numpy as jnp
3
- from transformers import FlaxRobertaForMaskedLM, RobertaForMaskedLM
 
4
 
5
  def to_f32(t):
6
  return jax.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, t)
7
 
8
- flax_model = FlaxRobertaForMaskedLM.from_pretrained("./")
9
- flax_model.params = to_f32(flax_model.params)
10
- flax_model.save_pretrained("./")
11
 
12
- model = RobertaForMaskedLM.from_pretrained("./", from_flax=True)
13
- model.save_pretrained("./", save_config=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ import tempfile
3
+
4
  import jax
5
  from jax import numpy as jnp
6
+ from transformers import AutoTokenizer, FlaxRobertaForMaskedLM, RobertaForMaskedLM
7
+
8
 
9
  def to_f32(t):
10
  return jax.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, t)
11
 
 
 
 
12
 
13
+ def main():
14
+ # Saving extra files from config.json and tokenizer.json files
15
+ tokenizer = AutoTokenizer.from_pretrained("./")
16
+ tokenizer.save_pretrained("./")
17
+
18
+ # Temporary saving bfloat16 Flax model into float32
19
+ tmp = tempfile.mkdtemp()
20
+ flax_model = FlaxRobertaForMaskedLM.from_pretrained("./")
21
+ flax_model.params = to_f32(flax_model.params)
22
+ flax_model.save_pretrained(tmp)
23
+ # Converting float32 Flax to PyTorch
24
+ model = RobertaForMaskedLM.from_pretrained(tmp, from_flax=True)
25
+ model.save_pretrained("./", save_config=False)
26
+
27
+
28
+ if __name__ == "__main__":
29
+ main()