File size: 2,418 Bytes
6f49966
 
b5baf02
 
 
 
 
2b6c2bd
 
 
 
 
 
b5baf02
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6f49966
 
b5baf02
 
 
 
 
2b6c2bd
 
b5baf02
 
 
 
 
 
 
6f49966
b5baf02
 
 
 
 
 
6f49966
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import gradio as gr
import numpy as np
from pathlib import Path
from matplotlib import pyplot as plt
import torch
import tempfile
import os
from sam_segment import predict_masks_with_sam
from lama_inpaint import inpaint_img_with_lama
from utils import load_img_to_array, save_array_to_img, dilate_mask, \
    show_mask, show_points


def mkstemp(suffix, dir=None):
    fd, path = tempfile.mkstemp(suffix=f"{suffix}", dir=dir)
    os.close(fd)
    return Path(path)


def get_masked_img(img, point_coords):
    point_labels = [1]
    dilate_kernel_size = 15
    device = "cuda" if torch.cuda.is_available() else "cpu"
    masks, _, _ = predict_masks_with_sam(
        img,
        [point_coords],
        point_labels,
        model_type="vit_h",
        ckpt_p="pretrained_models/sam_vit_h_4b8939.pth",
        device=device,
    )
    masks = masks.astype(np.uint8) * 255

    # dilate mask to avoid unmasked edge effect
    if dilate_kernel_size is not None:
        masks = [dilate_mask(mask, dilate_kernel_size) for mask in masks]

    figs = []
    for idx, mask in enumerate(masks):
        # save the pointed and masked image
        tmp_p = mkstemp(".png")
        dpi = plt.rcParams['figure.dpi']
        height, width = img.shape[:2]
        fig = plt.figure(figsize=(width/dpi/0.77, height/dpi/0.77))
        plt.imshow(img)
        plt.axis('off')
        # show_points(plt.gca(), [point_coords], point_labels,
        #             size=(width*0.04)**2)
        # plt.savefig(tmp_p, bbox_inches='tight', pad_inches=0)
        show_mask(plt.gca(), mask, random_color=False)
        plt.savefig(tmp_p, bbox_inches='tight', pad_inches=0)
        figs.append(fig)
        plt.close()
    return figs



with gr.Blocks() as demo:
    with gr.Row():
        img = gr.Image(label="Image")
        with gr.Row(label="Image with Segmentation Mask"):
            img_with_mask_0 = gr.Plot()
            img_with_mask_1 = gr.Plot()
            img_with_mask_2 = gr.Plot()
    with gr.Row():
        w = gr.Number()
        h = gr.Number()

    predict_mask = gr.Button("Predict Mask Using SAM")


    def get_select_coords(evt: gr.SelectData):
        return evt.index[0], evt.index[1]

    img.select(get_select_coords, [], [w, h])
    predict_mask.click(
        get_masked_img,
        [img, [w, h]],
        [img_with_mask_0, img_with_mask_1, img_with_mask_2]
    )


if __name__ == "__main__":
    demo.launch()