import torch from safetensors.torch import load_file, save_file model_files = ['model-00001-of-00003.safetensors', 'model-00002-of-00003.safetensors', 'model-00003-of-00003.safetensors'] merged_state_dict = {} for model_file in model_files: state_dict = load_file(model_file) for key, value in state_dict.items(): if key in merged_state_dict: merged_state_dict[key] += value else: merged_state_dict[key] = value torch.save(merged_state_dict, 'pytorch_model.bin')