from transformers import SamConfig, SamModel, SamProcessor, SamImageProcessor | |
from transformers.models.sam.convert_sam_original_to_hf_format import replace_keys | |
from segment_anything import sam_model_registry # pip install git+https://github.com/facebookresearch/segment-anything.git | |
# load the MedSAM ViT-B model | |
checkpoint = 'medsam_vit_b.pth' # https://drive.google.com/file/d/1UAmWL88roYR7wKlnApw5Bcuzf2iQgk6_/view?usp=drive_link | |
pt_model = sam_model_registry['vit_b'](checkpoint) | |
pt_state_dict = pt_model.state_dict() | |
# tweak the model's weights to transformers design | |
hf_state_dict = replace_keys(pt_state_dict) | |
# save the model | |
hf_model = SamModel(config=SamConfig()) | |
hf_model.load_state_dict(hf_state_dict) | |
hf_model.save_pretrained('./') | |
# update the processor, inputs are min-max scaled instead of normalized | |
hf_processor = SamProcessor( | |
image_processor=SamImageProcessor( | |
do_normalize=False, | |
image_mean=[0, 0, 0], | |
image_std=[1, 1, 1], | |
) | |
) | |
# save the processor | |
hf_processor.save_pretrained('./') | |