File size: 515 Bytes
8810cfa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 |
import torch
from safetensors.torch import load_file, save_file
model_files = ['model-00001-of-00003.safetensors', 'model-00002-of-00003.safetensors', 'model-00003-of-00003.safetensors']
merged_state_dict = {}
for model_file in model_files:
state_dict = load_file(model_file)
for key, value in state_dict.items():
if key in merged_state_dict:
merged_state_dict[key] += value
else:
merged_state_dict[key] = value
torch.save(merged_state_dict, 'pytorch_model.bin')
|