С Чичерин commited on
Commit
ad54d7a
1 Parent(s): e248ffb

added gradio app

Browse files
Files changed (4) hide show
  1. app.py +20 -3
  2. logo.jpeg +0 -0
  3. models.py +481 -0
  4. test.py +1004 -0
app.py CHANGED
@@ -1,7 +1,24 @@
1
  import gradio as gr
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
  iface.launch()
 
1
  import gradio as gr
2
+ from test import inference_img
3
 
4
+ device='cuda'
5
+ model = MaskForm()
6
+ model = model.to(device)
7
+ checkpoint = f"stylematte_synth.pth"
8
+ state_dict = torch.load(checkpoint, map_location=f'{device}')
9
+
10
+ model.load_state_dict(state_dict)
11
+ model.eval()
12
+
13
+ def predict(inp):
14
+ res = inference_img(model, inp)
15
+
16
+ return res
17
+
18
+
19
+ gr.Interface(fn=predict,
20
+ inputs=gr.Image(type="numpy"),
21
+ outputs=gr.Image(type="numpy"),
22
+ examples=["./logo.jpeg"]).launch(share=True)
23
 
 
24
  iface.launch()
logo.jpeg ADDED
models.py ADDED
@@ -0,0 +1,481 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import random
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import numpy as np
7
+ from typing import List
8
+ from itertools import chain
9
+
10
+ from transformers import SegformerForSemanticSegmentation,Mask2FormerForUniversalSegmentation
11
+
12
+ class EncoderDecoder(nn.Module):
13
+ def __init__(
14
+ self,
15
+ encoder,
16
+ decoder,
17
+ prefix=nn.Conv2d(3, 3, kernel_size=3, padding=1, bias=True),
18
+ ):
19
+ super().__init__()
20
+ self.encoder = encoder
21
+ self.decoder = decoder
22
+ self.prefix = prefix
23
+
24
+ def forward(self, x):
25
+ if self.prefix is not None:
26
+ x = self.prefix(x)
27
+ x = self.encoder(x)["hidden_states"] #transformers
28
+ return self.decoder(x)
29
+
30
+
31
+ def conv2d_relu(input_filters,output_filters,kernel_size=3, bias=True):
32
+ return nn.Sequential(
33
+ nn.Conv2d(input_filters, output_filters, kernel_size=kernel_size, padding=kernel_size//2, bias=bias),
34
+ nn.LeakyReLU(0.2, inplace=True),
35
+ nn.BatchNorm2d(output_filters)
36
+ )
37
+
38
+ def up_and_add(x, y):
39
+ return F.interpolate(x, size=(y.size(2), y.size(3)), mode='bilinear', align_corners=True) + y
40
+
41
+ class FPN_fuse(nn.Module):
42
+ def __init__(self, feature_channels=[256, 512, 1024, 2048], fpn_out=256):
43
+ super(FPN_fuse, self).__init__()
44
+ assert feature_channels[0] == fpn_out
45
+ self.conv1x1 = nn.ModuleList([nn.Conv2d(ft_size, fpn_out, kernel_size=1)
46
+ for ft_size in feature_channels[1:]])
47
+ self.smooth_conv = nn.ModuleList([nn.Conv2d(fpn_out, fpn_out, kernel_size=3, padding=1)]
48
+ * (len(feature_channels)-1))
49
+ self.conv_fusion = nn.Sequential(
50
+ nn.Conv2d(2*fpn_out, fpn_out, kernel_size=3, padding=1, bias=False),
51
+ nn.BatchNorm2d(fpn_out),
52
+ nn.ReLU(inplace=True),
53
+ )
54
+
55
+ def forward(self, features):
56
+
57
+ features[:-1] = [conv1x1(feature) for feature, conv1x1 in zip(features[:-1], self.conv1x1)]##
58
+ feature=up_and_add(self.smooth_conv[0](features[0]),features[1])
59
+ feature=up_and_add(self.smooth_conv[1](feature),features[2])
60
+ feature=up_and_add(self.smooth_conv[2](feature),features[3])
61
+
62
+
63
+ H, W = features[-1].size(2), features[-1].size(3)
64
+ x = [feature,features[-1]]
65
+ x = [F.interpolate(x_el, size=(H, W), mode='bilinear', align_corners=True) for x_el in x]
66
+
67
+ x = self.conv_fusion(torch.cat(x, dim=1))
68
+ #x = F.interpolate(x, size=(H*4, W*4), mode='bilinear', align_corners=True)
69
+ return x
70
+
71
+ class PSPModule(nn.Module):
72
+ # In the original inmplementation they use precise RoI pooling
73
+ # Instead of using adaptative average pooling
74
+ def __init__(self, in_channels, bin_sizes=[1, 2, 4, 6]):
75
+ super(PSPModule, self).__init__()
76
+ out_channels = in_channels // len(bin_sizes)
77
+ self.stages = nn.ModuleList([self._make_stages(in_channels, out_channels, b_s)
78
+ for b_s in bin_sizes])
79
+ self.bottleneck = nn.Sequential(
80
+ nn.Conv2d(in_channels+(out_channels * len(bin_sizes)), in_channels,
81
+ kernel_size=3, padding=1, bias=False),
82
+ nn.BatchNorm2d(in_channels),
83
+ nn.ReLU(inplace=True),
84
+ nn.Dropout2d(0.1)
85
+ )
86
+
87
+ def _make_stages(self, in_channels, out_channels, bin_sz):
88
+ prior = nn.AdaptiveAvgPool2d(output_size=bin_sz)
89
+ conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
90
+ bn = nn.BatchNorm2d(out_channels)
91
+ relu = nn.ReLU(inplace=True)
92
+ return nn.Sequential(prior, conv, bn, relu)
93
+
94
+ def forward(self, features):
95
+ h, w = features.size()[2], features.size()[3]
96
+ pyramids = [features]
97
+ pyramids.extend([F.interpolate(stage(features), size=(h, w), mode='bilinear',
98
+ align_corners=True) for stage in self.stages])
99
+ output = self.bottleneck(torch.cat(pyramids, dim=1))
100
+ return output
101
+ class UperNet_swin(nn.Module):
102
+ # Implementing only the object path
103
+ def __init__(self, backbone,pretrained=True):
104
+ super(UperNet_swin, self).__init__()
105
+
106
+
107
+ self.backbone = backbone
108
+ feature_channels = [192,384,768,768]
109
+ self.PPN = PSPModule(feature_channels[-1])
110
+ self.FPN = FPN_fuse(feature_channels, fpn_out=feature_channels[0])
111
+ self.head = nn.Conv2d(feature_channels[0], 1, kernel_size=3, padding=1)
112
+
113
+
114
+
115
+ def forward(self, x):
116
+ input_size = (x.size()[2], x.size()[3])
117
+ features = self.backbone(x)["hidden_states"]
118
+ features[-1] = self.PPN(features[-1])
119
+ x = self.head(self.FPN(features))
120
+
121
+ x = F.interpolate(x, size=input_size, mode='bilinear')
122
+ return x
123
+
124
+ def get_backbone_params(self):
125
+ return self.backbone.parameters()
126
+
127
+ def get_decoder_params(self):
128
+ return chain(self.PPN.parameters(), self.FPN.parameters(), self.head.parameters())
129
+
130
+ class UnetDecoder(nn.Module):
131
+ def __init__(
132
+ self,
133
+ encoder_channels= (3,192,384,768,768),
134
+ decoder_channels=(512,256,128,64),
135
+ n_blocks=4,
136
+ use_batchnorm=True,
137
+ attention_type=None,
138
+ center=False,
139
+ ):
140
+ super().__init__()
141
+
142
+ if n_blocks != len(decoder_channels):
143
+ raise ValueError(
144
+ "Model depth is {}, but you provide `decoder_channels` for {} blocks.".format(
145
+ n_blocks, len(decoder_channels)
146
+ )
147
+ )
148
+
149
+ # remove first skip with same spatial resolution
150
+ encoder_channels = encoder_channels[1:]
151
+ # reverse channels to start from head of encoder
152
+ encoder_channels = encoder_channels[::-1]
153
+
154
+ # computing blocks input and output channels
155
+ head_channels = encoder_channels[0]
156
+ in_channels = [head_channels] + list(decoder_channels[:-1])
157
+ skip_channels = list(encoder_channels[1:]) + [0]
158
+
159
+ out_channels = decoder_channels
160
+
161
+ if center:
162
+ self.center = CenterBlock(head_channels, head_channels, use_batchnorm=use_batchnorm)
163
+ else:
164
+ self.center = nn.Identity()
165
+
166
+ # combine decoder keyword arguments
167
+ kwargs = dict(use_batchnorm=use_batchnorm, attention_type=attention_type)
168
+ blocks = [
169
+ DecoderBlock(in_ch, skip_ch, out_ch, **kwargs)
170
+ for in_ch, skip_ch, out_ch in zip(in_channels, skip_channels, out_channels)
171
+ ]
172
+ self.blocks = nn.ModuleList(blocks)
173
+ upscale_factor=4
174
+ self.matting_head = nn.Sequential(
175
+ nn.Conv2d(64,1, kernel_size=3, padding=1),
176
+ nn.ReLU(),
177
+ nn.UpsamplingBilinear2d(scale_factor=upscale_factor),
178
+ )
179
+
180
+ def preprocess_features(self,x):
181
+ features=[]
182
+ for out_tensor in x:
183
+ bs,n,f=out_tensor.size()
184
+ h = int(n**0.5)
185
+ feature = out_tensor.view(-1,h,h,f).permute(0, 3, 1, 2).contiguous()
186
+ features.append(feature)
187
+ return features
188
+
189
+ def forward(self, features):
190
+ features = features[1:] # remove first skip with same spatial resolution
191
+ features = features[::-1] # reverse channels to start from head of encoder
192
+
193
+ features = self.preprocess_features(features)
194
+
195
+ head = features[0]
196
+ skips = features[1:]
197
+
198
+ x = self.center(head)
199
+ for i, decoder_block in enumerate(self.blocks):
200
+ skip = skips[i] if i < len(skips) else None
201
+ x = decoder_block(x, skip)
202
+ #y_i = self.upsample1(y_i)
203
+ #hypercol = torch.cat([y0,y1,y2,y3,y4], dim=1)
204
+ x = self.matting_head(x)
205
+ x=1-nn.ReLU()(1-x)
206
+ return x
207
+
208
+
209
+ class SegmentationHead(nn.Sequential):
210
+ def __init__(self, in_channels, out_channels, kernel_size=3, upsampling=1):
211
+ conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2)
212
+ upsampling = nn.UpsamplingBilinear2d(scale_factor=upsampling) if upsampling > 1 else nn.Identity()
213
+ super().__init__(conv2d, upsampling)
214
+
215
+
216
+ class DecoderBlock(nn.Module):
217
+ def __init__(
218
+ self,
219
+ in_channels,
220
+ skip_channels,
221
+ out_channels,
222
+ use_batchnorm=True,
223
+ attention_type=None,
224
+ ):
225
+ super().__init__()
226
+ self.conv1 = conv2d_relu(
227
+ in_channels + skip_channels,
228
+ out_channels,
229
+ kernel_size=3
230
+ )
231
+ self.conv2 = conv2d_relu(
232
+ out_channels,
233
+ out_channels,
234
+ kernel_size=3,
235
+ )
236
+ self.in_channels=in_channels
237
+ self.out_channels = out_channels
238
+ self.skip_channels = skip_channels
239
+ def forward(self, x, skip=None):
240
+ if skip is None:
241
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
242
+ else:
243
+ if x.shape[-1]!=skip.shape[-1]:
244
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
245
+ if skip is not None:
246
+ #print(x.shape,skip.shape)
247
+ x = torch.cat([x, skip], dim=1)
248
+ x = self.conv1(x)
249
+ x = self.conv2(x)
250
+ return x
251
+
252
+
253
+ class CenterBlock(nn.Sequential):
254
+ def __init__(self, in_channels, out_channels):
255
+ conv1 = conv2d_relu(
256
+ in_channels,
257
+ out_channels,
258
+ kernel_size=3,
259
+ )
260
+ conv2 = conv2d_relu(
261
+ out_channels,
262
+ out_channels,
263
+ kernel_size=3,
264
+ )
265
+ super().__init__(conv1, conv2)
266
+
267
+
268
+
269
+ class SegForm(nn.Module):
270
+ def __init__(self):
271
+ super(SegForm, self).__init__()
272
+ # configuration = SegformerConfig.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
273
+ # configuration.num_labels = 1 ## set output as 1
274
+ # self.model = SegformerForSemanticSegmentation(config=configuration)
275
+
276
+ self.model = SegformerForSemanticSegmentation.from_pretrained("nvidia/mit-b0", num_labels=1, ignore_mismatched_sizes=True
277
+ )
278
+ def forward(self, image):
279
+ img_segs = self.model(image)
280
+ upsampled_logits = nn.functional.interpolate(img_segs.logits,
281
+ scale_factor=4,
282
+ mode='nearest',
283
+ )
284
+ return upsampled_logits
285
+
286
+
287
+ class MaskForm(nn.Module):
288
+ def __init__(self):
289
+ super(MaskForm, self).__init__()
290
+ # configuration = SegformerConfig.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
291
+ # configuration.num_labels = 1 ## set output as 1
292
+ self.fpn = FPN_fuse(feature_channels=[256, 256, 256, 256],fpn_out=256)
293
+ self.pixel_decoder = Mask2FormerForUniversalSegmentation.from_pretrained("facebook/mask2former-swin-tiny-coco-instance").base_model.pixel_level_module
294
+ self.fgf = FastGuidedFilter()
295
+ self.conv = nn.Conv2d(256,1,kernel_size=3,padding=1)
296
+ # self.mean = torch.Tensor([0.43216, 0.394666, 0.37645]).float().view(-1, 1, 1)
297
+ # self.register_buffer('image_net_mean', self.mean)
298
+ # self.std = torch.Tensor([0.22803, 0.22145, 0.216989]).float().view(-1, 1, 1)
299
+ # self.register_buffer('image_net_std', self.std)
300
+ def forward(self, image, normalize=False):
301
+ # if normalize:
302
+ # image.sub_(self.get_buffer("image_net_mean")).div_(self.get_buffer("image_net_std"))
303
+
304
+ decoder_out = self.pixel_decoder(image)
305
+ decoder_states=list(decoder_out.decoder_hidden_states)
306
+ decoder_states.append(decoder_out.decoder_last_hidden_state)
307
+ out_pure=self.fpn(decoder_states)
308
+
309
+ image_lr=nn.functional.interpolate(image.mean(1, keepdim=True),
310
+ scale_factor=0.25,
311
+ mode='bicubic',
312
+ align_corners=True
313
+ )
314
+ out = self.conv(out_pure)
315
+ out = self.fgf(image_lr,out,image.mean(1, keepdim=True))#.clip(0,1)
316
+ # out = nn.Sigmoid()(out)
317
+ # out = nn.functional.interpolate(out,
318
+ # scale_factor=4,
319
+ # mode='bicubic',
320
+ # align_corners=True
321
+ # )
322
+
323
+ return torch.sigmoid(out)
324
+
325
+ def get_training_params(self):
326
+ return list(self.fpn.parameters())+list(self.conv.parameters())#+list(self.fgf.parameters())
327
+
328
+ class GuidedFilter(nn.Module):
329
+ def __init__(self, r, eps=1e-8):
330
+ super(GuidedFilter, self).__init__()
331
+
332
+ self.r = r
333
+ self.eps = eps
334
+ self.boxfilter = BoxFilter(r)
335
+
336
+
337
+ def forward(self, x, y):
338
+ n_x, c_x, h_x, w_x = x.size()
339
+ n_y, c_y, h_y, w_y = y.size()
340
+
341
+ assert n_x == n_y
342
+ assert c_x == 1 or c_x == c_y
343
+ assert h_x == h_y and w_x == w_y
344
+ assert h_x > 2 * self.r + 1 and w_x > 2 * self.r + 1
345
+
346
+ # N
347
+ N = self.boxfilter((x.data.new().resize_((1, 1, h_x, w_x)).fill_(1.0)))
348
+
349
+ # mean_x
350
+ mean_x = self.boxfilter(x) / N
351
+ # mean_y
352
+ mean_y = self.boxfilter(y) / N
353
+ # cov_xy
354
+ cov_xy = self.boxfilter(x * y) / N - mean_x * mean_y
355
+ # var_x
356
+ var_x = self.boxfilter(x * x) / N - mean_x * mean_x
357
+
358
+ # A
359
+ A = cov_xy / (var_x + self.eps)
360
+ # b
361
+ b = mean_y - A * mean_x
362
+
363
+ # mean_A; mean_b
364
+ mean_A = self.boxfilter(A) / N
365
+ mean_b = self.boxfilter(b) / N
366
+
367
+ return mean_A * x + mean_b
368
+ class FastGuidedFilter(nn.Module):
369
+ def __init__(self, r=1, eps=1e-8):
370
+ super(FastGuidedFilter, self).__init__()
371
+
372
+ self.r = r
373
+ self.eps = eps
374
+ self.boxfilter = BoxFilter(r)
375
+
376
+
377
+ def forward(self, lr_x, lr_y, hr_x):
378
+ n_lrx, c_lrx, h_lrx, w_lrx = lr_x.size()
379
+ n_lry, c_lry, h_lry, w_lry = lr_y.size()
380
+ n_hrx, c_hrx, h_hrx, w_hrx = hr_x.size()
381
+
382
+ assert n_lrx == n_lry and n_lry == n_hrx
383
+ assert c_lrx == c_hrx and (c_lrx == 1 or c_lrx == c_lry)
384
+ assert h_lrx == h_lry and w_lrx == w_lry
385
+ assert h_lrx > 2*self.r+1 and w_lrx > 2*self.r+1
386
+
387
+ ## N
388
+ N = self.boxfilter(lr_x.new().resize_((1, 1, h_lrx, w_lrx)).fill_(1.0))
389
+
390
+ ## mean_x
391
+ mean_x = self.boxfilter(lr_x) / N
392
+ ## mean_y
393
+ mean_y = self.boxfilter(lr_y) / N
394
+ ## cov_xy
395
+ cov_xy = self.boxfilter(lr_x * lr_y) / N - mean_x * mean_y
396
+ ## var_x
397
+ var_x = self.boxfilter(lr_x * lr_x) / N - mean_x * mean_x
398
+
399
+ ## A
400
+ A = cov_xy / (var_x + self.eps)
401
+ ## b
402
+ b = mean_y - A * mean_x
403
+
404
+ ## mean_A; mean_b
405
+ mean_A = F.interpolate(A, (h_hrx, w_hrx), mode='bilinear', align_corners=True)
406
+ mean_b = F.interpolate(b, (h_hrx, w_hrx), mode='bilinear', align_corners=True)
407
+
408
+ return mean_A*hr_x+mean_b
409
+ class DeepGuidedFilterRefiner(nn.Module):
410
+ def __init__(self, hid_channels=16):
411
+ super().__init__()
412
+ self.box_filter = nn.Conv2d(4, 4, kernel_size=3, padding=1, bias=False, groups=4)
413
+ self.box_filter.weight.data[...] = 1 / 9
414
+ self.conv = nn.Sequential(
415
+ nn.Conv2d(4 * 2 + hid_channels, hid_channels, kernel_size=1, bias=False),
416
+ nn.BatchNorm2d(hid_channels),
417
+ nn.ReLU(True),
418
+ nn.Conv2d(hid_channels, hid_channels, kernel_size=1, bias=False),
419
+ nn.BatchNorm2d(hid_channels),
420
+ nn.ReLU(True),
421
+ nn.Conv2d(hid_channels, 4, kernel_size=1, bias=True)
422
+ )
423
+
424
+ def forward(self, fine_src, base_src, base_fgr, base_pha, base_hid):
425
+ fine_x = torch.cat([fine_src, fine_src.mean(1, keepdim=True)], dim=1)
426
+ base_x = torch.cat([base_src, base_src.mean(1, keepdim=True)], dim=1)
427
+ base_y = torch.cat([base_fgr, base_pha], dim=1)
428
+
429
+ mean_x = self.box_filter(base_x)
430
+ mean_y = self.box_filter(base_y)
431
+ cov_xy = self.box_filter(base_x * base_y) - mean_x * mean_y
432
+ var_x = self.box_filter(base_x * base_x) - mean_x * mean_x
433
+
434
+ A = self.conv(torch.cat([cov_xy, var_x, base_hid], dim=1))
435
+ b = mean_y - A * mean_x
436
+
437
+ H, W = fine_src.shape[2:]
438
+ A = F.interpolate(A, (H, W), mode='bilinear', align_corners=False)
439
+ b = F.interpolate(b, (H, W), mode='bilinear', align_corners=False)
440
+
441
+ out = A * fine_x + b
442
+ fgr, pha = out.split([3, 1], dim=1)
443
+ return fgr, pha
444
+
445
+ def diff_x(input, r):
446
+ assert input.dim() == 4
447
+
448
+ left = input[:, :, r:2 * r + 1]
449
+ middle = input[:, :, 2 * r + 1: ] - input[:, :, :-2 * r - 1]
450
+ right = input[:, :, -1: ] - input[:, :, -2 * r - 1: -r - 1]
451
+
452
+ output = torch.cat([left, middle, right], dim=2)
453
+
454
+ return output
455
+
456
+ def diff_y(input, r):
457
+ assert input.dim() == 4
458
+
459
+ left = input[:, :, :, r:2 * r + 1]
460
+ middle = input[:, :, :, 2 * r + 1: ] - input[:, :, :, :-2 * r - 1]
461
+ right = input[:, :, :, -1: ] - input[:, :, :, -2 * r - 1: -r - 1]
462
+
463
+ output = torch.cat([left, middle, right], dim=3)
464
+
465
+ return output
466
+
467
+ class BoxFilter(nn.Module):
468
+ def __init__(self, r):
469
+ super(BoxFilter, self).__init__()
470
+
471
+ self.r = r
472
+
473
+ def forward(self, x):
474
+ assert x.dim() == 4
475
+
476
+ return diff_y(diff_x(x.cumsum(dim=2), self.r).cumsum(dim=3), self.r)
477
+
478
+ if __name__ == '__main__':
479
+ model = MaskForm().cuda()
480
+ out=model(torch.randn(1,3,640,480).cuda())
481
+ print(out.shape)
test.py ADDED
@@ -0,0 +1,1004 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #modified from Github repo: https://github.com/JizhiziLi/P3M
2
+ #added inference code for other networks
3
+
4
+
5
+ import torch
6
+ import cv2
7
+ import argparse
8
+ import numpy as np
9
+ from tqdm import tqdm
10
+ from PIL import Image
11
+ from skimage.transform import resize
12
+ from torchvision import transforms,models
13
+ import os
14
+ from models import *
15
+ import torch.nn.functional as F
16
+ import torch
17
+ import torch.nn as nn
18
+ import math
19
+ from torch.autograd import Variable
20
+ import torch.nn.functional as fnn
21
+ import glob
22
+ import tqdm
23
+ from torch.autograd import Variable
24
+ from typing import Type, Any, Callable, Union, List, Optional
25
+ import logging
26
+ import time
27
+ from omegaconf import OmegaConf
28
+ config = OmegaConf.load(os.path.join(os.path.dirname(
29
+ os.path.abspath(__file__)), "config/base.yaml"))
30
+ device = "cuda"
31
+
32
+ def conv3x3(in_planes, out_planes, stride=1):
33
+ "3x3 convolution with padding"
34
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
35
+ padding=1, bias=False)
36
+ class TFI(nn.Module):
37
+ expansion = 1
38
+ def __init__(self, planes,stride=1):
39
+ super(TFI, self).__init__()
40
+ middle_planes = int(planes/2)
41
+ self.transform = conv1x1(planes, middle_planes)
42
+ self.conv1 = conv3x3(middle_planes*3, planes, stride)
43
+ self.bn1 = nn.BatchNorm2d(planes)
44
+ self.relu = nn.ReLU(inplace=True)
45
+ self.stride = stride
46
+ def forward(self, input_s_guidance, input_m_decoder, input_m_encoder):
47
+ input_s_guidance_transform = self.transform(input_s_guidance)
48
+ input_m_decoder_transform = self.transform(input_m_decoder)
49
+ input_m_encoder_transform = self.transform(input_m_encoder)
50
+ x = torch.cat((input_s_guidance_transform,input_m_decoder_transform,input_m_encoder_transform),1)
51
+ out = self.conv1(x)
52
+ out = self.bn1(out)
53
+ out = self.relu(out)
54
+ return out
55
+ class SBFI(nn.Module):
56
+ def __init__(self, planes,stride=1):
57
+ super(SBFI, self).__init__()
58
+ self.stride = stride
59
+ self.transform1 = conv1x1(planes, int(planes/2))
60
+ self.transform2 = conv1x1(64, int(planes/2))
61
+ self.maxpool = nn.MaxPool2d(2, stride=stride)
62
+ self.conv1 = conv3x3(planes, planes, 1)
63
+ self.bn1 = nn.BatchNorm2d(planes)
64
+ self.relu = nn.ReLU(inplace=True)
65
+ def forward(self, input_m_decoder,e0):
66
+ input_m_decoder_transform = self.transform1(input_m_decoder)
67
+ e0_maxpool = self.maxpool(e0)
68
+ e0_transform = self.transform2(e0_maxpool)
69
+ x = torch.cat((input_m_decoder_transform,e0_transform),1)
70
+ out = self.conv1(x)
71
+ out = self.bn1(out)
72
+ out = self.relu(out)
73
+ out = out+input_m_decoder
74
+ return out
75
+ class DBFI(nn.Module):
76
+ def __init__(self, planes,stride=1):
77
+ super(DBFI, self).__init__()
78
+ self.stride = stride
79
+ self.transform1 = conv1x1(planes, int(planes/2))
80
+ self.transform2 = conv1x1(512, int(planes/2))
81
+ self.upsample = nn.Upsample(scale_factor=stride, mode='bilinear')
82
+ self.conv1 = conv3x3(planes, planes, 1)
83
+ self.bn1 = nn.BatchNorm2d(planes)
84
+ self.relu = nn.ReLU(inplace=True)
85
+ self.conv2 = conv3x3(planes, 3, 1)
86
+ self.upsample2 = nn.Upsample(scale_factor=int(32/stride), mode='bilinear')
87
+ def forward(self, input_s_decoder,e4):
88
+ input_s_decoder_transform = self.transform1(input_s_decoder)
89
+ e4_transform = self.transform2(e4)
90
+ e4_upsample = self.upsample(e4_transform)
91
+ x = torch.cat((input_s_decoder_transform,e4_upsample),1)
92
+ out = self.conv1(x)
93
+ out = self.bn1(out)
94
+ out = self.relu(out)
95
+ out = out+input_s_decoder
96
+ out_side = self.conv2(out)
97
+ out_side = self.upsample2(out_side)
98
+ return out, out_side
99
+ class P3mNet(nn.Module):
100
+ def __init__(self):
101
+ super().__init__()
102
+ self.resnet = resnet34_mp()
103
+ ############################
104
+ ### Encoder part - RESNETMP
105
+ ############################
106
+ self.encoder0 = nn.Sequential(
107
+ self.resnet.conv1,
108
+ self.resnet.bn1,
109
+ self.resnet.relu,
110
+ )
111
+ self.mp0 = self.resnet.maxpool1
112
+ self.encoder1 = nn.Sequential(
113
+ self.resnet.layer1)
114
+ self.mp1 = self.resnet.maxpool2
115
+ self.encoder2 = self.resnet.layer2
116
+ self.mp2 = self.resnet.maxpool3
117
+ self.encoder3 = self.resnet.layer3
118
+ self.mp3 = self.resnet.maxpool4
119
+ self.encoder4 = self.resnet.layer4
120
+ self.mp4 = self.resnet.maxpool5
121
+
122
+ self.tfi_3 = TFI(256)
123
+ self.tfi_2 = TFI(128)
124
+ self.tfi_1 = TFI(64)
125
+ self.tfi_0 = TFI(64)
126
+
127
+ self.sbfi_2 = SBFI(128, 8)
128
+ self.sbfi_1 = SBFI(64, 4)
129
+ self.sbfi_0 = SBFI(64, 2)
130
+
131
+ self.dbfi_2 = DBFI(128, 4)
132
+ self.dbfi_1 = DBFI(64, 8)
133
+ self.dbfi_0 = DBFI(64, 16)
134
+
135
+ ##########################
136
+ ### Decoder part - GLOBAL
137
+ ##########################
138
+ self.decoder4_g = nn.Sequential(
139
+ nn.Conv2d(512,512,3,padding=1),
140
+ nn.BatchNorm2d(512),
141
+ nn.ReLU(inplace=True),
142
+ nn.Conv2d(512,512,3,padding=1),
143
+ nn.BatchNorm2d(512),
144
+ nn.ReLU(inplace=True),
145
+ nn.Conv2d(512,256,3,padding=1),
146
+ nn.BatchNorm2d(256),
147
+ nn.ReLU(inplace=True),
148
+ nn.Upsample(scale_factor=2, mode='bilinear') )
149
+ self.decoder3_g = nn.Sequential(
150
+ nn.Conv2d(256,256,3,padding=1),
151
+ nn.BatchNorm2d(256),
152
+ nn.ReLU(inplace=True),
153
+ nn.Conv2d(256,256,3,padding=1),
154
+ nn.BatchNorm2d(256),
155
+ nn.ReLU(inplace=True),
156
+ nn.Conv2d(256,128,3,padding=1),
157
+ nn.BatchNorm2d(128),
158
+ nn.ReLU(inplace=True),
159
+ nn.Upsample(scale_factor=2, mode='bilinear') )
160
+ self.decoder2_g = nn.Sequential(
161
+ nn.Conv2d(128,128,3,padding=1),
162
+ nn.BatchNorm2d(128),
163
+ nn.ReLU(inplace=True),
164
+ nn.Conv2d(128,128,3,padding=1),
165
+ nn.BatchNorm2d(128),
166
+ nn.ReLU(inplace=True),
167
+ nn.Conv2d(128,64,3,padding=1),
168
+ nn.BatchNorm2d(64),
169
+ nn.ReLU(inplace=True),
170
+ nn.Upsample(scale_factor=2, mode='bilinear'))
171
+ self.decoder1_g = nn.Sequential(
172
+ nn.Conv2d(64,64,3,padding=1),
173
+ nn.BatchNorm2d(64),
174
+ nn.ReLU(inplace=True),
175
+ nn.Conv2d(64,64,3,padding=1),
176
+ nn.BatchNorm2d(64),
177
+ nn.ReLU(inplace=True),
178
+ nn.Conv2d(64,64,3,padding=1),
179
+ nn.BatchNorm2d(64),
180
+ nn.ReLU(inplace=True),
181
+ nn.Upsample(scale_factor=2, mode='bilinear'))
182
+ self.decoder0_g = nn.Sequential(
183
+ nn.Conv2d(64,64,3,padding=1),
184
+ nn.BatchNorm2d(64),
185
+ nn.ReLU(inplace=True),
186
+ nn.Conv2d(64,64,3,padding=1),
187
+ nn.BatchNorm2d(64),
188
+ nn.ReLU(inplace=True),
189
+ nn.Conv2d(64,3,3,padding=1),
190
+ nn.Upsample(scale_factor=2, mode='bilinear'))
191
+
192
+ ##########################
193
+ ### Decoder part - LOCAL
194
+ ##########################
195
+ self.decoder4_l = nn.Sequential(
196
+ nn.Conv2d(512,512,3,padding=1),
197
+ nn.BatchNorm2d(512),
198
+ nn.ReLU(inplace=True),
199
+ nn.Conv2d(512,512,3,padding=1),
200
+ nn.BatchNorm2d(512),
201
+ nn.ReLU(inplace=True),
202
+ nn.Conv2d(512,256,3,padding=1),
203
+ nn.BatchNorm2d(256),
204
+ nn.ReLU(inplace=True))
205
+ self.decoder3_l = nn.Sequential(
206
+ nn.Conv2d(256,256,3,padding=1),
207
+ nn.BatchNorm2d(256),
208
+ nn.ReLU(inplace=True),
209
+ nn.Conv2d(256,256,3,padding=1),
210
+ nn.BatchNorm2d(256),
211
+ nn.ReLU(inplace=True),
212
+ nn.Conv2d(256,128,3,padding=1),
213
+ nn.BatchNorm2d(128),
214
+ nn.ReLU(inplace=True))
215
+ self.decoder2_l = nn.Sequential(
216
+ nn.Conv2d(128,128,3,padding=1),
217
+ nn.BatchNorm2d(128),
218
+ nn.ReLU(inplace=True),
219
+ nn.Conv2d(128,128,3,padding=1),
220
+ nn.BatchNorm2d(128),
221
+ nn.ReLU(inplace=True),
222
+ nn.Conv2d(128,64,3,padding=1),
223
+ nn.BatchNorm2d(64),
224
+ nn.ReLU(inplace=True))
225
+ self.decoder1_l = nn.Sequential(
226
+ nn.Conv2d(64,64,3,padding=1),
227
+ nn.BatchNorm2d(64),
228
+ nn.ReLU(inplace=True),
229
+ nn.Conv2d(64,64,3,padding=1),
230
+ nn.BatchNorm2d(64),
231
+ nn.ReLU(inplace=True),
232
+ nn.Conv2d(64,64,3,padding=1),
233
+ nn.BatchNorm2d(64),
234
+ nn.ReLU(inplace=True))
235
+ self.decoder0_l = nn.Sequential(
236
+ nn.Conv2d(64,64,3,padding=1),
237
+ nn.BatchNorm2d(64),
238
+ nn.ReLU(inplace=True),
239
+ nn.Conv2d(64,64,3,padding=1),
240
+ nn.BatchNorm2d(64),
241
+ nn.ReLU(inplace=True))
242
+ self.decoder_final_l = nn.Conv2d(64,1,3,padding=1)
243
+
244
+
245
+ def forward(self, input):
246
+ ##########################
247
+ ### Encoder part - RESNET
248
+ ##########################
249
+ e0 = self.encoder0(input)
250
+ e0p, id0 = self.mp0(e0)
251
+ e1p, id1 = self.mp1(e0p)
252
+ e1 = self.encoder1(e1p)
253
+ e2p, id2 = self.mp2(e1)
254
+ e2 = self.encoder2(e2p)
255
+ e3p, id3 = self.mp3(e2)
256
+ e3 = self.encoder3(e3p)
257
+ e4p, id4 = self.mp4(e3)
258
+ e4 = self.encoder4(e4p)
259
+ ###########################
260
+ ### Decoder part - Global
261
+ ###########################
262
+ d4_g = self.decoder4_g(e4)
263
+ d3_g = self.decoder3_g(d4_g)
264
+ d2_g, global_sigmoid_side2 = self.dbfi_2(d3_g, e4)
265
+ d2_g = self.decoder2_g(d2_g)
266
+ d1_g, global_sigmoid_side1 = self.dbfi_1(d2_g, e4)
267
+ d1_g = self.decoder1_g(d1_g)
268
+ d0_g, global_sigmoid_side0 = self.dbfi_0(d1_g, e4)
269
+ d0_g = self.decoder0_g(d0_g)
270
+ global_sigmoid = d0_g
271
+ ###########################
272
+ ### Decoder part - Local
273
+ ###########################
274
+ d4_l = self.decoder4_l(e4)
275
+ d4_l = F.max_unpool2d(d4_l, id4, kernel_size=2, stride=2)
276
+ d3_l = self.tfi_3(d4_g, d4_l, e3)
277
+ d3_l = self.decoder3_l(d3_l)
278
+ d3_l = F.max_unpool2d(d3_l, id3, kernel_size=2, stride=2)
279
+ d2_l = self.tfi_2(d3_g, d3_l, e2)
280
+ d2_l = self.sbfi_2(d2_l, e0)
281
+ d2_l = self.decoder2_l(d2_l)
282
+ d2_l = F.max_unpool2d(d2_l, id2, kernel_size=2, stride=2)
283
+ d1_l = self.tfi_1(d2_g, d2_l, e1)
284
+ d1_l = self.sbfi_1(d1_l, e0)
285
+ d1_l = self.decoder1_l(d1_l)
286
+ d1_l = F.max_unpool2d(d1_l, id1, kernel_size=2, stride=2)
287
+ d0_l = self.tfi_0(d1_g, d1_l, e0p)
288
+ d0_l = self.sbfi_0(d0_l, e0)
289
+ d0_l = self.decoder0_l(d0_l)
290
+ d0_l = F.max_unpool2d(d0_l, id0, kernel_size=2, stride=2)
291
+ d0_l = self.decoder_final_l(d0_l)
292
+ local_sigmoid = F.sigmoid(d0_l)
293
+ ##########################
294
+ ### Fusion net - G/L
295
+ ##########################
296
+ fusion_sigmoid = get_masked_local_from_global(global_sigmoid, local_sigmoid)
297
+ return global_sigmoid, local_sigmoid, fusion_sigmoid, global_sigmoid_side2, global_sigmoid_side1, global_sigmoid_side0
298
+
299
+
300
+
301
+ def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
302
+ """3x3 convolution with padding"""
303
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
304
+ padding=dilation, groups=groups, bias=False, dilation=dilation)
305
+
306
+
307
+ def conv1x1(in_planes, out_planes, stride=1):
308
+ """1x1 convolution"""
309
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
310
+
311
+
312
+ class BasicBlock(nn.Module):
313
+ expansion: int = 1
314
+
315
+ def __init__(
316
+ self,
317
+ inplanes: int,
318
+ planes: int,
319
+ stride: int = 1,
320
+ downsample: Optional[nn.Module] = None,
321
+ groups: int = 1,
322
+ base_width: int = 64,
323
+ dilation: int = 1,
324
+ norm_layer: Optional[Callable[..., nn.Module]] = None
325
+ ) -> None:
326
+ super(BasicBlock, self).__init__()
327
+ if norm_layer is None:
328
+ norm_layer = nn.BatchNorm2d
329
+ if groups != 1 or base_width != 64:
330
+ raise ValueError('BasicBlock only supports groups=1 and base_width=64')
331
+ if dilation > 1:
332
+ raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
333
+
334
+ self.conv1 = conv3x3(inplanes, planes, stride)
335
+ self.bn1 = norm_layer(planes)
336
+ self.relu = nn.ReLU(inplace=True)
337
+ self.conv2 = conv3x3(planes, planes)
338
+ self.bn2 = norm_layer(planes)
339
+ self.downsample = downsample
340
+ self.stride = stride
341
+
342
+ def forward(self, x):
343
+ identity = x
344
+
345
+ out = self.conv1(x)
346
+ out = self.bn1(out)
347
+ out = self.relu(out)
348
+
349
+ out = self.conv2(out)
350
+ out = self.bn2(out)
351
+
352
+ if self.downsample is not None:
353
+ identity = self.downsample(x)
354
+ out += identity
355
+ out = self.relu(out)
356
+
357
+ return out
358
+
359
+
360
+ class Bottleneck(nn.Module):
361
+ expansion = 4
362
+ __constants__ = ['downsample']
363
+
364
+ def __init__(self, inplanes, planes,stride=1, downsample=None, groups=1,
365
+ base_width=64, dilation=1, norm_layer=None):
366
+ super(Bottleneck, self).__init__()
367
+ if norm_layer is None:
368
+ norm_layer = nn.BatchNorm2d
369
+ width = int(planes * (base_width / 64.)) * groups
370
+ self.conv1 = conv1x1(inplanes, width)
371
+ self.bn1 = norm_layer(width)
372
+ self.conv2 = conv3x3(width, width, stride, groups, dilation)
373
+ self.bn2 = norm_layer(width)
374
+ self.conv3 = conv1x1(width, planes * self.expansion)
375
+ self.bn3 = norm_layer(planes * self.expansion)
376
+ self.relu = nn.ReLU(inplace=True)
377
+ self.downsample = downsample
378
+ self.stride = stride
379
+
380
+ def forward(self, x):
381
+ identity = x
382
+
383
+ out = self.conv1(x)
384
+ out = self.bn1(out)
385
+ out = self.relu(out)
386
+
387
+ out = self.conv2(out)
388
+ out = self.bn2(out)
389
+ out = self.relu(out)
390
+ out = self.attention(out)
391
+
392
+ out = self.conv3(out)
393
+ out = self.bn3(out)
394
+ if self.downsample is not None:
395
+ identity = self.downsample(x)
396
+ out += identity
397
+ out = self.relu(out)
398
+
399
+ return out
400
+
401
+
402
+ class ResNet(nn.Module):
403
+
404
+ def __init__(self, block, layers, zero_init_residual=False,
405
+ groups=1, width_per_group=64, replace_stride_with_dilation=None,
406
+ norm_layer=None):
407
+ super(ResNet, self).__init__()
408
+ if norm_layer is None:
409
+ norm_layer = nn.BatchNorm2d
410
+ self._norm_layer = norm_layer
411
+ self.inplanes = 64
412
+ self.dilation = 1
413
+ if replace_stride_with_dilation is None:
414
+ replace_stride_with_dilation = [False, False, False]
415
+ if len(replace_stride_with_dilation) != 3:
416
+ raise ValueError("replace_stride_with_dilation should be None "
417
+ "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
418
+ self.groups = groups
419
+ self.base_width = width_per_group
420
+ self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=1, padding=3,
421
+ bias=False)
422
+ self.bn1 = norm_layer(self.inplanes)
423
+ self.relu = nn.ReLU(inplace=True)
424
+ self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, return_indices=True)
425
+ self.maxpool2 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, return_indices=True)
426
+ self.maxpool3 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, return_indices=True)
427
+ self.maxpool4 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, return_indices=True)
428
+ self.maxpool5 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, return_indices=True)
429
+ #pdb.set_trace()
430
+ self.layer1 = self._make_layer(block, 64, layers[0])
431
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=1,
432
+ dilate=replace_stride_with_dilation[0])
433
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=1,
434
+ dilate=replace_stride_with_dilation[1])
435
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=1,
436
+ dilate=replace_stride_with_dilation[2])
437
+
438
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
439
+ self.fc = nn.Linear(512 * block.expansion, 1000)
440
+
441
+ for m in self.modules():
442
+ if isinstance(m, nn.Conv2d):
443
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
444
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
445
+ nn.init.constant_(m.weight, 1)
446
+ nn.init.constant_(m.bias, 0)
447
+ if zero_init_residual:
448
+ for m in self.modules():
449
+ if isinstance(m, Bottleneck):
450
+ nn.init.constant_(m.bn3.weight, 0)
451
+ elif isinstance(m, BasicBlock):
452
+ nn.init.constant_(m.bn2.weight, 0)
453
+
454
+ def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
455
+ norm_layer = self._norm_layer
456
+ downsample = None
457
+ previous_dilation = self.dilation
458
+ if dilate:
459
+ self.dilation *= stride
460
+ stride = 1
461
+ if stride != 1 or self.inplanes != planes * block.expansion:
462
+ downsample = nn.Sequential(
463
+ conv1x1(self.inplanes, planes * block.expansion, stride),
464
+ norm_layer(planes * block.expansion),
465
+ )
466
+
467
+ layers = []
468
+ layers.append(block(self.inplanes, planes,stride, downsample, self.groups,
469
+ self.base_width, previous_dilation, norm_layer))
470
+ self.inplanes = planes * block.expansion
471
+ for _ in range(1, blocks):
472
+ layers.append(block(self.inplanes, planes,groups=self.groups,
473
+ base_width=self.base_width, dilation=self.dilation,
474
+ norm_layer=norm_layer))
475
+
476
+ return nn.Sequential(*layers)
477
+
478
+ def _forward_impl(self, x):
479
+ x1 = self.conv1(x)
480
+ x1 = self.bn1(x1)
481
+ x1 = self.relu(x1)
482
+ x1, idx1 = self.maxpool1(x1)
483
+
484
+ x2, idx2 = self.maxpool2(x1)
485
+ x2 = self.layer1(x2)
486
+
487
+ x3, idx3 = self.maxpool3(x2)
488
+ x3 = self.layer2(x3)
489
+
490
+ x4, idx4 = self.maxpool4(x3)
491
+ x4 = self.layer3(x4)
492
+
493
+ x5, idx5 = self.maxpool5(x4)
494
+ x5 = self.layer4(x5)
495
+
496
+ x_cls = self.avgpool(x5)
497
+ x_cls = torch.flatten(x_cls, 1)
498
+ x_cls = self.fc(x_cls)
499
+
500
+ return x_cls
501
+
502
+ def forward(self, x):
503
+ return self._forward_impl(x)
504
+
505
+
506
+ def resnet34_mp(**kwargs):
507
+ r"""ResNet-34 model from
508
+ `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`
509
+ """
510
+ model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
511
+ checkpoint = torch.load("checkpoints/r34mp_pretrained_imagenet.pth.tar")
512
+ model.load_state_dict(checkpoint)
513
+ return model
514
+
515
+ ##############################
516
+ ### Training loses for P3M-NET
517
+ ##############################
518
+ def get_crossentropy_loss(gt,pre):
519
+ gt_copy = gt.clone()
520
+ gt_copy[gt_copy==0] = 0
521
+ gt_copy[gt_copy==255] = 2
522
+ gt_copy[gt_copy>2] = 1
523
+ gt_copy = gt_copy.long()
524
+ gt_copy = gt_copy[:,0,:,:]
525
+ criterion = nn.CrossEntropyLoss()
526
+ entropy_loss = criterion(pre, gt_copy)
527
+ return entropy_loss
528
+
529
+ def get_alpha_loss(predict, alpha, trimap):
530
+ weighted = torch.zeros(trimap.shape).cuda()
531
+ weighted[trimap == 128] = 1.
532
+ alpha_f = alpha / 255.
533
+ alpha_f = alpha_f.cuda()
534
+ diff = predict - alpha_f
535
+ diff = diff * weighted
536
+ alpha_loss = torch.sqrt(diff ** 2 + 1e-12)
537
+ alpha_loss_weighted = alpha_loss.sum() / (weighted.sum() + 1.)
538
+ return alpha_loss_weighted
539
+
540
+ def get_alpha_loss_whole_img(predict, alpha):
541
+ weighted = torch.ones(alpha.shape).cuda()
542
+ alpha_f = alpha / 255.
543
+ alpha_f = alpha_f.cuda()
544
+ diff = predict - alpha_f
545
+ alpha_loss = torch.sqrt(diff ** 2 + 1e-12)
546
+ alpha_loss = alpha_loss.sum()/(weighted.sum())
547
+ return alpha_loss
548
+
549
+ ## Laplacian loss is refer to
550
+ ## https://gist.github.com/MarcoForte/a07c40a2b721739bb5c5987671aa5270
551
+ def build_gauss_kernel(size=5, sigma=1.0, n_channels=1, cuda=False):
552
+ if size % 2 != 1:
553
+ raise ValueError("kernel size must be uneven")
554
+ grid = np.float32(np.mgrid[0:size,0:size].T)
555
+ gaussian = lambda x: np.exp((x - size//2)**2/(-2*sigma**2))**2
556
+ kernel = np.sum(gaussian(grid), axis=2)
557
+ kernel /= np.sum(kernel)
558
+ kernel = np.tile(kernel, (n_channels, 1, 1))
559
+ kernel = torch.FloatTensor(kernel[:, None, :, :]).cuda()
560
+ return Variable(kernel, requires_grad=False)
561
+
562
+ def conv_gauss(img, kernel):
563
+ """ convolve img with a gaussian kernel that has been built with build_gauss_kernel """
564
+ n_channels, _, kw, kh = kernel.shape
565
+ img = fnn.pad(img, (kw//2, kh//2, kw//2, kh//2), mode='replicate')
566
+ return fnn.conv2d(img, kernel, groups=n_channels)
567
+
568
+ def laplacian_pyramid(img, kernel, max_levels=5):
569
+ current = img
570
+ pyr = []
571
+ for level in range(max_levels):
572
+ filtered = conv_gauss(current, kernel)
573
+ diff = current - filtered
574
+ pyr.append(diff)
575
+ current = fnn.avg_pool2d(filtered, 2)
576
+ pyr.append(current)
577
+ return pyr
578
+
579
+ def get_laplacian_loss(predict, alpha, trimap):
580
+ weighted = torch.zeros(trimap.shape).cuda()
581
+ weighted[trimap == 128] = 1.
582
+ alpha_f = alpha / 255.
583
+ alpha_f = alpha_f.cuda()
584
+ alpha_f = alpha_f.clone()*weighted
585
+ predict = predict.clone()*weighted
586
+ gauss_kernel = build_gauss_kernel(size=5, sigma=1.0, n_channels=1, cuda=True)
587
+ pyr_alpha = laplacian_pyramid(alpha_f, gauss_kernel, 5)
588
+ pyr_predict = laplacian_pyramid(predict, gauss_kernel, 5)
589
+ laplacian_loss_weighted = sum(fnn.l1_loss(a, b) for a, b in zip(pyr_alpha, pyr_predict))
590
+ return laplacian_loss_weighted
591
+
592
+ def get_laplacian_loss_whole_img(predict, alpha):
593
+ alpha_f = alpha / 255.
594
+ alpha_f = alpha_f.cuda()
595
+ gauss_kernel = build_gauss_kernel(size=5, sigma=1.0, n_channels=1, cuda=True)
596
+ pyr_alpha = laplacian_pyramid(alpha_f, gauss_kernel, 5)
597
+ pyr_predict = laplacian_pyramid(predict, gauss_kernel, 5)
598
+ laplacian_loss = sum(fnn.l1_loss(a, b) for a, b in zip(pyr_alpha, pyr_predict))
599
+ return laplacian_loss
600
+
601
+ def get_composition_loss_whole_img(img, alpha, fg, bg, predict):
602
+ weighted = torch.ones(alpha.shape).cuda()
603
+ predict_3 = torch.cat((predict, predict, predict), 1)
604
+ comp = predict_3 * fg + (1. - predict_3) * bg
605
+ comp_loss = torch.sqrt((comp - img) ** 2 + 1e-12)
606
+ comp_loss = comp_loss.sum()/(weighted.sum())
607
+ return comp_loss
608
+
609
+ ##############################
610
+ ### Test loss for matting
611
+ ##############################
612
+ def calculate_sad_mse_mad(predict_old,alpha,trimap):
613
+ predict = np.copy(predict_old)
614
+ pixel = float((trimap == 128).sum())
615
+ predict[trimap == 255] = 1.
616
+ predict[trimap == 0 ] = 0.
617
+ sad_diff = np.sum(np.abs(predict - alpha))/1000
618
+ if pixel==0:
619
+ pixel = trimap.shape[0]*trimap.shape[1]-float((trimap==255).sum())-float((trimap==0).sum())
620
+ mse_diff = np.sum((predict - alpha) ** 2)/pixel
621
+ mad_diff = np.sum(np.abs(predict - alpha))/pixel
622
+ return sad_diff, mse_diff, mad_diff
623
+
624
+ def calculate_sad_mse_mad_whole_img(predict, alpha):
625
+ pixel = predict.shape[0]*predict.shape[1]
626
+ sad_diff = np.sum(np.abs(predict - alpha))/1000
627
+ mse_diff = np.sum((predict - alpha) ** 2)/pixel
628
+ mad_diff = np.sum(np.abs(predict - alpha))/pixel
629
+ return sad_diff, mse_diff, mad_diff
630
+
631
+ def calculate_sad_fgbg(predict, alpha, trimap):
632
+ sad_diff = np.abs(predict-alpha)
633
+ weight_fg = np.zeros(predict.shape)
634
+ weight_bg = np.zeros(predict.shape)
635
+ weight_trimap = np.zeros(predict.shape)
636
+ weight_fg[trimap==255] = 1.
637
+ weight_bg[trimap==0 ] = 1.
638
+ weight_trimap[trimap==128 ] = 1.
639
+ sad_fg = np.sum(sad_diff*weight_fg)/1000
640
+ sad_bg = np.sum(sad_diff*weight_bg)/1000
641
+ sad_trimap = np.sum(sad_diff*weight_trimap)/1000
642
+ return sad_fg, sad_bg
643
+
644
+ def compute_gradient_whole_image(pd, gt):
645
+ from scipy.ndimage import gaussian_filter
646
+ pd_x = gaussian_filter(pd, sigma=1.4, order=[1, 0], output=np.float32)
647
+ pd_y = gaussian_filter(pd, sigma=1.4, order=[0, 1], output=np.float32)
648
+ gt_x = gaussian_filter(gt, sigma=1.4, order=[1, 0], output=np.float32)
649
+ gt_y = gaussian_filter(gt, sigma=1.4, order=[0, 1], output=np.float32)
650
+ pd_mag = np.sqrt(pd_x**2 + pd_y**2)
651
+ gt_mag = np.sqrt(gt_x**2 + gt_y**2)
652
+
653
+ error_map = np.square(pd_mag - gt_mag)
654
+ loss = np.sum(error_map) / 10
655
+ return loss
656
+
657
+ def compute_connectivity_loss_whole_image(pd, gt, step=0.1):
658
+
659
+ from scipy.ndimage import morphology
660
+ from skimage.measure import label, regionprops
661
+ h, w = pd.shape
662
+ thresh_steps = np.arange(0, 1.1, step)
663
+ l_map = -1 * np.ones((h, w), dtype=np.float32)
664
+ lambda_map = np.ones((h, w), dtype=np.float32)
665
+ for i in range(1, thresh_steps.size):
666
+ pd_th = pd >= thresh_steps[i]
667
+ gt_th = gt >= thresh_steps[i]
668
+ label_image = label(pd_th & gt_th, connectivity=1)
669
+ cc = regionprops(label_image)
670
+ size_vec = np.array([c.area for c in cc])
671
+ if len(size_vec) == 0:
672
+ continue
673
+ max_id = np.argmax(size_vec)
674
+ coords = cc[max_id].coords
675
+ omega = np.zeros((h, w), dtype=np.float32)
676
+ omega[coords[:, 0], coords[:, 1]] = 1
677
+ flag = (l_map == -1) & (omega == 0)
678
+ l_map[flag == 1] = thresh_steps[i-1]
679
+ dist_maps = morphology.distance_transform_edt(omega==0)
680
+ dist_maps = dist_maps / dist_maps.max()
681
+ l_map[l_map == -1] = 1
682
+ d_pd = pd - l_map
683
+ d_gt = gt - l_map
684
+ phi_pd = 1 - d_pd * (d_pd >= 0.15).astype(np.float32)
685
+ phi_gt = 1 - d_gt * (d_gt >= 0.15).astype(np.float32)
686
+ loss = np.sum(np.abs(phi_pd - phi_gt)) / 1000
687
+ return loss
688
+
689
+
690
+
691
+ def gen_trimap_from_segmap_e2e(segmap):
692
+ trimap = np.argmax(segmap, axis=1)[0]
693
+ trimap = trimap.astype(np.int64)
694
+ trimap[trimap==1]=128
695
+ trimap[trimap==2]=255
696
+ return trimap.astype(np.uint8)
697
+
698
+ def get_masked_local_from_global(global_sigmoid, local_sigmoid):
699
+ values, index = torch.max(global_sigmoid,1)
700
+ index = index[:,None,:,:].float()
701
+ ### index <===> [0, 1, 2]
702
+ ### bg_mask <===> [1, 0, 0]
703
+ bg_mask = index.clone()
704
+ bg_mask[bg_mask==2]=1
705
+ bg_mask = 1- bg_mask
706
+ ### trimap_mask <===> [0, 1, 0]
707
+ trimap_mask = index.clone()
708
+ trimap_mask[trimap_mask==2]=0
709
+ ### fg_mask <===> [0, 0, 1]
710
+ fg_mask = index.clone()
711
+ fg_mask[fg_mask==1]=0
712
+ fg_mask[fg_mask==2]=1
713
+ fusion_sigmoid = local_sigmoid*trimap_mask+fg_mask
714
+ return fusion_sigmoid
715
+
716
+ def get_masked_local_from_global_test(global_result, local_result):
717
+ weighted_global = np.ones(global_result.shape)
718
+ weighted_global[global_result==255] = 0
719
+ weighted_global[global_result==0] = 0
720
+ fusion_result = global_result*(1.-weighted_global)/255+local_result*weighted_global
721
+ return fusion_result
722
+ def inference_once( model, scale_img, scale_trimap=None):
723
+ pred_list = []
724
+ tensor_img = torch.from_numpy(scale_img[:, :, :]).permute(2, 0, 1).cuda()
725
+ input_t = tensor_img
726
+ input_t = input_t/255.0
727
+ normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
728
+ std=[0.229, 0.224, 0.225])
729
+ input_t = normalize(input_t)
730
+ input_t = input_t.unsqueeze(0).float()
731
+ # pred_global, pred_local, pred_fusion = model(input_t)[:3]
732
+ pred_fusion = model(input_t)[:3]
733
+ pred_global = pred_fusion
734
+ pred_local = pred_fusion
735
+
736
+ pred_global = pred_global.data.cpu().numpy()
737
+ pred_global = gen_trimap_from_segmap_e2e(pred_global)
738
+ pred_local = pred_local.data.cpu().numpy()[0,0,:,:]
739
+ pred_fusion = pred_fusion.data.cpu().numpy()[0,0,:,:]
740
+ return pred_global, pred_local, pred_fusion
741
+
742
+ # def inference_img( test_choice,model, img):
743
+ # h, w, c = img.shape
744
+ # new_h = min(config['datasets'].MAX_SIZE_H, h - (h % 32))
745
+ # new_w = min(config['datasets'].MAX_SIZE_W, w - (w % 32))
746
+ # if test_choice=='HYBRID':
747
+ # global_ratio = 1/2
748
+ # local_ratio = 1
749
+ # resize_h = int(h*global_ratio)
750
+ # resize_w = int(w*global_ratio)
751
+ # new_h = min(config['datasets'].MAX_SIZE_H, resize_h - (resize_h % 32))
752
+ # new_w = min(config['datasets'].MAX_SIZE_W, resize_w - (resize_w % 32))
753
+ # scale_img = resize(img,(new_h,new_w))*255.0
754
+ # pred_coutour_1, pred_retouching_1, pred_fusion_1 = inference_once( model, scale_img)
755
+ # pred_coutour_1 = resize(pred_coutour_1,(h,w))*255.0
756
+ # resize_h = int(h*local_ratio)
757
+ # resize_w = int(w*local_ratio)
758
+ # new_h = min(config['datasets'].MAX_SIZE_H, resize_h - (resize_h % 32))
759
+ # new_w = min(config['datasets'].MAX_SIZE_W, resize_w - (resize_w % 32))
760
+ # scale_img = resize(img,(new_h,new_w))*255.0
761
+ # pred_coutour_2, pred_retouching_2, pred_fusion_2 = inference_once( model, scale_img)
762
+ # pred_retouching_2 = resize(pred_retouching_2,(h,w))
763
+ # pred_fusion = get_masked_local_from_global_test(pred_coutour_1, pred_retouching_2)
764
+ # return pred_fusion
765
+ # else:
766
+ # resize_h = int(h/2)
767
+ # resize_w = int(w/2)
768
+ # new_h = min(config['datasets'].MAX_SIZE_H, resize_h - (resize_h % 32))
769
+ # new_w = min(config['datasets'].MAX_SIZE_W, resize_w - (resize_w % 32))
770
+ # scale_img = resize(img,(new_h,new_w))*255.0
771
+ # pred_global, pred_local, pred_fusion = inference_once( model, scale_img)
772
+ # pred_local = resize(pred_local,(h,w))
773
+ # pred_global = resize(pred_global,(h,w))*255.0
774
+ # pred_fusion = resize(pred_fusion,(h,w))
775
+ # return pred_fusion
776
+
777
+
778
+ def inference_img(model, img):
779
+ h,w,_ = img.shape
780
+ # print(img.shape)
781
+ if h%8!=0 or w%8!=0:
782
+ img=cv2.copyMakeBorder(img, 8-h%8, 0, 8-w%8, 0, cv2.BORDER_REFLECT)
783
+ # print(img.shape)
784
+
785
+ tensor_img = torch.from_numpy(img).permute(2, 0, 1).cuda()
786
+ input_t = tensor_img
787
+ input_t = input_t/255.0
788
+ normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
789
+ std=[0.229, 0.224, 0.225])
790
+ input_t = normalize(input_t)
791
+ input_t = input_t.unsqueeze(0).float()
792
+ with torch.no_grad():
793
+ out=model(input_t)
794
+ # print("out",out.shape)
795
+ result = out[0][:,-h:,-w:].cpu().numpy()
796
+ # print(result.shape)
797
+
798
+ return result[0]
799
+
800
+
801
+
802
+ def test_am2k(model):
803
+ ############################
804
+ # Some initial setting for paths
805
+ ############################
806
+ ORIGINAL_PATH = config['datasets']['am2k']['validation_original']
807
+ MASK_PATH = config['datasets']['am2k']['validation_mask']
808
+ TRIMAP_PATH = config['datasets']['am2k']['validation_trimap']
809
+ img_paths = glob.glob(ORIGINAL_PATH+"/*.jpg")
810
+
811
+ ############################
812
+ # Start testing
813
+ ############################
814
+ sad_diffs = 0.
815
+ mse_diffs = 0.
816
+ mad_diffs = 0.
817
+ grad_diffs = 0.
818
+ conn_diffs = 0.
819
+ sad_trimap_diffs = 0.
820
+ mse_trimap_diffs = 0.
821
+ mad_trimap_diffs = 0.
822
+ sad_fg_diffs = 0.
823
+ sad_bg_diffs = 0.
824
+
825
+
826
+ total_number = len(img_paths)
827
+ log("===============================")
828
+ log(f'====> Start Testing\n\t--Dataset: AM2k\n\t-\n\t--Number: {total_number}')
829
+
830
+ for img_path in tqdm.tqdm(img_paths):
831
+ img_name=(img_path.split("/")[-1])[:-4]
832
+ alpha_path = MASK_PATH+img_name+'.png'
833
+ trimap_path = TRIMAP_PATH+img_name+'.png'
834
+ pil_img = Image.open(img_path)
835
+ img = np.array(pil_img)
836
+ trimap = np.array(Image.open(trimap_path))
837
+ alpha = np.array(Image.open(alpha_path))/255.
838
+ img = img[:,:,:3] if img.ndim>2 else img
839
+ trimap = trimap[:,:,0] if trimap.ndim>2 else trimap
840
+ alpha = alpha[:,:,0] if alpha.ndim>2 else alpha
841
+
842
+ with torch.no_grad():
843
+ torch.cuda.empty_cache()
844
+ predict = inference_img( model, img)
845
+
846
+
847
+ sad_trimap_diff, mse_trimap_diff, mad_trimap_diff = calculate_sad_mse_mad(predict, alpha, trimap)
848
+ sad_diff, mse_diff, mad_diff = calculate_sad_mse_mad_whole_img(predict, alpha)
849
+ sad_fg_diff, sad_bg_diff = calculate_sad_fgbg(predict, alpha, trimap)
850
+ conn_diff = compute_connectivity_loss_whole_image(predict, alpha)
851
+ grad_diff = compute_gradient_whole_image(predict, alpha)
852
+
853
+ log(f"[{img_paths.index(img_path)}/{total_number}]\nImage:{img_name}\nsad:{sad_diff}\nmse:{mse_diff}\nmad:{mad_diff}\nsad_trimap:{sad_trimap_diff}\nmse_trimap:{mse_trimap_diff}\nmad_trimap:{mad_trimap_diff}\nsad_fg:{sad_fg_diff}\nsad_bg:{sad_bg_diff}\nconn:{conn_diff}\ngrad:{grad_diff}\n-----------")
854
+
855
+ sad_diffs += sad_diff
856
+ mse_diffs += mse_diff
857
+ mad_diffs += mad_diff
858
+ mse_trimap_diffs += mse_trimap_diff
859
+ sad_trimap_diffs += sad_trimap_diff
860
+ mad_trimap_diffs += mad_trimap_diff
861
+ sad_fg_diffs += sad_fg_diff
862
+ sad_bg_diffs += sad_bg_diff
863
+ conn_diffs += conn_diff
864
+ grad_diffs += grad_diff
865
+ Image.fromarray(np.uint8(predict*255)).save(f"test/{img_name}.png")
866
+
867
+
868
+ log("===============================")
869
+ log(f"Testing numbers: {total_number}")
870
+
871
+
872
+ log("SAD: {}".format(sad_diffs / total_number))
873
+ log("MSE: {}".format(mse_diffs / total_number))
874
+ log("MAD: {}".format(mad_diffs / total_number))
875
+ log("GRAD: {}".format(grad_diffs / total_number))
876
+ log("CONN: {}".format(conn_diffs / total_number))
877
+ log("SAD TRIMAP: {}".format(sad_trimap_diffs / total_number))
878
+ log("MSE TRIMAP: {}".format(mse_trimap_diffs / total_number))
879
+ log("MAD TRIMAP: {}".format(mad_trimap_diffs / total_number))
880
+ log("SAD FG: {}".format(sad_fg_diffs / total_number))
881
+ log("SAD BG: {}".format(sad_bg_diffs / total_number))
882
+ return sad_diffs/total_number,mse_diffs/total_number,grad_diffs/total_number
883
+
884
+
885
+ def test_p3m10k(model,dataset_choice, max_image=-1):
886
+ ############################
887
+ # Some initial setting for paths
888
+ ############################
889
+ if dataset_choice == 'P3M_500_P':
890
+ val_option = 'VAL500P'
891
+ else:
892
+ val_option = 'VAL500NP'
893
+ ORIGINAL_PATH = config['datasets']['p3m10k']+"/validation/"+config['datasets']['p3m10k_test'][val_option]['ORIGINAL_PATH']
894
+ MASK_PATH = config['datasets']['p3m10k']+"/validation/"+config['datasets']['p3m10k_test'][val_option]['MASK_PATH']
895
+ TRIMAP_PATH = config['datasets']['p3m10k']+"/validation/"+config['datasets']['p3m10k_test'][val_option]['TRIMAP_PATH']
896
+ ############################
897
+ # Start testing
898
+ ############################
899
+ sad_diffs = 0.
900
+ mse_diffs = 0.
901
+ mad_diffs = 0.
902
+ sad_trimap_diffs = 0.
903
+ mse_trimap_diffs = 0.
904
+ mad_trimap_diffs = 0.
905
+ sad_fg_diffs = 0.
906
+ sad_bg_diffs = 0.
907
+ conn_diffs = 0.
908
+ grad_diffs = 0.
909
+ model.eval()
910
+ img_paths = glob.glob(ORIGINAL_PATH+"/*.jpg")
911
+ if (max_image>1):
912
+ img_paths = img_paths[:max_image]
913
+ total_number = len(img_paths)
914
+ log("===============================")
915
+ log(f'====> Start Testing\n\t----Test: {dataset_choice}\n\t--Number: {total_number}')
916
+
917
+ for img_path in tqdm.tqdm(img_paths):
918
+ img_name=(img_path.split("/")[-1])[:-4]
919
+ alpha_path = MASK_PATH+img_name+'.png'
920
+ trimap_path = TRIMAP_PATH+img_name+'.png'
921
+ pil_img = Image.open(img_path)
922
+ img = np.array(pil_img)
923
+
924
+ trimap = np.array(Image.open(trimap_path))
925
+ alpha = np.array(Image.open(alpha_path))/255.
926
+ img = img[:,:,:3] if img.ndim>2 else img
927
+ trimap = trimap[:,:,0] if trimap.ndim>2 else trimap
928
+ alpha = alpha[:,:,0] if alpha.ndim>2 else alpha
929
+ with torch.no_grad():
930
+ torch.cuda.empty_cache()
931
+ start = time.time()
932
+
933
+
934
+ predict = inference_img( model, img) #HYBRID show less accuracy
935
+
936
+ # tensorimg=transforms.ToTensor()(pil_img)
937
+ # input_img=transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(tensorimg)
938
+
939
+ # predict = model(input_img.unsqueeze(0).to(device))[0][0].detach().cpu().numpy()
940
+ # if predict.shape!=(pil_img.height,pil_img.width):
941
+ # print("resize for ",img_path)
942
+ # predict = resize(predict,(pil_img.height,pil_img.width))
943
+ sad_trimap_diff, mse_trimap_diff, mad_trimap_diff = calculate_sad_mse_mad(predict, alpha, trimap)
944
+ sad_diff, mse_diff, mad_diff = calculate_sad_mse_mad_whole_img(predict, alpha)
945
+
946
+ sad_fg_diff, sad_bg_diff = calculate_sad_fgbg(predict, alpha, trimap)
947
+ conn_diff = compute_connectivity_loss_whole_image(predict, alpha)
948
+ grad_diff = compute_gradient_whole_image(predict, alpha)
949
+ log(f"[{img_paths.index(img_path)}/{total_number}]\nImage:{img_name}\nsad:{sad_diff}\nmse:{mse_diff}\nmad:{mad_diff}\nconn:{conn_diff}\ngrad:{grad_diff}\n-----------")
950
+ sad_diffs += sad_diff
951
+ mse_diffs += mse_diff
952
+ mad_diffs += mad_diff
953
+ mse_trimap_diffs += mse_trimap_diff
954
+ sad_trimap_diffs += sad_trimap_diff
955
+ mad_trimap_diffs += mad_trimap_diff
956
+ sad_fg_diffs += sad_fg_diff
957
+ sad_bg_diffs += sad_bg_diff
958
+ conn_diffs += conn_diff
959
+ grad_diffs += grad_diff
960
+
961
+ Image.fromarray(np.uint8(predict*255)).save(f"test/{img_name}.png")
962
+
963
+ log("===============================")
964
+ log(f"Testing numbers: {total_number}")
965
+ log("SAD: {}".format(sad_diffs / total_number))
966
+ log("MSE: {}".format(mse_diffs / total_number))
967
+ log("MAD: {}".format(mad_diffs / total_number))
968
+ log("SAD TRIMAP: {}".format(sad_trimap_diffs / total_number))
969
+ log("MSE TRIMAP: {}".format(mse_trimap_diffs / total_number))
970
+ log("MAD TRIMAP: {}".format(mad_trimap_diffs / total_number))
971
+ log("SAD FG: {}".format(sad_fg_diffs / total_number))
972
+ log("SAD BG: {}".format(sad_bg_diffs / total_number))
973
+ log("CONN: {}".format(conn_diffs / total_number))
974
+ log("GRAD: {}".format(grad_diffs / total_number))
975
+
976
+ return sad_diffs/total_number,mse_diffs/total_number,grad_diffs/total_number
977
+
978
+ def log(str):
979
+ print(str)
980
+ logging.info(str)
981
+
982
+ if __name__ == '__main__':
983
+ print('*********************************')
984
+ config = OmegaConf.load(os.path.join(os.path.dirname(
985
+ os.path.abspath(__file__)), "config/base.yaml"))
986
+ config=OmegaConf.merge(config,OmegaConf.from_cli())
987
+ print(config)
988
+ model = MaskForm()
989
+ model = model.to(device)
990
+ checkpoint = f"{config.checkpoint_dir}/{config.checkpoint}"
991
+ state_dict = torch.load(checkpoint, map_location=f'{device}')
992
+ print("loaded",checkpoint)
993
+ model.load_state_dict(state_dict)
994
+ model.eval()
995
+ logging.basicConfig(filename=f'report/{config.checkpoint.replace("/","--")}.report', encoding='utf-8',filemode='w', level=logging.INFO)
996
+ # ckpt = torch.load("checkpoints/p3mnet_pretrained_on_p3m10k.pth")
997
+ # model.load_state_dict(ckpt['state_dict'], strict=True)
998
+ # model = model.cuda()
999
+ if config.dataset_to_use =="AM2K":
1000
+ test_am2k(model)
1001
+ else:
1002
+ for dataset_choice in ['P3M_500_P','P3M_500_NP']:
1003
+ test_p3m10k(model,dataset_choice)
1004
+