vigpt2medium / flax_to_pt.py
imthanhlv's picture
add torch model
56d26ba
raw
history blame contribute delete
147 Bytes
from transformers import AutoModelForCausalLM
pt_model = AutoModelForCausalLM.from_pretrained('.', from_flax=True)
pt_model.save_pretrained(".")