Spaces:
Runtime error
Runtime error
С Чичерин
commited on
Commit
•
ad54d7a
1
Parent(s):
e248ffb
added gradio app
Browse files
app.py
CHANGED
@@ -1,7 +1,24 @@
|
|
1 |
import gradio as gr
|
|
|
2 |
|
3 |
-
|
4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
-
iface = gr.Interface(fn=greet, inputs="text", outputs="text")
|
7 |
iface.launch()
|
|
|
1 |
import gradio as gr
|
2 |
+
from test import inference_img
|
3 |
|
4 |
+
device='cuda'
|
5 |
+
model = MaskForm()
|
6 |
+
model = model.to(device)
|
7 |
+
checkpoint = f"stylematte_synth.pth"
|
8 |
+
state_dict = torch.load(checkpoint, map_location=f'{device}')
|
9 |
+
|
10 |
+
model.load_state_dict(state_dict)
|
11 |
+
model.eval()
|
12 |
+
|
13 |
+
def predict(inp):
|
14 |
+
res = inference_img(model, inp)
|
15 |
+
|
16 |
+
return res
|
17 |
+
|
18 |
+
|
19 |
+
gr.Interface(fn=predict,
|
20 |
+
inputs=gr.Image(type="numpy"),
|
21 |
+
outputs=gr.Image(type="numpy"),
|
22 |
+
examples=["./logo.jpeg"]).launch(share=True)
|
23 |
|
|
|
24 |
iface.launch()
|
logo.jpeg
ADDED
models.py
ADDED
@@ -0,0 +1,481 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import random
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
import numpy as np
|
7 |
+
from typing import List
|
8 |
+
from itertools import chain
|
9 |
+
|
10 |
+
from transformers import SegformerForSemanticSegmentation,Mask2FormerForUniversalSegmentation
|
11 |
+
|
12 |
+
class EncoderDecoder(nn.Module):
|
13 |
+
def __init__(
|
14 |
+
self,
|
15 |
+
encoder,
|
16 |
+
decoder,
|
17 |
+
prefix=nn.Conv2d(3, 3, kernel_size=3, padding=1, bias=True),
|
18 |
+
):
|
19 |
+
super().__init__()
|
20 |
+
self.encoder = encoder
|
21 |
+
self.decoder = decoder
|
22 |
+
self.prefix = prefix
|
23 |
+
|
24 |
+
def forward(self, x):
|
25 |
+
if self.prefix is not None:
|
26 |
+
x = self.prefix(x)
|
27 |
+
x = self.encoder(x)["hidden_states"] #transformers
|
28 |
+
return self.decoder(x)
|
29 |
+
|
30 |
+
|
31 |
+
def conv2d_relu(input_filters,output_filters,kernel_size=3, bias=True):
|
32 |
+
return nn.Sequential(
|
33 |
+
nn.Conv2d(input_filters, output_filters, kernel_size=kernel_size, padding=kernel_size//2, bias=bias),
|
34 |
+
nn.LeakyReLU(0.2, inplace=True),
|
35 |
+
nn.BatchNorm2d(output_filters)
|
36 |
+
)
|
37 |
+
|
38 |
+
def up_and_add(x, y):
|
39 |
+
return F.interpolate(x, size=(y.size(2), y.size(3)), mode='bilinear', align_corners=True) + y
|
40 |
+
|
41 |
+
class FPN_fuse(nn.Module):
|
42 |
+
def __init__(self, feature_channels=[256, 512, 1024, 2048], fpn_out=256):
|
43 |
+
super(FPN_fuse, self).__init__()
|
44 |
+
assert feature_channels[0] == fpn_out
|
45 |
+
self.conv1x1 = nn.ModuleList([nn.Conv2d(ft_size, fpn_out, kernel_size=1)
|
46 |
+
for ft_size in feature_channels[1:]])
|
47 |
+
self.smooth_conv = nn.ModuleList([nn.Conv2d(fpn_out, fpn_out, kernel_size=3, padding=1)]
|
48 |
+
* (len(feature_channels)-1))
|
49 |
+
self.conv_fusion = nn.Sequential(
|
50 |
+
nn.Conv2d(2*fpn_out, fpn_out, kernel_size=3, padding=1, bias=False),
|
51 |
+
nn.BatchNorm2d(fpn_out),
|
52 |
+
nn.ReLU(inplace=True),
|
53 |
+
)
|
54 |
+
|
55 |
+
def forward(self, features):
|
56 |
+
|
57 |
+
features[:-1] = [conv1x1(feature) for feature, conv1x1 in zip(features[:-1], self.conv1x1)]##
|
58 |
+
feature=up_and_add(self.smooth_conv[0](features[0]),features[1])
|
59 |
+
feature=up_and_add(self.smooth_conv[1](feature),features[2])
|
60 |
+
feature=up_and_add(self.smooth_conv[2](feature),features[3])
|
61 |
+
|
62 |
+
|
63 |
+
H, W = features[-1].size(2), features[-1].size(3)
|
64 |
+
x = [feature,features[-1]]
|
65 |
+
x = [F.interpolate(x_el, size=(H, W), mode='bilinear', align_corners=True) for x_el in x]
|
66 |
+
|
67 |
+
x = self.conv_fusion(torch.cat(x, dim=1))
|
68 |
+
#x = F.interpolate(x, size=(H*4, W*4), mode='bilinear', align_corners=True)
|
69 |
+
return x
|
70 |
+
|
71 |
+
class PSPModule(nn.Module):
|
72 |
+
# In the original inmplementation they use precise RoI pooling
|
73 |
+
# Instead of using adaptative average pooling
|
74 |
+
def __init__(self, in_channels, bin_sizes=[1, 2, 4, 6]):
|
75 |
+
super(PSPModule, self).__init__()
|
76 |
+
out_channels = in_channels // len(bin_sizes)
|
77 |
+
self.stages = nn.ModuleList([self._make_stages(in_channels, out_channels, b_s)
|
78 |
+
for b_s in bin_sizes])
|
79 |
+
self.bottleneck = nn.Sequential(
|
80 |
+
nn.Conv2d(in_channels+(out_channels * len(bin_sizes)), in_channels,
|
81 |
+
kernel_size=3, padding=1, bias=False),
|
82 |
+
nn.BatchNorm2d(in_channels),
|
83 |
+
nn.ReLU(inplace=True),
|
84 |
+
nn.Dropout2d(0.1)
|
85 |
+
)
|
86 |
+
|
87 |
+
def _make_stages(self, in_channels, out_channels, bin_sz):
|
88 |
+
prior = nn.AdaptiveAvgPool2d(output_size=bin_sz)
|
89 |
+
conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
|
90 |
+
bn = nn.BatchNorm2d(out_channels)
|
91 |
+
relu = nn.ReLU(inplace=True)
|
92 |
+
return nn.Sequential(prior, conv, bn, relu)
|
93 |
+
|
94 |
+
def forward(self, features):
|
95 |
+
h, w = features.size()[2], features.size()[3]
|
96 |
+
pyramids = [features]
|
97 |
+
pyramids.extend([F.interpolate(stage(features), size=(h, w), mode='bilinear',
|
98 |
+
align_corners=True) for stage in self.stages])
|
99 |
+
output = self.bottleneck(torch.cat(pyramids, dim=1))
|
100 |
+
return output
|
101 |
+
class UperNet_swin(nn.Module):
|
102 |
+
# Implementing only the object path
|
103 |
+
def __init__(self, backbone,pretrained=True):
|
104 |
+
super(UperNet_swin, self).__init__()
|
105 |
+
|
106 |
+
|
107 |
+
self.backbone = backbone
|
108 |
+
feature_channels = [192,384,768,768]
|
109 |
+
self.PPN = PSPModule(feature_channels[-1])
|
110 |
+
self.FPN = FPN_fuse(feature_channels, fpn_out=feature_channels[0])
|
111 |
+
self.head = nn.Conv2d(feature_channels[0], 1, kernel_size=3, padding=1)
|
112 |
+
|
113 |
+
|
114 |
+
|
115 |
+
def forward(self, x):
|
116 |
+
input_size = (x.size()[2], x.size()[3])
|
117 |
+
features = self.backbone(x)["hidden_states"]
|
118 |
+
features[-1] = self.PPN(features[-1])
|
119 |
+
x = self.head(self.FPN(features))
|
120 |
+
|
121 |
+
x = F.interpolate(x, size=input_size, mode='bilinear')
|
122 |
+
return x
|
123 |
+
|
124 |
+
def get_backbone_params(self):
|
125 |
+
return self.backbone.parameters()
|
126 |
+
|
127 |
+
def get_decoder_params(self):
|
128 |
+
return chain(self.PPN.parameters(), self.FPN.parameters(), self.head.parameters())
|
129 |
+
|
130 |
+
class UnetDecoder(nn.Module):
|
131 |
+
def __init__(
|
132 |
+
self,
|
133 |
+
encoder_channels= (3,192,384,768,768),
|
134 |
+
decoder_channels=(512,256,128,64),
|
135 |
+
n_blocks=4,
|
136 |
+
use_batchnorm=True,
|
137 |
+
attention_type=None,
|
138 |
+
center=False,
|
139 |
+
):
|
140 |
+
super().__init__()
|
141 |
+
|
142 |
+
if n_blocks != len(decoder_channels):
|
143 |
+
raise ValueError(
|
144 |
+
"Model depth is {}, but you provide `decoder_channels` for {} blocks.".format(
|
145 |
+
n_blocks, len(decoder_channels)
|
146 |
+
)
|
147 |
+
)
|
148 |
+
|
149 |
+
# remove first skip with same spatial resolution
|
150 |
+
encoder_channels = encoder_channels[1:]
|
151 |
+
# reverse channels to start from head of encoder
|
152 |
+
encoder_channels = encoder_channels[::-1]
|
153 |
+
|
154 |
+
# computing blocks input and output channels
|
155 |
+
head_channels = encoder_channels[0]
|
156 |
+
in_channels = [head_channels] + list(decoder_channels[:-1])
|
157 |
+
skip_channels = list(encoder_channels[1:]) + [0]
|
158 |
+
|
159 |
+
out_channels = decoder_channels
|
160 |
+
|
161 |
+
if center:
|
162 |
+
self.center = CenterBlock(head_channels, head_channels, use_batchnorm=use_batchnorm)
|
163 |
+
else:
|
164 |
+
self.center = nn.Identity()
|
165 |
+
|
166 |
+
# combine decoder keyword arguments
|
167 |
+
kwargs = dict(use_batchnorm=use_batchnorm, attention_type=attention_type)
|
168 |
+
blocks = [
|
169 |
+
DecoderBlock(in_ch, skip_ch, out_ch, **kwargs)
|
170 |
+
for in_ch, skip_ch, out_ch in zip(in_channels, skip_channels, out_channels)
|
171 |
+
]
|
172 |
+
self.blocks = nn.ModuleList(blocks)
|
173 |
+
upscale_factor=4
|
174 |
+
self.matting_head = nn.Sequential(
|
175 |
+
nn.Conv2d(64,1, kernel_size=3, padding=1),
|
176 |
+
nn.ReLU(),
|
177 |
+
nn.UpsamplingBilinear2d(scale_factor=upscale_factor),
|
178 |
+
)
|
179 |
+
|
180 |
+
def preprocess_features(self,x):
|
181 |
+
features=[]
|
182 |
+
for out_tensor in x:
|
183 |
+
bs,n,f=out_tensor.size()
|
184 |
+
h = int(n**0.5)
|
185 |
+
feature = out_tensor.view(-1,h,h,f).permute(0, 3, 1, 2).contiguous()
|
186 |
+
features.append(feature)
|
187 |
+
return features
|
188 |
+
|
189 |
+
def forward(self, features):
|
190 |
+
features = features[1:] # remove first skip with same spatial resolution
|
191 |
+
features = features[::-1] # reverse channels to start from head of encoder
|
192 |
+
|
193 |
+
features = self.preprocess_features(features)
|
194 |
+
|
195 |
+
head = features[0]
|
196 |
+
skips = features[1:]
|
197 |
+
|
198 |
+
x = self.center(head)
|
199 |
+
for i, decoder_block in enumerate(self.blocks):
|
200 |
+
skip = skips[i] if i < len(skips) else None
|
201 |
+
x = decoder_block(x, skip)
|
202 |
+
#y_i = self.upsample1(y_i)
|
203 |
+
#hypercol = torch.cat([y0,y1,y2,y3,y4], dim=1)
|
204 |
+
x = self.matting_head(x)
|
205 |
+
x=1-nn.ReLU()(1-x)
|
206 |
+
return x
|
207 |
+
|
208 |
+
|
209 |
+
class SegmentationHead(nn.Sequential):
|
210 |
+
def __init__(self, in_channels, out_channels, kernel_size=3, upsampling=1):
|
211 |
+
conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2)
|
212 |
+
upsampling = nn.UpsamplingBilinear2d(scale_factor=upsampling) if upsampling > 1 else nn.Identity()
|
213 |
+
super().__init__(conv2d, upsampling)
|
214 |
+
|
215 |
+
|
216 |
+
class DecoderBlock(nn.Module):
|
217 |
+
def __init__(
|
218 |
+
self,
|
219 |
+
in_channels,
|
220 |
+
skip_channels,
|
221 |
+
out_channels,
|
222 |
+
use_batchnorm=True,
|
223 |
+
attention_type=None,
|
224 |
+
):
|
225 |
+
super().__init__()
|
226 |
+
self.conv1 = conv2d_relu(
|
227 |
+
in_channels + skip_channels,
|
228 |
+
out_channels,
|
229 |
+
kernel_size=3
|
230 |
+
)
|
231 |
+
self.conv2 = conv2d_relu(
|
232 |
+
out_channels,
|
233 |
+
out_channels,
|
234 |
+
kernel_size=3,
|
235 |
+
)
|
236 |
+
self.in_channels=in_channels
|
237 |
+
self.out_channels = out_channels
|
238 |
+
self.skip_channels = skip_channels
|
239 |
+
def forward(self, x, skip=None):
|
240 |
+
if skip is None:
|
241 |
+
x = F.interpolate(x, scale_factor=2, mode="nearest")
|
242 |
+
else:
|
243 |
+
if x.shape[-1]!=skip.shape[-1]:
|
244 |
+
x = F.interpolate(x, scale_factor=2, mode="nearest")
|
245 |
+
if skip is not None:
|
246 |
+
#print(x.shape,skip.shape)
|
247 |
+
x = torch.cat([x, skip], dim=1)
|
248 |
+
x = self.conv1(x)
|
249 |
+
x = self.conv2(x)
|
250 |
+
return x
|
251 |
+
|
252 |
+
|
253 |
+
class CenterBlock(nn.Sequential):
|
254 |
+
def __init__(self, in_channels, out_channels):
|
255 |
+
conv1 = conv2d_relu(
|
256 |
+
in_channels,
|
257 |
+
out_channels,
|
258 |
+
kernel_size=3,
|
259 |
+
)
|
260 |
+
conv2 = conv2d_relu(
|
261 |
+
out_channels,
|
262 |
+
out_channels,
|
263 |
+
kernel_size=3,
|
264 |
+
)
|
265 |
+
super().__init__(conv1, conv2)
|
266 |
+
|
267 |
+
|
268 |
+
|
269 |
+
class SegForm(nn.Module):
|
270 |
+
def __init__(self):
|
271 |
+
super(SegForm, self).__init__()
|
272 |
+
# configuration = SegformerConfig.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
|
273 |
+
# configuration.num_labels = 1 ## set output as 1
|
274 |
+
# self.model = SegformerForSemanticSegmentation(config=configuration)
|
275 |
+
|
276 |
+
self.model = SegformerForSemanticSegmentation.from_pretrained("nvidia/mit-b0", num_labels=1, ignore_mismatched_sizes=True
|
277 |
+
)
|
278 |
+
def forward(self, image):
|
279 |
+
img_segs = self.model(image)
|
280 |
+
upsampled_logits = nn.functional.interpolate(img_segs.logits,
|
281 |
+
scale_factor=4,
|
282 |
+
mode='nearest',
|
283 |
+
)
|
284 |
+
return upsampled_logits
|
285 |
+
|
286 |
+
|
287 |
+
class MaskForm(nn.Module):
|
288 |
+
def __init__(self):
|
289 |
+
super(MaskForm, self).__init__()
|
290 |
+
# configuration = SegformerConfig.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
|
291 |
+
# configuration.num_labels = 1 ## set output as 1
|
292 |
+
self.fpn = FPN_fuse(feature_channels=[256, 256, 256, 256],fpn_out=256)
|
293 |
+
self.pixel_decoder = Mask2FormerForUniversalSegmentation.from_pretrained("facebook/mask2former-swin-tiny-coco-instance").base_model.pixel_level_module
|
294 |
+
self.fgf = FastGuidedFilter()
|
295 |
+
self.conv = nn.Conv2d(256,1,kernel_size=3,padding=1)
|
296 |
+
# self.mean = torch.Tensor([0.43216, 0.394666, 0.37645]).float().view(-1, 1, 1)
|
297 |
+
# self.register_buffer('image_net_mean', self.mean)
|
298 |
+
# self.std = torch.Tensor([0.22803, 0.22145, 0.216989]).float().view(-1, 1, 1)
|
299 |
+
# self.register_buffer('image_net_std', self.std)
|
300 |
+
def forward(self, image, normalize=False):
|
301 |
+
# if normalize:
|
302 |
+
# image.sub_(self.get_buffer("image_net_mean")).div_(self.get_buffer("image_net_std"))
|
303 |
+
|
304 |
+
decoder_out = self.pixel_decoder(image)
|
305 |
+
decoder_states=list(decoder_out.decoder_hidden_states)
|
306 |
+
decoder_states.append(decoder_out.decoder_last_hidden_state)
|
307 |
+
out_pure=self.fpn(decoder_states)
|
308 |
+
|
309 |
+
image_lr=nn.functional.interpolate(image.mean(1, keepdim=True),
|
310 |
+
scale_factor=0.25,
|
311 |
+
mode='bicubic',
|
312 |
+
align_corners=True
|
313 |
+
)
|
314 |
+
out = self.conv(out_pure)
|
315 |
+
out = self.fgf(image_lr,out,image.mean(1, keepdim=True))#.clip(0,1)
|
316 |
+
# out = nn.Sigmoid()(out)
|
317 |
+
# out = nn.functional.interpolate(out,
|
318 |
+
# scale_factor=4,
|
319 |
+
# mode='bicubic',
|
320 |
+
# align_corners=True
|
321 |
+
# )
|
322 |
+
|
323 |
+
return torch.sigmoid(out)
|
324 |
+
|
325 |
+
def get_training_params(self):
|
326 |
+
return list(self.fpn.parameters())+list(self.conv.parameters())#+list(self.fgf.parameters())
|
327 |
+
|
328 |
+
class GuidedFilter(nn.Module):
|
329 |
+
def __init__(self, r, eps=1e-8):
|
330 |
+
super(GuidedFilter, self).__init__()
|
331 |
+
|
332 |
+
self.r = r
|
333 |
+
self.eps = eps
|
334 |
+
self.boxfilter = BoxFilter(r)
|
335 |
+
|
336 |
+
|
337 |
+
def forward(self, x, y):
|
338 |
+
n_x, c_x, h_x, w_x = x.size()
|
339 |
+
n_y, c_y, h_y, w_y = y.size()
|
340 |
+
|
341 |
+
assert n_x == n_y
|
342 |
+
assert c_x == 1 or c_x == c_y
|
343 |
+
assert h_x == h_y and w_x == w_y
|
344 |
+
assert h_x > 2 * self.r + 1 and w_x > 2 * self.r + 1
|
345 |
+
|
346 |
+
# N
|
347 |
+
N = self.boxfilter((x.data.new().resize_((1, 1, h_x, w_x)).fill_(1.0)))
|
348 |
+
|
349 |
+
# mean_x
|
350 |
+
mean_x = self.boxfilter(x) / N
|
351 |
+
# mean_y
|
352 |
+
mean_y = self.boxfilter(y) / N
|
353 |
+
# cov_xy
|
354 |
+
cov_xy = self.boxfilter(x * y) / N - mean_x * mean_y
|
355 |
+
# var_x
|
356 |
+
var_x = self.boxfilter(x * x) / N - mean_x * mean_x
|
357 |
+
|
358 |
+
# A
|
359 |
+
A = cov_xy / (var_x + self.eps)
|
360 |
+
# b
|
361 |
+
b = mean_y - A * mean_x
|
362 |
+
|
363 |
+
# mean_A; mean_b
|
364 |
+
mean_A = self.boxfilter(A) / N
|
365 |
+
mean_b = self.boxfilter(b) / N
|
366 |
+
|
367 |
+
return mean_A * x + mean_b
|
368 |
+
class FastGuidedFilter(nn.Module):
|
369 |
+
def __init__(self, r=1, eps=1e-8):
|
370 |
+
super(FastGuidedFilter, self).__init__()
|
371 |
+
|
372 |
+
self.r = r
|
373 |
+
self.eps = eps
|
374 |
+
self.boxfilter = BoxFilter(r)
|
375 |
+
|
376 |
+
|
377 |
+
def forward(self, lr_x, lr_y, hr_x):
|
378 |
+
n_lrx, c_lrx, h_lrx, w_lrx = lr_x.size()
|
379 |
+
n_lry, c_lry, h_lry, w_lry = lr_y.size()
|
380 |
+
n_hrx, c_hrx, h_hrx, w_hrx = hr_x.size()
|
381 |
+
|
382 |
+
assert n_lrx == n_lry and n_lry == n_hrx
|
383 |
+
assert c_lrx == c_hrx and (c_lrx == 1 or c_lrx == c_lry)
|
384 |
+
assert h_lrx == h_lry and w_lrx == w_lry
|
385 |
+
assert h_lrx > 2*self.r+1 and w_lrx > 2*self.r+1
|
386 |
+
|
387 |
+
## N
|
388 |
+
N = self.boxfilter(lr_x.new().resize_((1, 1, h_lrx, w_lrx)).fill_(1.0))
|
389 |
+
|
390 |
+
## mean_x
|
391 |
+
mean_x = self.boxfilter(lr_x) / N
|
392 |
+
## mean_y
|
393 |
+
mean_y = self.boxfilter(lr_y) / N
|
394 |
+
## cov_xy
|
395 |
+
cov_xy = self.boxfilter(lr_x * lr_y) / N - mean_x * mean_y
|
396 |
+
## var_x
|
397 |
+
var_x = self.boxfilter(lr_x * lr_x) / N - mean_x * mean_x
|
398 |
+
|
399 |
+
## A
|
400 |
+
A = cov_xy / (var_x + self.eps)
|
401 |
+
## b
|
402 |
+
b = mean_y - A * mean_x
|
403 |
+
|
404 |
+
## mean_A; mean_b
|
405 |
+
mean_A = F.interpolate(A, (h_hrx, w_hrx), mode='bilinear', align_corners=True)
|
406 |
+
mean_b = F.interpolate(b, (h_hrx, w_hrx), mode='bilinear', align_corners=True)
|
407 |
+
|
408 |
+
return mean_A*hr_x+mean_b
|
409 |
+
class DeepGuidedFilterRefiner(nn.Module):
|
410 |
+
def __init__(self, hid_channels=16):
|
411 |
+
super().__init__()
|
412 |
+
self.box_filter = nn.Conv2d(4, 4, kernel_size=3, padding=1, bias=False, groups=4)
|
413 |
+
self.box_filter.weight.data[...] = 1 / 9
|
414 |
+
self.conv = nn.Sequential(
|
415 |
+
nn.Conv2d(4 * 2 + hid_channels, hid_channels, kernel_size=1, bias=False),
|
416 |
+
nn.BatchNorm2d(hid_channels),
|
417 |
+
nn.ReLU(True),
|
418 |
+
nn.Conv2d(hid_channels, hid_channels, kernel_size=1, bias=False),
|
419 |
+
nn.BatchNorm2d(hid_channels),
|
420 |
+
nn.ReLU(True),
|
421 |
+
nn.Conv2d(hid_channels, 4, kernel_size=1, bias=True)
|
422 |
+
)
|
423 |
+
|
424 |
+
def forward(self, fine_src, base_src, base_fgr, base_pha, base_hid):
|
425 |
+
fine_x = torch.cat([fine_src, fine_src.mean(1, keepdim=True)], dim=1)
|
426 |
+
base_x = torch.cat([base_src, base_src.mean(1, keepdim=True)], dim=1)
|
427 |
+
base_y = torch.cat([base_fgr, base_pha], dim=1)
|
428 |
+
|
429 |
+
mean_x = self.box_filter(base_x)
|
430 |
+
mean_y = self.box_filter(base_y)
|
431 |
+
cov_xy = self.box_filter(base_x * base_y) - mean_x * mean_y
|
432 |
+
var_x = self.box_filter(base_x * base_x) - mean_x * mean_x
|
433 |
+
|
434 |
+
A = self.conv(torch.cat([cov_xy, var_x, base_hid], dim=1))
|
435 |
+
b = mean_y - A * mean_x
|
436 |
+
|
437 |
+
H, W = fine_src.shape[2:]
|
438 |
+
A = F.interpolate(A, (H, W), mode='bilinear', align_corners=False)
|
439 |
+
b = F.interpolate(b, (H, W), mode='bilinear', align_corners=False)
|
440 |
+
|
441 |
+
out = A * fine_x + b
|
442 |
+
fgr, pha = out.split([3, 1], dim=1)
|
443 |
+
return fgr, pha
|
444 |
+
|
445 |
+
def diff_x(input, r):
|
446 |
+
assert input.dim() == 4
|
447 |
+
|
448 |
+
left = input[:, :, r:2 * r + 1]
|
449 |
+
middle = input[:, :, 2 * r + 1: ] - input[:, :, :-2 * r - 1]
|
450 |
+
right = input[:, :, -1: ] - input[:, :, -2 * r - 1: -r - 1]
|
451 |
+
|
452 |
+
output = torch.cat([left, middle, right], dim=2)
|
453 |
+
|
454 |
+
return output
|
455 |
+
|
456 |
+
def diff_y(input, r):
|
457 |
+
assert input.dim() == 4
|
458 |
+
|
459 |
+
left = input[:, :, :, r:2 * r + 1]
|
460 |
+
middle = input[:, :, :, 2 * r + 1: ] - input[:, :, :, :-2 * r - 1]
|
461 |
+
right = input[:, :, :, -1: ] - input[:, :, :, -2 * r - 1: -r - 1]
|
462 |
+
|
463 |
+
output = torch.cat([left, middle, right], dim=3)
|
464 |
+
|
465 |
+
return output
|
466 |
+
|
467 |
+
class BoxFilter(nn.Module):
|
468 |
+
def __init__(self, r):
|
469 |
+
super(BoxFilter, self).__init__()
|
470 |
+
|
471 |
+
self.r = r
|
472 |
+
|
473 |
+
def forward(self, x):
|
474 |
+
assert x.dim() == 4
|
475 |
+
|
476 |
+
return diff_y(diff_x(x.cumsum(dim=2), self.r).cumsum(dim=3), self.r)
|
477 |
+
|
478 |
+
if __name__ == '__main__':
|
479 |
+
model = MaskForm().cuda()
|
480 |
+
out=model(torch.randn(1,3,640,480).cuda())
|
481 |
+
print(out.shape)
|
test.py
ADDED
@@ -0,0 +1,1004 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#modified from Github repo: https://github.com/JizhiziLi/P3M
|
2 |
+
#added inference code for other networks
|
3 |
+
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import cv2
|
7 |
+
import argparse
|
8 |
+
import numpy as np
|
9 |
+
from tqdm import tqdm
|
10 |
+
from PIL import Image
|
11 |
+
from skimage.transform import resize
|
12 |
+
from torchvision import transforms,models
|
13 |
+
import os
|
14 |
+
from models import *
|
15 |
+
import torch.nn.functional as F
|
16 |
+
import torch
|
17 |
+
import torch.nn as nn
|
18 |
+
import math
|
19 |
+
from torch.autograd import Variable
|
20 |
+
import torch.nn.functional as fnn
|
21 |
+
import glob
|
22 |
+
import tqdm
|
23 |
+
from torch.autograd import Variable
|
24 |
+
from typing import Type, Any, Callable, Union, List, Optional
|
25 |
+
import logging
|
26 |
+
import time
|
27 |
+
from omegaconf import OmegaConf
|
28 |
+
config = OmegaConf.load(os.path.join(os.path.dirname(
|
29 |
+
os.path.abspath(__file__)), "config/base.yaml"))
|
30 |
+
device = "cuda"
|
31 |
+
|
32 |
+
def conv3x3(in_planes, out_planes, stride=1):
|
33 |
+
"3x3 convolution with padding"
|
34 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
35 |
+
padding=1, bias=False)
|
36 |
+
class TFI(nn.Module):
|
37 |
+
expansion = 1
|
38 |
+
def __init__(self, planes,stride=1):
|
39 |
+
super(TFI, self).__init__()
|
40 |
+
middle_planes = int(planes/2)
|
41 |
+
self.transform = conv1x1(planes, middle_planes)
|
42 |
+
self.conv1 = conv3x3(middle_planes*3, planes, stride)
|
43 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
44 |
+
self.relu = nn.ReLU(inplace=True)
|
45 |
+
self.stride = stride
|
46 |
+
def forward(self, input_s_guidance, input_m_decoder, input_m_encoder):
|
47 |
+
input_s_guidance_transform = self.transform(input_s_guidance)
|
48 |
+
input_m_decoder_transform = self.transform(input_m_decoder)
|
49 |
+
input_m_encoder_transform = self.transform(input_m_encoder)
|
50 |
+
x = torch.cat((input_s_guidance_transform,input_m_decoder_transform,input_m_encoder_transform),1)
|
51 |
+
out = self.conv1(x)
|
52 |
+
out = self.bn1(out)
|
53 |
+
out = self.relu(out)
|
54 |
+
return out
|
55 |
+
class SBFI(nn.Module):
|
56 |
+
def __init__(self, planes,stride=1):
|
57 |
+
super(SBFI, self).__init__()
|
58 |
+
self.stride = stride
|
59 |
+
self.transform1 = conv1x1(planes, int(planes/2))
|
60 |
+
self.transform2 = conv1x1(64, int(planes/2))
|
61 |
+
self.maxpool = nn.MaxPool2d(2, stride=stride)
|
62 |
+
self.conv1 = conv3x3(planes, planes, 1)
|
63 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
64 |
+
self.relu = nn.ReLU(inplace=True)
|
65 |
+
def forward(self, input_m_decoder,e0):
|
66 |
+
input_m_decoder_transform = self.transform1(input_m_decoder)
|
67 |
+
e0_maxpool = self.maxpool(e0)
|
68 |
+
e0_transform = self.transform2(e0_maxpool)
|
69 |
+
x = torch.cat((input_m_decoder_transform,e0_transform),1)
|
70 |
+
out = self.conv1(x)
|
71 |
+
out = self.bn1(out)
|
72 |
+
out = self.relu(out)
|
73 |
+
out = out+input_m_decoder
|
74 |
+
return out
|
75 |
+
class DBFI(nn.Module):
|
76 |
+
def __init__(self, planes,stride=1):
|
77 |
+
super(DBFI, self).__init__()
|
78 |
+
self.stride = stride
|
79 |
+
self.transform1 = conv1x1(planes, int(planes/2))
|
80 |
+
self.transform2 = conv1x1(512, int(planes/2))
|
81 |
+
self.upsample = nn.Upsample(scale_factor=stride, mode='bilinear')
|
82 |
+
self.conv1 = conv3x3(planes, planes, 1)
|
83 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
84 |
+
self.relu = nn.ReLU(inplace=True)
|
85 |
+
self.conv2 = conv3x3(planes, 3, 1)
|
86 |
+
self.upsample2 = nn.Upsample(scale_factor=int(32/stride), mode='bilinear')
|
87 |
+
def forward(self, input_s_decoder,e4):
|
88 |
+
input_s_decoder_transform = self.transform1(input_s_decoder)
|
89 |
+
e4_transform = self.transform2(e4)
|
90 |
+
e4_upsample = self.upsample(e4_transform)
|
91 |
+
x = torch.cat((input_s_decoder_transform,e4_upsample),1)
|
92 |
+
out = self.conv1(x)
|
93 |
+
out = self.bn1(out)
|
94 |
+
out = self.relu(out)
|
95 |
+
out = out+input_s_decoder
|
96 |
+
out_side = self.conv2(out)
|
97 |
+
out_side = self.upsample2(out_side)
|
98 |
+
return out, out_side
|
99 |
+
class P3mNet(nn.Module):
|
100 |
+
def __init__(self):
|
101 |
+
super().__init__()
|
102 |
+
self.resnet = resnet34_mp()
|
103 |
+
############################
|
104 |
+
### Encoder part - RESNETMP
|
105 |
+
############################
|
106 |
+
self.encoder0 = nn.Sequential(
|
107 |
+
self.resnet.conv1,
|
108 |
+
self.resnet.bn1,
|
109 |
+
self.resnet.relu,
|
110 |
+
)
|
111 |
+
self.mp0 = self.resnet.maxpool1
|
112 |
+
self.encoder1 = nn.Sequential(
|
113 |
+
self.resnet.layer1)
|
114 |
+
self.mp1 = self.resnet.maxpool2
|
115 |
+
self.encoder2 = self.resnet.layer2
|
116 |
+
self.mp2 = self.resnet.maxpool3
|
117 |
+
self.encoder3 = self.resnet.layer3
|
118 |
+
self.mp3 = self.resnet.maxpool4
|
119 |
+
self.encoder4 = self.resnet.layer4
|
120 |
+
self.mp4 = self.resnet.maxpool5
|
121 |
+
|
122 |
+
self.tfi_3 = TFI(256)
|
123 |
+
self.tfi_2 = TFI(128)
|
124 |
+
self.tfi_1 = TFI(64)
|
125 |
+
self.tfi_0 = TFI(64)
|
126 |
+
|
127 |
+
self.sbfi_2 = SBFI(128, 8)
|
128 |
+
self.sbfi_1 = SBFI(64, 4)
|
129 |
+
self.sbfi_0 = SBFI(64, 2)
|
130 |
+
|
131 |
+
self.dbfi_2 = DBFI(128, 4)
|
132 |
+
self.dbfi_1 = DBFI(64, 8)
|
133 |
+
self.dbfi_0 = DBFI(64, 16)
|
134 |
+
|
135 |
+
##########################
|
136 |
+
### Decoder part - GLOBAL
|
137 |
+
##########################
|
138 |
+
self.decoder4_g = nn.Sequential(
|
139 |
+
nn.Conv2d(512,512,3,padding=1),
|
140 |
+
nn.BatchNorm2d(512),
|
141 |
+
nn.ReLU(inplace=True),
|
142 |
+
nn.Conv2d(512,512,3,padding=1),
|
143 |
+
nn.BatchNorm2d(512),
|
144 |
+
nn.ReLU(inplace=True),
|
145 |
+
nn.Conv2d(512,256,3,padding=1),
|
146 |
+
nn.BatchNorm2d(256),
|
147 |
+
nn.ReLU(inplace=True),
|
148 |
+
nn.Upsample(scale_factor=2, mode='bilinear') )
|
149 |
+
self.decoder3_g = nn.Sequential(
|
150 |
+
nn.Conv2d(256,256,3,padding=1),
|
151 |
+
nn.BatchNorm2d(256),
|
152 |
+
nn.ReLU(inplace=True),
|
153 |
+
nn.Conv2d(256,256,3,padding=1),
|
154 |
+
nn.BatchNorm2d(256),
|
155 |
+
nn.ReLU(inplace=True),
|
156 |
+
nn.Conv2d(256,128,3,padding=1),
|
157 |
+
nn.BatchNorm2d(128),
|
158 |
+
nn.ReLU(inplace=True),
|
159 |
+
nn.Upsample(scale_factor=2, mode='bilinear') )
|
160 |
+
self.decoder2_g = nn.Sequential(
|
161 |
+
nn.Conv2d(128,128,3,padding=1),
|
162 |
+
nn.BatchNorm2d(128),
|
163 |
+
nn.ReLU(inplace=True),
|
164 |
+
nn.Conv2d(128,128,3,padding=1),
|
165 |
+
nn.BatchNorm2d(128),
|
166 |
+
nn.ReLU(inplace=True),
|
167 |
+
nn.Conv2d(128,64,3,padding=1),
|
168 |
+
nn.BatchNorm2d(64),
|
169 |
+
nn.ReLU(inplace=True),
|
170 |
+
nn.Upsample(scale_factor=2, mode='bilinear'))
|
171 |
+
self.decoder1_g = nn.Sequential(
|
172 |
+
nn.Conv2d(64,64,3,padding=1),
|
173 |
+
nn.BatchNorm2d(64),
|
174 |
+
nn.ReLU(inplace=True),
|
175 |
+
nn.Conv2d(64,64,3,padding=1),
|
176 |
+
nn.BatchNorm2d(64),
|
177 |
+
nn.ReLU(inplace=True),
|
178 |
+
nn.Conv2d(64,64,3,padding=1),
|
179 |
+
nn.BatchNorm2d(64),
|
180 |
+
nn.ReLU(inplace=True),
|
181 |
+
nn.Upsample(scale_factor=2, mode='bilinear'))
|
182 |
+
self.decoder0_g = nn.Sequential(
|
183 |
+
nn.Conv2d(64,64,3,padding=1),
|
184 |
+
nn.BatchNorm2d(64),
|
185 |
+
nn.ReLU(inplace=True),
|
186 |
+
nn.Conv2d(64,64,3,padding=1),
|
187 |
+
nn.BatchNorm2d(64),
|
188 |
+
nn.ReLU(inplace=True),
|
189 |
+
nn.Conv2d(64,3,3,padding=1),
|
190 |
+
nn.Upsample(scale_factor=2, mode='bilinear'))
|
191 |
+
|
192 |
+
##########################
|
193 |
+
### Decoder part - LOCAL
|
194 |
+
##########################
|
195 |
+
self.decoder4_l = nn.Sequential(
|
196 |
+
nn.Conv2d(512,512,3,padding=1),
|
197 |
+
nn.BatchNorm2d(512),
|
198 |
+
nn.ReLU(inplace=True),
|
199 |
+
nn.Conv2d(512,512,3,padding=1),
|
200 |
+
nn.BatchNorm2d(512),
|
201 |
+
nn.ReLU(inplace=True),
|
202 |
+
nn.Conv2d(512,256,3,padding=1),
|
203 |
+
nn.BatchNorm2d(256),
|
204 |
+
nn.ReLU(inplace=True))
|
205 |
+
self.decoder3_l = nn.Sequential(
|
206 |
+
nn.Conv2d(256,256,3,padding=1),
|
207 |
+
nn.BatchNorm2d(256),
|
208 |
+
nn.ReLU(inplace=True),
|
209 |
+
nn.Conv2d(256,256,3,padding=1),
|
210 |
+
nn.BatchNorm2d(256),
|
211 |
+
nn.ReLU(inplace=True),
|
212 |
+
nn.Conv2d(256,128,3,padding=1),
|
213 |
+
nn.BatchNorm2d(128),
|
214 |
+
nn.ReLU(inplace=True))
|
215 |
+
self.decoder2_l = nn.Sequential(
|
216 |
+
nn.Conv2d(128,128,3,padding=1),
|
217 |
+
nn.BatchNorm2d(128),
|
218 |
+
nn.ReLU(inplace=True),
|
219 |
+
nn.Conv2d(128,128,3,padding=1),
|
220 |
+
nn.BatchNorm2d(128),
|
221 |
+
nn.ReLU(inplace=True),
|
222 |
+
nn.Conv2d(128,64,3,padding=1),
|
223 |
+
nn.BatchNorm2d(64),
|
224 |
+
nn.ReLU(inplace=True))
|
225 |
+
self.decoder1_l = nn.Sequential(
|
226 |
+
nn.Conv2d(64,64,3,padding=1),
|
227 |
+
nn.BatchNorm2d(64),
|
228 |
+
nn.ReLU(inplace=True),
|
229 |
+
nn.Conv2d(64,64,3,padding=1),
|
230 |
+
nn.BatchNorm2d(64),
|
231 |
+
nn.ReLU(inplace=True),
|
232 |
+
nn.Conv2d(64,64,3,padding=1),
|
233 |
+
nn.BatchNorm2d(64),
|
234 |
+
nn.ReLU(inplace=True))
|
235 |
+
self.decoder0_l = nn.Sequential(
|
236 |
+
nn.Conv2d(64,64,3,padding=1),
|
237 |
+
nn.BatchNorm2d(64),
|
238 |
+
nn.ReLU(inplace=True),
|
239 |
+
nn.Conv2d(64,64,3,padding=1),
|
240 |
+
nn.BatchNorm2d(64),
|
241 |
+
nn.ReLU(inplace=True))
|
242 |
+
self.decoder_final_l = nn.Conv2d(64,1,3,padding=1)
|
243 |
+
|
244 |
+
|
245 |
+
def forward(self, input):
|
246 |
+
##########################
|
247 |
+
### Encoder part - RESNET
|
248 |
+
##########################
|
249 |
+
e0 = self.encoder0(input)
|
250 |
+
e0p, id0 = self.mp0(e0)
|
251 |
+
e1p, id1 = self.mp1(e0p)
|
252 |
+
e1 = self.encoder1(e1p)
|
253 |
+
e2p, id2 = self.mp2(e1)
|
254 |
+
e2 = self.encoder2(e2p)
|
255 |
+
e3p, id3 = self.mp3(e2)
|
256 |
+
e3 = self.encoder3(e3p)
|
257 |
+
e4p, id4 = self.mp4(e3)
|
258 |
+
e4 = self.encoder4(e4p)
|
259 |
+
###########################
|
260 |
+
### Decoder part - Global
|
261 |
+
###########################
|
262 |
+
d4_g = self.decoder4_g(e4)
|
263 |
+
d3_g = self.decoder3_g(d4_g)
|
264 |
+
d2_g, global_sigmoid_side2 = self.dbfi_2(d3_g, e4)
|
265 |
+
d2_g = self.decoder2_g(d2_g)
|
266 |
+
d1_g, global_sigmoid_side1 = self.dbfi_1(d2_g, e4)
|
267 |
+
d1_g = self.decoder1_g(d1_g)
|
268 |
+
d0_g, global_sigmoid_side0 = self.dbfi_0(d1_g, e4)
|
269 |
+
d0_g = self.decoder0_g(d0_g)
|
270 |
+
global_sigmoid = d0_g
|
271 |
+
###########################
|
272 |
+
### Decoder part - Local
|
273 |
+
###########################
|
274 |
+
d4_l = self.decoder4_l(e4)
|
275 |
+
d4_l = F.max_unpool2d(d4_l, id4, kernel_size=2, stride=2)
|
276 |
+
d3_l = self.tfi_3(d4_g, d4_l, e3)
|
277 |
+
d3_l = self.decoder3_l(d3_l)
|
278 |
+
d3_l = F.max_unpool2d(d3_l, id3, kernel_size=2, stride=2)
|
279 |
+
d2_l = self.tfi_2(d3_g, d3_l, e2)
|
280 |
+
d2_l = self.sbfi_2(d2_l, e0)
|
281 |
+
d2_l = self.decoder2_l(d2_l)
|
282 |
+
d2_l = F.max_unpool2d(d2_l, id2, kernel_size=2, stride=2)
|
283 |
+
d1_l = self.tfi_1(d2_g, d2_l, e1)
|
284 |
+
d1_l = self.sbfi_1(d1_l, e0)
|
285 |
+
d1_l = self.decoder1_l(d1_l)
|
286 |
+
d1_l = F.max_unpool2d(d1_l, id1, kernel_size=2, stride=2)
|
287 |
+
d0_l = self.tfi_0(d1_g, d1_l, e0p)
|
288 |
+
d0_l = self.sbfi_0(d0_l, e0)
|
289 |
+
d0_l = self.decoder0_l(d0_l)
|
290 |
+
d0_l = F.max_unpool2d(d0_l, id0, kernel_size=2, stride=2)
|
291 |
+
d0_l = self.decoder_final_l(d0_l)
|
292 |
+
local_sigmoid = F.sigmoid(d0_l)
|
293 |
+
##########################
|
294 |
+
### Fusion net - G/L
|
295 |
+
##########################
|
296 |
+
fusion_sigmoid = get_masked_local_from_global(global_sigmoid, local_sigmoid)
|
297 |
+
return global_sigmoid, local_sigmoid, fusion_sigmoid, global_sigmoid_side2, global_sigmoid_side1, global_sigmoid_side0
|
298 |
+
|
299 |
+
|
300 |
+
|
301 |
+
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
|
302 |
+
"""3x3 convolution with padding"""
|
303 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
304 |
+
padding=dilation, groups=groups, bias=False, dilation=dilation)
|
305 |
+
|
306 |
+
|
307 |
+
def conv1x1(in_planes, out_planes, stride=1):
|
308 |
+
"""1x1 convolution"""
|
309 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
310 |
+
|
311 |
+
|
312 |
+
class BasicBlock(nn.Module):
|
313 |
+
expansion: int = 1
|
314 |
+
|
315 |
+
def __init__(
|
316 |
+
self,
|
317 |
+
inplanes: int,
|
318 |
+
planes: int,
|
319 |
+
stride: int = 1,
|
320 |
+
downsample: Optional[nn.Module] = None,
|
321 |
+
groups: int = 1,
|
322 |
+
base_width: int = 64,
|
323 |
+
dilation: int = 1,
|
324 |
+
norm_layer: Optional[Callable[..., nn.Module]] = None
|
325 |
+
) -> None:
|
326 |
+
super(BasicBlock, self).__init__()
|
327 |
+
if norm_layer is None:
|
328 |
+
norm_layer = nn.BatchNorm2d
|
329 |
+
if groups != 1 or base_width != 64:
|
330 |
+
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
|
331 |
+
if dilation > 1:
|
332 |
+
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
|
333 |
+
|
334 |
+
self.conv1 = conv3x3(inplanes, planes, stride)
|
335 |
+
self.bn1 = norm_layer(planes)
|
336 |
+
self.relu = nn.ReLU(inplace=True)
|
337 |
+
self.conv2 = conv3x3(planes, planes)
|
338 |
+
self.bn2 = norm_layer(planes)
|
339 |
+
self.downsample = downsample
|
340 |
+
self.stride = stride
|
341 |
+
|
342 |
+
def forward(self, x):
|
343 |
+
identity = x
|
344 |
+
|
345 |
+
out = self.conv1(x)
|
346 |
+
out = self.bn1(out)
|
347 |
+
out = self.relu(out)
|
348 |
+
|
349 |
+
out = self.conv2(out)
|
350 |
+
out = self.bn2(out)
|
351 |
+
|
352 |
+
if self.downsample is not None:
|
353 |
+
identity = self.downsample(x)
|
354 |
+
out += identity
|
355 |
+
out = self.relu(out)
|
356 |
+
|
357 |
+
return out
|
358 |
+
|
359 |
+
|
360 |
+
class Bottleneck(nn.Module):
|
361 |
+
expansion = 4
|
362 |
+
__constants__ = ['downsample']
|
363 |
+
|
364 |
+
def __init__(self, inplanes, planes,stride=1, downsample=None, groups=1,
|
365 |
+
base_width=64, dilation=1, norm_layer=None):
|
366 |
+
super(Bottleneck, self).__init__()
|
367 |
+
if norm_layer is None:
|
368 |
+
norm_layer = nn.BatchNorm2d
|
369 |
+
width = int(planes * (base_width / 64.)) * groups
|
370 |
+
self.conv1 = conv1x1(inplanes, width)
|
371 |
+
self.bn1 = norm_layer(width)
|
372 |
+
self.conv2 = conv3x3(width, width, stride, groups, dilation)
|
373 |
+
self.bn2 = norm_layer(width)
|
374 |
+
self.conv3 = conv1x1(width, planes * self.expansion)
|
375 |
+
self.bn3 = norm_layer(planes * self.expansion)
|
376 |
+
self.relu = nn.ReLU(inplace=True)
|
377 |
+
self.downsample = downsample
|
378 |
+
self.stride = stride
|
379 |
+
|
380 |
+
def forward(self, x):
|
381 |
+
identity = x
|
382 |
+
|
383 |
+
out = self.conv1(x)
|
384 |
+
out = self.bn1(out)
|
385 |
+
out = self.relu(out)
|
386 |
+
|
387 |
+
out = self.conv2(out)
|
388 |
+
out = self.bn2(out)
|
389 |
+
out = self.relu(out)
|
390 |
+
out = self.attention(out)
|
391 |
+
|
392 |
+
out = self.conv3(out)
|
393 |
+
out = self.bn3(out)
|
394 |
+
if self.downsample is not None:
|
395 |
+
identity = self.downsample(x)
|
396 |
+
out += identity
|
397 |
+
out = self.relu(out)
|
398 |
+
|
399 |
+
return out
|
400 |
+
|
401 |
+
|
402 |
+
class ResNet(nn.Module):
|
403 |
+
|
404 |
+
def __init__(self, block, layers, zero_init_residual=False,
|
405 |
+
groups=1, width_per_group=64, replace_stride_with_dilation=None,
|
406 |
+
norm_layer=None):
|
407 |
+
super(ResNet, self).__init__()
|
408 |
+
if norm_layer is None:
|
409 |
+
norm_layer = nn.BatchNorm2d
|
410 |
+
self._norm_layer = norm_layer
|
411 |
+
self.inplanes = 64
|
412 |
+
self.dilation = 1
|
413 |
+
if replace_stride_with_dilation is None:
|
414 |
+
replace_stride_with_dilation = [False, False, False]
|
415 |
+
if len(replace_stride_with_dilation) != 3:
|
416 |
+
raise ValueError("replace_stride_with_dilation should be None "
|
417 |
+
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
|
418 |
+
self.groups = groups
|
419 |
+
self.base_width = width_per_group
|
420 |
+
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=1, padding=3,
|
421 |
+
bias=False)
|
422 |
+
self.bn1 = norm_layer(self.inplanes)
|
423 |
+
self.relu = nn.ReLU(inplace=True)
|
424 |
+
self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, return_indices=True)
|
425 |
+
self.maxpool2 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, return_indices=True)
|
426 |
+
self.maxpool3 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, return_indices=True)
|
427 |
+
self.maxpool4 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, return_indices=True)
|
428 |
+
self.maxpool5 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, return_indices=True)
|
429 |
+
#pdb.set_trace()
|
430 |
+
self.layer1 = self._make_layer(block, 64, layers[0])
|
431 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=1,
|
432 |
+
dilate=replace_stride_with_dilation[0])
|
433 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=1,
|
434 |
+
dilate=replace_stride_with_dilation[1])
|
435 |
+
self.layer4 = self._make_layer(block, 512, layers[3], stride=1,
|
436 |
+
dilate=replace_stride_with_dilation[2])
|
437 |
+
|
438 |
+
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
439 |
+
self.fc = nn.Linear(512 * block.expansion, 1000)
|
440 |
+
|
441 |
+
for m in self.modules():
|
442 |
+
if isinstance(m, nn.Conv2d):
|
443 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
444 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
445 |
+
nn.init.constant_(m.weight, 1)
|
446 |
+
nn.init.constant_(m.bias, 0)
|
447 |
+
if zero_init_residual:
|
448 |
+
for m in self.modules():
|
449 |
+
if isinstance(m, Bottleneck):
|
450 |
+
nn.init.constant_(m.bn3.weight, 0)
|
451 |
+
elif isinstance(m, BasicBlock):
|
452 |
+
nn.init.constant_(m.bn2.weight, 0)
|
453 |
+
|
454 |
+
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
|
455 |
+
norm_layer = self._norm_layer
|
456 |
+
downsample = None
|
457 |
+
previous_dilation = self.dilation
|
458 |
+
if dilate:
|
459 |
+
self.dilation *= stride
|
460 |
+
stride = 1
|
461 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
462 |
+
downsample = nn.Sequential(
|
463 |
+
conv1x1(self.inplanes, planes * block.expansion, stride),
|
464 |
+
norm_layer(planes * block.expansion),
|
465 |
+
)
|
466 |
+
|
467 |
+
layers = []
|
468 |
+
layers.append(block(self.inplanes, planes,stride, downsample, self.groups,
|
469 |
+
self.base_width, previous_dilation, norm_layer))
|
470 |
+
self.inplanes = planes * block.expansion
|
471 |
+
for _ in range(1, blocks):
|
472 |
+
layers.append(block(self.inplanes, planes,groups=self.groups,
|
473 |
+
base_width=self.base_width, dilation=self.dilation,
|
474 |
+
norm_layer=norm_layer))
|
475 |
+
|
476 |
+
return nn.Sequential(*layers)
|
477 |
+
|
478 |
+
def _forward_impl(self, x):
|
479 |
+
x1 = self.conv1(x)
|
480 |
+
x1 = self.bn1(x1)
|
481 |
+
x1 = self.relu(x1)
|
482 |
+
x1, idx1 = self.maxpool1(x1)
|
483 |
+
|
484 |
+
x2, idx2 = self.maxpool2(x1)
|
485 |
+
x2 = self.layer1(x2)
|
486 |
+
|
487 |
+
x3, idx3 = self.maxpool3(x2)
|
488 |
+
x3 = self.layer2(x3)
|
489 |
+
|
490 |
+
x4, idx4 = self.maxpool4(x3)
|
491 |
+
x4 = self.layer3(x4)
|
492 |
+
|
493 |
+
x5, idx5 = self.maxpool5(x4)
|
494 |
+
x5 = self.layer4(x5)
|
495 |
+
|
496 |
+
x_cls = self.avgpool(x5)
|
497 |
+
x_cls = torch.flatten(x_cls, 1)
|
498 |
+
x_cls = self.fc(x_cls)
|
499 |
+
|
500 |
+
return x_cls
|
501 |
+
|
502 |
+
def forward(self, x):
|
503 |
+
return self._forward_impl(x)
|
504 |
+
|
505 |
+
|
506 |
+
def resnet34_mp(**kwargs):
|
507 |
+
r"""ResNet-34 model from
|
508 |
+
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`
|
509 |
+
"""
|
510 |
+
model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
|
511 |
+
checkpoint = torch.load("checkpoints/r34mp_pretrained_imagenet.pth.tar")
|
512 |
+
model.load_state_dict(checkpoint)
|
513 |
+
return model
|
514 |
+
|
515 |
+
##############################
|
516 |
+
### Training loses for P3M-NET
|
517 |
+
##############################
|
518 |
+
def get_crossentropy_loss(gt,pre):
|
519 |
+
gt_copy = gt.clone()
|
520 |
+
gt_copy[gt_copy==0] = 0
|
521 |
+
gt_copy[gt_copy==255] = 2
|
522 |
+
gt_copy[gt_copy>2] = 1
|
523 |
+
gt_copy = gt_copy.long()
|
524 |
+
gt_copy = gt_copy[:,0,:,:]
|
525 |
+
criterion = nn.CrossEntropyLoss()
|
526 |
+
entropy_loss = criterion(pre, gt_copy)
|
527 |
+
return entropy_loss
|
528 |
+
|
529 |
+
def get_alpha_loss(predict, alpha, trimap):
|
530 |
+
weighted = torch.zeros(trimap.shape).cuda()
|
531 |
+
weighted[trimap == 128] = 1.
|
532 |
+
alpha_f = alpha / 255.
|
533 |
+
alpha_f = alpha_f.cuda()
|
534 |
+
diff = predict - alpha_f
|
535 |
+
diff = diff * weighted
|
536 |
+
alpha_loss = torch.sqrt(diff ** 2 + 1e-12)
|
537 |
+
alpha_loss_weighted = alpha_loss.sum() / (weighted.sum() + 1.)
|
538 |
+
return alpha_loss_weighted
|
539 |
+
|
540 |
+
def get_alpha_loss_whole_img(predict, alpha):
|
541 |
+
weighted = torch.ones(alpha.shape).cuda()
|
542 |
+
alpha_f = alpha / 255.
|
543 |
+
alpha_f = alpha_f.cuda()
|
544 |
+
diff = predict - alpha_f
|
545 |
+
alpha_loss = torch.sqrt(diff ** 2 + 1e-12)
|
546 |
+
alpha_loss = alpha_loss.sum()/(weighted.sum())
|
547 |
+
return alpha_loss
|
548 |
+
|
549 |
+
## Laplacian loss is refer to
|
550 |
+
## https://gist.github.com/MarcoForte/a07c40a2b721739bb5c5987671aa5270
|
551 |
+
def build_gauss_kernel(size=5, sigma=1.0, n_channels=1, cuda=False):
|
552 |
+
if size % 2 != 1:
|
553 |
+
raise ValueError("kernel size must be uneven")
|
554 |
+
grid = np.float32(np.mgrid[0:size,0:size].T)
|
555 |
+
gaussian = lambda x: np.exp((x - size//2)**2/(-2*sigma**2))**2
|
556 |
+
kernel = np.sum(gaussian(grid), axis=2)
|
557 |
+
kernel /= np.sum(kernel)
|
558 |
+
kernel = np.tile(kernel, (n_channels, 1, 1))
|
559 |
+
kernel = torch.FloatTensor(kernel[:, None, :, :]).cuda()
|
560 |
+
return Variable(kernel, requires_grad=False)
|
561 |
+
|
562 |
+
def conv_gauss(img, kernel):
|
563 |
+
""" convolve img with a gaussian kernel that has been built with build_gauss_kernel """
|
564 |
+
n_channels, _, kw, kh = kernel.shape
|
565 |
+
img = fnn.pad(img, (kw//2, kh//2, kw//2, kh//2), mode='replicate')
|
566 |
+
return fnn.conv2d(img, kernel, groups=n_channels)
|
567 |
+
|
568 |
+
def laplacian_pyramid(img, kernel, max_levels=5):
|
569 |
+
current = img
|
570 |
+
pyr = []
|
571 |
+
for level in range(max_levels):
|
572 |
+
filtered = conv_gauss(current, kernel)
|
573 |
+
diff = current - filtered
|
574 |
+
pyr.append(diff)
|
575 |
+
current = fnn.avg_pool2d(filtered, 2)
|
576 |
+
pyr.append(current)
|
577 |
+
return pyr
|
578 |
+
|
579 |
+
def get_laplacian_loss(predict, alpha, trimap):
|
580 |
+
weighted = torch.zeros(trimap.shape).cuda()
|
581 |
+
weighted[trimap == 128] = 1.
|
582 |
+
alpha_f = alpha / 255.
|
583 |
+
alpha_f = alpha_f.cuda()
|
584 |
+
alpha_f = alpha_f.clone()*weighted
|
585 |
+
predict = predict.clone()*weighted
|
586 |
+
gauss_kernel = build_gauss_kernel(size=5, sigma=1.0, n_channels=1, cuda=True)
|
587 |
+
pyr_alpha = laplacian_pyramid(alpha_f, gauss_kernel, 5)
|
588 |
+
pyr_predict = laplacian_pyramid(predict, gauss_kernel, 5)
|
589 |
+
laplacian_loss_weighted = sum(fnn.l1_loss(a, b) for a, b in zip(pyr_alpha, pyr_predict))
|
590 |
+
return laplacian_loss_weighted
|
591 |
+
|
592 |
+
def get_laplacian_loss_whole_img(predict, alpha):
|
593 |
+
alpha_f = alpha / 255.
|
594 |
+
alpha_f = alpha_f.cuda()
|
595 |
+
gauss_kernel = build_gauss_kernel(size=5, sigma=1.0, n_channels=1, cuda=True)
|
596 |
+
pyr_alpha = laplacian_pyramid(alpha_f, gauss_kernel, 5)
|
597 |
+
pyr_predict = laplacian_pyramid(predict, gauss_kernel, 5)
|
598 |
+
laplacian_loss = sum(fnn.l1_loss(a, b) for a, b in zip(pyr_alpha, pyr_predict))
|
599 |
+
return laplacian_loss
|
600 |
+
|
601 |
+
def get_composition_loss_whole_img(img, alpha, fg, bg, predict):
|
602 |
+
weighted = torch.ones(alpha.shape).cuda()
|
603 |
+
predict_3 = torch.cat((predict, predict, predict), 1)
|
604 |
+
comp = predict_3 * fg + (1. - predict_3) * bg
|
605 |
+
comp_loss = torch.sqrt((comp - img) ** 2 + 1e-12)
|
606 |
+
comp_loss = comp_loss.sum()/(weighted.sum())
|
607 |
+
return comp_loss
|
608 |
+
|
609 |
+
##############################
|
610 |
+
### Test loss for matting
|
611 |
+
##############################
|
612 |
+
def calculate_sad_mse_mad(predict_old,alpha,trimap):
|
613 |
+
predict = np.copy(predict_old)
|
614 |
+
pixel = float((trimap == 128).sum())
|
615 |
+
predict[trimap == 255] = 1.
|
616 |
+
predict[trimap == 0 ] = 0.
|
617 |
+
sad_diff = np.sum(np.abs(predict - alpha))/1000
|
618 |
+
if pixel==0:
|
619 |
+
pixel = trimap.shape[0]*trimap.shape[1]-float((trimap==255).sum())-float((trimap==0).sum())
|
620 |
+
mse_diff = np.sum((predict - alpha) ** 2)/pixel
|
621 |
+
mad_diff = np.sum(np.abs(predict - alpha))/pixel
|
622 |
+
return sad_diff, mse_diff, mad_diff
|
623 |
+
|
624 |
+
def calculate_sad_mse_mad_whole_img(predict, alpha):
|
625 |
+
pixel = predict.shape[0]*predict.shape[1]
|
626 |
+
sad_diff = np.sum(np.abs(predict - alpha))/1000
|
627 |
+
mse_diff = np.sum((predict - alpha) ** 2)/pixel
|
628 |
+
mad_diff = np.sum(np.abs(predict - alpha))/pixel
|
629 |
+
return sad_diff, mse_diff, mad_diff
|
630 |
+
|
631 |
+
def calculate_sad_fgbg(predict, alpha, trimap):
|
632 |
+
sad_diff = np.abs(predict-alpha)
|
633 |
+
weight_fg = np.zeros(predict.shape)
|
634 |
+
weight_bg = np.zeros(predict.shape)
|
635 |
+
weight_trimap = np.zeros(predict.shape)
|
636 |
+
weight_fg[trimap==255] = 1.
|
637 |
+
weight_bg[trimap==0 ] = 1.
|
638 |
+
weight_trimap[trimap==128 ] = 1.
|
639 |
+
sad_fg = np.sum(sad_diff*weight_fg)/1000
|
640 |
+
sad_bg = np.sum(sad_diff*weight_bg)/1000
|
641 |
+
sad_trimap = np.sum(sad_diff*weight_trimap)/1000
|
642 |
+
return sad_fg, sad_bg
|
643 |
+
|
644 |
+
def compute_gradient_whole_image(pd, gt):
|
645 |
+
from scipy.ndimage import gaussian_filter
|
646 |
+
pd_x = gaussian_filter(pd, sigma=1.4, order=[1, 0], output=np.float32)
|
647 |
+
pd_y = gaussian_filter(pd, sigma=1.4, order=[0, 1], output=np.float32)
|
648 |
+
gt_x = gaussian_filter(gt, sigma=1.4, order=[1, 0], output=np.float32)
|
649 |
+
gt_y = gaussian_filter(gt, sigma=1.4, order=[0, 1], output=np.float32)
|
650 |
+
pd_mag = np.sqrt(pd_x**2 + pd_y**2)
|
651 |
+
gt_mag = np.sqrt(gt_x**2 + gt_y**2)
|
652 |
+
|
653 |
+
error_map = np.square(pd_mag - gt_mag)
|
654 |
+
loss = np.sum(error_map) / 10
|
655 |
+
return loss
|
656 |
+
|
657 |
+
def compute_connectivity_loss_whole_image(pd, gt, step=0.1):
|
658 |
+
|
659 |
+
from scipy.ndimage import morphology
|
660 |
+
from skimage.measure import label, regionprops
|
661 |
+
h, w = pd.shape
|
662 |
+
thresh_steps = np.arange(0, 1.1, step)
|
663 |
+
l_map = -1 * np.ones((h, w), dtype=np.float32)
|
664 |
+
lambda_map = np.ones((h, w), dtype=np.float32)
|
665 |
+
for i in range(1, thresh_steps.size):
|
666 |
+
pd_th = pd >= thresh_steps[i]
|
667 |
+
gt_th = gt >= thresh_steps[i]
|
668 |
+
label_image = label(pd_th & gt_th, connectivity=1)
|
669 |
+
cc = regionprops(label_image)
|
670 |
+
size_vec = np.array([c.area for c in cc])
|
671 |
+
if len(size_vec) == 0:
|
672 |
+
continue
|
673 |
+
max_id = np.argmax(size_vec)
|
674 |
+
coords = cc[max_id].coords
|
675 |
+
omega = np.zeros((h, w), dtype=np.float32)
|
676 |
+
omega[coords[:, 0], coords[:, 1]] = 1
|
677 |
+
flag = (l_map == -1) & (omega == 0)
|
678 |
+
l_map[flag == 1] = thresh_steps[i-1]
|
679 |
+
dist_maps = morphology.distance_transform_edt(omega==0)
|
680 |
+
dist_maps = dist_maps / dist_maps.max()
|
681 |
+
l_map[l_map == -1] = 1
|
682 |
+
d_pd = pd - l_map
|
683 |
+
d_gt = gt - l_map
|
684 |
+
phi_pd = 1 - d_pd * (d_pd >= 0.15).astype(np.float32)
|
685 |
+
phi_gt = 1 - d_gt * (d_gt >= 0.15).astype(np.float32)
|
686 |
+
loss = np.sum(np.abs(phi_pd - phi_gt)) / 1000
|
687 |
+
return loss
|
688 |
+
|
689 |
+
|
690 |
+
|
691 |
+
def gen_trimap_from_segmap_e2e(segmap):
|
692 |
+
trimap = np.argmax(segmap, axis=1)[0]
|
693 |
+
trimap = trimap.astype(np.int64)
|
694 |
+
trimap[trimap==1]=128
|
695 |
+
trimap[trimap==2]=255
|
696 |
+
return trimap.astype(np.uint8)
|
697 |
+
|
698 |
+
def get_masked_local_from_global(global_sigmoid, local_sigmoid):
|
699 |
+
values, index = torch.max(global_sigmoid,1)
|
700 |
+
index = index[:,None,:,:].float()
|
701 |
+
### index <===> [0, 1, 2]
|
702 |
+
### bg_mask <===> [1, 0, 0]
|
703 |
+
bg_mask = index.clone()
|
704 |
+
bg_mask[bg_mask==2]=1
|
705 |
+
bg_mask = 1- bg_mask
|
706 |
+
### trimap_mask <===> [0, 1, 0]
|
707 |
+
trimap_mask = index.clone()
|
708 |
+
trimap_mask[trimap_mask==2]=0
|
709 |
+
### fg_mask <===> [0, 0, 1]
|
710 |
+
fg_mask = index.clone()
|
711 |
+
fg_mask[fg_mask==1]=0
|
712 |
+
fg_mask[fg_mask==2]=1
|
713 |
+
fusion_sigmoid = local_sigmoid*trimap_mask+fg_mask
|
714 |
+
return fusion_sigmoid
|
715 |
+
|
716 |
+
def get_masked_local_from_global_test(global_result, local_result):
|
717 |
+
weighted_global = np.ones(global_result.shape)
|
718 |
+
weighted_global[global_result==255] = 0
|
719 |
+
weighted_global[global_result==0] = 0
|
720 |
+
fusion_result = global_result*(1.-weighted_global)/255+local_result*weighted_global
|
721 |
+
return fusion_result
|
722 |
+
def inference_once( model, scale_img, scale_trimap=None):
|
723 |
+
pred_list = []
|
724 |
+
tensor_img = torch.from_numpy(scale_img[:, :, :]).permute(2, 0, 1).cuda()
|
725 |
+
input_t = tensor_img
|
726 |
+
input_t = input_t/255.0
|
727 |
+
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
728 |
+
std=[0.229, 0.224, 0.225])
|
729 |
+
input_t = normalize(input_t)
|
730 |
+
input_t = input_t.unsqueeze(0).float()
|
731 |
+
# pred_global, pred_local, pred_fusion = model(input_t)[:3]
|
732 |
+
pred_fusion = model(input_t)[:3]
|
733 |
+
pred_global = pred_fusion
|
734 |
+
pred_local = pred_fusion
|
735 |
+
|
736 |
+
pred_global = pred_global.data.cpu().numpy()
|
737 |
+
pred_global = gen_trimap_from_segmap_e2e(pred_global)
|
738 |
+
pred_local = pred_local.data.cpu().numpy()[0,0,:,:]
|
739 |
+
pred_fusion = pred_fusion.data.cpu().numpy()[0,0,:,:]
|
740 |
+
return pred_global, pred_local, pred_fusion
|
741 |
+
|
742 |
+
# def inference_img( test_choice,model, img):
|
743 |
+
# h, w, c = img.shape
|
744 |
+
# new_h = min(config['datasets'].MAX_SIZE_H, h - (h % 32))
|
745 |
+
# new_w = min(config['datasets'].MAX_SIZE_W, w - (w % 32))
|
746 |
+
# if test_choice=='HYBRID':
|
747 |
+
# global_ratio = 1/2
|
748 |
+
# local_ratio = 1
|
749 |
+
# resize_h = int(h*global_ratio)
|
750 |
+
# resize_w = int(w*global_ratio)
|
751 |
+
# new_h = min(config['datasets'].MAX_SIZE_H, resize_h - (resize_h % 32))
|
752 |
+
# new_w = min(config['datasets'].MAX_SIZE_W, resize_w - (resize_w % 32))
|
753 |
+
# scale_img = resize(img,(new_h,new_w))*255.0
|
754 |
+
# pred_coutour_1, pred_retouching_1, pred_fusion_1 = inference_once( model, scale_img)
|
755 |
+
# pred_coutour_1 = resize(pred_coutour_1,(h,w))*255.0
|
756 |
+
# resize_h = int(h*local_ratio)
|
757 |
+
# resize_w = int(w*local_ratio)
|
758 |
+
# new_h = min(config['datasets'].MAX_SIZE_H, resize_h - (resize_h % 32))
|
759 |
+
# new_w = min(config['datasets'].MAX_SIZE_W, resize_w - (resize_w % 32))
|
760 |
+
# scale_img = resize(img,(new_h,new_w))*255.0
|
761 |
+
# pred_coutour_2, pred_retouching_2, pred_fusion_2 = inference_once( model, scale_img)
|
762 |
+
# pred_retouching_2 = resize(pred_retouching_2,(h,w))
|
763 |
+
# pred_fusion = get_masked_local_from_global_test(pred_coutour_1, pred_retouching_2)
|
764 |
+
# return pred_fusion
|
765 |
+
# else:
|
766 |
+
# resize_h = int(h/2)
|
767 |
+
# resize_w = int(w/2)
|
768 |
+
# new_h = min(config['datasets'].MAX_SIZE_H, resize_h - (resize_h % 32))
|
769 |
+
# new_w = min(config['datasets'].MAX_SIZE_W, resize_w - (resize_w % 32))
|
770 |
+
# scale_img = resize(img,(new_h,new_w))*255.0
|
771 |
+
# pred_global, pred_local, pred_fusion = inference_once( model, scale_img)
|
772 |
+
# pred_local = resize(pred_local,(h,w))
|
773 |
+
# pred_global = resize(pred_global,(h,w))*255.0
|
774 |
+
# pred_fusion = resize(pred_fusion,(h,w))
|
775 |
+
# return pred_fusion
|
776 |
+
|
777 |
+
|
778 |
+
def inference_img(model, img):
|
779 |
+
h,w,_ = img.shape
|
780 |
+
# print(img.shape)
|
781 |
+
if h%8!=0 or w%8!=0:
|
782 |
+
img=cv2.copyMakeBorder(img, 8-h%8, 0, 8-w%8, 0, cv2.BORDER_REFLECT)
|
783 |
+
# print(img.shape)
|
784 |
+
|
785 |
+
tensor_img = torch.from_numpy(img).permute(2, 0, 1).cuda()
|
786 |
+
input_t = tensor_img
|
787 |
+
input_t = input_t/255.0
|
788 |
+
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
789 |
+
std=[0.229, 0.224, 0.225])
|
790 |
+
input_t = normalize(input_t)
|
791 |
+
input_t = input_t.unsqueeze(0).float()
|
792 |
+
with torch.no_grad():
|
793 |
+
out=model(input_t)
|
794 |
+
# print("out",out.shape)
|
795 |
+
result = out[0][:,-h:,-w:].cpu().numpy()
|
796 |
+
# print(result.shape)
|
797 |
+
|
798 |
+
return result[0]
|
799 |
+
|
800 |
+
|
801 |
+
|
802 |
+
def test_am2k(model):
|
803 |
+
############################
|
804 |
+
# Some initial setting for paths
|
805 |
+
############################
|
806 |
+
ORIGINAL_PATH = config['datasets']['am2k']['validation_original']
|
807 |
+
MASK_PATH = config['datasets']['am2k']['validation_mask']
|
808 |
+
TRIMAP_PATH = config['datasets']['am2k']['validation_trimap']
|
809 |
+
img_paths = glob.glob(ORIGINAL_PATH+"/*.jpg")
|
810 |
+
|
811 |
+
############################
|
812 |
+
# Start testing
|
813 |
+
############################
|
814 |
+
sad_diffs = 0.
|
815 |
+
mse_diffs = 0.
|
816 |
+
mad_diffs = 0.
|
817 |
+
grad_diffs = 0.
|
818 |
+
conn_diffs = 0.
|
819 |
+
sad_trimap_diffs = 0.
|
820 |
+
mse_trimap_diffs = 0.
|
821 |
+
mad_trimap_diffs = 0.
|
822 |
+
sad_fg_diffs = 0.
|
823 |
+
sad_bg_diffs = 0.
|
824 |
+
|
825 |
+
|
826 |
+
total_number = len(img_paths)
|
827 |
+
log("===============================")
|
828 |
+
log(f'====> Start Testing\n\t--Dataset: AM2k\n\t-\n\t--Number: {total_number}')
|
829 |
+
|
830 |
+
for img_path in tqdm.tqdm(img_paths):
|
831 |
+
img_name=(img_path.split("/")[-1])[:-4]
|
832 |
+
alpha_path = MASK_PATH+img_name+'.png'
|
833 |
+
trimap_path = TRIMAP_PATH+img_name+'.png'
|
834 |
+
pil_img = Image.open(img_path)
|
835 |
+
img = np.array(pil_img)
|
836 |
+
trimap = np.array(Image.open(trimap_path))
|
837 |
+
alpha = np.array(Image.open(alpha_path))/255.
|
838 |
+
img = img[:,:,:3] if img.ndim>2 else img
|
839 |
+
trimap = trimap[:,:,0] if trimap.ndim>2 else trimap
|
840 |
+
alpha = alpha[:,:,0] if alpha.ndim>2 else alpha
|
841 |
+
|
842 |
+
with torch.no_grad():
|
843 |
+
torch.cuda.empty_cache()
|
844 |
+
predict = inference_img( model, img)
|
845 |
+
|
846 |
+
|
847 |
+
sad_trimap_diff, mse_trimap_diff, mad_trimap_diff = calculate_sad_mse_mad(predict, alpha, trimap)
|
848 |
+
sad_diff, mse_diff, mad_diff = calculate_sad_mse_mad_whole_img(predict, alpha)
|
849 |
+
sad_fg_diff, sad_bg_diff = calculate_sad_fgbg(predict, alpha, trimap)
|
850 |
+
conn_diff = compute_connectivity_loss_whole_image(predict, alpha)
|
851 |
+
grad_diff = compute_gradient_whole_image(predict, alpha)
|
852 |
+
|
853 |
+
log(f"[{img_paths.index(img_path)}/{total_number}]\nImage:{img_name}\nsad:{sad_diff}\nmse:{mse_diff}\nmad:{mad_diff}\nsad_trimap:{sad_trimap_diff}\nmse_trimap:{mse_trimap_diff}\nmad_trimap:{mad_trimap_diff}\nsad_fg:{sad_fg_diff}\nsad_bg:{sad_bg_diff}\nconn:{conn_diff}\ngrad:{grad_diff}\n-----------")
|
854 |
+
|
855 |
+
sad_diffs += sad_diff
|
856 |
+
mse_diffs += mse_diff
|
857 |
+
mad_diffs += mad_diff
|
858 |
+
mse_trimap_diffs += mse_trimap_diff
|
859 |
+
sad_trimap_diffs += sad_trimap_diff
|
860 |
+
mad_trimap_diffs += mad_trimap_diff
|
861 |
+
sad_fg_diffs += sad_fg_diff
|
862 |
+
sad_bg_diffs += sad_bg_diff
|
863 |
+
conn_diffs += conn_diff
|
864 |
+
grad_diffs += grad_diff
|
865 |
+
Image.fromarray(np.uint8(predict*255)).save(f"test/{img_name}.png")
|
866 |
+
|
867 |
+
|
868 |
+
log("===============================")
|
869 |
+
log(f"Testing numbers: {total_number}")
|
870 |
+
|
871 |
+
|
872 |
+
log("SAD: {}".format(sad_diffs / total_number))
|
873 |
+
log("MSE: {}".format(mse_diffs / total_number))
|
874 |
+
log("MAD: {}".format(mad_diffs / total_number))
|
875 |
+
log("GRAD: {}".format(grad_diffs / total_number))
|
876 |
+
log("CONN: {}".format(conn_diffs / total_number))
|
877 |
+
log("SAD TRIMAP: {}".format(sad_trimap_diffs / total_number))
|
878 |
+
log("MSE TRIMAP: {}".format(mse_trimap_diffs / total_number))
|
879 |
+
log("MAD TRIMAP: {}".format(mad_trimap_diffs / total_number))
|
880 |
+
log("SAD FG: {}".format(sad_fg_diffs / total_number))
|
881 |
+
log("SAD BG: {}".format(sad_bg_diffs / total_number))
|
882 |
+
return sad_diffs/total_number,mse_diffs/total_number,grad_diffs/total_number
|
883 |
+
|
884 |
+
|
885 |
+
def test_p3m10k(model,dataset_choice, max_image=-1):
|
886 |
+
############################
|
887 |
+
# Some initial setting for paths
|
888 |
+
############################
|
889 |
+
if dataset_choice == 'P3M_500_P':
|
890 |
+
val_option = 'VAL500P'
|
891 |
+
else:
|
892 |
+
val_option = 'VAL500NP'
|
893 |
+
ORIGINAL_PATH = config['datasets']['p3m10k']+"/validation/"+config['datasets']['p3m10k_test'][val_option]['ORIGINAL_PATH']
|
894 |
+
MASK_PATH = config['datasets']['p3m10k']+"/validation/"+config['datasets']['p3m10k_test'][val_option]['MASK_PATH']
|
895 |
+
TRIMAP_PATH = config['datasets']['p3m10k']+"/validation/"+config['datasets']['p3m10k_test'][val_option]['TRIMAP_PATH']
|
896 |
+
############################
|
897 |
+
# Start testing
|
898 |
+
############################
|
899 |
+
sad_diffs = 0.
|
900 |
+
mse_diffs = 0.
|
901 |
+
mad_diffs = 0.
|
902 |
+
sad_trimap_diffs = 0.
|
903 |
+
mse_trimap_diffs = 0.
|
904 |
+
mad_trimap_diffs = 0.
|
905 |
+
sad_fg_diffs = 0.
|
906 |
+
sad_bg_diffs = 0.
|
907 |
+
conn_diffs = 0.
|
908 |
+
grad_diffs = 0.
|
909 |
+
model.eval()
|
910 |
+
img_paths = glob.glob(ORIGINAL_PATH+"/*.jpg")
|
911 |
+
if (max_image>1):
|
912 |
+
img_paths = img_paths[:max_image]
|
913 |
+
total_number = len(img_paths)
|
914 |
+
log("===============================")
|
915 |
+
log(f'====> Start Testing\n\t----Test: {dataset_choice}\n\t--Number: {total_number}')
|
916 |
+
|
917 |
+
for img_path in tqdm.tqdm(img_paths):
|
918 |
+
img_name=(img_path.split("/")[-1])[:-4]
|
919 |
+
alpha_path = MASK_PATH+img_name+'.png'
|
920 |
+
trimap_path = TRIMAP_PATH+img_name+'.png'
|
921 |
+
pil_img = Image.open(img_path)
|
922 |
+
img = np.array(pil_img)
|
923 |
+
|
924 |
+
trimap = np.array(Image.open(trimap_path))
|
925 |
+
alpha = np.array(Image.open(alpha_path))/255.
|
926 |
+
img = img[:,:,:3] if img.ndim>2 else img
|
927 |
+
trimap = trimap[:,:,0] if trimap.ndim>2 else trimap
|
928 |
+
alpha = alpha[:,:,0] if alpha.ndim>2 else alpha
|
929 |
+
with torch.no_grad():
|
930 |
+
torch.cuda.empty_cache()
|
931 |
+
start = time.time()
|
932 |
+
|
933 |
+
|
934 |
+
predict = inference_img( model, img) #HYBRID show less accuracy
|
935 |
+
|
936 |
+
# tensorimg=transforms.ToTensor()(pil_img)
|
937 |
+
# input_img=transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(tensorimg)
|
938 |
+
|
939 |
+
# predict = model(input_img.unsqueeze(0).to(device))[0][0].detach().cpu().numpy()
|
940 |
+
# if predict.shape!=(pil_img.height,pil_img.width):
|
941 |
+
# print("resize for ",img_path)
|
942 |
+
# predict = resize(predict,(pil_img.height,pil_img.width))
|
943 |
+
sad_trimap_diff, mse_trimap_diff, mad_trimap_diff = calculate_sad_mse_mad(predict, alpha, trimap)
|
944 |
+
sad_diff, mse_diff, mad_diff = calculate_sad_mse_mad_whole_img(predict, alpha)
|
945 |
+
|
946 |
+
sad_fg_diff, sad_bg_diff = calculate_sad_fgbg(predict, alpha, trimap)
|
947 |
+
conn_diff = compute_connectivity_loss_whole_image(predict, alpha)
|
948 |
+
grad_diff = compute_gradient_whole_image(predict, alpha)
|
949 |
+
log(f"[{img_paths.index(img_path)}/{total_number}]\nImage:{img_name}\nsad:{sad_diff}\nmse:{mse_diff}\nmad:{mad_diff}\nconn:{conn_diff}\ngrad:{grad_diff}\n-----------")
|
950 |
+
sad_diffs += sad_diff
|
951 |
+
mse_diffs += mse_diff
|
952 |
+
mad_diffs += mad_diff
|
953 |
+
mse_trimap_diffs += mse_trimap_diff
|
954 |
+
sad_trimap_diffs += sad_trimap_diff
|
955 |
+
mad_trimap_diffs += mad_trimap_diff
|
956 |
+
sad_fg_diffs += sad_fg_diff
|
957 |
+
sad_bg_diffs += sad_bg_diff
|
958 |
+
conn_diffs += conn_diff
|
959 |
+
grad_diffs += grad_diff
|
960 |
+
|
961 |
+
Image.fromarray(np.uint8(predict*255)).save(f"test/{img_name}.png")
|
962 |
+
|
963 |
+
log("===============================")
|
964 |
+
log(f"Testing numbers: {total_number}")
|
965 |
+
log("SAD: {}".format(sad_diffs / total_number))
|
966 |
+
log("MSE: {}".format(mse_diffs / total_number))
|
967 |
+
log("MAD: {}".format(mad_diffs / total_number))
|
968 |
+
log("SAD TRIMAP: {}".format(sad_trimap_diffs / total_number))
|
969 |
+
log("MSE TRIMAP: {}".format(mse_trimap_diffs / total_number))
|
970 |
+
log("MAD TRIMAP: {}".format(mad_trimap_diffs / total_number))
|
971 |
+
log("SAD FG: {}".format(sad_fg_diffs / total_number))
|
972 |
+
log("SAD BG: {}".format(sad_bg_diffs / total_number))
|
973 |
+
log("CONN: {}".format(conn_diffs / total_number))
|
974 |
+
log("GRAD: {}".format(grad_diffs / total_number))
|
975 |
+
|
976 |
+
return sad_diffs/total_number,mse_diffs/total_number,grad_diffs/total_number
|
977 |
+
|
978 |
+
def log(str):
|
979 |
+
print(str)
|
980 |
+
logging.info(str)
|
981 |
+
|
982 |
+
if __name__ == '__main__':
|
983 |
+
print('*********************************')
|
984 |
+
config = OmegaConf.load(os.path.join(os.path.dirname(
|
985 |
+
os.path.abspath(__file__)), "config/base.yaml"))
|
986 |
+
config=OmegaConf.merge(config,OmegaConf.from_cli())
|
987 |
+
print(config)
|
988 |
+
model = MaskForm()
|
989 |
+
model = model.to(device)
|
990 |
+
checkpoint = f"{config.checkpoint_dir}/{config.checkpoint}"
|
991 |
+
state_dict = torch.load(checkpoint, map_location=f'{device}')
|
992 |
+
print("loaded",checkpoint)
|
993 |
+
model.load_state_dict(state_dict)
|
994 |
+
model.eval()
|
995 |
+
logging.basicConfig(filename=f'report/{config.checkpoint.replace("/","--")}.report', encoding='utf-8',filemode='w', level=logging.INFO)
|
996 |
+
# ckpt = torch.load("checkpoints/p3mnet_pretrained_on_p3m10k.pth")
|
997 |
+
# model.load_state_dict(ckpt['state_dict'], strict=True)
|
998 |
+
# model = model.cuda()
|
999 |
+
if config.dataset_to_use =="AM2K":
|
1000 |
+
test_am2k(model)
|
1001 |
+
else:
|
1002 |
+
for dataset_choice in ['P3M_500_P','P3M_500_NP']:
|
1003 |
+
test_p3m10k(model,dataset_choice)
|
1004 |
+
|