medsam-vit-base / scripts /pt_model.py
flaviagiammarino's picture
Create scripts/pt_model.py
5f9e8f9
raw
history blame
1.05 kB
from transformers import SamConfig, SamModel, SamProcessor, SamImageProcessor
from transformers.models.sam.convert_sam_original_to_hf_format import replace_keys
from segment_anything import sam_model_registry # pip install git+https://github.com/facebookresearch/segment-anything.git
# load the MedSAM ViT-B model
checkpoint = 'medsam_vit_b.pth' # https://drive.google.com/file/d/1UAmWL88roYR7wKlnApw5Bcuzf2iQgk6_/view?usp=drive_link
pt_model = sam_model_registry['vit_b'](checkpoint)
pt_state_dict = pt_model.state_dict()
# tweak the model's weights to transformers design
hf_state_dict = replace_keys(pt_state_dict)
# save the model
hf_model = SamModel(config=SamConfig())
hf_model.load_state_dict(hf_state_dict)
hf_model.save_pretrained('./')
# update the processor, inputs are min-max scaled instead of normalized
hf_processor = SamProcessor(
image_processor=SamImageProcessor(
do_normalize=False,
image_mean=[0, 0, 0],
image_std=[1, 1, 1],
)
)
# save the processor
hf_processor.save_pretrained('./')