zhaoyian01 commited on
Commit
6d1366a
1 Parent(s): d9b7cbf

Add application file

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. app.py +10 -0
  2. isegm/__init__.py +0 -0
  3. isegm/inference/clicker.py +118 -0
  4. isegm/inference/evaluation.py +197 -0
  5. isegm/inference/predictors/__init__.py +99 -0
  6. isegm/inference/predictors/base.py +191 -0
  7. isegm/inference/predictors/brs.py +307 -0
  8. isegm/inference/predictors/brs_functors.py +109 -0
  9. isegm/inference/predictors/brs_losses.py +58 -0
  10. isegm/inference/transforms/__init__.py +5 -0
  11. isegm/inference/transforms/base.py +38 -0
  12. isegm/inference/transforms/crops.py +97 -0
  13. isegm/inference/transforms/flip.py +37 -0
  14. isegm/inference/transforms/limit_longest_side.py +22 -0
  15. isegm/inference/transforms/zoom_in.py +190 -0
  16. isegm/inference/utils.py +149 -0
  17. isegm/model/__init__.py +0 -0
  18. isegm/model/build_sam.py +145 -0
  19. isegm/model/initializer.py +105 -0
  20. isegm/model/is_deeplab_model.py +25 -0
  21. isegm/model/is_hrformer_model.py +41 -0
  22. isegm/model/is_hrnet_model.py +26 -0
  23. isegm/model/is_model.py +114 -0
  24. isegm/model/is_plainvit_model.py +95 -0
  25. isegm/model/is_plainvit_model_lora.py +95 -0
  26. isegm/model/is_segformer_model.py +29 -0
  27. isegm/model/is_swinformer_model.py +21 -0
  28. isegm/model/is_text_graco_model.py +63 -0
  29. isegm/model/losses.py +195 -0
  30. isegm/model/metrics.py +101 -0
  31. isegm/model/modeling/__init__.py +0 -0
  32. isegm/model/modeling/basic_blocks.py +71 -0
  33. isegm/model/modeling/clip/__init__.py +1 -0
  34. isegm/model/modeling/clip/clip.py +245 -0
  35. isegm/model/modeling/clip/model.py +436 -0
  36. isegm/model/modeling/clip/simple_tokenizer.py +132 -0
  37. isegm/model/modeling/clip_text_encoding.py +29 -0
  38. isegm/model/modeling/deeplab_v3.py +176 -0
  39. isegm/model/modeling/hrformer.py +487 -0
  40. isegm/model/modeling/hrformer_helper/__init__.py +0 -0
  41. isegm/model/modeling/hrformer_helper/backbone_selector.py +54 -0
  42. isegm/model/modeling/hrformer_helper/hrt/__init__.py +0 -0
  43. isegm/model/modeling/hrformer_helper/hrt/hrt_backbone.py +661 -0
  44. isegm/model/modeling/hrformer_helper/hrt/hrt_config.py +123 -0
  45. isegm/model/modeling/hrformer_helper/hrt/logger.py +205 -0
  46. isegm/model/modeling/hrformer_helper/hrt/module_helper.py +310 -0
  47. isegm/model/modeling/hrformer_helper/hrt/modules/__init__.py +0 -0
  48. isegm/model/modeling/hrformer_helper/hrt/modules/bottleneck_block.py +128 -0
  49. isegm/model/modeling/hrformer_helper/hrt/modules/ffn_block.py +287 -0
  50. isegm/model/modeling/hrformer_helper/hrt/modules/multihead_attention.py +342 -0
app.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from web_app import GraCoWebApplication
2
+
3
+
4
+ def main():
5
+ app = GraCoWebApplication()
6
+ app.launch()
7
+
8
+
9
+ if __name__ == '__main__':
10
+ main()
isegm/__init__.py ADDED
File without changes
isegm/inference/clicker.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from copy import deepcopy
3
+ import cv2
4
+
5
+
6
+ class Clicker(object):
7
+ def __init__(self, gt_mask=None, init_clicks=None, ignore_label=-1, click_indx_offset=0):
8
+ self.click_indx_offset = click_indx_offset
9
+ if gt_mask is not None:
10
+ self.gt_mask = gt_mask == 1
11
+ self.not_ignore_mask = gt_mask != ignore_label
12
+ else:
13
+ self.gt_mask = None
14
+
15
+ self.reset_clicks()
16
+
17
+ if init_clicks is not None:
18
+ for click in init_clicks:
19
+ self.add_click(click)
20
+
21
+ def make_next_click(self, pred_mask):
22
+ assert self.gt_mask is not None
23
+ click = self._get_next_click(pred_mask)
24
+ self.add_click(click)
25
+
26
+ def get_clicks(self, clicks_limit=None):
27
+ return self.clicks_list[:clicks_limit]
28
+
29
+ def _get_next_click(self, pred_mask, padding=True):
30
+ fn_mask = np.logical_and(np.logical_and(self.gt_mask, np.logical_not(pred_mask)), self.not_ignore_mask)
31
+ fp_mask = np.logical_and(np.logical_and(np.logical_not(self.gt_mask), pred_mask), self.not_ignore_mask)
32
+
33
+ if padding:
34
+ fn_mask = np.pad(fn_mask, ((1, 1), (1, 1)), 'constant')
35
+ fp_mask = np.pad(fp_mask, ((1, 1), (1, 1)), 'constant')
36
+
37
+ fn_mask_dt = cv2.distanceTransform(fn_mask.astype(np.uint8), cv2.DIST_L2, 0)
38
+ fp_mask_dt = cv2.distanceTransform(fp_mask.astype(np.uint8), cv2.DIST_L2, 0)
39
+
40
+ if padding:
41
+ fn_mask_dt = fn_mask_dt[1:-1, 1:-1]
42
+ fp_mask_dt = fp_mask_dt[1:-1, 1:-1]
43
+
44
+ fn_mask_dt = fn_mask_dt * self.not_clicked_map
45
+ fp_mask_dt = fp_mask_dt * self.not_clicked_map
46
+
47
+ fn_max_dist = np.max(fn_mask_dt)
48
+ fp_max_dist = np.max(fp_mask_dt)
49
+
50
+ is_positive = fn_max_dist > fp_max_dist
51
+ if is_positive:
52
+ coords_y, coords_x = np.where(fn_mask_dt == fn_max_dist) # coords is [y, x]
53
+ else:
54
+ coords_y, coords_x = np.where(fp_mask_dt == fp_max_dist) # coords is [y, x]
55
+
56
+ return Click(is_positive=is_positive, coords=(coords_y[0], coords_x[0]))
57
+
58
+ def add_click(self, click):
59
+ coords = click.coords
60
+
61
+ click.indx = self.click_indx_offset + self.num_pos_clicks + self.num_neg_clicks
62
+ if click.is_positive:
63
+ self.num_pos_clicks += 1
64
+ else:
65
+ self.num_neg_clicks += 1
66
+
67
+ self.clicks_list.append(click)
68
+ if self.gt_mask is not None:
69
+ self.not_clicked_map[coords[0], coords[1]] = False
70
+
71
+ def _remove_last_click(self):
72
+ click = self.clicks_list.pop()
73
+ coords = click.coords
74
+
75
+ if click.is_positive:
76
+ self.num_pos_clicks -= 1
77
+ else:
78
+ self.num_neg_clicks -= 1
79
+
80
+ if self.gt_mask is not None:
81
+ self.not_clicked_map[coords[0], coords[1]] = True
82
+
83
+ def reset_clicks(self):
84
+ if self.gt_mask is not None:
85
+ self.not_clicked_map = np.ones_like(self.gt_mask, dtype=bool)
86
+
87
+ self.num_pos_clicks = 0
88
+ self.num_neg_clicks = 0
89
+
90
+ self.clicks_list = []
91
+
92
+ def get_state(self):
93
+ return deepcopy(self.clicks_list)
94
+
95
+ def set_state(self, state):
96
+ self.reset_clicks()
97
+ for click in state:
98
+ self.add_click(click)
99
+
100
+ def __len__(self):
101
+ return len(self.clicks_list)
102
+
103
+
104
+ class Click:
105
+ def __init__(self, is_positive, coords, indx=None):
106
+ self.is_positive = is_positive
107
+ self.coords = coords
108
+ self.indx = indx
109
+
110
+ @property
111
+ def coords_and_indx(self):
112
+ return (*self.coords, self.indx)
113
+
114
+ def copy(self, **kwargs):
115
+ self_copy = deepcopy(self)
116
+ for k, v in kwargs.items():
117
+ setattr(self_copy, k, v)
118
+ return self_copy
isegm/inference/evaluation.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from time import time
2
+
3
+ import numpy as np
4
+ import torch
5
+ import cv2
6
+ from isegm.inference import utils
7
+ from isegm.inference.clicker import Click, Clicker
8
+
9
+ try:
10
+ get_ipython()
11
+ from tqdm import tqdm_notebook as tqdm
12
+ except NameError:
13
+ from tqdm import tqdm
14
+
15
+
16
+ def evaluate_dataset(dataset, predictor, sam_type=None, oracle=False, gra_oracle=False, **kwargs):
17
+ all_ious = []
18
+ start_time = time()
19
+ all_gras = {}
20
+
21
+ for index in tqdm(range(len(dataset)), leave=False):
22
+ sample = dataset.get_sample(index)
23
+
24
+ for object_id in sample.objects_ids:
25
+ if gra_oracle:
26
+ sample_ious, gra_idx = evaluate_sample_oracle(sample.image, sample.gt_mask(object_id), predictor,
27
+ sample_id=index, sam_type=sam_type, oracle=oracle, **kwargs)
28
+ all_gras[gra_idx] = all_gras.get(gra_idx, 0) + 1
29
+ else:
30
+ _, sample_ious, _ = evaluate_sample(sample.image, sample.gt_mask(object_id), predictor,
31
+ sample_id=index, sam_type=sam_type, oracle=oracle, **kwargs)
32
+ all_ious.append(sample_ious)
33
+ end_time = time()
34
+ elapsed_time = end_time - start_time
35
+ if len(all_gras) > 0:
36
+ print(all_gras)
37
+
38
+ return all_ious, elapsed_time
39
+
40
+
41
+ def evaluate_sample(image, gt_mask, predictor, max_iou_thr,
42
+ pred_thr=0.49, min_clicks=1, max_clicks=20,
43
+ sample_id=None, sam_type=False, oracle=False, callback=None):
44
+ clicker = Clicker(gt_mask=gt_mask)
45
+ pred_mask = np.zeros_like(gt_mask)
46
+ ious_list = []
47
+ with torch.no_grad():
48
+ predictor.set_input_image(image)
49
+ if sam_type == 'SAM':
50
+ for click_indx in range(max_clicks):
51
+ clicker.make_next_click(pred_mask)
52
+ point_coords, point_labels = get_sam_input(clicker)
53
+ if oracle:
54
+ ious = []
55
+ pred_masks = []
56
+ pred_probs, _, _ = predictor.predict(point_coords, point_labels, multimask_output=True, return_logits=True)
57
+ for idx in range(pred_probs.shape[0]):
58
+ pred_masks.append(pred_probs[idx] > predictor.model.mask_threshold)
59
+ ious.append(utils.get_iou(gt_mask, pred_masks[-1]))
60
+ tgt_idx = np.argmax(np.array(ious))
61
+ iou = ious[tgt_idx]
62
+ pred_mask = pred_masks[tgt_idx]
63
+ else:
64
+ pred_probs, _, _ = predictor.predict(point_coords, point_labels, multimask_output=False, return_logits=True)
65
+ pred_probs = pred_probs[0]
66
+ pred_mask = pred_probs > predictor.model.mask_threshold
67
+ iou = utils.get_iou(gt_mask, pred_mask)
68
+
69
+ if callback is not None:
70
+ callback(image, gt_mask, pred_probs, sample_id, click_indx, clicker.clicks_list)
71
+
72
+ ious_list.append(iou)
73
+ if iou >= max_iou_thr and click_indx + 1 >= min_clicks:
74
+ break
75
+ return clicker.clicks_list, np.array(ious_list, dtype=np.float32), pred_probs
76
+ else:
77
+ for click_indx in range(max_clicks):
78
+ clicker.make_next_click(pred_mask)
79
+ pred_probs = predictor.get_prediction(clicker)
80
+ pred_mask = pred_probs > pred_thr
81
+ iou = utils.get_iou(gt_mask, pred_mask)
82
+
83
+ if callback is not None:
84
+ callback(image, gt_mask, pred_probs, sample_id, click_indx, clicker.clicks_list)
85
+
86
+ ious_list.append(iou)
87
+ if iou >= max_iou_thr and click_indx + 1 >= min_clicks:
88
+ break
89
+ return clicker.clicks_list, np.array(ious_list, dtype=np.float32), pred_probs
90
+
91
+
92
+ def evaluate_sample_oracle(image, gt_mask, predictor, max_iou_thr,
93
+ pred_thr=0.49, min_clicks=1, max_clicks=20,
94
+ sample_id=None, sam_type=False, oracle=False, callback=None):
95
+ clicker = Clicker(gt_mask=gt_mask)
96
+ ious_lists = []
97
+ click_indxs = []
98
+ with torch.no_grad():
99
+ predictor.set_input_image(image)
100
+ min_num = 100
101
+ for gra in range(1, 11):
102
+ cur_gra = round(gra * 0.1, 1)
103
+ ious_list = []
104
+ clicker.reset_clicks()
105
+ pred_mask = np.zeros_like(gt_mask)
106
+ if sam_type == 'SAM_GraCo':
107
+ for click_indx in range(max_clicks):
108
+ clicker.make_next_click(pred_mask)
109
+ point_coords, point_labels = get_sam_input(clicker)
110
+ if oracle:
111
+ ious = []
112
+ pred_masks = []
113
+ pred_probs, _, _ = predictor.predict(point_coords, point_labels, gra=cur_gra, multimask_output=True, return_logits=True)
114
+ for idx in range(pred_probs.shape[0]):
115
+ pred_masks.append(pred_probs[idx] > predictor.model.mask_threshold)
116
+ ious.append(utils.get_iou(gt_mask, pred_masks[-1]))
117
+ tgt_idx = np.argmax(np.array(ious))
118
+ iou = ious[tgt_idx]
119
+ pred_mask = pred_masks[tgt_idx]
120
+ else:
121
+ pred_probs, _, _ = predictor.predict(point_coords, point_labels, gra=cur_gra, multimask_output=False, return_logits=True)
122
+ pred_probs = pred_probs[0]
123
+ pred_mask = pred_probs > predictor.model.mask_threshold
124
+ iou = utils.get_iou(gt_mask, pred_mask)
125
+
126
+ if callback is not None:
127
+ callback(image, gt_mask, pred_probs, sample_id, click_indx, clicker.clicks_list)
128
+
129
+ ious_list.append(iou)
130
+ if iou >= max_iou_thr and click_indx + 1 >= min_clicks:
131
+ min_num = min(min_num, click_indx + 1)
132
+ break
133
+ if min_num <= max_clicks and click_indx + 1 > min_num:
134
+ break
135
+ else:
136
+ predictor.prev_prediction = torch.zeros_like(predictor.original_image[:, :1, :, :])
137
+ for click_indx in range(max_clicks):
138
+ clicker.make_next_click(pred_mask)
139
+ pred_probs = predictor.get_prediction(clicker, gra=cur_gra)
140
+
141
+ pred_mask = pred_probs > pred_thr
142
+ iou = utils.get_iou(gt_mask, pred_mask)
143
+
144
+ if callback is not None:
145
+ callback(image, gt_mask, pred_probs, sample_id, click_indx, clicker.clicks_list)
146
+
147
+ ious_list.append(iou)
148
+ if iou >= max_iou_thr and click_indx + 1 >= min_clicks:
149
+ min_num = min(min_num, click_indx + 1)
150
+ break
151
+ if min_num <= max_clicks and click_indx + 1 > min_num:
152
+ break
153
+ ious_lists.append(np.array(ious_list, dtype=np.float32))
154
+ click_indxs.append(click_indx)
155
+ click_indxs = np.array(click_indxs)
156
+ tgt_idxs = np.squeeze(np.argwhere(click_indxs == np.min(click_indxs)), axis=1)
157
+ selected_ious = [ious_lists[i] for i in tgt_idxs]
158
+ max_index = np.argmax([ious[0] for ious in selected_ious])
159
+ ious = selected_ious[max_index]
160
+ tgt_idx = tgt_idxs[max_index]
161
+
162
+ return ious, tgt_idx
163
+
164
+
165
+ def get_sam_input(clicker, reverse=True):
166
+ clicks_list = clicker.get_clicks()
167
+ points_nd = get_points_nd([clicks_list])
168
+ point_length = len(points_nd[0]) // 2
169
+ point_coords = []
170
+ point_labels = []
171
+ for i, point in enumerate(points_nd[0]):
172
+ if point[0] == -1:
173
+ continue
174
+ if i < point_length:
175
+ point_labels.append(1)
176
+ else:
177
+ point_labels.append(0)
178
+ if reverse:
179
+ point_coords.append([point[1], point[0]]) # for SAM
180
+ return np.array(point_coords), np.array(point_labels)
181
+
182
+ def get_points_nd(clicks_lists):
183
+ total_clicks = []
184
+ num_pos_clicks = [sum(x.is_positive for x in clicks_list) for clicks_list in clicks_lists]
185
+ num_neg_clicks = [len(clicks_list) - num_pos for clicks_list, num_pos in zip(clicks_lists, num_pos_clicks)]
186
+ num_max_points = max(num_pos_clicks + num_neg_clicks)
187
+ num_max_points = max(1, num_max_points)
188
+
189
+ for clicks_list in clicks_lists:
190
+ pos_clicks = [click.coords_and_indx for click in clicks_list if click.is_positive]
191
+ pos_clicks = pos_clicks + (num_max_points - len(pos_clicks)) * [(-1, -1, -1)]
192
+
193
+ neg_clicks = [click.coords_and_indx for click in clicks_list if not click.is_positive]
194
+ neg_clicks = neg_clicks + (num_max_points - len(neg_clicks)) * [(-1, -1, -1)]
195
+ total_clicks.append(pos_clicks + neg_clicks)
196
+
197
+ return total_clicks
isegm/inference/predictors/__init__.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .base import BasePredictor
2
+ from .brs import InputBRSPredictor, FeatureBRSPredictor, HRNetFeatureBRSPredictor
3
+ from .brs_functors import InputOptimizer, ScaleBiasOptimizer
4
+ from isegm.inference.transforms import ZoomIn
5
+ from isegm.model.is_hrnet_model import HRNetModel
6
+
7
+
8
+ def get_predictor(net, brs_mode, device,
9
+ gra=None, sam_type=None,
10
+ prob_thresh=0.49,
11
+ with_flip=True,
12
+ zoom_in_params=dict(),
13
+ predictor_params=None,
14
+ brs_opt_func_params=None,
15
+ lbfgs_params=None):
16
+ lbfgs_params_ = {
17
+ 'm': 20,
18
+ 'factr': 0,
19
+ 'pgtol': 1e-8,
20
+ 'maxfun': 20,
21
+ }
22
+
23
+ predictor_params_ = {
24
+ 'optimize_after_n_clicks': 1
25
+ }
26
+
27
+ if zoom_in_params is not None:
28
+ zoom_in = ZoomIn(**zoom_in_params)
29
+ else:
30
+ zoom_in = None
31
+
32
+ if lbfgs_params is not None:
33
+ lbfgs_params_.update(lbfgs_params)
34
+ lbfgs_params_['maxiter'] = 2 * lbfgs_params_['maxfun']
35
+
36
+ if brs_opt_func_params is None:
37
+ brs_opt_func_params = dict()
38
+
39
+ if isinstance(net, (list, tuple)):
40
+ assert brs_mode == 'NoBRS', "Multi-stage models support only NoBRS mode."
41
+
42
+ if brs_mode == 'NoBRS':
43
+ if predictor_params is not None:
44
+ predictor_params_.update(predictor_params)
45
+ predictor = BasePredictor(net, device, gra=gra, sam_type=sam_type, zoom_in=zoom_in, with_flip=with_flip, **predictor_params_)
46
+ elif brs_mode.startswith('f-BRS'):
47
+ predictor_params_.update({
48
+ 'net_clicks_limit': 8,
49
+ })
50
+ if predictor_params is not None:
51
+ predictor_params_.update(predictor_params)
52
+
53
+ insertion_mode = {
54
+ 'f-BRS-A': 'after_c4',
55
+ 'f-BRS-B': 'after_aspp',
56
+ 'f-BRS-C': 'after_deeplab'
57
+ }[brs_mode]
58
+
59
+ opt_functor = ScaleBiasOptimizer(prob_thresh=prob_thresh,
60
+ with_flip=with_flip,
61
+ optimizer_params=lbfgs_params_,
62
+ **brs_opt_func_params)
63
+
64
+ if isinstance(net, HRNetModel):
65
+ FeaturePredictor = HRNetFeatureBRSPredictor
66
+ insertion_mode = {'after_c4': 'A', 'after_aspp': 'A', 'after_deeplab': 'C'}[insertion_mode]
67
+ else:
68
+ FeaturePredictor = FeatureBRSPredictor
69
+
70
+ predictor = FeaturePredictor(net, device,
71
+ opt_functor=opt_functor,
72
+ with_flip=with_flip,
73
+ insertion_mode=insertion_mode,
74
+ zoom_in=zoom_in,
75
+ **predictor_params_)
76
+ elif brs_mode == 'RGB-BRS' or brs_mode == 'DistMap-BRS':
77
+ use_dmaps = brs_mode == 'DistMap-BRS'
78
+
79
+ predictor_params_.update({
80
+ 'net_clicks_limit': 5,
81
+ })
82
+ if predictor_params is not None:
83
+ predictor_params_.update(predictor_params)
84
+
85
+ opt_functor = InputOptimizer(prob_thresh=prob_thresh,
86
+ with_flip=with_flip,
87
+ optimizer_params=lbfgs_params_,
88
+ **brs_opt_func_params)
89
+
90
+ predictor = InputBRSPredictor(net, device,
91
+ optimize_target='dmaps' if use_dmaps else 'rgb',
92
+ opt_functor=opt_functor,
93
+ with_flip=with_flip,
94
+ zoom_in=zoom_in,
95
+ **predictor_params_)
96
+ else:
97
+ raise NotImplementedError
98
+
99
+ return predictor
isegm/inference/predictors/base.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import numpy as np
4
+ from torchvision import transforms
5
+ from isegm.inference.transforms import AddHorizontalFlip, SigmoidForPred, LimitLongestSide
6
+
7
+ class BasePredictor(object):
8
+ def __init__(self, model, device, gra=None, sam_type=None,
9
+ net_clicks_limit=None,
10
+ with_flip=False,
11
+ zoom_in=None,
12
+ max_size=None,
13
+ **kwargs):
14
+ self.with_flip = with_flip
15
+ self.net_clicks_limit = net_clicks_limit
16
+ self.original_image = None
17
+ self.device = device
18
+ self.gra=gra if gra is not None and gra > 0 else None
19
+ self.sam_type = sam_type
20
+ self.zoom_in = zoom_in
21
+ self.prev_prediction = None
22
+ self.model_indx = 0
23
+ self.click_models = None
24
+ self.net_state_dict = None
25
+
26
+ if isinstance(model, tuple):
27
+ self.net, self.click_models = model
28
+ else:
29
+ self.net = model
30
+
31
+ self.to_tensor = transforms.ToTensor()
32
+
33
+ self.transforms = [zoom_in] if zoom_in is not None else []
34
+ if max_size is not None:
35
+ self.transforms.append(LimitLongestSide(max_size=max_size))
36
+ self.transforms.append(SigmoidForPred())
37
+ if with_flip:
38
+ self.transforms.append(AddHorizontalFlip())
39
+
40
+ def set_input_image(self, image):
41
+ if not isinstance(image, torch.Tensor):
42
+ image_nd = self.to_tensor(image)
43
+ else:
44
+ image_nd = image
45
+ for transform in self.transforms:
46
+ transform.reset()
47
+ self.original_image = image_nd.to(self.device)
48
+ if len(self.original_image.shape) == 3:
49
+ self.original_image = self.original_image.unsqueeze(0)
50
+ self.prev_prediction = torch.zeros_like(self.original_image[:, :1, :, :])
51
+
52
+ def get_prediction(self, clicker, prev_mask=None, gra=None):
53
+ clicks_list = clicker.get_clicks()
54
+
55
+ if self.click_models is not None:
56
+ model_indx = min(clicker.click_indx_offset + len(clicks_list), len(self.click_models)) - 1
57
+ if model_indx != self.model_indx:
58
+ self.model_indx = model_indx
59
+ self.net = self.click_models[model_indx]
60
+
61
+ input_image = self.original_image
62
+ if prev_mask is None:
63
+ prev_mask = self.prev_prediction
64
+ if (hasattr(self.net, 'with_prev_mask') and self.net.with_prev_mask) or self.sam_type is not None:
65
+ input_image = torch.cat((input_image, prev_mask), dim=1)
66
+ image_nd, clicks_lists, is_image_changed = self.apply_transforms(
67
+ input_image, [clicks_list]
68
+ )
69
+ pred_logits = self._get_prediction(image_nd, clicks_lists, is_image_changed, gra=gra)
70
+
71
+ prediction = F.interpolate(pred_logits, mode='bilinear', align_corners=True,
72
+ size=image_nd.size()[2:])
73
+
74
+ for t in reversed(self.transforms):
75
+ prediction = t.inv_transform(prediction)
76
+
77
+ if self.zoom_in is not None and self.zoom_in.check_possible_recalculation():
78
+ return self.get_prediction(clicker)
79
+
80
+ self.prev_prediction = prediction
81
+ return prediction.cpu().numpy()[0, 0]
82
+
83
+ def _get_prediction(self, image_nd, clicks_lists, is_image_changed, gra=None):
84
+ points_nd = self.get_points_nd(clicks_lists)
85
+ if gra is None:
86
+ gra = self.gra
87
+ if self.sam_type == 'SAM':
88
+ batched_input = self.get_sam_batched_input(image_nd, points_nd)
89
+ batched_output = self.net(batched_input, multimask_output=False, return_logits=True)
90
+ return torch.cat([batch['masks'] for batch in batched_output], dim=0)
91
+
92
+ if gra is not None:
93
+ return self.net(image_nd, points_nd, torch.Tensor([gra]).to(self.device))['instances']
94
+ else:
95
+ return self.net(image_nd, points_nd)['instances']
96
+
97
+
98
+ def _batch_infer(self, batch_image_tensor, batch_clickers, prev_mask=None):
99
+ if prev_mask is None:
100
+ prev_mask = self.prev_prediction
101
+
102
+ if hasattr(self.net, 'with_prev_mask') and self.net.with_prev_mask:
103
+ input_image = torch.cat((batch_image_tensor, prev_mask), dim=1)
104
+
105
+ clicks_lists = [clicker.get_clicks() for clicker in batch_clickers]
106
+ image_nd, clicks_lists, is_image_changed = self.apply_transforms(
107
+ input_image, clicks_lists
108
+ )
109
+ points_nd = self.get_points_nd(clicks_lists)
110
+ pred_logits = self.net(image_nd, points_nd)['instances']
111
+ prediction = F.interpolate(pred_logits, mode='bilinear', align_corners=True,
112
+ size=image_nd.size()[2:])
113
+
114
+ for t in reversed(self.transforms):
115
+ prediction = t.inv_transform(prediction)
116
+
117
+ self.prev_prediction = prediction
118
+ return prediction.cpu().numpy()[:, 0]
119
+
120
+ def _get_transform_states(self):
121
+ return [x.get_state() for x in self.transforms]
122
+
123
+ def _set_transform_states(self, states):
124
+ assert len(states) == len(self.transforms)
125
+ for state, transform in zip(states, self.transforms):
126
+ transform.set_state(state)
127
+
128
+ def apply_transforms(self, image_nd, clicks_lists):
129
+ is_image_changed = False
130
+ for t in self.transforms:
131
+ image_nd, clicks_lists = t.transform(image_nd, clicks_lists)
132
+ is_image_changed |= t.image_changed
133
+
134
+ return image_nd, clicks_lists, is_image_changed
135
+
136
+ def get_points_nd(self, clicks_lists):
137
+ total_clicks = []
138
+ num_pos_clicks = [sum(x.is_positive for x in clicks_list) for clicks_list in clicks_lists]
139
+ num_neg_clicks = [len(clicks_list) - num_pos for clicks_list, num_pos in zip(clicks_lists, num_pos_clicks)]
140
+ num_max_points = max(num_pos_clicks + num_neg_clicks)
141
+ if self.net_clicks_limit is not None:
142
+ num_max_points = min(self.net_clicks_limit, num_max_points)
143
+ num_max_points = max(1, num_max_points)
144
+
145
+ for clicks_list in clicks_lists:
146
+ clicks_list = clicks_list[:self.net_clicks_limit]
147
+ pos_clicks = [click.coords_and_indx for click in clicks_list if click.is_positive]
148
+ pos_clicks = pos_clicks + (num_max_points - len(pos_clicks)) * [(-1, -1, -1)]
149
+
150
+ neg_clicks = [click.coords_and_indx for click in clicks_list if not click.is_positive]
151
+ neg_clicks = neg_clicks + (num_max_points - len(neg_clicks)) * [(-1, -1, -1)]
152
+ total_clicks.append(pos_clicks + neg_clicks)
153
+
154
+ return torch.tensor(total_clicks, device=self.device)
155
+
156
+ def get_sam_batched_input(self, image_nd, points_nd):
157
+ batched_output = []
158
+ for i in range(image_nd.shape[0]):
159
+ image = image_nd[i]
160
+ point_length = points_nd[i].shape[0] // 2
161
+ point_coords = []
162
+ point_labels = []
163
+ for i, point in enumerate(points_nd[i]):
164
+ point_np = point.cpu().numpy()
165
+ if point_np[0] == -1:
166
+ continue
167
+ if i < point_length:
168
+ point_labels.append(1)
169
+ else:
170
+ point_labels.append(0)
171
+
172
+ point_coords.append([point_np[1], point_np[0]])
173
+ res = {
174
+ 'image': image[:3, :, :],
175
+ 'point_coords': torch.as_tensor(np.array(point_coords), dtype=torch.float, device=self.device)[None, :],
176
+ 'point_labels': torch.as_tensor(np.array(point_labels), dtype=torch.float, device=self.device)[None, :],
177
+ 'original_size': image.cpu().numpy().shape[1:],
178
+ 'mask_inputs': image[3, :, :][None, None, :]
179
+ }
180
+ batched_output.append(res)
181
+ return batched_output
182
+
183
+ def get_states(self):
184
+ return {
185
+ 'transform_states': self._get_transform_states(),
186
+ 'prev_prediction': self.prev_prediction.clone()
187
+ }
188
+
189
+ def set_states(self, states):
190
+ self._set_transform_states(states['transform_states'])
191
+ self.prev_prediction = states['prev_prediction']
isegm/inference/predictors/brs.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import numpy as np
4
+ from scipy.optimize import fmin_l_bfgs_b
5
+
6
+ from .base import BasePredictor
7
+
8
+
9
+ class BRSBasePredictor(BasePredictor):
10
+ def __init__(self, model, device, opt_functor, optimize_after_n_clicks=1, **kwargs):
11
+ super().__init__(model, device, **kwargs)
12
+ self.optimize_after_n_clicks = optimize_after_n_clicks
13
+ self.opt_functor = opt_functor
14
+
15
+ self.opt_data = None
16
+ self.input_data = None
17
+
18
+ def set_input_image(self, image):
19
+ super().set_input_image(image)
20
+ self.opt_data = None
21
+ self.input_data = None
22
+
23
+ def _get_clicks_maps_nd(self, clicks_lists, image_shape, radius=1):
24
+ pos_clicks_map = np.zeros((len(clicks_lists), 1) + image_shape, dtype=np.float32)
25
+ neg_clicks_map = np.zeros((len(clicks_lists), 1) + image_shape, dtype=np.float32)
26
+
27
+ for list_indx, clicks_list in enumerate(clicks_lists):
28
+ for click in clicks_list:
29
+ y, x = click.coords
30
+ y, x = int(round(y)), int(round(x))
31
+ y1, x1 = y - radius, x - radius
32
+ y2, x2 = y + radius + 1, x + radius + 1
33
+
34
+ if click.is_positive:
35
+ pos_clicks_map[list_indx, 0, y1:y2, x1:x2] = True
36
+ else:
37
+ neg_clicks_map[list_indx, 0, y1:y2, x1:x2] = True
38
+
39
+ with torch.no_grad():
40
+ pos_clicks_map = torch.from_numpy(pos_clicks_map).to(self.device)
41
+ neg_clicks_map = torch.from_numpy(neg_clicks_map).to(self.device)
42
+
43
+ return pos_clicks_map, neg_clicks_map
44
+
45
+ def get_states(self):
46
+ return {'transform_states': self._get_transform_states(), 'opt_data': self.opt_data}
47
+
48
+ def set_states(self, states):
49
+ self._set_transform_states(states['transform_states'])
50
+ self.opt_data = states['opt_data']
51
+
52
+
53
+ class FeatureBRSPredictor(BRSBasePredictor):
54
+ def __init__(self, model, device, opt_functor, insertion_mode='after_deeplab', **kwargs):
55
+ super().__init__(model, device, opt_functor=opt_functor, **kwargs)
56
+ self.insertion_mode = insertion_mode
57
+ self._c1_features = None
58
+
59
+ if self.insertion_mode == 'after_deeplab':
60
+ self.num_channels = model.feature_extractor.ch
61
+ elif self.insertion_mode == 'after_c4':
62
+ self.num_channels = model.feature_extractor.aspp_in_channels
63
+ elif self.insertion_mode == 'after_aspp':
64
+ self.num_channels = model.feature_extractor.ch + 32
65
+ else:
66
+ raise NotImplementedError
67
+
68
+ def _get_prediction(self, image_nd, clicks_lists, is_image_changed):
69
+ points_nd = self.get_points_nd(clicks_lists)
70
+ pos_mask, neg_mask = self._get_clicks_maps_nd(clicks_lists, image_nd.shape[2:])
71
+
72
+ num_clicks = len(clicks_lists[0])
73
+ bs = image_nd.shape[0] // 2 if self.with_flip else image_nd.shape[0]
74
+
75
+ if self.opt_data is None or self.opt_data.shape[0] // (2 * self.num_channels) != bs:
76
+ self.opt_data = np.zeros((bs * 2 * self.num_channels), dtype=np.float32)
77
+
78
+ if num_clicks <= self.net_clicks_limit or is_image_changed or self.input_data is None:
79
+ self.input_data = self._get_head_input(image_nd, points_nd)
80
+
81
+ def get_prediction_logits(scale, bias):
82
+ scale = scale.view(bs, -1, 1, 1)
83
+ bias = bias.view(bs, -1, 1, 1)
84
+ if self.with_flip:
85
+ scale = scale.repeat(2, 1, 1, 1)
86
+ bias = bias.repeat(2, 1, 1, 1)
87
+
88
+ scaled_backbone_features = self.input_data * scale
89
+ scaled_backbone_features = scaled_backbone_features + bias
90
+ if self.insertion_mode == 'after_c4':
91
+ x = self.net.feature_extractor.aspp(scaled_backbone_features)
92
+ x = F.interpolate(x, mode='bilinear', size=self._c1_features.size()[2:],
93
+ align_corners=True)
94
+ x = torch.cat((x, self._c1_features), dim=1)
95
+ scaled_backbone_features = self.net.feature_extractor.head(x)
96
+ elif self.insertion_mode == 'after_aspp':
97
+ scaled_backbone_features = self.net.feature_extractor.head(scaled_backbone_features)
98
+
99
+ pred_logits = self.net.head(scaled_backbone_features)
100
+ pred_logits = F.interpolate(pred_logits, size=image_nd.size()[2:], mode='bilinear',
101
+ align_corners=True)
102
+ return pred_logits
103
+
104
+ self.opt_functor.init_click(get_prediction_logits, pos_mask, neg_mask, self.device)
105
+ if num_clicks > self.optimize_after_n_clicks:
106
+ opt_result = fmin_l_bfgs_b(func=self.opt_functor, x0=self.opt_data,
107
+ **self.opt_functor.optimizer_params)
108
+ self.opt_data = opt_result[0]
109
+
110
+ with torch.no_grad():
111
+ if self.opt_functor.best_prediction is not None:
112
+ opt_pred_logits = self.opt_functor.best_prediction
113
+ else:
114
+ opt_data_nd = torch.from_numpy(self.opt_data).to(self.device)
115
+ opt_vars, _ = self.opt_functor.unpack_opt_params(opt_data_nd)
116
+ opt_pred_logits = get_prediction_logits(*opt_vars)
117
+
118
+ return opt_pred_logits
119
+
120
+ def _get_head_input(self, image_nd, points):
121
+ with torch.no_grad():
122
+ image_nd, prev_mask = self.net.prepare_input(image_nd)
123
+ coord_features = self.net.get_coord_features(image_nd, prev_mask, points)
124
+
125
+ if self.net.rgb_conv is not None:
126
+ x = self.net.rgb_conv(torch.cat((image_nd, coord_features), dim=1))
127
+ additional_features = None
128
+ elif hasattr(self.net, 'maps_transform'):
129
+ x = image_nd
130
+ additional_features = self.net.maps_transform(coord_features)
131
+
132
+ if self.insertion_mode == 'after_c4' or self.insertion_mode == 'after_aspp':
133
+ c1, _, c3, c4 = self.net.feature_extractor.backbone(x, additional_features)
134
+ c1 = self.net.feature_extractor.skip_project(c1)
135
+
136
+ if self.insertion_mode == 'after_aspp':
137
+ x = self.net.feature_extractor.aspp(c4)
138
+ x = F.interpolate(x, size=c1.size()[2:], mode='bilinear', align_corners=True)
139
+ x = torch.cat((x, c1), dim=1)
140
+ backbone_features = x
141
+ else:
142
+ backbone_features = c4
143
+ self._c1_features = c1
144
+ else:
145
+ backbone_features = self.net.feature_extractor(x, additional_features)[0]
146
+
147
+ return backbone_features
148
+
149
+
150
+ class HRNetFeatureBRSPredictor(BRSBasePredictor):
151
+ def __init__(self, model, device, opt_functor, insertion_mode='A', **kwargs):
152
+ super().__init__(model, device, opt_functor=opt_functor, **kwargs)
153
+ self.insertion_mode = insertion_mode
154
+ self._c1_features = None
155
+
156
+ if self.insertion_mode == 'A':
157
+ self.num_channels = sum(k * model.feature_extractor.width for k in [1, 2, 4, 8])
158
+ elif self.insertion_mode == 'C':
159
+ self.num_channels = 2 * model.feature_extractor.ocr_width
160
+ else:
161
+ raise NotImplementedError
162
+
163
+ def _get_prediction(self, image_nd, clicks_lists, is_image_changed):
164
+ points_nd = self.get_points_nd(clicks_lists)
165
+ pos_mask, neg_mask = self._get_clicks_maps_nd(clicks_lists, image_nd.shape[2:])
166
+ num_clicks = len(clicks_lists[0])
167
+ bs = image_nd.shape[0] // 2 if self.with_flip else image_nd.shape[0]
168
+
169
+ if self.opt_data is None or self.opt_data.shape[0] // (2 * self.num_channels) != bs:
170
+ self.opt_data = np.zeros((bs * 2 * self.num_channels), dtype=np.float32)
171
+
172
+ if num_clicks <= self.net_clicks_limit or is_image_changed or self.input_data is None:
173
+ self.input_data = self._get_head_input(image_nd, points_nd)
174
+
175
+ def get_prediction_logits(scale, bias):
176
+ scale = scale.view(bs, -1, 1, 1)
177
+ bias = bias.view(bs, -1, 1, 1)
178
+ if self.with_flip:
179
+ scale = scale.repeat(2, 1, 1, 1)
180
+ bias = bias.repeat(2, 1, 1, 1)
181
+
182
+ scaled_backbone_features = self.input_data * scale
183
+ scaled_backbone_features = scaled_backbone_features + bias
184
+ if self.insertion_mode == 'A':
185
+ if self.net.feature_extractor.ocr_width > 0:
186
+ out_aux = self.net.feature_extractor.aux_head(scaled_backbone_features)
187
+ feats = self.net.feature_extractor.conv3x3_ocr(scaled_backbone_features)
188
+
189
+ context = self.net.feature_extractor.ocr_gather_head(feats, out_aux)
190
+ feats = self.net.feature_extractor.ocr_distri_head(feats, context)
191
+ else:
192
+ feats = scaled_backbone_features
193
+ pred_logits = self.net.feature_extractor.cls_head(feats)
194
+ elif self.insertion_mode == 'C':
195
+ pred_logits = self.net.feature_extractor.cls_head(scaled_backbone_features)
196
+ else:
197
+ raise NotImplementedError
198
+
199
+ pred_logits = F.interpolate(pred_logits, size=image_nd.size()[2:], mode='bilinear',
200
+ align_corners=True)
201
+ return pred_logits
202
+
203
+ self.opt_functor.init_click(get_prediction_logits, pos_mask, neg_mask, self.device)
204
+ if num_clicks > self.optimize_after_n_clicks:
205
+ opt_result = fmin_l_bfgs_b(func=self.opt_functor, x0=self.opt_data,
206
+ **self.opt_functor.optimizer_params)
207
+ self.opt_data = opt_result[0]
208
+
209
+ with torch.no_grad():
210
+ if self.opt_functor.best_prediction is not None:
211
+ opt_pred_logits = self.opt_functor.best_prediction
212
+ else:
213
+ opt_data_nd = torch.from_numpy(self.opt_data).to(self.device)
214
+ opt_vars, _ = self.opt_functor.unpack_opt_params(opt_data_nd)
215
+ opt_pred_logits = get_prediction_logits(*opt_vars)
216
+
217
+ return opt_pred_logits
218
+
219
+ def _get_head_input(self, image_nd, points):
220
+ with torch.no_grad():
221
+ image_nd, prev_mask = self.net.prepare_input(image_nd)
222
+ coord_features = self.net.get_coord_features(image_nd, prev_mask, points)
223
+
224
+ if self.net.rgb_conv is not None:
225
+ x = self.net.rgb_conv(torch.cat((image_nd, coord_features), dim=1))
226
+ additional_features = None
227
+ elif hasattr(self.net, 'maps_transform'):
228
+ x = image_nd
229
+ additional_features = self.net.maps_transform(coord_features)
230
+
231
+ feats = self.net.feature_extractor.compute_hrnet_feats(x, additional_features)
232
+
233
+ if self.insertion_mode == 'A':
234
+ backbone_features = feats
235
+ elif self.insertion_mode == 'C':
236
+ out_aux = self.net.feature_extractor.aux_head(feats)
237
+ feats = self.net.feature_extractor.conv3x3_ocr(feats)
238
+
239
+ context = self.net.feature_extractor.ocr_gather_head(feats, out_aux)
240
+ backbone_features = self.net.feature_extractor.ocr_distri_head(feats, context)
241
+ else:
242
+ raise NotImplementedError
243
+
244
+ return backbone_features
245
+
246
+
247
+ class InputBRSPredictor(BRSBasePredictor):
248
+ def __init__(self, model, device, opt_functor, optimize_target='rgb', **kwargs):
249
+ super().__init__(model, device, opt_functor=opt_functor, **kwargs)
250
+ self.optimize_target = optimize_target
251
+
252
+ def _get_prediction(self, image_nd, clicks_lists, is_image_changed):
253
+ points_nd = self.get_points_nd(clicks_lists)
254
+ pos_mask, neg_mask = self._get_clicks_maps_nd(clicks_lists, image_nd.shape[2:])
255
+ num_clicks = len(clicks_lists[0])
256
+
257
+ if self.opt_data is None or is_image_changed:
258
+ if self.optimize_target == 'dmaps':
259
+ opt_channels = self.net.coord_feature_ch - 1 if self.net.with_prev_mask else self.net.coord_feature_ch
260
+ else:
261
+ opt_channels = 3
262
+ bs = image_nd.shape[0] // 2 if self.with_flip else image_nd.shape[0]
263
+ self.opt_data = torch.zeros((bs, opt_channels, image_nd.shape[2], image_nd.shape[3]),
264
+ device=self.device, dtype=torch.float32)
265
+
266
+ def get_prediction_logits(opt_bias):
267
+ input_image, prev_mask = self.net.prepare_input(image_nd)
268
+ dmaps = self.net.get_coord_features(input_image, prev_mask, points_nd)
269
+
270
+ if self.optimize_target == 'rgb':
271
+ input_image = input_image + opt_bias
272
+ elif self.optimize_target == 'dmaps':
273
+ if self.net.with_prev_mask:
274
+ dmaps[:, 1:, :, :] = dmaps[:, 1:, :, :] + opt_bias
275
+ else:
276
+ dmaps = dmaps + opt_bias
277
+
278
+ if self.net.rgb_conv is not None:
279
+ x = self.net.rgb_conv(torch.cat((input_image, dmaps), dim=1))
280
+ if self.optimize_target == 'all':
281
+ x = x + opt_bias
282
+ coord_features = None
283
+ elif hasattr(self.net, 'maps_transform'):
284
+ x = input_image
285
+ coord_features = self.net.maps_transform(dmaps)
286
+
287
+ pred_logits = self.net.backbone_forward(x, coord_features=coord_features)['instances']
288
+ pred_logits = F.interpolate(pred_logits, size=image_nd.size()[2:], mode='bilinear', align_corners=True)
289
+
290
+ return pred_logits
291
+
292
+ self.opt_functor.init_click(get_prediction_logits, pos_mask, neg_mask, self.device,
293
+ shape=self.opt_data.shape)
294
+ if num_clicks > self.optimize_after_n_clicks:
295
+ opt_result = fmin_l_bfgs_b(func=self.opt_functor, x0=self.opt_data.cpu().numpy().ravel(),
296
+ **self.opt_functor.optimizer_params)
297
+
298
+ self.opt_data = torch.from_numpy(opt_result[0]).view(self.opt_data.shape).to(self.device)
299
+
300
+ with torch.no_grad():
301
+ if self.opt_functor.best_prediction is not None:
302
+ opt_pred_logits = self.opt_functor.best_prediction
303
+ else:
304
+ opt_vars, _ = self.opt_functor.unpack_opt_params(self.opt_data)
305
+ opt_pred_logits = get_prediction_logits(*opt_vars)
306
+
307
+ return opt_pred_logits
isegm/inference/predictors/brs_functors.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+ from isegm.model.metrics import _compute_iou
5
+ from .brs_losses import BRSMaskLoss
6
+
7
+
8
+ class BaseOptimizer:
9
+ def __init__(self, optimizer_params,
10
+ prob_thresh=0.49,
11
+ reg_weight=1e-3,
12
+ min_iou_diff=0.01,
13
+ brs_loss=BRSMaskLoss(),
14
+ with_flip=False,
15
+ flip_average=False,
16
+ **kwargs):
17
+ self.brs_loss = brs_loss
18
+ self.optimizer_params = optimizer_params
19
+ self.prob_thresh = prob_thresh
20
+ self.reg_weight = reg_weight
21
+ self.min_iou_diff = min_iou_diff
22
+ self.with_flip = with_flip
23
+ self.flip_average = flip_average
24
+
25
+ self.best_prediction = None
26
+ self._get_prediction_logits = None
27
+ self._opt_shape = None
28
+ self._best_loss = None
29
+ self._click_masks = None
30
+ self._last_mask = None
31
+ self.device = None
32
+
33
+ def init_click(self, get_prediction_logits, pos_mask, neg_mask, device, shape=None):
34
+ self.best_prediction = None
35
+ self._get_prediction_logits = get_prediction_logits
36
+ self._click_masks = (pos_mask, neg_mask)
37
+ self._opt_shape = shape
38
+ self._last_mask = None
39
+ self.device = device
40
+
41
+ def __call__(self, x):
42
+ opt_params = torch.from_numpy(x).float().to(self.device)
43
+ opt_params.requires_grad_(True)
44
+
45
+ with torch.enable_grad():
46
+ opt_vars, reg_loss = self.unpack_opt_params(opt_params)
47
+ result_before_sigmoid = self._get_prediction_logits(*opt_vars)
48
+ result = torch.sigmoid(result_before_sigmoid)
49
+
50
+ pos_mask, neg_mask = self._click_masks
51
+ if self.with_flip and self.flip_average:
52
+ result, result_flipped = torch.chunk(result, 2, dim=0)
53
+ result = 0.5 * (result + torch.flip(result_flipped, dims=[3]))
54
+ pos_mask, neg_mask = pos_mask[:result.shape[0]], neg_mask[:result.shape[0]]
55
+
56
+ loss, f_max_pos, f_max_neg = self.brs_loss(result, pos_mask, neg_mask)
57
+ loss = loss + reg_loss
58
+
59
+ f_val = loss.detach().cpu().numpy()
60
+ if self.best_prediction is None or f_val < self._best_loss:
61
+ self.best_prediction = result_before_sigmoid.detach()
62
+ self._best_loss = f_val
63
+
64
+ if f_max_pos < (1 - self.prob_thresh) and f_max_neg < self.prob_thresh:
65
+ return [f_val, np.zeros_like(x)]
66
+
67
+ current_mask = result > self.prob_thresh
68
+ if self._last_mask is not None and self.min_iou_diff > 0:
69
+ diff_iou = _compute_iou(current_mask, self._last_mask)
70
+ if len(diff_iou) > 0 and diff_iou.mean() > 1 - self.min_iou_diff:
71
+ return [f_val, np.zeros_like(x)]
72
+ self._last_mask = current_mask
73
+
74
+ loss.backward()
75
+ f_grad = opt_params.grad.cpu().numpy().ravel().astype(np.float)
76
+
77
+ return [f_val, f_grad]
78
+
79
+ def unpack_opt_params(self, opt_params):
80
+ raise NotImplementedError
81
+
82
+
83
+ class InputOptimizer(BaseOptimizer):
84
+ def unpack_opt_params(self, opt_params):
85
+ opt_params = opt_params.view(self._opt_shape)
86
+ if self.with_flip:
87
+ opt_params_flipped = torch.flip(opt_params, dims=[3])
88
+ opt_params = torch.cat([opt_params, opt_params_flipped], dim=0)
89
+ reg_loss = self.reg_weight * torch.sum(opt_params**2)
90
+
91
+ return (opt_params,), reg_loss
92
+
93
+
94
+ class ScaleBiasOptimizer(BaseOptimizer):
95
+ def __init__(self, *args, scale_act=None, reg_bias_weight=10.0, **kwargs):
96
+ super().__init__(*args, **kwargs)
97
+ self.scale_act = scale_act
98
+ self.reg_bias_weight = reg_bias_weight
99
+
100
+ def unpack_opt_params(self, opt_params):
101
+ scale, bias = torch.chunk(opt_params, 2, dim=0)
102
+ reg_loss = self.reg_weight * (torch.sum(scale**2) + self.reg_bias_weight * torch.sum(bias**2))
103
+
104
+ if self.scale_act == 'tanh':
105
+ scale = torch.tanh(scale)
106
+ elif self.scale_act == 'sin':
107
+ scale = torch.sin(scale)
108
+
109
+ return (1 + scale, bias), reg_loss
isegm/inference/predictors/brs_losses.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from isegm.model.losses import SigmoidBinaryCrossEntropyLoss
4
+
5
+
6
+ class BRSMaskLoss(torch.nn.Module):
7
+ def __init__(self, eps=1e-5):
8
+ super().__init__()
9
+ self._eps = eps
10
+
11
+ def forward(self, result, pos_mask, neg_mask):
12
+ pos_diff = (1 - result) * pos_mask
13
+ pos_target = torch.sum(pos_diff ** 2)
14
+ pos_target = pos_target / (torch.sum(pos_mask) + self._eps)
15
+
16
+ neg_diff = result * neg_mask
17
+ neg_target = torch.sum(neg_diff ** 2)
18
+ neg_target = neg_target / (torch.sum(neg_mask) + self._eps)
19
+
20
+ loss = pos_target + neg_target
21
+
22
+ with torch.no_grad():
23
+ f_max_pos = torch.max(torch.abs(pos_diff)).item()
24
+ f_max_neg = torch.max(torch.abs(neg_diff)).item()
25
+
26
+ return loss, f_max_pos, f_max_neg
27
+
28
+
29
+ class OracleMaskLoss(torch.nn.Module):
30
+ def __init__(self):
31
+ super().__init__()
32
+ self.gt_mask = None
33
+ self.loss = SigmoidBinaryCrossEntropyLoss(from_sigmoid=True)
34
+ self.predictor = None
35
+ self.history = []
36
+
37
+ def set_gt_mask(self, gt_mask):
38
+ self.gt_mask = gt_mask
39
+ self.history = []
40
+
41
+ def forward(self, result, pos_mask, neg_mask):
42
+ gt_mask = self.gt_mask.to(result.device)
43
+ if self.predictor.object_roi is not None:
44
+ r1, r2, c1, c2 = self.predictor.object_roi[:4]
45
+ gt_mask = gt_mask[:, :, r1:r2 + 1, c1:c2 + 1]
46
+ gt_mask = torch.nn.functional.interpolate(gt_mask, result.size()[2:], mode='bilinear', align_corners=True)
47
+
48
+ if result.shape[0] == 2:
49
+ gt_mask_flipped = torch.flip(gt_mask, dims=[3])
50
+ gt_mask = torch.cat([gt_mask, gt_mask_flipped], dim=0)
51
+
52
+ loss = self.loss(result, gt_mask)
53
+ self.history.append(loss.detach().cpu().numpy()[0])
54
+
55
+ if len(self.history) > 5 and abs(self.history[-5] - self.history[-1]) < 1e-5:
56
+ return 0, 0, 0
57
+
58
+ return loss, 1.0, 1.0
isegm/inference/transforms/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .base import SigmoidForPred
2
+ from .flip import AddHorizontalFlip
3
+ from .zoom_in import ZoomIn
4
+ from .limit_longest_side import LimitLongestSide
5
+ from .crops import Crops
isegm/inference/transforms/base.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class BaseTransform(object):
5
+ def __init__(self):
6
+ self.image_changed = False
7
+
8
+ def transform(self, image_nd, clicks_lists):
9
+ raise NotImplementedError
10
+
11
+ def inv_transform(self, prob_map):
12
+ raise NotImplementedError
13
+
14
+ def reset(self):
15
+ raise NotImplementedError
16
+
17
+ def get_state(self):
18
+ raise NotImplementedError
19
+
20
+ def set_state(self, state):
21
+ raise NotImplementedError
22
+
23
+
24
+ class SigmoidForPred(BaseTransform):
25
+ def transform(self, image_nd, clicks_lists):
26
+ return image_nd, clicks_lists
27
+
28
+ def inv_transform(self, prob_map):
29
+ return torch.sigmoid(prob_map)
30
+
31
+ def reset(self):
32
+ pass
33
+
34
+ def get_state(self):
35
+ return None
36
+
37
+ def set_state(self, state):
38
+ pass
isegm/inference/transforms/crops.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import numpy as np
5
+ from typing import List
6
+
7
+ from isegm.inference.clicker import Click
8
+ from .base import BaseTransform
9
+
10
+
11
+ class Crops(BaseTransform):
12
+ def __init__(self, crop_size=(320, 480), min_overlap=0.2):
13
+ super().__init__()
14
+ self.crop_height, self.crop_width = crop_size
15
+ self.min_overlap = min_overlap
16
+
17
+ self.x_offsets = None
18
+ self.y_offsets = None
19
+ self._counts = None
20
+
21
+ def transform(self, image_nd, clicks_lists: List[List[Click]]):
22
+ assert image_nd.shape[0] == 1 and len(clicks_lists) == 1
23
+ image_height, image_width = image_nd.shape[2:4]
24
+ self._counts = None
25
+
26
+ if image_height < self.crop_height or image_width < self.crop_width:
27
+ return image_nd, clicks_lists
28
+
29
+ self.x_offsets = get_offsets(image_width, self.crop_width, self.min_overlap)
30
+ self.y_offsets = get_offsets(image_height, self.crop_height, self.min_overlap)
31
+ self._counts = np.zeros((image_height, image_width))
32
+
33
+ image_crops = []
34
+ for dy in self.y_offsets:
35
+ for dx in self.x_offsets:
36
+ self._counts[dy:dy + self.crop_height, dx:dx + self.crop_width] += 1
37
+ image_crop = image_nd[:, :, dy:dy + self.crop_height, dx:dx + self.crop_width]
38
+ image_crops.append(image_crop)
39
+ image_crops = torch.cat(image_crops, dim=0)
40
+ self._counts = torch.tensor(self._counts, device=image_nd.device, dtype=torch.float32)
41
+
42
+ clicks_list = clicks_lists[0]
43
+ clicks_lists = []
44
+ for dy in self.y_offsets:
45
+ for dx in self.x_offsets:
46
+ crop_clicks = [x.copy(coords=(x.coords[0] - dy, x.coords[1] - dx)) for x in clicks_list]
47
+ clicks_lists.append(crop_clicks)
48
+
49
+ return image_crops, clicks_lists
50
+
51
+ def inv_transform(self, prob_map):
52
+ if self._counts is None:
53
+ return prob_map
54
+
55
+ new_prob_map = torch.zeros((1, 1, *self._counts.shape),
56
+ dtype=prob_map.dtype, device=prob_map.device)
57
+
58
+ crop_indx = 0
59
+ for dy in self.y_offsets:
60
+ for dx in self.x_offsets:
61
+ new_prob_map[0, 0, dy:dy + self.crop_height, dx:dx + self.crop_width] += prob_map[crop_indx, 0]
62
+ crop_indx += 1
63
+ new_prob_map = torch.div(new_prob_map, self._counts)
64
+
65
+ return new_prob_map
66
+
67
+ def get_state(self):
68
+ return self.x_offsets, self.y_offsets, self._counts
69
+
70
+ def set_state(self, state):
71
+ self.x_offsets, self.y_offsets, self._counts = state
72
+
73
+ def reset(self):
74
+ self.x_offsets = None
75
+ self.y_offsets = None
76
+ self._counts = None
77
+
78
+
79
+ def get_offsets(length, crop_size, min_overlap_ratio=0.2):
80
+ if length == crop_size:
81
+ return [0]
82
+
83
+ N = (length / crop_size - min_overlap_ratio) / (1 - min_overlap_ratio)
84
+ N = math.ceil(N)
85
+
86
+ overlap_ratio = (N - length / crop_size) / (N - 1)
87
+ overlap_width = int(crop_size * overlap_ratio)
88
+
89
+ offsets = [0]
90
+ for i in range(1, N):
91
+ new_offset = offsets[-1] + crop_size - overlap_width
92
+ if new_offset + crop_size > length:
93
+ new_offset = length - crop_size
94
+
95
+ offsets.append(new_offset)
96
+
97
+ return offsets
isegm/inference/transforms/flip.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from typing import List
4
+ from isegm.inference.clicker import Click
5
+ from .base import BaseTransform
6
+
7
+
8
+ class AddHorizontalFlip(BaseTransform):
9
+ def transform(self, image_nd, clicks_lists: List[List[Click]]):
10
+ assert len(image_nd.shape) == 4
11
+ image_nd = torch.cat([image_nd, torch.flip(image_nd, dims=[3])], dim=0)
12
+
13
+ image_width = image_nd.shape[3]
14
+ clicks_lists_flipped = []
15
+ for clicks_list in clicks_lists:
16
+ clicks_list_flipped = [click.copy(coords=(click.coords[0], image_width - click.coords[1] - 1))
17
+ for click in clicks_list]
18
+ clicks_lists_flipped.append(clicks_list_flipped)
19
+ clicks_lists = clicks_lists + clicks_lists_flipped
20
+
21
+ return image_nd, clicks_lists
22
+
23
+ def inv_transform(self, prob_map):
24
+ assert len(prob_map.shape) == 4 and prob_map.shape[0] % 2 == 0
25
+ num_maps = prob_map.shape[0] // 2
26
+ prob_map, prob_map_flipped = prob_map[:num_maps], prob_map[num_maps:]
27
+
28
+ return 0.5 * (prob_map + torch.flip(prob_map_flipped, dims=[3]))
29
+
30
+ def get_state(self):
31
+ return None
32
+
33
+ def set_state(self, state):
34
+ pass
35
+
36
+ def reset(self):
37
+ pass
isegm/inference/transforms/limit_longest_side.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .zoom_in import ZoomIn, get_roi_image_nd
2
+
3
+
4
+ class LimitLongestSide(ZoomIn):
5
+ def __init__(self, max_size=800):
6
+ super().__init__(target_size=max_size, skip_clicks=0)
7
+
8
+ def transform(self, image_nd, clicks_lists):
9
+ assert image_nd.shape[0] == 1 and len(clicks_lists) == 1
10
+ image_max_size = max(image_nd.shape[2:4])
11
+ self.image_changed = False
12
+
13
+ if image_max_size <= self.target_size:
14
+ return image_nd, clicks_lists
15
+ self._input_image = image_nd
16
+
17
+ self._object_roi = (0, image_nd.shape[2] - 1, 0, image_nd.shape[3] - 1)
18
+ self._roi_image = get_roi_image_nd(image_nd, self._object_roi, self.target_size)
19
+ self.image_changed = True
20
+
21
+ tclicks_lists = [self._transform_clicks(clicks_lists[0])]
22
+ return self._roi_image, tclicks_lists
isegm/inference/transforms/zoom_in.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from typing import List
4
+ from isegm.inference.clicker import Click
5
+ from isegm.utils.misc import get_bbox_iou, get_bbox_from_mask, expand_bbox, clamp_bbox
6
+ from .base import BaseTransform
7
+
8
+
9
+ class ZoomIn(BaseTransform):
10
+ def __init__(self,
11
+ target_size=400,
12
+ skip_clicks=1,
13
+ expansion_ratio=1.4,
14
+ min_crop_size=200,
15
+ recompute_thresh_iou=0.5,
16
+ prob_thresh=0.50):
17
+ super().__init__()
18
+ self.target_size = target_size
19
+ self.min_crop_size = min_crop_size
20
+ self.skip_clicks = skip_clicks
21
+ self.expansion_ratio = expansion_ratio
22
+ self.recompute_thresh_iou = recompute_thresh_iou
23
+ self.prob_thresh = prob_thresh
24
+
25
+ self._input_image_shape = None
26
+ self._prev_probs = None
27
+ self._object_roi = None
28
+ self._roi_image = None
29
+
30
+ def transform(self, image_nd, clicks_lists: List[List[Click]]):
31
+ transformed_image = []
32
+ transformed_clicks_lists = []
33
+ for bindx in range(len(clicks_lists)):
34
+ new_image_nd, new_clicks_lists = self._transform(image_nd[bindx].unsqueeze(0), [clicks_lists[bindx]])
35
+ transformed_image.append(new_image_nd)
36
+ transformed_clicks_lists.append(new_clicks_lists[0])
37
+ return torch.cat(transformed_image, dim=0), transformed_clicks_lists
38
+
39
+ def _transform(self, image_nd, clicks_lists: List[List[Click]]):
40
+ assert image_nd.shape[0] == 1 and len(clicks_lists) == 1
41
+ self.image_changed = False
42
+
43
+ clicks_list = clicks_lists[0]
44
+ if len(clicks_list) <= self.skip_clicks:
45
+ return image_nd, clicks_lists
46
+
47
+ self._input_image_shape = image_nd.shape
48
+
49
+ current_object_roi = None
50
+ if self._prev_probs is not None:
51
+ current_pred_mask = (self._prev_probs > self.prob_thresh)[0, 0]
52
+ if current_pred_mask.sum() > 0:
53
+ current_object_roi = get_object_roi(current_pred_mask, clicks_list,
54
+ self.expansion_ratio, self.min_crop_size)
55
+
56
+ if current_object_roi is None:
57
+ if self.skip_clicks >= 0:
58
+ return image_nd, clicks_lists
59
+ else:
60
+ current_object_roi = 0, image_nd.shape[2] - 1, 0, image_nd.shape[3] - 1
61
+
62
+ update_object_roi = False
63
+ if self._object_roi is None:
64
+ update_object_roi = True
65
+ elif not check_object_roi(self._object_roi, clicks_list):
66
+ update_object_roi = True
67
+ elif get_bbox_iou(current_object_roi, self._object_roi) < self.recompute_thresh_iou:
68
+ update_object_roi = True
69
+
70
+ if update_object_roi:
71
+ self._object_roi = current_object_roi
72
+ self.image_changed = True
73
+ self._roi_image = get_roi_image_nd(image_nd, self._object_roi, self.target_size)
74
+ tclicks_lists = [self._transform_clicks(clicks_list)]
75
+ return self._roi_image.to(image_nd.device), tclicks_lists
76
+
77
+ def inv_transform(self, prob_map):
78
+ new_prob_maps = []
79
+ for bindx in range(prob_map.shape[0]):
80
+ new_prob_map = self._inv_transform(prob_map[bindx].unsqueeze(0))
81
+ new_prob_maps.append(new_prob_map)
82
+ return torch.cat(new_prob_maps, dim=0)
83
+
84
+ def _inv_transform(self, prob_map):
85
+ if self._object_roi is None:
86
+ self._prev_probs = prob_map.cpu().numpy()
87
+ return prob_map
88
+
89
+ assert prob_map.shape[0] == 1
90
+ rmin, rmax, cmin, cmax = self._object_roi
91
+ prob_map = torch.nn.functional.interpolate(prob_map, size=(rmax - rmin + 1, cmax - cmin + 1),
92
+ mode='bilinear', align_corners=True)
93
+
94
+ if self._prev_probs is not None:
95
+ new_prob_map = torch.zeros(*self._prev_probs.shape, device=prob_map.device, dtype=prob_map.dtype)
96
+ new_prob_map[:, :, rmin:rmax + 1, cmin:cmax + 1] = prob_map
97
+ else:
98
+ new_prob_map = prob_map
99
+
100
+ self._prev_probs = new_prob_map.cpu().numpy()
101
+
102
+ return new_prob_map
103
+
104
+ def check_possible_recalculation(self):
105
+ if self._prev_probs is None or self._object_roi is not None or self.skip_clicks > 0:
106
+ return False
107
+
108
+ pred_mask = (self._prev_probs > self.prob_thresh)[0, 0]
109
+ if pred_mask.sum() > 0:
110
+ possible_object_roi = get_object_roi(pred_mask, [],
111
+ self.expansion_ratio, self.min_crop_size)
112
+ image_roi = (0, self._input_image_shape[2] - 1, 0, self._input_image_shape[3] - 1)
113
+ if get_bbox_iou(possible_object_roi, image_roi) < 0.50:
114
+ return True
115
+ return False
116
+
117
+ def get_state(self):
118
+ roi_image = self._roi_image.cpu() if self._roi_image is not None else None
119
+ return self._input_image_shape, self._object_roi, self._prev_probs, roi_image, self.image_changed
120
+
121
+ def set_state(self, state):
122
+ self._input_image_shape, self._object_roi, self._prev_probs, self._roi_image, self.image_changed = state
123
+
124
+ def reset(self):
125
+ self._input_image_shape = None
126
+ self._object_roi = None
127
+ self._prev_probs = None
128
+ self._roi_image = None
129
+ self.image_changed = False
130
+
131
+ def _transform_clicks(self, clicks_list):
132
+ if self._object_roi is None:
133
+ return clicks_list
134
+
135
+ rmin, rmax, cmin, cmax = self._object_roi
136
+ crop_height, crop_width = self._roi_image.shape[2:]
137
+
138
+ transformed_clicks = []
139
+ for click in clicks_list:
140
+ new_r = crop_height * (click.coords[0] - rmin) / (rmax - rmin + 1)
141
+ new_c = crop_width * (click.coords[1] - cmin) / (cmax - cmin + 1)
142
+ transformed_clicks.append(click.copy(coords=(new_r, new_c)))
143
+ return transformed_clicks
144
+
145
+
146
+ def get_object_roi(pred_mask, clicks_list, expansion_ratio, min_crop_size):
147
+ pred_mask = pred_mask.copy()
148
+
149
+ for click in clicks_list:
150
+ if click.is_positive:
151
+ pred_mask[int(click.coords[0]), int(click.coords[1])] = 1
152
+
153
+ bbox = get_bbox_from_mask(pred_mask)
154
+ bbox = expand_bbox(bbox, expansion_ratio, min_crop_size)
155
+ h, w = pred_mask.shape[0], pred_mask.shape[1]
156
+ bbox = clamp_bbox(bbox, 0, h - 1, 0, w - 1)
157
+
158
+ return bbox
159
+
160
+
161
+ def get_roi_image_nd(image_nd, object_roi, target_size):
162
+ rmin, rmax, cmin, cmax = object_roi
163
+
164
+ height = rmax - rmin + 1
165
+ width = cmax - cmin + 1
166
+
167
+ if isinstance(target_size, tuple):
168
+ new_height, new_width = target_size
169
+ else:
170
+ scale = target_size / max(height, width)
171
+ new_height = int(round(height * scale))
172
+ new_width = int(round(width * scale))
173
+
174
+ with torch.no_grad():
175
+ roi_image_nd = image_nd[:, :, rmin:rmax + 1, cmin:cmax + 1]
176
+ roi_image_nd = torch.nn.functional.interpolate(roi_image_nd, size=(new_height, new_width),
177
+ mode='bilinear', align_corners=True)
178
+
179
+ return roi_image_nd
180
+
181
+
182
+ def check_object_roi(object_roi, clicks_list):
183
+ for click in clicks_list:
184
+ if click.is_positive:
185
+ if click.coords[0] < object_roi[0] or click.coords[0] >= object_roi[1]:
186
+ return False
187
+ if click.coords[1] < object_roi[2] or click.coords[1] >= object_roi[3]:
188
+ return False
189
+
190
+ return True
isegm/inference/utils.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import timedelta
2
+ from pathlib import Path
3
+ import torch
4
+ import numpy as np
5
+ from isegm.utils.serialization import load_model
6
+
7
+
8
+ def get_time_metrics(all_ious, elapsed_time):
9
+ n_images = len(all_ious)
10
+ n_clicks = sum(map(len, all_ious))
11
+
12
+ mean_spc = elapsed_time / n_clicks
13
+ mean_spi = elapsed_time / n_images
14
+
15
+ return mean_spc, mean_spi
16
+
17
+
18
+ def load_is_model(checkpoint, device, eval_ritm, lora_checkpoint=None, **kwargs):
19
+ if isinstance(checkpoint, (str, Path)):
20
+ state_dict = torch.load(checkpoint, map_location='cpu')
21
+ else:
22
+ state_dict = checkpoint
23
+ if isinstance(state_dict, list):
24
+ model = load_single_is_model(state_dict[0], device, eval_ritm, **kwargs)
25
+ models = [load_single_is_model(x, device, eval_ritm, **kwargs) for x in state_dict]
26
+
27
+ return model, models
28
+ else:
29
+ return load_single_is_model(state_dict, device, eval_ritm, lora_checkpoint=lora_checkpoint, **kwargs)
30
+
31
+
32
+ def load_single_is_model(state_dict, device, eval_ritm, lora_checkpoint=None, **kwargs):
33
+ if 'config' in state_dict.keys():
34
+ _config = state_dict['config']
35
+ if lora_checkpoint is not None:
36
+ lora_state_dict = torch.load(lora_checkpoint, map_location='cpu')
37
+ _config = lora_state_dict['config']
38
+
39
+ model = load_model(_config, eval_ritm, **kwargs)
40
+ print("Load predictor weights...")
41
+ if 'state_dict' in state_dict.keys():
42
+ msg = model.load_state_dict(state_dict['state_dict'], strict=False)
43
+ else:
44
+ try:
45
+ msg = model.load_state_dict(state_dict, strict=False)
46
+ except:
47
+ current_state_dict = model.state_dict()
48
+
49
+ new_state_dict = {}
50
+ for k, v in state_dict.items():
51
+ if k in current_state_dict and v.shape == current_state_dict[k].shape:
52
+ new_state_dict[k] = v
53
+
54
+ msg = model.load_state_dict(new_state_dict, strict=False)
55
+ print(msg)
56
+
57
+ if lora_checkpoint is not None:
58
+ print("Load predictor LoRA weights...")
59
+ msg = model.load_state_dict(lora_state_dict['state_dict'], strict=False)
60
+ print(msg[1])
61
+
62
+ for param in model.parameters():
63
+ param.requires_grad = False
64
+ model.to(device)
65
+ model.eval()
66
+
67
+ return model
68
+
69
+
70
+ def get_iou(gt_mask, pred_mask, ignore_label=-1):
71
+ ignore_gt_mask_inv = gt_mask != ignore_label
72
+ obj_gt_mask = gt_mask == 1
73
+
74
+ intersection = np.logical_and(np.logical_and(pred_mask, obj_gt_mask), ignore_gt_mask_inv).sum()
75
+ union = np.logical_and(np.logical_or(pred_mask, obj_gt_mask), ignore_gt_mask_inv).sum()
76
+
77
+ return intersection / union
78
+
79
+
80
+ def compute_noc_metric(all_ious, iou_thrs, max_clicks=20):
81
+ def _get_noc(iou_arr, iou_thr):
82
+ vals = iou_arr >= iou_thr
83
+ return np.argmax(vals) + 1 if np.any(vals) else max_clicks
84
+
85
+ noc_list = []
86
+ noc_list_std = []
87
+ over_max_list = []
88
+ for iou_thr in iou_thrs:
89
+ scores_arr = np.array([_get_noc(iou_arr, iou_thr)
90
+ for iou_arr in all_ious], dtype=np.int_)
91
+
92
+ score = scores_arr.mean()
93
+ score_std = scores_arr.std()
94
+ over_max = (scores_arr == max_clicks).sum()
95
+
96
+ noc_list.append(score)
97
+ noc_list_std.append(score_std)
98
+ over_max_list.append(over_max)
99
+
100
+ return noc_list, noc_list_std, over_max_list
101
+
102
+
103
+ def find_checkpoint(weights_folder, checkpoint_name):
104
+ weights_folder = Path(weights_folder)
105
+ if ':' in checkpoint_name:
106
+ model_name, checkpoint_name = checkpoint_name.split(':')
107
+ models_candidates = [x for x in weights_folder.glob(f'{model_name}*') if x.is_dir()]
108
+ assert len(models_candidates) == 1
109
+ model_folder = models_candidates[0]
110
+ else:
111
+ model_folder = weights_folder
112
+
113
+ if checkpoint_name.endswith('.pth'):
114
+ if Path(checkpoint_name).exists():
115
+ checkpoint_path = checkpoint_name
116
+ else:
117
+ checkpoint_path = weights_folder / checkpoint_name
118
+ else:
119
+ model_checkpoints = list(model_folder.rglob(f'{checkpoint_name}*.pth'))
120
+ assert len(model_checkpoints) == 1
121
+ checkpoint_path = model_checkpoints[0]
122
+
123
+ return str(checkpoint_path)
124
+
125
+
126
+ def get_results_table(noc_list, over_max_list, brs_type, dataset_name, mean_spc, elapsed_time, iou_first,
127
+ n_clicks=20, model_name=None):
128
+ table_header = (f'|{"BRS Type":^13}|{"Dataset":^11}|'
129
+ f'{"NoC@80%":^9}|{"NoC@85%":^9}|{"NoC@90%":^9}|'
130
+ f'{"IoU@1":^9}|'
131
+ f'{">="+str(n_clicks)+"@85%":^9}|{">="+str(n_clicks)+"@90%":^9}|'
132
+ f'{"SPC,s":^7}|{"Time":^9}|')
133
+ row_width = len(table_header)
134
+
135
+ header = f'Eval results for model: {model_name}\n' if model_name is not None else ''
136
+ header += '-' * row_width + '\n'
137
+ header += table_header + '\n' + '-' * row_width
138
+
139
+ eval_time = str(timedelta(seconds=int(elapsed_time)))
140
+ table_row = f'|{brs_type:^13}|{dataset_name:^11}|'
141
+ table_row += f'{noc_list[0]:^9.2f}|'
142
+ table_row += f'{noc_list[1]:^9.2f}|' if len(noc_list) > 1 else f'{"?":^9}|'
143
+ table_row += f'{noc_list[2]:^9.2f}|' if len(noc_list) > 2 else f'{"?":^9}|'
144
+ table_row += f'{iou_first:^9.2f}|'
145
+ table_row += f'{over_max_list[1]:^9}|' if len(noc_list) > 1 else f'{"?":^9}|'
146
+ table_row += f'{over_max_list[2]:^9}|' if len(noc_list) > 2 else f'{"?":^9}|'
147
+ table_row += f'{mean_spc:^7.3f}|{eval_time:^9}|'
148
+
149
+ return header, table_row
isegm/model/__init__.py ADDED
File without changes
isegm/model/build_sam.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+
9
+ from functools import partial
10
+
11
+ from .sam_modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer, SAMISWrapper
12
+
13
+
14
+ def build_sam_vit_h(checkpoint=None, enable_lora=False, enable_gra=False, mode='eval', image_size=1024):
15
+ return _build_sam(
16
+ encoder_embed_dim=1280,
17
+ encoder_depth=32,
18
+ encoder_num_heads=16,
19
+ encoder_global_attn_indexes=[7, 15, 23, 31],
20
+ checkpoint=checkpoint,
21
+ enable_lora=enable_lora,
22
+ enable_gra=enable_gra,
23
+ mode=mode,
24
+ image_size=image_size,
25
+ )
26
+
27
+
28
+ build_sam = build_sam_vit_h
29
+
30
+
31
+ def build_sam_vit_l(checkpoint=None, enable_lora=False, enable_gra=False, mode='eval', image_size=1024):
32
+ return _build_sam(
33
+ encoder_embed_dim=1024,
34
+ encoder_depth=24,
35
+ encoder_num_heads=16,
36
+ encoder_global_attn_indexes=[5, 11, 17, 23],
37
+ checkpoint=checkpoint,
38
+ enable_lora=enable_lora,
39
+ enable_gra=enable_gra,
40
+ mode=mode,
41
+ image_size=image_size,
42
+ )
43
+
44
+
45
+ def build_sam_vit_b(checkpoint=None, enable_lora=False, enable_gra=False, mode='eval', image_size=1024):
46
+ return _build_sam(
47
+ encoder_embed_dim=768,
48
+ encoder_depth=12,
49
+ encoder_num_heads=12,
50
+ encoder_global_attn_indexes=[2, 5, 8, 11],
51
+ checkpoint=checkpoint,
52
+ enable_lora=enable_lora,
53
+ enable_gra=enable_gra,
54
+ mode=mode,
55
+ image_size=image_size,
56
+ )
57
+
58
+
59
+ sam_model_registry = {
60
+ "default": build_sam_vit_h,
61
+ "vit_h": build_sam_vit_h,
62
+ "vit_l": build_sam_vit_l,
63
+ "vit_b": build_sam_vit_b,
64
+ }
65
+
66
+
67
+ def _build_sam(
68
+ encoder_embed_dim,
69
+ encoder_depth,
70
+ encoder_num_heads,
71
+ encoder_global_attn_indexes,
72
+ checkpoint=None,
73
+ enable_lora=False,
74
+ enable_gra=False,
75
+ mode='eval',
76
+ image_size=1024,
77
+ ):
78
+ prompt_embed_dim = 256
79
+ image_size = image_size
80
+ vit_patch_size = 16
81
+ image_embedding_size = image_size // vit_patch_size
82
+
83
+ if mode == 'train':
84
+ sam = SAMISWrapper(
85
+ encoder_embed_dim=encoder_embed_dim,
86
+ encoder_depth=encoder_depth,
87
+ encoder_num_heads=encoder_num_heads,
88
+ encoder_global_attn_indexes=encoder_global_attn_indexes,
89
+ enable_lora=enable_lora,
90
+ enable_gra=enable_gra,
91
+ with_prev_mask=True,
92
+ image_size=image_size,
93
+ pixel_mean=[123.675, 116.28, 103.53],
94
+ pixel_std=[58.395, 57.12, 57.375],
95
+ )
96
+ else:
97
+ sam = Sam(
98
+ image_encoder=ImageEncoderViT(
99
+ depth=encoder_depth,
100
+ embed_dim=encoder_embed_dim,
101
+ img_size=image_size,
102
+ mlp_ratio=4,
103
+ norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
104
+ num_heads=encoder_num_heads,
105
+ patch_size=vit_patch_size,
106
+ qkv_bias=True,
107
+ use_rel_pos=True,
108
+ global_attn_indexes=encoder_global_attn_indexes,
109
+ window_size=14,
110
+ out_chans=prompt_embed_dim,
111
+ ),
112
+ prompt_encoder=PromptEncoder(
113
+ embed_dim=prompt_embed_dim,
114
+ image_embedding_size=(image_embedding_size, image_embedding_size),
115
+ input_image_size=(image_size, image_size),
116
+ mask_in_chans=16,
117
+ ),
118
+ mask_decoder=MaskDecoder(
119
+ num_multimask_outputs=3,
120
+ transformer=TwoWayTransformer(
121
+ depth=2,
122
+ embedding_dim=prompt_embed_dim,
123
+ mlp_dim=2048,
124
+ num_heads=8,
125
+ ),
126
+ transformer_dim=prompt_embed_dim,
127
+ iou_head_depth=3,
128
+ iou_head_hidden_dim=256,
129
+ ),
130
+ pixel_mean=[123.675, 116.28, 103.53],
131
+ pixel_std=[58.395, 57.12, 57.375],
132
+ )
133
+ sam.eval()
134
+ if checkpoint is not None:
135
+ with open(checkpoint, "rb") as f:
136
+ pretrained_dict = torch.load(f)
137
+
138
+ model_dict = sam.state_dict()
139
+ new_pretrained_dict = {}
140
+ for k, v in pretrained_dict.items():
141
+ if k in model_dict and v.shape == model_dict[k].shape:
142
+ new_pretrained_dict[k] = v
143
+ msg = sam.load_state_dict(new_pretrained_dict, strict=False)
144
+ print("SAM load Info: ", msg)
145
+ return sam
isegm/model/initializer.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+
5
+
6
+ class Initializer(object):
7
+ def __init__(self, local_init=True, gamma=None):
8
+ self.local_init = local_init
9
+ self.gamma = gamma
10
+
11
+ def __call__(self, m):
12
+ if getattr(m, '__initialized', False):
13
+ return
14
+
15
+ if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d,
16
+ nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d,
17
+ nn.GroupNorm, nn.SyncBatchNorm)) or 'BatchNorm' in m.__class__.__name__:
18
+ if m.weight is not None:
19
+ self._init_gamma(m.weight.data)
20
+ if m.bias is not None:
21
+ self._init_beta(m.bias.data)
22
+ else:
23
+ if getattr(m, 'weight', None) is not None:
24
+ self._init_weight(m.weight.data)
25
+ if getattr(m, 'bias', None) is not None:
26
+ self._init_bias(m.bias.data)
27
+
28
+ if self.local_init:
29
+ object.__setattr__(m, '__initialized', True)
30
+
31
+ def _init_weight(self, data):
32
+ nn.init.uniform_(data, -0.07, 0.07)
33
+
34
+ def _init_bias(self, data):
35
+ nn.init.constant_(data, 0)
36
+
37
+ def _init_gamma(self, data):
38
+ if self.gamma is None:
39
+ nn.init.constant_(data, 1.0)
40
+ else:
41
+ nn.init.normal_(data, 1.0, self.gamma)
42
+
43
+ def _init_beta(self, data):
44
+ nn.init.constant_(data, 0)
45
+
46
+
47
+ class Bilinear(Initializer):
48
+ def __init__(self, scale, groups, in_channels, **kwargs):
49
+ super().__init__(**kwargs)
50
+ self.scale = scale
51
+ self.groups = groups
52
+ self.in_channels = in_channels
53
+
54
+ def _init_weight(self, data):
55
+ """Reset the weight and bias."""
56
+ bilinear_kernel = self.get_bilinear_kernel(self.scale)
57
+ weight = torch.zeros_like(data)
58
+ for i in range(self.in_channels):
59
+ if self.groups == 1:
60
+ j = i
61
+ else:
62
+ j = 0
63
+ weight[i, j] = bilinear_kernel
64
+ data[:] = weight
65
+
66
+ @staticmethod
67
+ def get_bilinear_kernel(scale):
68
+ """Generate a bilinear upsampling kernel."""
69
+ kernel_size = 2 * scale - scale % 2
70
+ scale = (kernel_size + 1) // 2
71
+ center = scale - 0.5 * (1 + kernel_size % 2)
72
+
73
+ og = np.ogrid[:kernel_size, :kernel_size]
74
+ kernel = (1 - np.abs(og[0] - center) / scale) * (1 - np.abs(og[1] - center) / scale)
75
+
76
+ return torch.tensor(kernel, dtype=torch.float32)
77
+
78
+
79
+ class XavierGluon(Initializer):
80
+ def __init__(self, rnd_type='uniform', factor_type='avg', magnitude=3, **kwargs):
81
+ super().__init__(**kwargs)
82
+
83
+ self.rnd_type = rnd_type
84
+ self.factor_type = factor_type
85
+ self.magnitude = float(magnitude)
86
+
87
+ def _init_weight(self, arr):
88
+ fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(arr)
89
+
90
+ if self.factor_type == 'avg':
91
+ factor = (fan_in + fan_out) / 2.0
92
+ elif self.factor_type == 'in':
93
+ factor = fan_in
94
+ elif self.factor_type == 'out':
95
+ factor = fan_out
96
+ else:
97
+ raise ValueError('Incorrect factor type')
98
+ scale = np.sqrt(self.magnitude / factor)
99
+
100
+ if self.rnd_type == 'uniform':
101
+ nn.init.uniform_(arr, -scale, scale)
102
+ elif self.rnd_type == 'gaussian':
103
+ nn.init.normal_(arr, 0, scale)
104
+ else:
105
+ raise ValueError('Unknown random type')
isegm/model/is_deeplab_model.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ from isegm.utils.serialization import serialize
4
+ from .is_model import ISModel
5
+ from .modeling.deeplab_v3 import DeepLabV3Plus
6
+ from .modeling.basic_blocks import SepConvHead
7
+ from isegm.model.modifiers import LRMult
8
+
9
+
10
+ class DeeplabModel(ISModel):
11
+ @serialize
12
+ def __init__(self, backbone='resnet50', deeplab_ch=256, aspp_dropout=0.5,
13
+ backbone_norm_layer=None, backbone_lr_mult=0.1, norm_layer=nn.BatchNorm2d, **kwargs):
14
+ super().__init__(norm_layer=norm_layer, **kwargs)
15
+
16
+ self.feature_extractor = DeepLabV3Plus(backbone=backbone, ch=deeplab_ch, project_dropout=aspp_dropout,
17
+ norm_layer=norm_layer, backbone_norm_layer=backbone_norm_layer)
18
+ self.feature_extractor.backbone.apply(LRMult(backbone_lr_mult))
19
+ self.head = SepConvHead(1, in_channels=deeplab_ch, mid_channels=deeplab_ch // 2,
20
+ num_layers=2, norm_layer=norm_layer)
21
+
22
+ def backbone_forward(self, image, coord_features=None):
23
+ backbone_features = self.feature_extractor(image, coord_features)
24
+
25
+ return {'instances': self.head(backbone_features[0])}
isegm/model/is_hrformer_model.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from collections import OrderedDict
5
+
6
+ from isegm.utils.serialization import serialize
7
+ from .is_model import ISModel
8
+ from isegm.model.modifiers import LRMult
9
+ from .modeling.hrformer import HRT_B_OCR_V3
10
+
11
+ class HRFormerModel(ISModel):
12
+ @serialize
13
+ def __init__(
14
+ self,
15
+ num_classes=1,
16
+ in_ch=6,
17
+ backbone_lr_mult=0.1,
18
+ **kwargs
19
+ ):
20
+
21
+ super().__init__(**kwargs)
22
+
23
+ self.feature_extractor = HRT_B_OCR_V3(num_classes, in_ch)
24
+ self.feature_extractor.apply(LRMult(backbone_lr_mult))
25
+
26
+ def backbone_forward(self, image, coord_features=None):
27
+ backbone_features = self.feature_extractor(image)
28
+ return {'instances': backbone_features[0], 'instances_aux': backbone_features[1]}
29
+
30
+ def init_weight(self, pretrained=None):
31
+ if pretrained is not None:
32
+ state_dict = torch.load(pretrained)['model']
33
+ state_dict_rename = OrderedDict()
34
+ for k, v in state_dict.items():
35
+ state_dict_rename['backbone.' + k] = v
36
+
37
+ ori_proj_weight = state_dict_rename['backbone.conv1.weight']
38
+ state_dict_rename['backbone.conv1.weight'] = torch.cat([ori_proj_weight, ori_proj_weight], dim=1)
39
+
40
+ self.feature_extractor.load_state_dict(state_dict_rename, False)
41
+ print('Successfully loaded pretrained model.')
isegm/model/is_hrnet_model.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ from isegm.utils.serialization import serialize
4
+ from .is_model import ISModel
5
+ from .modeling.hrnet_ocr import HighResolutionNet
6
+ from isegm.model.modifiers import LRMult
7
+
8
+
9
+ class HRNetModel(ISModel):
10
+ @serialize
11
+ def __init__(self, width=48, ocr_width=256, small=False, backbone_lr_mult=0.1,
12
+ norm_layer=nn.BatchNorm2d, **kwargs):
13
+ super().__init__(**kwargs)
14
+
15
+ self.feature_extractor = HighResolutionNet(width=width, ocr_width=ocr_width, small=small,
16
+ num_classes=1, norm_layer=norm_layer)
17
+ self.feature_extractor.apply(LRMult(backbone_lr_mult))
18
+ if ocr_width > 0:
19
+ self.feature_extractor.ocr_distri_head.apply(LRMult(1.0))
20
+ self.feature_extractor.ocr_gather_head.apply(LRMult(1.0))
21
+ self.feature_extractor.conv3x3_ocr.apply(LRMult(1.0))
22
+
23
+ def backbone_forward(self, image, coord_features=None):
24
+ net_outputs = self.feature_extractor(image, coord_features)
25
+
26
+ return {'instances': net_outputs[0], 'instances_aux': net_outputs[1]}
isegm/model/is_model.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+
5
+ from isegm.model.ops import DistMaps, BatchImageNormalize, ScaleLayer
6
+
7
+
8
+ class ISModel(nn.Module):
9
+ def __init__(self, with_aux_output=False, norm_radius=5, use_disks=False, cpu_dist_maps=False,
10
+ use_rgb_conv=False, use_leaky_relu=False, # the two arguments only used for RITM
11
+ with_prev_mask=False, norm_mean_std=([.485, .456, .406], [.229, .224, .225])):
12
+ super().__init__()
13
+
14
+ self.with_aux_output = with_aux_output
15
+ self.with_prev_mask = with_prev_mask
16
+ self.normalization = BatchImageNormalize(norm_mean_std[0], norm_mean_std[1])
17
+
18
+ self.coord_feature_ch = 2
19
+ if self.with_prev_mask:
20
+ self.coord_feature_ch += 1
21
+
22
+ if use_rgb_conv:
23
+ # Only RITM models need to transform the coordinate features, though they don't use
24
+ # exact 'rgb_conv'. We keep 'use_rgb_conv' only for compatible issues.
25
+ # The simpleclick models use a patch embedding layer instead
26
+ mt_layers = [
27
+ nn.Conv2d(in_channels=self.coord_feature_ch, out_channels=16, kernel_size=1),
28
+ nn.LeakyReLU(negative_slope=0.2) if use_leaky_relu else nn.ReLU(inplace=True),
29
+ nn.Conv2d(in_channels=16, out_channels=64, kernel_size=3, stride=2, padding=1),
30
+ ScaleLayer(init_value=0.05, lr_mult=1)
31
+ ]
32
+ self.maps_transform = nn.Sequential(*mt_layers)
33
+ else:
34
+ self.maps_transform=nn.Identity()
35
+
36
+ self.dist_maps = DistMaps(norm_radius=norm_radius, spatial_scale=1.0,
37
+ cpu_mode=cpu_dist_maps, use_disks=use_disks)
38
+
39
+ def forward(self, image, points, text=None, gra=None):
40
+ image, prev_mask = self.prepare_input(image)
41
+ coord_features = self.get_coord_features(image, prev_mask, points)
42
+ coord_features = self.maps_transform(coord_features)
43
+
44
+ if gra is not None and text is not None:
45
+ outputs = self.backbone_forward(image, coord_features, text=text, gra=gra)
46
+ elif gra is not None:
47
+ outputs = self.backbone_forward(image, coord_features, gra=gra)
48
+ elif text is not None:
49
+ outputs = self.backbone_forward(image, coord_features, text=text)
50
+ else:
51
+ outputs = self.backbone_forward(image, coord_features)
52
+
53
+ outputs['instances'] = nn.functional.interpolate(outputs['instances'], size=image.size()[2:],
54
+ mode='bilinear', align_corners=True)
55
+ if self.with_aux_output:
56
+ outputs['instances_aux'] = nn.functional.interpolate(outputs['instances_aux'], size=image.size()[2:],
57
+ mode='bilinear', align_corners=True)
58
+
59
+ return outputs
60
+
61
+ def prepare_input(self, image):
62
+ prev_mask = None
63
+ if self.with_prev_mask:
64
+ prev_mask = image[:, 3:, :, :]
65
+ image = image[:, :3, :, :]
66
+
67
+ image = self.normalization(image)
68
+ return image, prev_mask
69
+
70
+ def backbone_forward(self, image, coord_features=None):
71
+ raise NotImplementedError
72
+
73
+ def get_coord_features(self, image, prev_mask, points):
74
+ coord_features = self.dist_maps(image, points)
75
+ if prev_mask is not None:
76
+ coord_features = torch.cat((prev_mask, coord_features), dim=1)
77
+
78
+ return coord_features
79
+
80
+
81
+ def split_points_by_order(tpoints: torch.Tensor, groups):
82
+ points = tpoints.cpu().numpy()
83
+ num_groups = len(groups)
84
+ bs = points.shape[0]
85
+ num_points = points.shape[1] // 2
86
+
87
+ groups = [x if x > 0 else num_points for x in groups]
88
+ group_points = [np.full((bs, 2 * x, 3), -1, dtype=np.float32)
89
+ for x in groups]
90
+
91
+ last_point_indx_group = np.zeros((bs, num_groups, 2), dtype=np.int_)
92
+ for group_indx, group_size in enumerate(groups):
93
+ last_point_indx_group[:, group_indx, 1] = group_size
94
+
95
+ for bindx in range(bs):
96
+ for pindx in range(2 * num_points):
97
+ point = points[bindx, pindx, :]
98
+ group_id = int(point[2])
99
+ if group_id < 0:
100
+ continue
101
+
102
+ is_negative = int(pindx >= num_points)
103
+ if group_id >= num_groups or (group_id == 0 and is_negative): # disable negative first click
104
+ group_id = num_groups - 1
105
+
106
+ new_point_indx = last_point_indx_group[bindx, group_id, is_negative]
107
+ last_point_indx_group[bindx, group_id, is_negative] += 1
108
+
109
+ group_points[group_id][bindx, new_point_indx, :] = point
110
+
111
+ group_points = [torch.tensor(x, dtype=tpoints.dtype, device=tpoints.device)
112
+ for x in group_points]
113
+
114
+ return group_points
isegm/model/is_plainvit_model.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch.nn as nn
3
+ from isegm.utils.serialization import serialize
4
+ from .is_model import ISModel
5
+ from .modeling.models_vit import VisionTransformer, PatchEmbed
6
+ from .modeling.swin_transformer import SwinTransfomerSegHead
7
+
8
+
9
+ class SimpleFPN(nn.Module):
10
+ def __init__(self, in_dim=768, out_dims=[128, 256, 512, 1024]):
11
+ super().__init__()
12
+ self.down_4_chan = max(out_dims[0]*2, in_dim // 2)
13
+ self.down_4 = nn.Sequential(
14
+ nn.ConvTranspose2d(in_dim, self.down_4_chan, 2, stride=2),
15
+ nn.GroupNorm(1, self.down_4_chan),
16
+ nn.GELU(),
17
+ nn.ConvTranspose2d(self.down_4_chan, self.down_4_chan // 2, 2, stride=2),
18
+ nn.GroupNorm(1, self.down_4_chan // 2),
19
+ nn.Conv2d(self.down_4_chan // 2, out_dims[0], 1),
20
+ nn.GroupNorm(1, out_dims[0]),
21
+ nn.GELU()
22
+ )
23
+ self.down_8_chan = max(out_dims[1], in_dim // 2)
24
+ self.down_8 = nn.Sequential(
25
+ nn.ConvTranspose2d(in_dim, self.down_8_chan, 2, stride=2),
26
+ nn.GroupNorm(1, self.down_8_chan),
27
+ nn.Conv2d(self.down_8_chan, out_dims[1], 1),
28
+ nn.GroupNorm(1, out_dims[1]),
29
+ nn.GELU()
30
+ )
31
+ self.down_16 = nn.Sequential(
32
+ nn.Conv2d(in_dim, out_dims[2], 1),
33
+ nn.GroupNorm(1, out_dims[2]),
34
+ nn.GELU()
35
+ )
36
+ self.down_32_chan = max(out_dims[3], in_dim * 2)
37
+ self.down_32 = nn.Sequential(
38
+ nn.Conv2d(in_dim, self.down_32_chan, 2, stride=2),
39
+ nn.GroupNorm(1, self.down_32_chan),
40
+ nn.Conv2d(self.down_32_chan, out_dims[3], 1),
41
+ nn.GroupNorm(1, out_dims[3]),
42
+ nn.GELU()
43
+ )
44
+
45
+ self.init_weights()
46
+
47
+ def init_weights(self):
48
+ pass
49
+
50
+ def forward(self, x):
51
+ x_down_4 = self.down_4(x)
52
+ x_down_8 = self.down_8(x)
53
+ x_down_16 = self.down_16(x)
54
+ x_down_32 = self.down_32(x)
55
+
56
+ return [x_down_4, x_down_8, x_down_16, x_down_32]
57
+
58
+
59
+ class PlainVitModel(ISModel):
60
+ @serialize
61
+ def __init__(
62
+ self,
63
+ backbone_params={},
64
+ neck_params={},
65
+ head_params={},
66
+ random_split=False,
67
+ **kwargs
68
+ ):
69
+
70
+ super().__init__(**kwargs)
71
+ self.random_split = random_split
72
+
73
+ self.patch_embed_coords = PatchEmbed(
74
+ img_size= backbone_params['img_size'],
75
+ patch_size=backbone_params['patch_size'],
76
+ in_chans=3 if self.with_prev_mask else 2,
77
+ embed_dim=backbone_params['embed_dim'],
78
+ )
79
+
80
+ self.backbone = VisionTransformer(**backbone_params)
81
+ self.neck = SimpleFPN(**neck_params)
82
+ self.head = SwinTransfomerSegHead(**head_params)
83
+
84
+ def backbone_forward(self, image, coord_features=None, gra=None):
85
+ coord_features = self.patch_embed_coords(coord_features)
86
+ backbone_features = self.backbone.forward_backbone(image, coord_features, gra=gra, shuffle=self.random_split)
87
+
88
+ # Extract 4 stage backbone feature map: 1/4, 1/8, 1/16, 1/32
89
+ B, N, C = backbone_features.shape
90
+ grid_size = self.backbone.patch_embed.grid_size
91
+
92
+ backbone_features = backbone_features.transpose(-1,-2).view(B, C, grid_size[0], grid_size[1])
93
+ multi_scale_features = self.neck(backbone_features)
94
+
95
+ return {'instances': self.head(multi_scale_features), 'instances_aux': None}
isegm/model/is_plainvit_model_lora.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch.nn as nn
3
+ from isegm.utils.serialization import serialize
4
+ from .is_model import ISModel
5
+ from .modeling.models_vit_lora import VisionTransformer_lora, PatchEmbed
6
+ from .modeling.swin_transformer import SwinTransfomerSegHead
7
+
8
+
9
+ class SimpleFPN(nn.Module):
10
+ def __init__(self, in_dim=768, out_dims=[128, 256, 512, 1024]):
11
+ super().__init__()
12
+ self.down_4_chan = max(out_dims[0]*2, in_dim // 2)
13
+ self.down_4 = nn.Sequential(
14
+ nn.ConvTranspose2d(in_dim, self.down_4_chan, 2, stride=2),
15
+ nn.GroupNorm(1, self.down_4_chan),
16
+ nn.GELU(),
17
+ nn.ConvTranspose2d(self.down_4_chan, self.down_4_chan // 2, 2, stride=2),
18
+ nn.GroupNorm(1, self.down_4_chan // 2),
19
+ nn.Conv2d(self.down_4_chan // 2, out_dims[0], 1),
20
+ nn.GroupNorm(1, out_dims[0]),
21
+ nn.GELU()
22
+ )
23
+ self.down_8_chan = max(out_dims[1], in_dim // 2)
24
+ self.down_8 = nn.Sequential(
25
+ nn.ConvTranspose2d(in_dim, self.down_8_chan, 2, stride=2),
26
+ nn.GroupNorm(1, self.down_8_chan),
27
+ nn.Conv2d(self.down_8_chan, out_dims[1], 1),
28
+ nn.GroupNorm(1, out_dims[1]),
29
+ nn.GELU()
30
+ )
31
+ self.down_16 = nn.Sequential(
32
+ nn.Conv2d(in_dim, out_dims[2], 1),
33
+ nn.GroupNorm(1, out_dims[2]),
34
+ nn.GELU()
35
+ )
36
+ self.down_32_chan = max(out_dims[3], in_dim * 2)
37
+ self.down_32 = nn.Sequential(
38
+ nn.Conv2d(in_dim, self.down_32_chan, 2, stride=2),
39
+ nn.GroupNorm(1, self.down_32_chan),
40
+ nn.Conv2d(self.down_32_chan, out_dims[3], 1),
41
+ nn.GroupNorm(1, out_dims[3]),
42
+ nn.GELU()
43
+ )
44
+
45
+ self.init_weights()
46
+
47
+ def init_weights(self):
48
+ pass
49
+
50
+ def forward(self, x):
51
+ x_down_4 = self.down_4(x)
52
+ x_down_8 = self.down_8(x)
53
+ x_down_16 = self.down_16(x)
54
+ x_down_32 = self.down_32(x)
55
+
56
+ return [x_down_4, x_down_8, x_down_16, x_down_32]
57
+
58
+
59
+ class PlainVitModel_lora(ISModel):
60
+ @serialize
61
+ def __init__(
62
+ self,
63
+ backbone_params={},
64
+ neck_params={},
65
+ head_params={},
66
+ random_split=False,
67
+ **kwargs
68
+ ):
69
+
70
+ super().__init__(**kwargs)
71
+ self.random_split = random_split
72
+
73
+ self.patch_embed_coords = PatchEmbed(
74
+ img_size= backbone_params['img_size'],
75
+ patch_size=backbone_params['patch_size'],
76
+ in_chans=3 if self.with_prev_mask else 2,
77
+ embed_dim=backbone_params['embed_dim'],
78
+ )
79
+
80
+ self.backbone = VisionTransformer_lora(**backbone_params)
81
+ self.neck = SimpleFPN(**neck_params)
82
+ self.head = SwinTransfomerSegHead(**head_params)
83
+
84
+ def backbone_forward(self, image, coord_features=None, gra=None):
85
+ coord_features = self.patch_embed_coords(coord_features)
86
+ backbone_features = self.backbone.forward_backbone(image, coord_features, gra=gra, shuffle=self.random_split)
87
+
88
+ # Extract 4 stage backbone feature map: 1/4, 1/8, 1/16, 1/32
89
+ B, N, C = backbone_features.shape
90
+ grid_size = self.backbone.patch_embed.grid_size
91
+
92
+ backbone_features = backbone_features.transpose(-1,-2).view(B, C, grid_size[0], grid_size[1])
93
+ multi_scale_features = self.neck(backbone_features)
94
+
95
+ return {'instances': self.head(multi_scale_features), 'instances_aux': None}
isegm/model/is_segformer_model.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ from isegm.utils.serialization import serialize
4
+ from .is_model import ISModel
5
+ from isegm.model.modifiers import LRMult
6
+ from .modeling.segformer import MixVisionTransformer, SegformerHead
7
+
8
+
9
+ class SegformerModel(ISModel):
10
+ @serialize
11
+ def __init__(
12
+ self,
13
+ backbone_params=None,
14
+ decode_head_params=None,
15
+ backbone_lr_mult=0.1,
16
+ **kwargs
17
+ ):
18
+
19
+ super().__init__(**kwargs)
20
+
21
+ self.feature_extractor = MixVisionTransformer(**backbone_params)
22
+ self.feature_extractor.apply(LRMult(backbone_lr_mult))
23
+
24
+ self.head = SegformerHead(**decode_head_params)
25
+
26
+ def backbone_forward(self, image, coord_features=None):
27
+ backbone_features = self.feature_extractor(image, coord_features)
28
+ return {'instances': self.head(backbone_features), 'instances_aux': None}
29
+
isegm/model/is_swinformer_model.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from isegm.utils.serialization import serialize
2
+ from .is_model import ISModel
3
+ from .modeling.swin_transformer import SwinTransformer, SwinTransfomerSegHead
4
+
5
+ class SwinformerModel(ISModel):
6
+ @serialize
7
+ def __init__(
8
+ self,
9
+ backbone_params={},
10
+ head_params={},
11
+ **kwargs
12
+ ):
13
+
14
+ super().__init__(**kwargs)
15
+
16
+ self.backbone = SwinTransformer(**backbone_params)
17
+ self.head = SwinTransfomerSegHead(**head_params)
18
+
19
+ def backbone_forward(self, image, coord_features=None):
20
+ backbone_features = self.backbone(image, coord_features)
21
+ return {'instances': self.head(backbone_features), 'instances_aux': None}
isegm/model/is_text_graco_model.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from isegm.utils.serialization import serialize
3
+ from .is_model import ISModel
4
+ from .is_plainvit_model import SimpleFPN
5
+ from .modeling.models_vit import VisionTransformer, PatchEmbed
6
+ from .modeling.twoway_transformer import TwoWayTransformer, PositionEmbeddingRandom
7
+ from .modeling.swin_transformer import SwinTransfomerSegHead
8
+ from .modeling.clip_text_encoding import ClipTextEncoder
9
+
10
+
11
+ class TextGraCoModel(ISModel):
12
+ @serialize
13
+ def __init__(
14
+ self,
15
+ image_encoder_params={},
16
+ text_encoder_params={},
17
+ cross_encoder_params={},
18
+ neck_params={},
19
+ head_params={},
20
+ random_split=False,
21
+ **kwargs
22
+ ):
23
+
24
+ super().__init__(**kwargs)
25
+ self.random_split = random_split
26
+
27
+ self.patch_embed_coords = PatchEmbed(
28
+ img_size=image_encoder_params['img_size'],
29
+ patch_size=image_encoder_params['patch_size'],
30
+ in_chans=3 if self.with_prev_mask else 2,
31
+ embed_dim=image_encoder_params['embed_dim'],
32
+ )
33
+
34
+ self.image_encoder = VisionTransformer(**image_encoder_params)
35
+ self.text_encoder = ClipTextEncoder(**text_encoder_params)
36
+ self.cross_encoder = TwoWayTransformer(**cross_encoder_params)
37
+
38
+ self.pe_layer = PositionEmbeddingRandom(cross_encoder_params["embedding_dim"] // 2)
39
+ patch_size = image_encoder_params['patch_size'][0]
40
+ self.image_embedding_size = image_encoder_params['img_size'][0] // (patch_size if patch_size > 0 else 1)
41
+
42
+ self.neck = SimpleFPN(**neck_params)
43
+ self.head = SwinTransfomerSegHead(**head_params)
44
+
45
+ def backbone_forward(self, image, coord_features=None, text=None, gra=None):
46
+ coord_features = self.patch_embed_coords(coord_features)
47
+ backbone_features = self.image_encoder.forward_backbone(image, coord_features, gra=gra, shuffle=self.random_split)
48
+ text_features = self.text_encoder(text)
49
+
50
+ text_features, backbone_features = self.cross_encoder(
51
+ backbone_features,
52
+ self.pe_layer((self.image_embedding_size, self.image_embedding_size)).unsqueeze(0),
53
+ text_features)
54
+
55
+ # Extract 4 stage image_encoder feature map: 1/4, 1/8, 1/16, 1/32
56
+ B, N, C = backbone_features.shape
57
+ grid_size = self.image_encoder.patch_embed.grid_size
58
+
59
+ backbone_features = backbone_features.transpose(-1,-2).view(B, C, grid_size[0], grid_size[1])
60
+ multi_scale_features = self.neck(backbone_features)
61
+
62
+ return {'instances': self.head(multi_scale_features), 'instances_aux': None}
63
+
isegm/model/losses.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ from isegm.utils import misc
7
+
8
+
9
+ class NormalizedFocalLossSigmoid(nn.Module):
10
+ def __init__(self, axis=-1, alpha=0.25, gamma=2, max_mult=-1, eps=1e-12,
11
+ from_sigmoid=False, detach_delimeter=True,
12
+ batch_axis=0, weight=None, size_average=True,
13
+ ignore_label=-1):
14
+ super(NormalizedFocalLossSigmoid, self).__init__()
15
+ self._axis = axis
16
+ self._alpha = alpha
17
+ self._gamma = gamma
18
+ self._ignore_label = ignore_label
19
+ self._weight = weight if weight is not None else 1.0
20
+ self._batch_axis = batch_axis
21
+
22
+ self._from_logits = from_sigmoid
23
+ self._eps = eps
24
+ self._size_average = size_average
25
+ self._detach_delimeter = detach_delimeter
26
+ self._max_mult = max_mult
27
+ self._k_sum = 0
28
+ self._m_max = 0
29
+
30
+ def forward(self, pred, label):
31
+ one_hot = label > 0.5
32
+ sample_weight = label != self._ignore_label
33
+
34
+ if not self._from_logits:
35
+ pred = torch.sigmoid(pred)
36
+
37
+ alpha = torch.where(one_hot, self._alpha * sample_weight, (1 - self._alpha) * sample_weight)
38
+ pt = torch.where(sample_weight, 1.0 - torch.abs(label - pred), torch.ones_like(pred))
39
+
40
+ beta = (1 - pt) ** self._gamma
41
+
42
+ sw_sum = torch.sum(sample_weight, dim=(-2, -1), keepdim=True)
43
+ beta_sum = torch.sum(beta, dim=(-2, -1), keepdim=True)
44
+ mult = sw_sum / (beta_sum + self._eps)
45
+ if self._detach_delimeter:
46
+ mult = mult.detach()
47
+ beta = beta * mult
48
+ if self._max_mult > 0:
49
+ beta = torch.clamp_max(beta, self._max_mult)
50
+
51
+ with torch.no_grad():
52
+ ignore_area = torch.sum(label == self._ignore_label, dim=tuple(range(1, label.dim()))).cpu().numpy()
53
+ sample_mult = torch.mean(mult, dim=tuple(range(1, mult.dim()))).cpu().numpy()
54
+ if np.any(ignore_area == 0):
55
+ self._k_sum = 0.9 * self._k_sum + 0.1 * sample_mult[ignore_area == 0].mean()
56
+
57
+ beta_pmax, _ = torch.flatten(beta, start_dim=1).max(dim=1)
58
+ beta_pmax = beta_pmax.mean().item()
59
+ self._m_max = 0.8 * self._m_max + 0.2 * beta_pmax
60
+
61
+ loss = -alpha * beta * torch.log(torch.min(pt + self._eps, torch.ones(1, dtype=torch.float).to(pt.device)))
62
+ loss = self._weight * (loss * sample_weight)
63
+
64
+ if self._size_average:
65
+ bsum = torch.sum(sample_weight, dim=misc.get_dims_with_exclusion(sample_weight.dim(), self._batch_axis))
66
+ loss = torch.sum(loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis)) / (bsum + self._eps)
67
+ else:
68
+ loss = torch.sum(loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis))
69
+
70
+ return loss
71
+
72
+ def log_states(self, sw, name, global_step):
73
+ sw.add_scalar(tag=name + '_k', value=self._k_sum, global_step=global_step)
74
+ sw.add_scalar(tag=name + '_m', value=self._m_max, global_step=global_step)
75
+
76
+
77
+ class FocalLoss(nn.Module):
78
+ def __init__(self, axis=-1, alpha=0.25, gamma=2,
79
+ from_logits=False, batch_axis=0,
80
+ weight=None, num_class=None,
81
+ eps=1e-9, size_average=True, scale=1.0,
82
+ ignore_label=-1):
83
+ super(FocalLoss, self).__init__()
84
+ self._axis = axis
85
+ self._alpha = alpha
86
+ self._gamma = gamma
87
+ self._ignore_label = ignore_label
88
+ self._weight = weight if weight is not None else 1.0
89
+ self._batch_axis = batch_axis
90
+
91
+ self._scale = scale
92
+ self._num_class = num_class
93
+ self._from_logits = from_logits
94
+ self._eps = eps
95
+ self._size_average = size_average
96
+
97
+ def forward(self, pred, label, sample_weight=None):
98
+ one_hot = label > 0.5
99
+ sample_weight = label != self._ignore_label
100
+
101
+ if not self._from_logits:
102
+ pred = torch.sigmoid(pred)
103
+
104
+ alpha = torch.where(one_hot, self._alpha * sample_weight, (1 - self._alpha) * sample_weight)
105
+ pt = torch.where(sample_weight, 1.0 - torch.abs(label - pred), torch.ones_like(pred))
106
+
107
+ beta = (1 - pt) ** self._gamma
108
+
109
+ loss = -alpha * beta * torch.log(torch.min(pt + self._eps, torch.ones(1, dtype=torch.float).to(pt.device)))
110
+ loss = self._weight * (loss * sample_weight)
111
+
112
+ if self._size_average:
113
+ tsum = torch.sum(sample_weight, dim=misc.get_dims_with_exclusion(label.dim(), self._batch_axis))
114
+ loss = torch.sum(loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis)) / (tsum + self._eps)
115
+ else:
116
+ loss = torch.sum(loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis))
117
+
118
+ return self._scale * loss
119
+
120
+
121
+ class SoftIoU(nn.Module):
122
+ def __init__(self, from_sigmoid=False, ignore_label=-1):
123
+ super().__init__()
124
+ self._from_sigmoid = from_sigmoid
125
+ self._ignore_label = ignore_label
126
+
127
+ def forward(self, pred, label):
128
+ label = label.view(pred.size())
129
+ sample_weight = label != self._ignore_label
130
+
131
+ if not self._from_sigmoid:
132
+ pred = torch.sigmoid(pred)
133
+
134
+ loss = 1.0 - torch.sum(pred * label * sample_weight, dim=(1, 2, 3)) \
135
+ / (torch.sum(torch.max(pred, label) * sample_weight, dim=(1, 2, 3)) + 1e-8)
136
+
137
+ return loss
138
+
139
+
140
+ class SigmoidBinaryCrossEntropyLoss(nn.Module):
141
+ def __init__(self, from_sigmoid=False, weight=None, batch_axis=0, ignore_label=-1):
142
+ super(SigmoidBinaryCrossEntropyLoss, self).__init__()
143
+ self._from_sigmoid = from_sigmoid
144
+ self._ignore_label = ignore_label
145
+ self._weight = weight if weight is not None else 1.0
146
+ self._batch_axis = batch_axis
147
+
148
+ def forward(self, pred, label):
149
+ label = label.view(pred.size())
150
+ sample_weight = label != self._ignore_label
151
+ label = torch.where(sample_weight, label, torch.zeros_like(label))
152
+
153
+ if not self._from_sigmoid:
154
+ loss = torch.relu(pred) - pred * label + F.softplus(-torch.abs(pred))
155
+ else:
156
+ eps = 1e-12
157
+ loss = -(torch.log(pred + eps) * label
158
+ + torch.log(1. - pred + eps) * (1. - label))
159
+
160
+ loss = self._weight * (loss * sample_weight)
161
+ return torch.mean(loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis))
162
+
163
+
164
+ class BinaryDiceLoss(nn.Module):
165
+ """ Dice Loss for binary segmentation
166
+ """
167
+
168
+ def forward(self, pred, label):
169
+ batchsize = pred.size(0)
170
+
171
+ # convert probability to binary label using maximum probability
172
+ input_pred, input_label = pred.max(1)
173
+ input_pred *= input_label.float()
174
+
175
+ # convert to floats
176
+ input_pred = input_pred.float()
177
+ target_label = label.float()
178
+
179
+ # convert to 1D
180
+ input_pred = input_pred.view(batchsize, -1)
181
+ target_label = target_label.view(batchsize, -1)
182
+
183
+ # compute dice score
184
+ intersect = torch.sum(input_pred * target_label, 1)
185
+ input_area = torch.sum(input_pred * input_pred, 1)
186
+ target_area = torch.sum(target_label * target_label, 1)
187
+
188
+ sum = input_area + target_area
189
+ epsilon = torch.tensor(1e-6)
190
+
191
+ # batch dice loss and ignore dice loss where target area = 0
192
+ batch_loss = torch.tensor(1.0) - (torch.tensor(2.0) * intersect + epsilon) / (sum + epsilon)
193
+ loss = batch_loss.mean()
194
+
195
+ return loss
isegm/model/metrics.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+ from isegm.utils import misc
5
+
6
+
7
+ class TrainMetric(object):
8
+ def __init__(self, pred_outputs, gt_outputs):
9
+ self.pred_outputs = pred_outputs
10
+ self.gt_outputs = gt_outputs
11
+
12
+ def update(self, *args, **kwargs):
13
+ raise NotImplementedError
14
+
15
+ def get_epoch_value(self):
16
+ raise NotImplementedError
17
+
18
+ def reset_epoch_stats(self):
19
+ raise NotImplementedError
20
+
21
+ def log_states(self, sw, tag_prefix, global_step):
22
+ pass
23
+
24
+ @property
25
+ def name(self):
26
+ return type(self).__name__
27
+
28
+
29
+ class AdaptiveIoU(TrainMetric):
30
+ def __init__(self, init_thresh=0.4, thresh_step=0.025, thresh_beta=0.99, iou_beta=0.9,
31
+ ignore_label=-1, from_logits=True,
32
+ pred_output='instances', gt_output='instances'):
33
+ super().__init__(pred_outputs=(pred_output,), gt_outputs=(gt_output,))
34
+ self._ignore_label = ignore_label
35
+ self._from_logits = from_logits
36
+ self._iou_thresh = init_thresh
37
+ self._thresh_step = thresh_step
38
+ self._thresh_beta = thresh_beta
39
+ self._iou_beta = iou_beta
40
+ self._ema_iou = 0.0
41
+ self._epoch_iou_sum = 0.0
42
+ self._epoch_batch_count = 0
43
+
44
+ def update(self, pred, gt):
45
+ gt_mask = gt > 0.5
46
+ if self._from_logits:
47
+ pred = torch.sigmoid(pred)
48
+
49
+ gt_mask_area = torch.sum(gt_mask, dim=(1, 2)).detach().cpu().numpy()
50
+ if np.all(gt_mask_area == 0):
51
+ return
52
+
53
+ ignore_mask = gt == self._ignore_label
54
+ max_iou = _compute_iou(pred > self._iou_thresh, gt_mask, ignore_mask).mean()
55
+ best_thresh = self._iou_thresh
56
+ for t in [best_thresh - self._thresh_step, best_thresh + self._thresh_step]:
57
+ temp_iou = _compute_iou(pred > t, gt_mask, ignore_mask).mean()
58
+ if temp_iou > max_iou:
59
+ max_iou = temp_iou
60
+ best_thresh = t
61
+
62
+ self._iou_thresh = self._thresh_beta * self._iou_thresh + (1 - self._thresh_beta) * best_thresh
63
+ self._ema_iou = self._iou_beta * self._ema_iou + (1 - self._iou_beta) * max_iou
64
+ self._epoch_iou_sum += max_iou
65
+ self._epoch_batch_count += 1
66
+
67
+ def get_epoch_value(self):
68
+ if self._epoch_batch_count > 0:
69
+ return self._epoch_iou_sum / self._epoch_batch_count
70
+ else:
71
+ return 0.0
72
+
73
+ def reset_epoch_stats(self):
74
+ self._epoch_iou_sum = 0.0
75
+ self._epoch_batch_count = 0
76
+
77
+ def log_states(self, sw, tag_prefix, global_step):
78
+ sw.add_scalar(tag=tag_prefix + '_ema_iou', value=self._ema_iou, global_step=global_step)
79
+ sw.add_scalar(tag=tag_prefix + '_iou_thresh', value=self._iou_thresh, global_step=global_step)
80
+
81
+ @property
82
+ def iou_thresh(self):
83
+ return self._iou_thresh
84
+
85
+
86
+ def _compute_iou(pred_mask, gt_mask, ignore_mask=None, keep_ignore=False):
87
+ if ignore_mask is not None:
88
+ pred_mask = torch.where(ignore_mask, torch.zeros_like(pred_mask), pred_mask)
89
+
90
+ reduction_dims = misc.get_dims_with_exclusion(gt_mask.dim(), 0)
91
+ union = torch.mean((pred_mask | gt_mask).float(), dim=reduction_dims).detach().cpu().numpy()
92
+ intersection = torch.mean((pred_mask & gt_mask).float(), dim=reduction_dims).detach().cpu().numpy()
93
+ nonzero = union > 0
94
+
95
+ iou = intersection[nonzero] / union[nonzero]
96
+ if not keep_ignore:
97
+ return iou
98
+ else:
99
+ result = np.full_like(intersection, -1)
100
+ result[nonzero] = iou
101
+ return result
isegm/model/modeling/__init__.py ADDED
File without changes
isegm/model/modeling/basic_blocks.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ from isegm.model import ops
4
+
5
+
6
+ class ConvHead(nn.Module):
7
+ def __init__(self, out_channels, in_channels=32, num_layers=1,
8
+ kernel_size=3, padding=1,
9
+ norm_layer=nn.BatchNorm2d):
10
+ super(ConvHead, self).__init__()
11
+ convhead = []
12
+
13
+ for i in range(num_layers):
14
+ convhead.extend([
15
+ nn.Conv2d(in_channels, in_channels, kernel_size, padding=padding),
16
+ nn.ReLU(),
17
+ norm_layer(in_channels) if norm_layer is not None else nn.Identity()
18
+ ])
19
+ convhead.append(nn.Conv2d(in_channels, out_channels, 1, padding=0))
20
+
21
+ self.convhead = nn.Sequential(*convhead)
22
+
23
+ def forward(self, *inputs):
24
+ return self.convhead(inputs[0])
25
+
26
+
27
+ class SepConvHead(nn.Module):
28
+ def __init__(self, num_outputs, in_channels, mid_channels, num_layers=1,
29
+ kernel_size=3, padding=1, dropout_ratio=0.0, dropout_indx=0,
30
+ norm_layer=nn.BatchNorm2d):
31
+ super(SepConvHead, self).__init__()
32
+
33
+ sepconvhead = []
34
+
35
+ for i in range(num_layers):
36
+ sepconvhead.append(
37
+ SeparableConv2d(in_channels=in_channels if i == 0 else mid_channels,
38
+ out_channels=mid_channels,
39
+ dw_kernel=kernel_size, dw_padding=padding,
40
+ norm_layer=norm_layer, activation='relu')
41
+ )
42
+ if dropout_ratio > 0 and dropout_indx == i:
43
+ sepconvhead.append(nn.Dropout(dropout_ratio))
44
+
45
+ sepconvhead.append(
46
+ nn.Conv2d(in_channels=mid_channels, out_channels=num_outputs, kernel_size=1, padding=0)
47
+ )
48
+
49
+ self.layers = nn.Sequential(*sepconvhead)
50
+
51
+ def forward(self, *inputs):
52
+ x = inputs[0]
53
+
54
+ return self.layers(x)
55
+
56
+
57
+ class SeparableConv2d(nn.Module):
58
+ def __init__(self, in_channels, out_channels, dw_kernel, dw_padding, dw_stride=1,
59
+ activation=None, use_bias=False, norm_layer=None):
60
+ super(SeparableConv2d, self).__init__()
61
+ _activation = ops.select_activation_function(activation)
62
+ self.body = nn.Sequential(
63
+ nn.Conv2d(in_channels, in_channels, kernel_size=dw_kernel, stride=dw_stride,
64
+ padding=dw_padding, bias=use_bias, groups=in_channels),
65
+ nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=use_bias),
66
+ norm_layer(out_channels) if norm_layer is not None else nn.Identity(),
67
+ _activation()
68
+ )
69
+
70
+ def forward(self, x):
71
+ return self.body(x)
isegm/model/modeling/clip/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .clip import *
isegm/model/modeling/clip/clip.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import os
3
+ import urllib
4
+ import warnings
5
+ from typing import Any, Union, List
6
+ from pkg_resources import packaging
7
+
8
+ import torch
9
+ from PIL import Image
10
+ from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
11
+ from tqdm import tqdm
12
+
13
+ from .model import build_model
14
+ from .simple_tokenizer import SimpleTokenizer as _Tokenizer
15
+
16
+ try:
17
+ from torchvision.transforms import InterpolationMode
18
+ BICUBIC = InterpolationMode.BICUBIC
19
+ except ImportError:
20
+ BICUBIC = Image.BICUBIC
21
+
22
+
23
+ if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"):
24
+ warnings.warn("PyTorch version 1.7.1 or higher is recommended")
25
+
26
+
27
+ __all__ = ["available_models", "load", "tokenize"]
28
+ _tokenizer = _Tokenizer()
29
+
30
+ _MODELS = {
31
+ "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
32
+ "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
33
+ "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
34
+ "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
35
+ "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
36
+ "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
37
+ "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
38
+ "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
39
+ "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
40
+ }
41
+
42
+
43
+ def _download(url: str, root: str):
44
+ os.makedirs(root, exist_ok=True)
45
+ filename = os.path.basename(url)
46
+
47
+ expected_sha256 = url.split("/")[-2]
48
+ download_target = os.path.join(root, filename)
49
+
50
+ if os.path.exists(download_target) and not os.path.isfile(download_target):
51
+ raise RuntimeError(f"{download_target} exists and is not a regular file")
52
+
53
+ if os.path.isfile(download_target):
54
+ if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
55
+ return download_target
56
+ else:
57
+ warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
58
+
59
+ with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
60
+ with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
61
+ while True:
62
+ buffer = source.read(8192)
63
+ if not buffer:
64
+ break
65
+
66
+ output.write(buffer)
67
+ loop.update(len(buffer))
68
+
69
+ if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
70
+ raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match")
71
+
72
+ return download_target
73
+
74
+
75
+ def _convert_image_to_rgb(image):
76
+ return image.convert("RGB")
77
+
78
+
79
+ def _transform(n_px):
80
+ return Compose([
81
+ Resize(n_px, interpolation=BICUBIC),
82
+ CenterCrop(n_px),
83
+ _convert_image_to_rgb,
84
+ ToTensor(),
85
+ Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
86
+ ])
87
+
88
+
89
+ def available_models() -> List[str]:
90
+ """Returns the names of available CLIP models"""
91
+ return list(_MODELS.keys())
92
+
93
+
94
+ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None):
95
+ """Load a CLIP model
96
+
97
+ Parameters
98
+ ----------
99
+ name : str
100
+ A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
101
+
102
+ device : Union[str, torch.device]
103
+ The device to put the loaded model
104
+
105
+ jit : bool
106
+ Whether to load the optimized JIT model or more hackable non-JIT model (default).
107
+
108
+ download_root: str
109
+ path to download the model files; by default, it uses "~/.cache/clip"
110
+
111
+ Returns
112
+ -------
113
+ model : torch.nn.Module
114
+ The CLIP model
115
+
116
+ preprocess : Callable[[PIL.Image], torch.Tensor]
117
+ A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
118
+ """
119
+ if name in _MODELS:
120
+ model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
121
+ elif os.path.isfile(name):
122
+ model_path = name
123
+ else:
124
+ raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
125
+
126
+ with open(model_path, 'rb') as opened_file:
127
+ try:
128
+ # loading JIT archive
129
+ model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval()
130
+ state_dict = None
131
+ except RuntimeError:
132
+ # loading saved state dict
133
+ if jit:
134
+ warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
135
+ jit = False
136
+ state_dict = torch.load(opened_file, map_location="cpu")
137
+
138
+ if not jit:
139
+ model = build_model(state_dict or model.state_dict()).to(device)
140
+ if str(device) == "cpu":
141
+ model.float()
142
+ return model, _transform(model.visual.input_resolution)
143
+
144
+ # patch the device names
145
+ device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
146
+ device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
147
+
148
+ def _node_get(node: torch._C.Node, key: str):
149
+ """Gets attributes of a node which is polymorphic over return type.
150
+
151
+ From https://github.com/pytorch/pytorch/pull/82628
152
+ """
153
+ sel = node.kindOf(key)
154
+ return getattr(node, sel)(key)
155
+
156
+ def patch_device(module):
157
+ try:
158
+ graphs = [module.graph] if hasattr(module, "graph") else []
159
+ except RuntimeError:
160
+ graphs = []
161
+
162
+ if hasattr(module, "forward1"):
163
+ graphs.append(module.forward1.graph)
164
+
165
+ for graph in graphs:
166
+ for node in graph.findAllNodes("prim::Constant"):
167
+ if "value" in node.attributeNames() and str(_node_get(node, "value")).startswith("cuda"):
168
+ node.copyAttributes(device_node)
169
+
170
+ model.apply(patch_device)
171
+ patch_device(model.encode_image)
172
+ patch_device(model.encode_text)
173
+
174
+ # patch dtype to float32 on CPU
175
+ if str(device) == "cpu":
176
+ float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
177
+ float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
178
+ float_node = float_input.node()
179
+
180
+ def patch_float(module):
181
+ try:
182
+ graphs = [module.graph] if hasattr(module, "graph") else []
183
+ except RuntimeError:
184
+ graphs = []
185
+
186
+ if hasattr(module, "forward1"):
187
+ graphs.append(module.forward1.graph)
188
+
189
+ for graph in graphs:
190
+ for node in graph.findAllNodes("aten::to"):
191
+ inputs = list(node.inputs())
192
+ for i in [1, 2]: # dtype can be the second or third argument to aten::to()
193
+ if _node_get(inputs[i].node(), "value") == 5:
194
+ inputs[i].node().copyAttributes(float_node)
195
+
196
+ model.apply(patch_float)
197
+ patch_float(model.encode_image)
198
+ patch_float(model.encode_text)
199
+
200
+ model.float()
201
+
202
+ return model, _transform(model.input_resolution.item())
203
+
204
+
205
+ def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]:
206
+ """
207
+ Returns the tokenized representation of given input string(s)
208
+
209
+ Parameters
210
+ ----------
211
+ texts : Union[str, List[str]]
212
+ An input string or a list of input strings to tokenize
213
+
214
+ context_length : int
215
+ The context length to use; all CLIP models use 77 as the context length
216
+
217
+ truncate: bool
218
+ Whether to truncate the text in case its encoding is longer than the context length
219
+
220
+ Returns
221
+ -------
222
+ A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length].
223
+ We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long.
224
+ """
225
+ if isinstance(texts, str):
226
+ texts = [texts]
227
+
228
+ sot_token = _tokenizer.encoder["<|startoftext|>"]
229
+ eot_token = _tokenizer.encoder["<|endoftext|>"]
230
+ all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
231
+ if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"):
232
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
233
+ else:
234
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.int)
235
+
236
+ for i, tokens in enumerate(all_tokens):
237
+ if len(tokens) > context_length:
238
+ if truncate:
239
+ tokens = tokens[:context_length]
240
+ tokens[-1] = eot_token
241
+ else:
242
+ raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
243
+ result[i, :len(tokens)] = torch.tensor(tokens)
244
+
245
+ return result
isegm/model/modeling/clip/model.py ADDED
@@ -0,0 +1,436 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ from typing import Tuple, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import nn
8
+
9
+
10
+ class Bottleneck(nn.Module):
11
+ expansion = 4
12
+
13
+ def __init__(self, inplanes, planes, stride=1):
14
+ super().__init__()
15
+
16
+ # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
17
+ self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
18
+ self.bn1 = nn.BatchNorm2d(planes)
19
+ self.relu1 = nn.ReLU(inplace=True)
20
+
21
+ self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
22
+ self.bn2 = nn.BatchNorm2d(planes)
23
+ self.relu2 = nn.ReLU(inplace=True)
24
+
25
+ self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
26
+
27
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
28
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
29
+ self.relu3 = nn.ReLU(inplace=True)
30
+
31
+ self.downsample = None
32
+ self.stride = stride
33
+
34
+ if stride > 1 or inplanes != planes * Bottleneck.expansion:
35
+ # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
36
+ self.downsample = nn.Sequential(OrderedDict([
37
+ ("-1", nn.AvgPool2d(stride)),
38
+ ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
39
+ ("1", nn.BatchNorm2d(planes * self.expansion))
40
+ ]))
41
+
42
+ def forward(self, x: torch.Tensor):
43
+ identity = x
44
+
45
+ out = self.relu1(self.bn1(self.conv1(x)))
46
+ out = self.relu2(self.bn2(self.conv2(out)))
47
+ out = self.avgpool(out)
48
+ out = self.bn3(self.conv3(out))
49
+
50
+ if self.downsample is not None:
51
+ identity = self.downsample(x)
52
+
53
+ out += identity
54
+ out = self.relu3(out)
55
+ return out
56
+
57
+
58
+ class AttentionPool2d(nn.Module):
59
+ def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
60
+ super().__init__()
61
+ self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
62
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
63
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
64
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
65
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
66
+ self.num_heads = num_heads
67
+
68
+ def forward(self, x):
69
+ x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC
70
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
71
+ x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
72
+ x, _ = F.multi_head_attention_forward(
73
+ query=x[:1], key=x, value=x,
74
+ embed_dim_to_check=x.shape[-1],
75
+ num_heads=self.num_heads,
76
+ q_proj_weight=self.q_proj.weight,
77
+ k_proj_weight=self.k_proj.weight,
78
+ v_proj_weight=self.v_proj.weight,
79
+ in_proj_weight=None,
80
+ in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
81
+ bias_k=None,
82
+ bias_v=None,
83
+ add_zero_attn=False,
84
+ dropout_p=0,
85
+ out_proj_weight=self.c_proj.weight,
86
+ out_proj_bias=self.c_proj.bias,
87
+ use_separate_proj_weight=True,
88
+ training=self.training,
89
+ need_weights=False
90
+ )
91
+ return x.squeeze(0)
92
+
93
+
94
+ class ModifiedResNet(nn.Module):
95
+ """
96
+ A ResNet class that is similar to torchvision's but contains the following changes:
97
+ - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
98
+ - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
99
+ - The final pooling layer is a QKV attention instead of an average pool
100
+ """
101
+
102
+ def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
103
+ super().__init__()
104
+ self.output_dim = output_dim
105
+ self.input_resolution = input_resolution
106
+
107
+ # the 3-layer stem
108
+ self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
109
+ self.bn1 = nn.BatchNorm2d(width // 2)
110
+ self.relu1 = nn.ReLU(inplace=True)
111
+ self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
112
+ self.bn2 = nn.BatchNorm2d(width // 2)
113
+ self.relu2 = nn.ReLU(inplace=True)
114
+ self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
115
+ self.bn3 = nn.BatchNorm2d(width)
116
+ self.relu3 = nn.ReLU(inplace=True)
117
+ self.avgpool = nn.AvgPool2d(2)
118
+
119
+ # residual layers
120
+ self._inplanes = width # this is a *mutable* variable used during construction
121
+ self.layer1 = self._make_layer(width, layers[0])
122
+ self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
123
+ self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
124
+ self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
125
+
126
+ embed_dim = width * 32 # the ResNet feature dimension
127
+ self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
128
+
129
+ def _make_layer(self, planes, blocks, stride=1):
130
+ layers = [Bottleneck(self._inplanes, planes, stride)]
131
+
132
+ self._inplanes = planes * Bottleneck.expansion
133
+ for _ in range(1, blocks):
134
+ layers.append(Bottleneck(self._inplanes, planes))
135
+
136
+ return nn.Sequential(*layers)
137
+
138
+ def forward(self, x):
139
+ def stem(x):
140
+ x = self.relu1(self.bn1(self.conv1(x)))
141
+ x = self.relu2(self.bn2(self.conv2(x)))
142
+ x = self.relu3(self.bn3(self.conv3(x)))
143
+ x = self.avgpool(x)
144
+ return x
145
+
146
+ x = x.type(self.conv1.weight.dtype)
147
+ x = stem(x)
148
+ x = self.layer1(x)
149
+ x = self.layer2(x)
150
+ x = self.layer3(x)
151
+ x = self.layer4(x)
152
+ x = self.attnpool(x)
153
+
154
+ return x
155
+
156
+
157
+ class LayerNorm(nn.LayerNorm):
158
+ """Subclass torch's LayerNorm to handle fp16."""
159
+
160
+ def forward(self, x: torch.Tensor):
161
+ orig_type = x.dtype
162
+ ret = super().forward(x.type(torch.float32))
163
+ return ret.type(orig_type)
164
+
165
+
166
+ class QuickGELU(nn.Module):
167
+ def forward(self, x: torch.Tensor):
168
+ return x * torch.sigmoid(1.702 * x)
169
+
170
+
171
+ class ResidualAttentionBlock(nn.Module):
172
+ def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
173
+ super().__init__()
174
+
175
+ self.attn = nn.MultiheadAttention(d_model, n_head)
176
+ self.ln_1 = LayerNorm(d_model)
177
+ self.mlp = nn.Sequential(OrderedDict([
178
+ ("c_fc", nn.Linear(d_model, d_model * 4)),
179
+ ("gelu", QuickGELU()),
180
+ ("c_proj", nn.Linear(d_model * 4, d_model))
181
+ ]))
182
+ self.ln_2 = LayerNorm(d_model)
183
+ self.attn_mask = attn_mask
184
+
185
+ def attention(self, x: torch.Tensor):
186
+ self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
187
+ return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
188
+
189
+ def forward(self, x: torch.Tensor):
190
+ x = x + self.attention(self.ln_1(x))
191
+ x = x + self.mlp(self.ln_2(x))
192
+ return x
193
+
194
+
195
+ class Transformer(nn.Module):
196
+ def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
197
+ super().__init__()
198
+ self.width = width
199
+ self.layers = layers
200
+ self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
201
+
202
+ def forward(self, x: torch.Tensor):
203
+ return self.resblocks(x)
204
+
205
+
206
+ class VisionTransformer(nn.Module):
207
+ def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
208
+ super().__init__()
209
+ self.input_resolution = input_resolution
210
+ self.output_dim = output_dim
211
+ self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
212
+
213
+ scale = width ** -0.5
214
+ self.class_embedding = nn.Parameter(scale * torch.randn(width))
215
+ self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
216
+ self.ln_pre = LayerNorm(width)
217
+
218
+ self.transformer = Transformer(width, layers, heads)
219
+
220
+ self.ln_post = LayerNorm(width)
221
+ self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
222
+
223
+ def forward(self, x: torch.Tensor):
224
+ x = self.conv1(x) # shape = [*, width, grid, grid]
225
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
226
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
227
+ x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
228
+ x = x + self.positional_embedding.to(x.dtype)
229
+ x = self.ln_pre(x)
230
+
231
+ x = x.permute(1, 0, 2) # NLD -> LND
232
+ x = self.transformer(x)
233
+ x = x.permute(1, 0, 2) # LND -> NLD
234
+
235
+ x = self.ln_post(x[:, 0, :])
236
+
237
+ if self.proj is not None:
238
+ x = x @ self.proj
239
+
240
+ return x
241
+
242
+
243
+ class CLIP(nn.Module):
244
+ def __init__(self,
245
+ embed_dim: int,
246
+ # vision
247
+ image_resolution: int,
248
+ vision_layers: Union[Tuple[int, int, int, int], int],
249
+ vision_width: int,
250
+ vision_patch_size: int,
251
+ # text
252
+ context_length: int,
253
+ vocab_size: int,
254
+ transformer_width: int,
255
+ transformer_heads: int,
256
+ transformer_layers: int
257
+ ):
258
+ super().__init__()
259
+
260
+ self.context_length = context_length
261
+
262
+ if isinstance(vision_layers, (tuple, list)):
263
+ vision_heads = vision_width * 32 // 64
264
+ self.visual = ModifiedResNet(
265
+ layers=vision_layers,
266
+ output_dim=embed_dim,
267
+ heads=vision_heads,
268
+ input_resolution=image_resolution,
269
+ width=vision_width
270
+ )
271
+ else:
272
+ vision_heads = vision_width // 64
273
+ self.visual = VisionTransformer(
274
+ input_resolution=image_resolution,
275
+ patch_size=vision_patch_size,
276
+ width=vision_width,
277
+ layers=vision_layers,
278
+ heads=vision_heads,
279
+ output_dim=embed_dim
280
+ )
281
+
282
+ self.transformer = Transformer(
283
+ width=transformer_width,
284
+ layers=transformer_layers,
285
+ heads=transformer_heads,
286
+ attn_mask=self.build_attention_mask()
287
+ )
288
+
289
+ self.vocab_size = vocab_size
290
+ self.token_embedding = nn.Embedding(vocab_size, transformer_width)
291
+ self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
292
+ self.ln_final = LayerNorm(transformer_width)
293
+
294
+ self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
295
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
296
+
297
+ self.initialize_parameters()
298
+
299
+ def initialize_parameters(self):
300
+ nn.init.normal_(self.token_embedding.weight, std=0.02)
301
+ nn.init.normal_(self.positional_embedding, std=0.01)
302
+
303
+ if isinstance(self.visual, ModifiedResNet):
304
+ if self.visual.attnpool is not None:
305
+ std = self.visual.attnpool.c_proj.in_features ** -0.5
306
+ nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
307
+ nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
308
+ nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
309
+ nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
310
+
311
+ for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
312
+ for name, param in resnet_block.named_parameters():
313
+ if name.endswith("bn3.weight"):
314
+ nn.init.zeros_(param)
315
+
316
+ proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
317
+ attn_std = self.transformer.width ** -0.5
318
+ fc_std = (2 * self.transformer.width) ** -0.5
319
+ for block in self.transformer.resblocks:
320
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
321
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
322
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
323
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
324
+
325
+ if self.text_projection is not None:
326
+ nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
327
+
328
+ def build_attention_mask(self):
329
+ # lazily create causal attention mask, with full attention between the vision tokens
330
+ # pytorch uses additive attention mask; fill with -inf
331
+ mask = torch.empty(self.context_length, self.context_length)
332
+ mask.fill_(float("-inf"))
333
+ mask.triu_(1) # zero out the lower diagonal
334
+ return mask
335
+
336
+ @property
337
+ def dtype(self):
338
+ return self.visual.conv1.weight.dtype
339
+
340
+ def encode_image(self, image):
341
+ return self.visual(image.type(self.dtype))
342
+
343
+ def encode_text(self, text):
344
+ x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
345
+
346
+ x = x + self.positional_embedding.type(self.dtype)
347
+ x = x.permute(1, 0, 2) # NLD -> LND
348
+ x = self.transformer(x)
349
+ x = x.permute(1, 0, 2) # LND -> NLD
350
+ x = self.ln_final(x).type(self.dtype)
351
+
352
+ # x.shape = [batch_size, n_ctx, transformer.width]
353
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
354
+ # x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
355
+
356
+ return x
357
+
358
+ def forward(self, image, text):
359
+ image_features = self.encode_image(image)
360
+ text_features = self.encode_text(text)
361
+
362
+ # normalized features
363
+ image_features = image_features / image_features.norm(dim=1, keepdim=True)
364
+ text_features = text_features / text_features.norm(dim=1, keepdim=True)
365
+
366
+ # cosine similarity as logits
367
+ logit_scale = self.logit_scale.exp()
368
+ logits_per_image = logit_scale * image_features @ text_features.t()
369
+ logits_per_text = logits_per_image.t()
370
+
371
+ # shape = [global_batch_size, global_batch_size]
372
+ return logits_per_image, logits_per_text
373
+
374
+
375
+ def convert_weights(model: nn.Module):
376
+ """Convert applicable model parameters to fp16"""
377
+
378
+ def _convert_weights_to_fp16(l):
379
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
380
+ l.weight.data = l.weight.data.half()
381
+ if l.bias is not None:
382
+ l.bias.data = l.bias.data.half()
383
+
384
+ if isinstance(l, nn.MultiheadAttention):
385
+ for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
386
+ tensor = getattr(l, attr)
387
+ if tensor is not None:
388
+ tensor.data = tensor.data.half()
389
+
390
+ for name in ["text_projection", "proj"]:
391
+ if hasattr(l, name):
392
+ attr = getattr(l, name)
393
+ if attr is not None:
394
+ attr.data = attr.data.half()
395
+
396
+ model.apply(_convert_weights_to_fp16)
397
+
398
+
399
+ def build_model(state_dict: dict):
400
+ vit = "visual.proj" in state_dict
401
+
402
+ if vit:
403
+ vision_width = state_dict["visual.conv1.weight"].shape[0]
404
+ vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
405
+ vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
406
+ grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
407
+ image_resolution = vision_patch_size * grid_size
408
+ else:
409
+ counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
410
+ vision_layers = tuple(counts)
411
+ vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
412
+ output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
413
+ vision_patch_size = None
414
+ assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
415
+ image_resolution = output_width * 32
416
+
417
+ embed_dim = state_dict["text_projection"].shape[1]
418
+ context_length = state_dict["positional_embedding"].shape[0]
419
+ vocab_size = state_dict["token_embedding.weight"].shape[0]
420
+ transformer_width = state_dict["ln_final.weight"].shape[0]
421
+ transformer_heads = transformer_width // 64
422
+ transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks")))
423
+
424
+ model = CLIP(
425
+ embed_dim,
426
+ image_resolution, vision_layers, vision_width, vision_patch_size,
427
+ context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
428
+ )
429
+
430
+ for key in ["input_resolution", "context_length", "vocab_size"]:
431
+ if key in state_dict:
432
+ del state_dict[key]
433
+
434
+ convert_weights(model)
435
+ model.load_state_dict(state_dict)
436
+ return model.eval()
isegm/model/modeling/clip/simple_tokenizer.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gzip
2
+ import html
3
+ import os
4
+ from functools import lru_cache
5
+
6
+ import ftfy
7
+ import regex as re
8
+
9
+
10
+ @lru_cache()
11
+ def default_bpe():
12
+ return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
13
+
14
+
15
+ @lru_cache()
16
+ def bytes_to_unicode():
17
+ """
18
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
19
+ The reversible bpe codes work on unicode strings.
20
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
21
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
22
+ This is a signficant percentage of your normal, say, 32K bpe vocab.
23
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
24
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
25
+ """
26
+ bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
27
+ cs = bs[:]
28
+ n = 0
29
+ for b in range(2**8):
30
+ if b not in bs:
31
+ bs.append(b)
32
+ cs.append(2**8+n)
33
+ n += 1
34
+ cs = [chr(n) for n in cs]
35
+ return dict(zip(bs, cs))
36
+
37
+
38
+ def get_pairs(word):
39
+ """Return set of symbol pairs in a word.
40
+ Word is represented as tuple of symbols (symbols being variable-length strings).
41
+ """
42
+ pairs = set()
43
+ prev_char = word[0]
44
+ for char in word[1:]:
45
+ pairs.add((prev_char, char))
46
+ prev_char = char
47
+ return pairs
48
+
49
+
50
+ def basic_clean(text):
51
+ text = ftfy.fix_text(text)
52
+ text = html.unescape(html.unescape(text))
53
+ return text.strip()
54
+
55
+
56
+ def whitespace_clean(text):
57
+ text = re.sub(r'\s+', ' ', text)
58
+ text = text.strip()
59
+ return text
60
+
61
+
62
+ class SimpleTokenizer(object):
63
+ def __init__(self, bpe_path: str = default_bpe()):
64
+ self.byte_encoder = bytes_to_unicode()
65
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
66
+ merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
67
+ merges = merges[1:49152-256-2+1]
68
+ merges = [tuple(merge.split()) for merge in merges]
69
+ vocab = list(bytes_to_unicode().values())
70
+ vocab = vocab + [v+'</w>' for v in vocab]
71
+ for merge in merges:
72
+ vocab.append(''.join(merge))
73
+ vocab.extend(['<|startoftext|>', '<|endoftext|>'])
74
+ self.encoder = dict(zip(vocab, range(len(vocab))))
75
+ self.decoder = {v: k for k, v in self.encoder.items()}
76
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
77
+ self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
78
+ self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
79
+
80
+ def bpe(self, token):
81
+ if token in self.cache:
82
+ return self.cache[token]
83
+ word = tuple(token[:-1]) + ( token[-1] + '</w>',)
84
+ pairs = get_pairs(word)
85
+
86
+ if not pairs:
87
+ return token+'</w>'
88
+
89
+ while True:
90
+ bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
91
+ if bigram not in self.bpe_ranks:
92
+ break
93
+ first, second = bigram
94
+ new_word = []
95
+ i = 0
96
+ while i < len(word):
97
+ try:
98
+ j = word.index(first, i)
99
+ new_word.extend(word[i:j])
100
+ i = j
101
+ except:
102
+ new_word.extend(word[i:])
103
+ break
104
+
105
+ if word[i] == first and i < len(word)-1 and word[i+1] == second:
106
+ new_word.append(first+second)
107
+ i += 2
108
+ else:
109
+ new_word.append(word[i])
110
+ i += 1
111
+ new_word = tuple(new_word)
112
+ word = new_word
113
+ if len(word) == 1:
114
+ break
115
+ else:
116
+ pairs = get_pairs(word)
117
+ word = ' '.join(word)
118
+ self.cache[token] = word
119
+ return word
120
+
121
+ def encode(self, text):
122
+ bpe_tokens = []
123
+ text = whitespace_clean(basic_clean(text)).lower()
124
+ for token in re.findall(self.pat, text):
125
+ token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
126
+ bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
127
+ return bpe_tokens
128
+
129
+ def decode(self, tokens):
130
+ text = ''.join([self.decoder[token] for token in tokens])
131
+ text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
132
+ return text
isegm/model/modeling/clip_text_encoding.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from .clip import clip
4
+
5
+ class ClipTextEncoder(nn.Module):
6
+ def __init__(self, clip_enocder_name="ViT-B/32", embedding_dim=512, out_dim=768):
7
+ super().__init__()
8
+ assert clip_enocder_name in ['RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64', 'ViT-B/32', 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px']
9
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
10
+ self.model, self.preprocess = clip.load(clip_enocder_name, device=self.device)
11
+
12
+ # freeze model
13
+ for _, param in self.model.named_parameters():
14
+ param.requires_grad = False
15
+ self.out_proj = nn.Linear(embedding_dim, out_dim)
16
+ nn.init.zeros_(self.out_proj.bias)
17
+
18
+ @torch.no_grad()
19
+ def forward(self, prompt):
20
+ '''
21
+ prompt: text tokens
22
+ '''
23
+ text_features = self.model.encode_text(prompt).type(torch.float32)
24
+ # norm
25
+ # text_features /= text_features.norm(dim=-1, keepdim=True) # [bs, 1024]
26
+ # proj
27
+ text_features = self.out_proj(text_features)
28
+ return text_features
29
+
isegm/model/modeling/deeplab_v3.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from contextlib import ExitStack
2
+
3
+ import torch
4
+ from torch import nn
5
+ import torch.nn.functional as F
6
+
7
+ from .basic_blocks import SeparableConv2d
8
+ from .resnet import ResNetBackbone
9
+ from isegm.model import ops
10
+
11
+
12
+ class DeepLabV3Plus(nn.Module):
13
+ def __init__(self, backbone='resnet50', norm_layer=nn.BatchNorm2d,
14
+ backbone_norm_layer=None,
15
+ ch=256,
16
+ project_dropout=0.5,
17
+ inference_mode=False,
18
+ **kwargs):
19
+ super(DeepLabV3Plus, self).__init__()
20
+ if backbone_norm_layer is None:
21
+ backbone_norm_layer = norm_layer
22
+
23
+ self.backbone_name = backbone
24
+ self.norm_layer = norm_layer
25
+ self.backbone_norm_layer = backbone_norm_layer
26
+ self.inference_mode = False
27
+ self.ch = ch
28
+ self.aspp_in_channels = 2048
29
+ self.skip_project_in_channels = 256 # layer 1 out_channels
30
+
31
+ self._kwargs = kwargs
32
+ if backbone == 'resnet34':
33
+ self.aspp_in_channels = 512
34
+ self.skip_project_in_channels = 64
35
+
36
+ self.backbone = ResNetBackbone(backbone=self.backbone_name, pretrained_base=False,
37
+ norm_layer=self.backbone_norm_layer, **kwargs)
38
+
39
+ self.head = _DeepLabHead(in_channels=ch + 32, mid_channels=ch, out_channels=ch,
40
+ norm_layer=self.norm_layer)
41
+ self.skip_project = _SkipProject(self.skip_project_in_channels, 32, norm_layer=self.norm_layer)
42
+ self.aspp = _ASPP(in_channels=self.aspp_in_channels,
43
+ atrous_rates=[12, 24, 36],
44
+ out_channels=ch,
45
+ project_dropout=project_dropout,
46
+ norm_layer=self.norm_layer)
47
+
48
+ if inference_mode:
49
+ self.set_prediction_mode()
50
+
51
+ def load_pretrained_weights(self):
52
+ pretrained = ResNetBackbone(backbone=self.backbone_name, pretrained_base=True,
53
+ norm_layer=self.backbone_norm_layer, **self._kwargs)
54
+ backbone_state_dict = self.backbone.state_dict()
55
+ pretrained_state_dict = pretrained.state_dict()
56
+
57
+ backbone_state_dict.update(pretrained_state_dict)
58
+ self.backbone.load_state_dict(backbone_state_dict)
59
+
60
+ if self.inference_mode:
61
+ for param in self.backbone.parameters():
62
+ param.requires_grad = False
63
+
64
+ def set_prediction_mode(self):
65
+ self.inference_mode = True
66
+ self.eval()
67
+
68
+ def forward(self, x, additional_features=None):
69
+ with ExitStack() as stack:
70
+ if self.inference_mode:
71
+ stack.enter_context(torch.no_grad())
72
+
73
+ c1, _, c3, c4 = self.backbone(x, additional_features)
74
+ c1 = self.skip_project(c1)
75
+
76
+ x = self.aspp(c4)
77
+ x = F.interpolate(x, c1.size()[2:], mode='bilinear', align_corners=True)
78
+ x = torch.cat((x, c1), dim=1)
79
+ x = self.head(x)
80
+
81
+ return x,
82
+
83
+
84
+ class _SkipProject(nn.Module):
85
+ def __init__(self, in_channels, out_channels, norm_layer=nn.BatchNorm2d):
86
+ super(_SkipProject, self).__init__()
87
+ _activation = ops.select_activation_function("relu")
88
+
89
+ self.skip_project = nn.Sequential(
90
+ nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
91
+ norm_layer(out_channels),
92
+ _activation()
93
+ )
94
+
95
+ def forward(self, x):
96
+ return self.skip_project(x)
97
+
98
+
99
+ class _DeepLabHead(nn.Module):
100
+ def __init__(self, out_channels, in_channels, mid_channels=256, norm_layer=nn.BatchNorm2d):
101
+ super(_DeepLabHead, self).__init__()
102
+
103
+ self.block = nn.Sequential(
104
+ SeparableConv2d(in_channels=in_channels, out_channels=mid_channels, dw_kernel=3,
105
+ dw_padding=1, activation='relu', norm_layer=norm_layer),
106
+ SeparableConv2d(in_channels=mid_channels, out_channels=mid_channels, dw_kernel=3,
107
+ dw_padding=1, activation='relu', norm_layer=norm_layer),
108
+ nn.Conv2d(in_channels=mid_channels, out_channels=out_channels, kernel_size=1)
109
+ )
110
+
111
+ def forward(self, x):
112
+ return self.block(x)
113
+
114
+
115
+ class _ASPP(nn.Module):
116
+ def __init__(self, in_channels, atrous_rates, out_channels=256,
117
+ project_dropout=0.5, norm_layer=nn.BatchNorm2d):
118
+ super(_ASPP, self).__init__()
119
+
120
+ b0 = nn.Sequential(
121
+ nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=False),
122
+ norm_layer(out_channels),
123
+ nn.ReLU()
124
+ )
125
+
126
+ rate1, rate2, rate3 = tuple(atrous_rates)
127
+ b1 = _ASPPConv(in_channels, out_channels, rate1, norm_layer)
128
+ b2 = _ASPPConv(in_channels, out_channels, rate2, norm_layer)
129
+ b3 = _ASPPConv(in_channels, out_channels, rate3, norm_layer)
130
+ b4 = _AsppPooling(in_channels, out_channels, norm_layer=norm_layer)
131
+
132
+ self.concurent = nn.ModuleList([b0, b1, b2, b3, b4])
133
+
134
+ project = [
135
+ nn.Conv2d(in_channels=5*out_channels, out_channels=out_channels,
136
+ kernel_size=1, bias=False),
137
+ norm_layer(out_channels),
138
+ nn.ReLU()
139
+ ]
140
+ if project_dropout > 0:
141
+ project.append(nn.Dropout(project_dropout))
142
+ self.project = nn.Sequential(*project)
143
+
144
+ def forward(self, x):
145
+ x = torch.cat([block(x) for block in self.concurent], dim=1)
146
+
147
+ return self.project(x)
148
+
149
+
150
+ class _AsppPooling(nn.Module):
151
+ def __init__(self, in_channels, out_channels, norm_layer):
152
+ super(_AsppPooling, self).__init__()
153
+
154
+ self.gap = nn.Sequential(
155
+ nn.AdaptiveAvgPool2d((1, 1)),
156
+ nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
157
+ kernel_size=1, bias=False),
158
+ norm_layer(out_channels),
159
+ nn.ReLU()
160
+ )
161
+
162
+ def forward(self, x):
163
+ pool = self.gap(x)
164
+ return F.interpolate(pool, x.size()[2:], mode='bilinear', align_corners=True)
165
+
166
+
167
+ def _ASPPConv(in_channels, out_channels, atrous_rate, norm_layer):
168
+ block = nn.Sequential(
169
+ nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
170
+ kernel_size=3, padding=atrous_rate,
171
+ dilation=atrous_rate, bias=False),
172
+ norm_layer(out_channels),
173
+ nn.ReLU()
174
+ )
175
+
176
+ return block
isegm/model/modeling/hrformer.py ADDED
@@ -0,0 +1,487 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2
+ ## Created by: RainbowSecret
3
+ ## Microsoft Research
4
5
+ ## Copyright (c) 2021
6
+ ##
7
+ ## This source code is licensed under the MIT-style license found in the
8
+ ## LICENSE file in the root directory of this source tree
9
+ ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
10
+
11
+ import os
12
+ import math
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+
17
+ # from .hrformer_helper.backbone_selector import BackboneSelector
18
+ from .hrformer_helper.hrt.module_helper import ModuleHelper
19
+ from .hrformer_helper.hrt.modules.spatial_ocr_block import SpatialGather_Module, SpatialOCR_Module
20
+
21
+ from .hrformer_helper.hrt.logger import Logger as Log
22
+ from .hrformer_helper.hrt.hrt_backbone import HRTBackbone, HRTBackbone_v2
23
+
24
+
25
+ class BackboneSelector(object):
26
+ def __init__(self, configer):
27
+ self.configer = configer
28
+
29
+ def get_backbone(self, **params):
30
+ backbone = self.configer.get("network", "backbone")
31
+
32
+ model = None
33
+ # if (
34
+ # "resnet" in backbone or "resnext" in backbone or "resnest" in backbone
35
+ # ) and "senet" not in backbone:
36
+ # model = ResNetBackbone(self.configer)(**params)
37
+
38
+ if "hrt" in backbone:
39
+ model = HRTBackbone(self.configer)(**params)
40
+ pass
41
+
42
+ # elif "hrnet" in backbone:
43
+ # model = HRNetBackbone(self.configer)(**params)
44
+
45
+ # elif "swin" in backbone:
46
+ # model = SwinTransformerBackbone(self.configer)(**params)
47
+
48
+ else:
49
+ Log.error("Backbone {} is invalid.".format(backbone))
50
+ exit(1)
51
+
52
+ return model
53
+
54
+
55
+ class HRT_B_OCR_V3(nn.Module):
56
+ def __init__(self, num_classes, in_ch=3, backbone='hrt_base', bn_type="torchbn", pretrained=None):
57
+ super(HRT_B_OCR_V3, self).__init__()
58
+ self.num_classes = num_classes
59
+ self.bn_type = bn_type
60
+ self.backbone = HRTBackbone_v2(backbone, pretrained, in_ch)()
61
+
62
+ in_channels = 1170
63
+ hidden_dim = 512
64
+ group_channel = math.gcd(in_channels, hidden_dim)
65
+ self.conv3x3 = nn.Sequential(
66
+ nn.Conv2d(
67
+ in_channels,
68
+ hidden_dim,
69
+ kernel_size=7,
70
+ stride=1,
71
+ padding=3,
72
+ groups=group_channel,
73
+ ),
74
+ ModuleHelper.BNReLU(
75
+ hidden_dim, bn_type=self.bn_type
76
+ ),
77
+ )
78
+ self.ocr_gather_head = SpatialGather_Module(self.num_classes)
79
+ self.ocr_distri_head = SpatialOCR_Module(
80
+ in_channels=hidden_dim,
81
+ key_channels=hidden_dim // 2,
82
+ out_channels=hidden_dim,
83
+ scale=1,
84
+ dropout=0.05,
85
+ bn_type=self.bn_type,
86
+ )
87
+ self.cls_head = nn.Conv2d(
88
+ hidden_dim, self.num_classes, kernel_size=1, stride=1, padding=0, bias=True
89
+ )
90
+ self.aux_head = nn.Sequential(
91
+ nn.Conv2d(
92
+ in_channels,
93
+ hidden_dim,
94
+ kernel_size=7,
95
+ stride=1,
96
+ padding=3,
97
+ groups=group_channel,
98
+ ),
99
+ ModuleHelper.BNReLU(
100
+ hidden_dim, bn_type=self.bn_type
101
+ ),
102
+ nn.Conv2d(
103
+ hidden_dim,
104
+ self.num_classes,
105
+ kernel_size=1,
106
+ stride=1,
107
+ padding=0,
108
+ bias=True,
109
+ ),
110
+ )
111
+
112
+ def forward(self, x_):
113
+ x = self.backbone(x_)
114
+ _, _, h, w = x[0].size()
115
+
116
+ feat1 = x[0]
117
+ feat2 = F.interpolate(x[1], size=(h, w), mode="bilinear", align_corners=True)
118
+ feat3 = F.interpolate(x[2], size=(h, w), mode="bilinear", align_corners=True)
119
+ feat4 = F.interpolate(x[3], size=(h, w), mode="bilinear", align_corners=True)
120
+
121
+ feats = torch.cat([feat1, feat2, feat3, feat4], 1)
122
+ out_aux = self.aux_head(feats)
123
+
124
+ feats = self.conv3x3(feats)
125
+
126
+ context = self.ocr_gather_head(feats, out_aux)
127
+ feats = self.ocr_distri_head(feats, context)
128
+
129
+ out = self.cls_head(feats)
130
+
131
+ out_aux = F.interpolate(
132
+ out_aux, size=(x_.size(2), x_.size(3)), mode="bilinear", align_corners=True
133
+ )
134
+ out = F.interpolate(
135
+ out, size=(x_.size(2), x_.size(3)), mode="bilinear", align_corners=True
136
+ )
137
+ return out_aux, out
138
+
139
+
140
+ class HRT_S_OCR_V2(nn.Module):
141
+ def __init__(self, num_classes, backbone='hrt_small', bn_type="torchbn", pretrained=None):
142
+ super(HRT_S_OCR_V2, self).__init__()
143
+ self.num_classes = num_classes
144
+ self.bn_type = bn_type
145
+ self.backbone = HRTBackbone_v2(backbone, pretrained)()
146
+
147
+ in_channels = 480
148
+ self.conv3x3 = nn.Sequential(
149
+ nn.Conv2d(in_channels, 512, kernel_size=3, stride=1, padding=1),
150
+ ModuleHelper.BNReLU(512, bn_type=self.bn_type),
151
+ )
152
+ self.ocr_gather_head = SpatialGather_Module(self.num_classes)
153
+ self.ocr_distri_head = SpatialOCR_Module(
154
+ in_channels=512,
155
+ key_channels=256,
156
+ out_channels=512,
157
+ scale=1,
158
+ dropout=0.05,
159
+ bn_type=self.bn_type,
160
+ )
161
+ self.cls_head = nn.Conv2d(
162
+ 512, self.num_classes, kernel_size=1, stride=1, padding=0, bias=True
163
+ )
164
+ self.aux_head = nn.Sequential(
165
+ nn.Conv2d(in_channels, 512, kernel_size=3, stride=1, padding=1),
166
+ ModuleHelper.BNReLU(512, bn_type=self.bn_type),
167
+ nn.Conv2d(
168
+ 512, self.num_classes, kernel_size=1, stride=1, padding=0, bias=True
169
+ ),
170
+ )
171
+
172
+ def forward(self, x_):
173
+ x = self.backbone(x_)
174
+ _, _, h, w = x[0].size()
175
+
176
+ feat1 = x[0]
177
+ feat2 = F.interpolate(x[1], size=(h, w), mode="bilinear", align_corners=True)
178
+ feat3 = F.interpolate(x[2], size=(h, w), mode="bilinear", align_corners=True)
179
+ feat4 = F.interpolate(x[3], size=(h, w), mode="bilinear", align_corners=True)
180
+
181
+ feats = torch.cat([feat1, feat2, feat3, feat4], 1)
182
+ out_aux = self.aux_head(feats)
183
+
184
+ feats = self.conv3x3(feats)
185
+
186
+ context = self.ocr_gather_head(feats, out_aux)
187
+ feats = self.ocr_distri_head(feats, context)
188
+
189
+ out = self.cls_head(feats)
190
+
191
+ out_aux = F.interpolate(
192
+ out_aux, size=(x_.size(2), x_.size(3)), mode="bilinear", align_corners=True
193
+ )
194
+ out = F.interpolate(
195
+ out, size=(x_.size(2), x_.size(3)), mode="bilinear", align_corners=True
196
+ )
197
+ return out_aux, out
198
+
199
+
200
+ class HRT_SMALL_OCR_V2(nn.Module):
201
+ def __init__(self, configer):
202
+ super(HRT_SMALL_OCR_V2, self).__init__()
203
+ self.configer = configer
204
+ self.num_classes = self.configer.get("data", "num_classes")
205
+ self.backbone = BackboneSelector(configer).get_backbone()
206
+
207
+ in_channels = 480
208
+ self.conv3x3 = nn.Sequential(
209
+ nn.Conv2d(in_channels, 512, kernel_size=3, stride=1, padding=1),
210
+ ModuleHelper.BNReLU(512, bn_type=self.configer.get("network", "bn_type")),
211
+ )
212
+ self.ocr_gather_head = SpatialGather_Module(self.num_classes)
213
+ self.ocr_distri_head = SpatialOCR_Module(
214
+ in_channels=512,
215
+ key_channels=256,
216
+ out_channels=512,
217
+ scale=1,
218
+ dropout=0.05,
219
+ bn_type=self.configer.get("network", "bn_type"),
220
+ )
221
+ self.cls_head = nn.Conv2d(
222
+ 512, self.num_classes, kernel_size=1, stride=1, padding=0, bias=True
223
+ )
224
+ self.aux_head = nn.Sequential(
225
+ nn.Conv2d(in_channels, 512, kernel_size=3, stride=1, padding=1),
226
+ ModuleHelper.BNReLU(512, bn_type=self.configer.get("network", "bn_type")),
227
+ nn.Conv2d(
228
+ 512, self.num_classes, kernel_size=1, stride=1, padding=0, bias=True
229
+ ),
230
+ )
231
+
232
+ def forward(self, x_):
233
+ x = self.backbone(x_)
234
+ _, _, h, w = x[0].size()
235
+
236
+ feat1 = x[0]
237
+ feat2 = F.interpolate(x[1], size=(h, w), mode="bilinear", align_corners=True)
238
+ feat3 = F.interpolate(x[2], size=(h, w), mode="bilinear", align_corners=True)
239
+ feat4 = F.interpolate(x[3], size=(h, w), mode="bilinear", align_corners=True)
240
+
241
+ feats = torch.cat([feat1, feat2, feat3, feat4], 1)
242
+ out_aux = self.aux_head(feats)
243
+
244
+ feats = self.conv3x3(feats)
245
+
246
+ context = self.ocr_gather_head(feats, out_aux)
247
+ feats = self.ocr_distri_head(feats, context)
248
+
249
+ out = self.cls_head(feats)
250
+
251
+ out_aux = F.interpolate(
252
+ out_aux, size=(x_.size(2), x_.size(3)), mode="bilinear", align_corners=True
253
+ )
254
+ out = F.interpolate(
255
+ out, size=(x_.size(2), x_.size(3)), mode="bilinear", align_corners=True
256
+ )
257
+ return out_aux, out
258
+
259
+
260
+ class HRT_BASE_OCR_V2(nn.Module):
261
+ def __init__(self, configer):
262
+ super(HRT_BASE_OCR_V2, self).__init__()
263
+ self.configer = configer
264
+ self.num_classes = self.configer.get("data", "num_classes")
265
+ self.backbone = BackboneSelector(configer).get_backbone()
266
+
267
+ in_channels = 1170
268
+ self.conv3x3 = nn.Sequential(
269
+ nn.Conv2d(in_channels, 512, kernel_size=3, stride=1, padding=1),
270
+ ModuleHelper.BNReLU(512, bn_type=self.configer.get("network", "bn_type")),
271
+ )
272
+ self.ocr_gather_head = SpatialGather_Module(self.num_classes)
273
+ self.ocr_distri_head = SpatialOCR_Module(
274
+ in_channels=512,
275
+ key_channels=256,
276
+ out_channels=512,
277
+ scale=1,
278
+ dropout=0.05,
279
+ bn_type=self.configer.get("network", "bn_type"),
280
+ )
281
+ self.cls_head = nn.Conv2d(
282
+ 512, self.num_classes, kernel_size=1, stride=1, padding=0, bias=True
283
+ )
284
+ self.aux_head = nn.Sequential(
285
+ nn.Conv2d(in_channels, 512, kernel_size=3, stride=1, padding=1),
286
+ ModuleHelper.BNReLU(512, bn_type=self.configer.get("network", "bn_type")),
287
+ nn.Conv2d(
288
+ 512, self.num_classes, kernel_size=1, stride=1, padding=0, bias=True
289
+ ),
290
+ )
291
+
292
+ def forward(self, x_):
293
+ x = self.backbone(x_)
294
+ _, _, h, w = x[0].size()
295
+
296
+ feat1 = x[0]
297
+ feat2 = F.interpolate(x[1], size=(h, w), mode="bilinear", align_corners=True)
298
+ feat3 = F.interpolate(x[2], size=(h, w), mode="bilinear", align_corners=True)
299
+ feat4 = F.interpolate(x[3], size=(h, w), mode="bilinear", align_corners=True)
300
+
301
+ feats = torch.cat([feat1, feat2, feat3, feat4], 1)
302
+ out_aux = self.aux_head(feats)
303
+
304
+ feats = self.conv3x3(feats)
305
+
306
+ context = self.ocr_gather_head(feats, out_aux)
307
+ feats = self.ocr_distri_head(feats, context)
308
+
309
+ out = self.cls_head(feats)
310
+
311
+ out_aux = F.interpolate(
312
+ out_aux, size=(x_.size(2), x_.size(3)), mode="bilinear", align_corners=True
313
+ )
314
+ out = F.interpolate(
315
+ out, size=(x_.size(2), x_.size(3)), mode="bilinear", align_corners=True
316
+ )
317
+ return out_aux, out
318
+
319
+
320
+ class HRT_SMALL_OCR_V3(nn.Module):
321
+ def __init__(self, configer):
322
+ super(HRT_SMALL_OCR_V3, self).__init__()
323
+ self.configer = configer
324
+ self.num_classes = self.configer.get("data", "num_classes")
325
+ self.backbone = BackboneSelector(configer).get_backbone()
326
+
327
+ in_channels = 480
328
+ hidden_dim = 512
329
+ group_channel = math.gcd(in_channels, hidden_dim)
330
+ self.conv3x3 = nn.Sequential(
331
+ nn.Conv2d(
332
+ in_channels,
333
+ hidden_dim,
334
+ kernel_size=7,
335
+ stride=1,
336
+ padding=3,
337
+ groups=group_channel,
338
+ ),
339
+ ModuleHelper.BNReLU(
340
+ hidden_dim, bn_type=self.configer.get("network", "bn_type")
341
+ ),
342
+ )
343
+ self.ocr_gather_head = SpatialGather_Module(self.num_classes)
344
+ self.ocr_distri_head = SpatialOCR_Module(
345
+ in_channels=hidden_dim,
346
+ key_channels=hidden_dim // 2,
347
+ out_channels=hidden_dim,
348
+ scale=1,
349
+ dropout=0.05,
350
+ bn_type=self.configer.get("network", "bn_type"),
351
+ )
352
+ self.cls_head = nn.Conv2d(
353
+ hidden_dim, self.num_classes, kernel_size=1, stride=1, padding=0, bias=True
354
+ )
355
+ self.aux_head = nn.Sequential(
356
+ nn.Conv2d(
357
+ in_channels,
358
+ hidden_dim,
359
+ kernel_size=7,
360
+ stride=1,
361
+ padding=3,
362
+ groups=group_channel,
363
+ ),
364
+ ModuleHelper.BNReLU(
365
+ hidden_dim, bn_type=self.configer.get("network", "bn_type")
366
+ ),
367
+ nn.Conv2d(
368
+ hidden_dim,
369
+ self.num_classes,
370
+ kernel_size=1,
371
+ stride=1,
372
+ padding=0,
373
+ bias=True,
374
+ ),
375
+ )
376
+
377
+ def forward(self, x_):
378
+ x = self.backbone(x_)
379
+ _, _, h, w = x[0].size()
380
+
381
+ feat1 = x[0]
382
+ feat2 = F.interpolate(x[1], size=(h, w), mode="bilinear", align_corners=True)
383
+ feat3 = F.interpolate(x[2], size=(h, w), mode="bilinear", align_corners=True)
384
+ feat4 = F.interpolate(x[3], size=(h, w), mode="bilinear", align_corners=True)
385
+
386
+ feats = torch.cat([feat1, feat2, feat3, feat4], 1)
387
+ out_aux = self.aux_head(feats)
388
+
389
+ feats = self.conv3x3(feats)
390
+
391
+ context = self.ocr_gather_head(feats, out_aux)
392
+ feats = self.ocr_distri_head(feats, context)
393
+
394
+ out = self.cls_head(feats)
395
+
396
+ out_aux = F.interpolate(
397
+ out_aux, size=(x_.size(2), x_.size(3)), mode="bilinear", align_corners=True
398
+ )
399
+ out = F.interpolate(
400
+ out, size=(x_.size(2), x_.size(3)), mode="bilinear", align_corners=True
401
+ )
402
+ return out_aux, out
403
+
404
+
405
+ class HRT_BASE_OCR_V3(nn.Module):
406
+ def __init__(self, configer):
407
+ super(HRT_BASE_OCR_V3, self).__init__()
408
+ self.configer = configer
409
+ self.num_classes = self.configer.get("data", "num_classes")
410
+ self.backbone = BackboneSelector(configer).get_backbone()
411
+
412
+ in_channels = 1170
413
+ hidden_dim = 512
414
+ group_channel = math.gcd(in_channels, hidden_dim)
415
+ self.conv3x3 = nn.Sequential(
416
+ nn.Conv2d(
417
+ in_channels,
418
+ hidden_dim,
419
+ kernel_size=7,
420
+ stride=1,
421
+ padding=3,
422
+ groups=group_channel,
423
+ ),
424
+ ModuleHelper.BNReLU(
425
+ hidden_dim, bn_type=self.configer.get("network", "bn_type")
426
+ ),
427
+ )
428
+ self.ocr_gather_head = SpatialGather_Module(self.num_classes)
429
+ self.ocr_distri_head = SpatialOCR_Module(
430
+ in_channels=hidden_dim,
431
+ key_channels=hidden_dim // 2,
432
+ out_channels=hidden_dim,
433
+ scale=1,
434
+ dropout=0.05,
435
+ bn_type=self.configer.get("network", "bn_type"),
436
+ )
437
+ self.cls_head = nn.Conv2d(
438
+ hidden_dim, self.num_classes, kernel_size=1, stride=1, padding=0, bias=True
439
+ )
440
+ self.aux_head = nn.Sequential(
441
+ nn.Conv2d(
442
+ in_channels,
443
+ hidden_dim,
444
+ kernel_size=7,
445
+ stride=1,
446
+ padding=3,
447
+ groups=group_channel,
448
+ ),
449
+ ModuleHelper.BNReLU(
450
+ hidden_dim, bn_type=self.configer.get("network", "bn_type")
451
+ ),
452
+ nn.Conv2d(
453
+ hidden_dim,
454
+ self.num_classes,
455
+ kernel_size=1,
456
+ stride=1,
457
+ padding=0,
458
+ bias=True,
459
+ ),
460
+ )
461
+
462
+ def forward(self, x_):
463
+ x = self.backbone(x_)
464
+ _, _, h, w = x[0].size()
465
+
466
+ feat1 = x[0]
467
+ feat2 = F.interpolate(x[1], size=(h, w), mode="bilinear", align_corners=True)
468
+ feat3 = F.interpolate(x[2], size=(h, w), mode="bilinear", align_corners=True)
469
+ feat4 = F.interpolate(x[3], size=(h, w), mode="bilinear", align_corners=True)
470
+
471
+ feats = torch.cat([feat1, feat2, feat3, feat4], 1)
472
+ out_aux = self.aux_head(feats)
473
+
474
+ feats = self.conv3x3(feats)
475
+
476
+ context = self.ocr_gather_head(feats, out_aux)
477
+ feats = self.ocr_distri_head(feats, context)
478
+
479
+ out = self.cls_head(feats)
480
+
481
+ out_aux = F.interpolate(
482
+ out_aux, size=(x_.size(2), x_.size(3)), mode="bilinear", align_corners=True
483
+ )
484
+ out = F.interpolate(
485
+ out, size=(x_.size(2), x_.size(3)), mode="bilinear", align_corners=True
486
+ )
487
+ return out_aux, out
isegm/model/modeling/hrformer_helper/__init__.py ADDED
File without changes
isegm/model/modeling/hrformer_helper/backbone_selector.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2
+ ## Created by: Donny You, RainbowSecret
3
+ ## Microsoft Research
4
5
+ ## Copyright (c) 2019
6
+ ##
7
+ ## This source code is licensed under the MIT-style license found in the
8
+ ## LICENSE file in the root directory of this source tree
9
+ ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
10
+
11
+
12
+ from __future__ import absolute_import
13
+ from __future__ import division
14
+ from __future__ import print_function
15
+
16
+ # from lib.models.backbones.resnet.resnet_backbone import ResNetBackbone
17
+ # from lib.models.backbones.hrnet.hrnet_backbone import HRNetBackbone
18
+ from .hrt.hrt_backbone import HRTBackbone
19
+ # from lib.models.backbones.swin.swin_backbone import SwinTransformerBackbone
20
+ from .hrt.logger import Logger as Log
21
+
22
+
23
+ class BackboneSelector(object):
24
+ def __init__(self, configer):
25
+ self.configer = configer
26
+
27
+ def get_backbone(self, **params):
28
+ backbone = self.configer.get("network", "backbone")
29
+
30
+ model = None
31
+ # if (
32
+ # "resnet" in backbone or "resnext" in backbone or "resnest" in backbone
33
+ # ) and "senet" not in backbone:
34
+ # model = ResNetBackbone(self.configer)(**params)
35
+
36
+ if "hrt" in backbone:
37
+ # model = HRTBackbone(self.configer)(**params)
38
+ pass
39
+
40
+ # elif "hrnet" in backbone:
41
+ # model = HRNetBackbone(self.configer)(**params)
42
+
43
+ # elif "swin" in backbone:
44
+ # model = SwinTransformerBackbone(self.configer)(**params)
45
+
46
+ else:
47
+ Log.error("Backbone {} is invalid.".format(backbone))
48
+ exit(1)
49
+
50
+ return model
51
+
52
+ class Test():
53
+ def __init__():
54
+ pass
isegm/model/modeling/hrformer_helper/hrt/__init__.py ADDED
File without changes
isegm/model/modeling/hrformer_helper/hrt/hrt_backbone.py ADDED
@@ -0,0 +1,661 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pdb
3
+ import argparse
4
+ import torch
5
+ import logging
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ from .modules.bottleneck_block import Bottleneck, BottleneckDWP
10
+ from .modules.transformer_block import GeneralTransformerBlock
11
+
12
+ from .module_helper import ModuleHelper
13
+ from .logger import Logger as Log
14
+
15
+ blocks_dict = {
16
+ "BOTTLENECK": Bottleneck,
17
+ "TRANSFORMER_BLOCK": GeneralTransformerBlock,
18
+ }
19
+
20
+
21
+ BN_MOMENTUM = 0.1
22
+
23
+
24
+ class HighResolutionTransformerModule(nn.Module):
25
+ def __init__(
26
+ self,
27
+ num_branches,
28
+ blocks,
29
+ num_blocks,
30
+ num_inchannels,
31
+ num_channels,
32
+ num_heads,
33
+ num_window_sizes,
34
+ num_mlp_ratios,
35
+ multi_scale_output=True,
36
+ drop_path=0.0,
37
+ ):
38
+ """Based on Local-Attention & FFN-DW-BN
39
+ num_heads: the number of head witin each MHSA
40
+ num_window_sizes: the window size for the local self-attention
41
+ num_halo_sizes: the halo size around the local window
42
+ - reference: ``Scaling Local Self-Attention for Parameter Efficient Visual Backbones''
43
+ num_sr_ratios: the spatial reduction ratios of PVT/SRA scheme.
44
+ - reference: ``Pyramid Vision Transformer: A Versatile Backbone for Dense Prediction without Convolutions''
45
+ """
46
+ super(HighResolutionTransformerModule, self).__init__()
47
+ self._check_branches(
48
+ num_branches, blocks, num_blocks, num_inchannels, num_channels
49
+ )
50
+
51
+ self.num_inchannels = num_inchannels
52
+ self.num_branches = num_branches
53
+
54
+ self.multi_scale_output = multi_scale_output
55
+ self.branches = self._make_branches(
56
+ num_branches,
57
+ blocks,
58
+ num_blocks,
59
+ num_channels,
60
+ num_heads,
61
+ num_window_sizes,
62
+ num_mlp_ratios,
63
+ drop_path,
64
+ )
65
+ self.fuse_layers = self._make_fuse_layers()
66
+ self.relu = nn.ReLU(inplace=True)
67
+
68
+ self.num_heads = num_heads
69
+ self.num_window_sizes = num_window_sizes
70
+ self.num_mlp_ratios = num_mlp_ratios
71
+
72
+ def _check_branches(
73
+ self, num_branches, blocks, num_blocks, num_inchannels, num_channels
74
+ ):
75
+ if num_branches != len(num_blocks):
76
+ error_msg = "NUM_BRANCHES({}) <> NUM_BLOCKS({})".format(
77
+ num_branches, len(num_blocks)
78
+ )
79
+ Log.error(error_msg)
80
+ raise ValueError(error_msg)
81
+
82
+ if num_branches != len(num_channels):
83
+ error_msg = "NUM_BRANCHES({}) <> NUM_CHANNELS({})".format(
84
+ num_branches, len(num_channels)
85
+ )
86
+ Log.error(error_msg)
87
+ raise ValueError(error_msg)
88
+
89
+ if num_branches != len(num_inchannels):
90
+ error_msg = "NUM_BRANCHES({}) <> NUM_INCHANNELS({})".format(
91
+ num_branches, len(num_inchannels)
92
+ )
93
+ Log.error(error_msg)
94
+ raise ValueError(error_msg)
95
+
96
+ def _make_one_branch(
97
+ self,
98
+ branch_index,
99
+ block,
100
+ num_blocks,
101
+ num_channels,
102
+ num_heads,
103
+ num_window_sizes,
104
+ num_mlp_ratios,
105
+ drop_paths,
106
+ stride=1,
107
+ ):
108
+ downsample = None
109
+ if (
110
+ stride != 1
111
+ or self.num_inchannels[branch_index]
112
+ != num_channels[branch_index] * block.expansion
113
+ ):
114
+ downsample = nn.Sequential(
115
+ nn.Conv2d(
116
+ self.num_inchannels[branch_index],
117
+ num_channels[branch_index] * block.expansion,
118
+ kernel_size=1,
119
+ stride=stride,
120
+ bias=False,
121
+ ),
122
+ nn.SyncBatchNorm(
123
+ num_channels[branch_index] * block.expansion, momentum=BN_MOMENTUM
124
+ ),
125
+ )
126
+
127
+ layers = []
128
+ layers.append(
129
+ block(
130
+ self.num_inchannels[branch_index],
131
+ num_channels[branch_index],
132
+ num_heads=num_heads[branch_index],
133
+ window_size=num_window_sizes[branch_index],
134
+ mlp_ratio=num_mlp_ratios[branch_index],
135
+ drop_path=drop_paths[0],
136
+ )
137
+ )
138
+
139
+ self.num_inchannels[branch_index] = num_channels[branch_index] * block.expansion
140
+ for i in range(1, num_blocks[branch_index]):
141
+ layers.append(
142
+ block(
143
+ self.num_inchannels[branch_index],
144
+ num_channels[branch_index],
145
+ num_heads=num_heads[branch_index],
146
+ window_size=num_window_sizes[branch_index],
147
+ mlp_ratio=num_mlp_ratios[branch_index],
148
+ drop_path=drop_paths[i],
149
+ )
150
+ )
151
+ return nn.Sequential(*layers)
152
+
153
+ def _make_branches(
154
+ self,
155
+ num_branches,
156
+ block,
157
+ num_blocks,
158
+ num_channels,
159
+ num_heads,
160
+ num_window_sizes,
161
+ num_mlp_ratios,
162
+ drop_paths,
163
+ ):
164
+ branches = []
165
+
166
+ for i in range(num_branches):
167
+ branches.append(
168
+ self._make_one_branch(
169
+ i,
170
+ block,
171
+ num_blocks,
172
+ num_channels,
173
+ num_heads,
174
+ num_window_sizes,
175
+ num_mlp_ratios,
176
+ drop_paths=[_ * (2 ** i) for _ in drop_paths]
177
+ if os.environ.get("multi_res_drop_path", False)
178
+ else drop_paths,
179
+ )
180
+ )
181
+
182
+ return nn.ModuleList(branches)
183
+
184
+ def _make_fuse_layers(self):
185
+ if self.num_branches == 1:
186
+ return None
187
+ num_branches = self.num_branches
188
+ num_inchannels = self.num_inchannels
189
+ fuse_layers = []
190
+ for i in range(num_branches if self.multi_scale_output else 1):
191
+ fuse_layer = []
192
+ for j in range(num_branches):
193
+ if j > i:
194
+ fuse_layer.append(
195
+ nn.Sequential(
196
+ nn.Conv2d(
197
+ num_inchannels[j],
198
+ num_inchannels[i],
199
+ kernel_size=1,
200
+ stride=1,
201
+ bias=False,
202
+ ),
203
+ nn.SyncBatchNorm(num_inchannels[i], momentum=BN_MOMENTUM),
204
+ nn.Upsample(scale_factor=2 ** (j - i), mode="nearest"),
205
+ )
206
+ )
207
+ elif j == i:
208
+ fuse_layer.append(None)
209
+ else:
210
+ conv3x3s = []
211
+ for k in range(i - j):
212
+ if k == i - j - 1:
213
+ num_outchannels_conv3x3 = num_inchannels[i]
214
+ conv3x3s.append(
215
+ nn.Sequential(
216
+ nn.Conv2d(
217
+ num_inchannels[j],
218
+ num_inchannels[j],
219
+ kernel_size=3,
220
+ stride=2,
221
+ padding=1,
222
+ groups=num_inchannels[j],
223
+ bias=False,
224
+ ),
225
+ nn.SyncBatchNorm(
226
+ num_inchannels[j], momentum=BN_MOMENTUM
227
+ ),
228
+ nn.Conv2d(
229
+ num_inchannels[j],
230
+ num_outchannels_conv3x3,
231
+ kernel_size=1,
232
+ stride=1,
233
+ bias=False,
234
+ ),
235
+ nn.SyncBatchNorm(
236
+ num_outchannels_conv3x3, momentum=BN_MOMENTUM
237
+ ),
238
+ )
239
+ )
240
+ else:
241
+ num_outchannels_conv3x3 = num_inchannels[j]
242
+ conv3x3s.append(
243
+ nn.Sequential(
244
+ nn.Conv2d(
245
+ num_inchannels[j],
246
+ num_inchannels[j],
247
+ kernel_size=3,
248
+ stride=2,
249
+ padding=1,
250
+ groups=num_inchannels[j],
251
+ bias=False,
252
+ ),
253
+ nn.SyncBatchNorm(
254
+ num_inchannels[j], momentum=BN_MOMENTUM
255
+ ),
256
+ nn.Conv2d(
257
+ num_inchannels[j],
258
+ num_outchannels_conv3x3,
259
+ kernel_size=1,
260
+ stride=1,
261
+ bias=False,
262
+ ),
263
+ nn.SyncBatchNorm(
264
+ num_outchannels_conv3x3, momentum=BN_MOMENTUM
265
+ ),
266
+ nn.ReLU(False),
267
+ )
268
+ )
269
+ fuse_layer.append(nn.Sequential(*conv3x3s))
270
+ fuse_layers.append(nn.ModuleList(fuse_layer))
271
+
272
+ return nn.ModuleList(fuse_layers)
273
+
274
+ def get_num_inchannels(self):
275
+ return self.num_inchannels
276
+
277
+ def forward(self, x):
278
+ if self.num_branches == 1:
279
+ return [self.branches[0](x[0])]
280
+
281
+ for i in range(self.num_branches):
282
+ x[i] = self.branches[i](x[i])
283
+
284
+ x_fuse = []
285
+ for i in range(len(self.fuse_layers)):
286
+ y = x[0] if i == 0 else self.fuse_layers[i][0](x[0])
287
+ for j in range(1, self.num_branches):
288
+ if i == j:
289
+ y = y + x[j]
290
+ elif j > i:
291
+ width_output = x[i].shape[-1]
292
+ height_output = x[i].shape[-2]
293
+ y = y + F.interpolate(
294
+ self.fuse_layers[i][j](x[j]),
295
+ size=[height_output, width_output],
296
+ mode="bilinear",
297
+ align_corners=True,
298
+ )
299
+ else:
300
+ y = y + self.fuse_layers[i][j](x[j])
301
+ x_fuse.append(self.relu(y))
302
+
303
+ return x_fuse
304
+
305
+
306
+ class HighResolutionTransformer(nn.Module):
307
+ def __init__(self, cfg, in_ch=3, **kwargs):
308
+ super(HighResolutionTransformer, self).__init__()
309
+
310
+ self.conv1 = nn.Conv2d(in_ch, 64, kernel_size=3, stride=2, padding=1, bias=False)
311
+ self.bn1 = nn.SyncBatchNorm(64, momentum=BN_MOMENTUM)
312
+ self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=False)
313
+ self.bn2 = nn.SyncBatchNorm(64, momentum=BN_MOMENTUM)
314
+ self.relu = nn.ReLU(inplace=True)
315
+
316
+ # stochastic depth
317
+ depth_s2 = cfg["STAGE2"]["NUM_BLOCKS"][0] * cfg["STAGE2"]["NUM_MODULES"]
318
+ depth_s3 = cfg["STAGE3"]["NUM_BLOCKS"][0] * cfg["STAGE3"]["NUM_MODULES"]
319
+ depth_s4 = cfg["STAGE4"]["NUM_BLOCKS"][0] * cfg["STAGE4"]["NUM_MODULES"]
320
+ depths = [depth_s2, depth_s3, depth_s4]
321
+ drop_path_rate = cfg["DROP_PATH_RATE"]
322
+ if os.environ.get("drop_path_rate") is not None:
323
+ drop_path_rate = float(os.environ.get("drop_path_rate"))
324
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
325
+
326
+ self.stage1_cfg = cfg["STAGE1"]
327
+ num_channels = self.stage1_cfg["NUM_CHANNELS"][0]
328
+ block = blocks_dict[self.stage1_cfg["BLOCK"]]
329
+ num_blocks = self.stage1_cfg["NUM_BLOCKS"][0]
330
+ self.layer1 = self._make_layer(block, 64, num_channels, num_blocks)
331
+ stage1_out_channel = block.expansion * num_channels
332
+
333
+ self.stage2_cfg = cfg["STAGE2"]
334
+ num_channels = self.stage2_cfg["NUM_CHANNELS"]
335
+ block = blocks_dict[self.stage2_cfg["BLOCK"]]
336
+ num_channels = [
337
+ num_channels[i] * block.expansion for i in range(len(num_channels))
338
+ ]
339
+ self.transition1 = self._make_transition_layer(
340
+ [stage1_out_channel], num_channels
341
+ )
342
+ self.stage2, pre_stage_channels = self._make_stage(
343
+ self.stage2_cfg, num_channels, drop_path=dpr[0:depth_s2]
344
+ )
345
+
346
+ self.stage3_cfg = cfg["STAGE3"]
347
+ num_channels = self.stage3_cfg["NUM_CHANNELS"]
348
+ block = blocks_dict[self.stage3_cfg["BLOCK"]]
349
+ num_channels = [
350
+ num_channels[i] * block.expansion for i in range(len(num_channels))
351
+ ]
352
+ self.transition2 = self._make_transition_layer(pre_stage_channels, num_channels)
353
+ self.stage3, pre_stage_channels = self._make_stage(
354
+ self.stage3_cfg, num_channels, drop_path=dpr[depth_s2 : depth_s2 + depth_s3]
355
+ )
356
+
357
+ self.stage4_cfg = cfg["STAGE4"]
358
+ num_channels = self.stage4_cfg["NUM_CHANNELS"]
359
+ block = blocks_dict[self.stage4_cfg["BLOCK"]]
360
+ num_channels = [
361
+ num_channels[i] * block.expansion for i in range(len(num_channels))
362
+ ]
363
+ self.transition3 = self._make_transition_layer(pre_stage_channels, num_channels)
364
+ self.stage4, pre_stage_channels = self._make_stage(
365
+ self.stage4_cfg,
366
+ num_channels,
367
+ multi_scale_output=True,
368
+ drop_path=dpr[depth_s2 + depth_s3 :],
369
+ )
370
+
371
+ if os.environ.get("keep_imagenet_head"):
372
+ (
373
+ self.incre_modules,
374
+ self.downsamp_modules,
375
+ self.final_layer,
376
+ ) = self._make_head(pre_stage_channels)
377
+
378
+ def _make_head(self, pre_stage_channels):
379
+ head_block = BottleneckDWP
380
+ head_channels = [32, 64, 128, 256]
381
+
382
+ # Increasing the #channels on each resolution
383
+ # from C, 2C, 4C, 8C to 128, 256, 512, 1024
384
+ incre_modules = []
385
+ for i, channels in enumerate(pre_stage_channels):
386
+ incre_module = self._make_layer(
387
+ head_block, channels, head_channels[i], 1, stride=1
388
+ )
389
+ incre_modules.append(incre_module)
390
+ incre_modules = nn.ModuleList(incre_modules)
391
+
392
+ # downsampling modules
393
+ downsamp_modules = []
394
+ for i in range(len(pre_stage_channels) - 1):
395
+ in_channels = head_channels[i] * head_block.expansion
396
+ out_channels = head_channels[i + 1] * head_block.expansion
397
+ downsamp_module = nn.Sequential(
398
+ nn.Conv2d(
399
+ in_channels,
400
+ in_channels,
401
+ kernel_size=3,
402
+ stride=2,
403
+ padding=1,
404
+ groups=in_channels,
405
+ ),
406
+ nn.SyncBatchNorm(in_channels, momentum=BN_MOMENTUM),
407
+ nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1),
408
+ nn.SyncBatchNorm(out_channels, momentum=BN_MOMENTUM),
409
+ nn.ReLU(inplace=True),
410
+ )
411
+ downsamp_modules.append(downsamp_module)
412
+ downsamp_modules = nn.ModuleList(downsamp_modules)
413
+
414
+ final_layer = nn.Sequential(
415
+ nn.Conv2d(
416
+ in_channels=head_channels[3] * head_block.expansion,
417
+ out_channels=2048,
418
+ kernel_size=1,
419
+ stride=1,
420
+ padding=0,
421
+ ),
422
+ nn.SyncBatchNorm(2048, momentum=BN_MOMENTUM),
423
+ nn.ReLU(inplace=True),
424
+ )
425
+
426
+ return incre_modules, downsamp_modules, final_layer
427
+
428
+ def _make_transition_layer(self, num_channels_pre_layer, num_channels_cur_layer):
429
+ num_branches_cur = len(num_channels_cur_layer)
430
+ num_branches_pre = len(num_channels_pre_layer)
431
+
432
+ transition_layers = []
433
+ for i in range(num_branches_cur):
434
+ if i < num_branches_pre:
435
+ if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
436
+ transition_layers.append(
437
+ nn.Sequential(
438
+ nn.Conv2d(
439
+ num_channels_pre_layer[i],
440
+ num_channels_cur_layer[i],
441
+ 3,
442
+ 1,
443
+ 1,
444
+ bias=False,
445
+ ),
446
+ nn.SyncBatchNorm(
447
+ num_channels_cur_layer[i], momentum=BN_MOMENTUM
448
+ ),
449
+ nn.ReLU(inplace=True),
450
+ )
451
+ )
452
+ else:
453
+ transition_layers.append(None)
454
+ else:
455
+ conv3x3s = []
456
+ for j in range(i + 1 - num_branches_pre):
457
+ inchannels = num_channels_pre_layer[-1]
458
+ outchannels = (
459
+ num_channels_cur_layer[i]
460
+ if j == i - num_branches_pre
461
+ else inchannels
462
+ )
463
+ conv3x3s.append(
464
+ nn.Sequential(
465
+ nn.Conv2d(inchannels, outchannels, 3, 2, 1, bias=False),
466
+ nn.SyncBatchNorm(outchannels, momentum=BN_MOMENTUM),
467
+ nn.ReLU(inplace=True),
468
+ )
469
+ )
470
+ transition_layers.append(nn.Sequential(*conv3x3s))
471
+
472
+ return nn.ModuleList(transition_layers)
473
+
474
+ def _make_layer(
475
+ self,
476
+ block,
477
+ inplanes,
478
+ planes,
479
+ blocks,
480
+ num_heads=1,
481
+ stride=1,
482
+ window_size=7,
483
+ mlp_ratio=4.0,
484
+ ):
485
+ downsample = None
486
+ if stride != 1 or inplanes != planes * block.expansion:
487
+ downsample = nn.Sequential(
488
+ nn.Conv2d(
489
+ inplanes,
490
+ planes * block.expansion,
491
+ kernel_size=1,
492
+ stride=stride,
493
+ bias=False,
494
+ ),
495
+ nn.SyncBatchNorm(planes * block.expansion, momentum=BN_MOMENTUM),
496
+ )
497
+ layers = []
498
+
499
+ if isinstance(block, GeneralTransformerBlock):
500
+ layers.append(
501
+ block(
502
+ inplanes,
503
+ planes,
504
+ num_heads,
505
+ window_size,
506
+ mlp_ratio,
507
+ )
508
+ )
509
+ else:
510
+ layers.append(block(inplanes, planes, stride, downsample))
511
+
512
+ inplanes = planes * block.expansion
513
+ for i in range(1, blocks):
514
+ layers.append(block(inplanes, planes))
515
+
516
+ return nn.Sequential(*layers)
517
+
518
+ def _make_stage(
519
+ self, layer_config, num_inchannels, multi_scale_output=True, drop_path=0.0
520
+ ):
521
+ num_modules = layer_config["NUM_MODULES"]
522
+ num_branches = layer_config["NUM_BRANCHES"]
523
+ num_blocks = layer_config["NUM_BLOCKS"]
524
+ num_channels = layer_config["NUM_CHANNELS"]
525
+ block = blocks_dict[layer_config["BLOCK"]]
526
+ num_heads = layer_config["NUM_HEADS"]
527
+ num_window_sizes = layer_config["NUM_WINDOW_SIZES"]
528
+ num_mlp_ratios = layer_config["NUM_MLP_RATIOS"]
529
+
530
+ modules = []
531
+ for i in range(num_modules):
532
+ # multi_scale_output is only used last module
533
+ if not multi_scale_output and i == num_modules - 1:
534
+ reset_multi_scale_output = False
535
+ else:
536
+ reset_multi_scale_output = True
537
+
538
+ modules.append(
539
+ HighResolutionTransformerModule(
540
+ num_branches,
541
+ block,
542
+ num_blocks,
543
+ num_inchannels,
544
+ num_channels,
545
+ num_heads,
546
+ num_window_sizes,
547
+ num_mlp_ratios,
548
+ reset_multi_scale_output,
549
+ drop_path=drop_path[num_blocks[0] * i : num_blocks[0] * (i + 1)],
550
+ )
551
+ )
552
+ num_inchannels = modules[-1].get_num_inchannels()
553
+
554
+ return nn.Sequential(*modules), num_inchannels
555
+
556
+ def forward(self, x):
557
+ x = self.conv1(x)
558
+ x = self.bn1(x)
559
+ x = self.relu(x)
560
+ x = self.conv2(x)
561
+ x = self.bn2(x)
562
+ x = self.relu(x)
563
+ x = self.layer1(x)
564
+
565
+ x_list = []
566
+ for i in range(self.stage2_cfg["NUM_BRANCHES"]):
567
+ if self.transition1[i] is not None:
568
+ x_list.append(self.transition1[i](x))
569
+ else:
570
+ x_list.append(x)
571
+ y_list = self.stage2(x_list)
572
+
573
+ x_list = []
574
+ for i in range(self.stage3_cfg["NUM_BRANCHES"]):
575
+ if self.transition2[i] is not None:
576
+ x_list.append(self.transition2[i](y_list[-1]))
577
+ else:
578
+ x_list.append(y_list[i])
579
+ y_list = self.stage3(x_list)
580
+
581
+ x_list = []
582
+ for i in range(self.stage4_cfg["NUM_BRANCHES"]):
583
+ if self.transition3[i] is not None:
584
+ x_list.append(self.transition3[i](y_list[-1]))
585
+ else:
586
+ x_list.append(y_list[i])
587
+ y_list = self.stage4(x_list)
588
+
589
+ if os.environ.get("keep_imagenet_head"):
590
+ x_list = []
591
+ y = self.incre_modules[0](y_list[0])
592
+ x_list.append(y)
593
+ for i in range(len(self.downsamp_modules)):
594
+ y = self.incre_modules[i + 1](y_list[i + 1]) + self.downsamp_modules[i](
595
+ y
596
+ )
597
+ x_list.append(y)
598
+
599
+ y = self.final_layer(y)
600
+ del x_list[-1]
601
+ x_list.append(y)
602
+ return x_list
603
+
604
+ else:
605
+ return y_list
606
+
607
+
608
+ class HRTBackbone(object):
609
+ def __init__(self, configer):
610
+ self.configer = configer
611
+
612
+ def __call__(self):
613
+ arch = self.configer.get("network", "backbone")
614
+ from .hrt_config import MODEL_CONFIGS
615
+
616
+ if arch in [
617
+ "hrt_small",
618
+ "hrt_base",
619
+ "hrt_base_win13",
620
+ "hrt_base_win15",
621
+ ]:
622
+ arch_net = HighResolutionTransformer(MODEL_CONFIGS[arch])
623
+ arch_net = ModuleHelper.load_model(
624
+ arch_net,
625
+ pretrained=self.configer.get("network", "pretrained"),
626
+ all_match=False,
627
+ network="hrt_window" if "win" in arch else "hrt",
628
+ )
629
+
630
+ else:
631
+ raise Exception("Architecture undefined!")
632
+
633
+ return arch_net
634
+
635
+
636
+ class HRTBackbone_v2(object):
637
+ def __init__(self, backbone='hrt_small', pretrained=None, in_ch=3):
638
+ self.backbone = backbone
639
+ self.pretrained = pretrained
640
+ self.in_ch = in_ch
641
+
642
+ def __call__(self):
643
+ from .hrt_config import MODEL_CONFIGS
644
+ if self.backbone in [
645
+ "hrt_small",
646
+ "hrt_base",
647
+ "hrt_base_win13",
648
+ "hrt_base_win15",
649
+ ]:
650
+ arch_net = HighResolutionTransformer(MODEL_CONFIGS[self.backbone], in_ch=self.in_ch)
651
+ arch_net = ModuleHelper.load_model(
652
+ arch_net,
653
+ pretrained=self.pretrained,
654
+ all_match=False,
655
+ network="hrt_window" if "win" in self.backbone else "hrt",
656
+ )
657
+
658
+ else:
659
+ raise Exception("ARCHITECTURE UNDEFINED!")
660
+
661
+ return arch_net
isegm/model/modeling/hrformer_helper/hrt/hrt_config.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft
3
+ # Licensed under the MIT License.
4
+ # Rainbowsecret ([email protected])
5
+ # ------------------------------------------------------------------------------
6
+
7
+ from __future__ import absolute_import
8
+ from __future__ import division
9
+ from __future__ import print_function
10
+
11
+ from yacs.config import CfgNode as CN
12
+
13
+ # configs for HRT_SMALL
14
+ HRT_SMALL = CN()
15
+ HRT_SMALL.DROP_PATH_RATE = 0.2
16
+
17
+ HRT_SMALL.STAGE1 = CN()
18
+ HRT_SMALL.STAGE1.NUM_MODULES = 1
19
+ HRT_SMALL.STAGE1.NUM_BRANCHES = 1
20
+ HRT_SMALL.STAGE1.NUM_BLOCKS = [2]
21
+ HRT_SMALL.STAGE1.NUM_CHANNELS = [64]
22
+ HRT_SMALL.STAGE1.NUM_HEADS = [2]
23
+ HRT_SMALL.STAGE1.NUM_MLP_RATIOS = [4]
24
+ HRT_SMALL.STAGE1.NUM_RESOLUTIONS = [[56, 56]]
25
+ HRT_SMALL.STAGE1.BLOCK = "BOTTLENECK"
26
+
27
+ HRT_SMALL.STAGE2 = CN()
28
+ HRT_SMALL.STAGE2.NUM_MODULES = 1
29
+ HRT_SMALL.STAGE2.NUM_BRANCHES = 2
30
+ HRT_SMALL.STAGE2.NUM_BLOCKS = [2, 2]
31
+ HRT_SMALL.STAGE2.NUM_CHANNELS = [32, 64]
32
+ HRT_SMALL.STAGE2.NUM_HEADS = [1, 2]
33
+ HRT_SMALL.STAGE2.NUM_MLP_RATIOS = [4, 4]
34
+ HRT_SMALL.STAGE2.NUM_RESOLUTIONS = [[56, 56], [28, 28]]
35
+ HRT_SMALL.STAGE2.NUM_WINDOW_SIZES = [7, 7]
36
+ HRT_SMALL.STAGE2.BLOCK = "TRANSFORMER_BLOCK"
37
+
38
+ HRT_SMALL.STAGE3 = CN()
39
+ HRT_SMALL.STAGE3.NUM_MODULES = 4
40
+ HRT_SMALL.STAGE3.NUM_BRANCHES = 3
41
+ HRT_SMALL.STAGE3.NUM_BLOCKS = [2, 2, 2]
42
+ HRT_SMALL.STAGE3.NUM_CHANNELS = [32, 64, 128]
43
+ HRT_SMALL.STAGE3.NUM_HEADS = [1, 2, 4]
44
+ HRT_SMALL.STAGE3.NUM_MLP_RATIOS = [4, 4, 4]
45
+ HRT_SMALL.STAGE3.NUM_RESOLUTIONS = [[56, 56], [28, 28], [14, 14]]
46
+ HRT_SMALL.STAGE3.NUM_WINDOW_SIZES = [7, 7, 7]
47
+ HRT_SMALL.STAGE3.BLOCK = "TRANSFORMER_BLOCK"
48
+
49
+ HRT_SMALL.STAGE4 = CN()
50
+ HRT_SMALL.STAGE4.NUM_MODULES = 2
51
+ HRT_SMALL.STAGE4.NUM_BRANCHES = 4
52
+ HRT_SMALL.STAGE4.NUM_BLOCKS = [2, 2, 2, 2]
53
+ HRT_SMALL.STAGE4.NUM_CHANNELS = [32, 64, 128, 256]
54
+ HRT_SMALL.STAGE4.NUM_HEADS = [1, 2, 4, 8]
55
+ HRT_SMALL.STAGE4.NUM_MLP_RATIOS = [4, 4, 4, 4]
56
+ HRT_SMALL.STAGE4.NUM_RESOLUTIONS = [[56, 56], [28, 28], [14, 14], [7, 7]]
57
+ HRT_SMALL.STAGE4.NUM_WINDOW_SIZES = [7, 7, 7, 7]
58
+ HRT_SMALL.STAGE4.BLOCK = "TRANSFORMER_BLOCK"
59
+
60
+ # configs for HRT_BASE
61
+ HRT_BASE = CN()
62
+ HRT_BASE.DROP_PATH_RATE = 0.2
63
+
64
+ HRT_BASE.STAGE1 = CN()
65
+ HRT_BASE.STAGE1.NUM_MODULES = 1
66
+ HRT_BASE.STAGE1.NUM_BRANCHES = 1
67
+ HRT_BASE.STAGE1.NUM_BLOCKS = [2]
68
+ HRT_BASE.STAGE1.NUM_CHANNELS = [64]
69
+ HRT_BASE.STAGE1.NUM_HEADS = [2]
70
+ HRT_BASE.STAGE1.NUM_MLP_RATIOS = [4]
71
+ HRT_BASE.STAGE1.NUM_RESOLUTIONS = [[56, 56]]
72
+ HRT_BASE.STAGE1.BLOCK = "BOTTLENECK"
73
+
74
+ HRT_BASE.STAGE2 = CN()
75
+ HRT_BASE.STAGE2.NUM_MODULES = 1
76
+ HRT_BASE.STAGE2.NUM_BRANCHES = 2
77
+ HRT_BASE.STAGE2.NUM_BLOCKS = [2, 2]
78
+ HRT_BASE.STAGE2.NUM_CHANNELS = [78, 156]
79
+ HRT_BASE.STAGE2.NUM_HEADS = [2, 4]
80
+ HRT_BASE.STAGE2.NUM_MLP_RATIOS = [4, 4]
81
+ HRT_BASE.STAGE2.NUM_RESOLUTIONS = [[56, 56], [28, 28]]
82
+ HRT_BASE.STAGE2.NUM_WINDOW_SIZES = [7, 7]
83
+ HRT_BASE.STAGE2.BLOCK = "TRANSFORMER_BLOCK"
84
+
85
+ HRT_BASE.STAGE3 = CN()
86
+ HRT_BASE.STAGE3.NUM_MODULES = 4
87
+ HRT_BASE.STAGE3.NUM_BRANCHES = 3
88
+ HRT_BASE.STAGE3.NUM_BLOCKS = [2, 2, 2]
89
+ HRT_BASE.STAGE3.NUM_CHANNELS = [78, 156, 312]
90
+ HRT_BASE.STAGE3.NUM_HEADS = [2, 4, 8]
91
+ HRT_BASE.STAGE3.NUM_MLP_RATIOS = [4, 4, 4]
92
+ HRT_BASE.STAGE3.NUM_RESOLUTIONS = [[56, 56], [28, 28], [14, 14]]
93
+ HRT_BASE.STAGE3.NUM_WINDOW_SIZES = [7, 7, 7]
94
+ HRT_BASE.STAGE3.BLOCK = "TRANSFORMER_BLOCK"
95
+
96
+ HRT_BASE.STAGE4 = CN()
97
+ HRT_BASE.STAGE4.NUM_MODULES = 2
98
+ HRT_BASE.STAGE4.NUM_BRANCHES = 4
99
+ HRT_BASE.STAGE4.NUM_BLOCKS = [2, 2, 2, 2]
100
+ HRT_BASE.STAGE4.NUM_CHANNELS = [78, 156, 312, 624]
101
+ HRT_BASE.STAGE4.NUM_HEADS = [2, 4, 8, 16]
102
+ HRT_BASE.STAGE4.NUM_MLP_RATIOS = [4, 4, 4, 4]
103
+ HRT_BASE.STAGE4.NUM_RESOLUTIONS = [[56, 56], [28, 28], [14, 14], [7, 7]]
104
+ HRT_BASE.STAGE4.NUM_WINDOW_SIZES = [7, 7, 7, 7]
105
+ HRT_BASE.STAGE4.BLOCK = "TRANSFORMER_BLOCK"
106
+
107
+ HRT_BASE_WIN_13 = HRT_BASE.clone()
108
+ HRT_BASE_WIN_13.STAGE2.NUM_WINDOW_SIZES = [13, 13]
109
+ HRT_BASE_WIN_13.STAGE3.NUM_WINDOW_SIZES = [13, 13, 13]
110
+ HRT_BASE_WIN_13.STAGE4.NUM_WINDOW_SIZES = [13, 13, 13, 13]
111
+
112
+
113
+ HRT_BASE_WIN_15 = HRT_BASE.clone()
114
+ HRT_BASE_WIN_15.STAGE2.NUM_WINDOW_SIZES = [15, 15]
115
+ HRT_BASE_WIN_15.STAGE3.NUM_WINDOW_SIZES = [15, 15, 15]
116
+ HRT_BASE_WIN_15.STAGE4.NUM_WINDOW_SIZES = [15, 15, 15, 15]
117
+
118
+ MODEL_CONFIGS = {
119
+ "hrt_small": HRT_SMALL,
120
+ "hrt_base": HRT_BASE,
121
+ "hrt_base_win13": HRT_BASE_WIN_13,
122
+ "hrt_base_win15": HRT_BASE_WIN_15,
123
+ }
isegm/model/modeling/hrformer_helper/hrt/logger.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding:utf-8 -*-
3
+ # Author: Donny You([email protected])
4
+ # Logging tool implemented with the python Package logging.
5
+
6
+
7
+ from __future__ import absolute_import
8
+ from __future__ import division
9
+ from __future__ import print_function
10
+
11
+ import argparse
12
+ import logging
13
+ import os
14
+ import sys
15
+
16
+
17
+ DEFAULT_LOGFILE_LEVEL = 'debug'
18
+ DEFAULT_STDOUT_LEVEL = 'info'
19
+ DEFAULT_LOG_FILE = './default.log'
20
+ DEFAULT_LOG_FORMAT = '%(asctime)s %(levelname)-7s %(message)s'
21
+
22
+ LOG_LEVEL_DICT = {
23
+ 'debug': logging.DEBUG,
24
+ 'info': logging.INFO,
25
+ 'warning': logging.WARNING,
26
+ 'error': logging.ERROR,
27
+ 'critical': logging.CRITICAL
28
+ }
29
+
30
+
31
+ class Logger(object):
32
+ """
33
+ Args:
34
+ Log level: CRITICAL>ERROR>WARNING>INFO>DEBUG.
35
+ Log file: The file that stores the logging info.
36
+ rewrite: Clear the log file.
37
+ log format: The format of log messages.
38
+ stdout level: The log level to print on the screen.
39
+ """
40
+ logfile_level = None
41
+ log_file = None
42
+ log_format = None
43
+ rewrite = None
44
+ stdout_level = None
45
+ logger = None
46
+
47
+ _caches = {}
48
+
49
+ @staticmethod
50
+ def init(logfile_level=DEFAULT_LOGFILE_LEVEL,
51
+ log_file=DEFAULT_LOG_FILE,
52
+ log_format=DEFAULT_LOG_FORMAT,
53
+ rewrite=False,
54
+ stdout_level=None):
55
+ Logger.logfile_level = logfile_level
56
+ Logger.log_file = log_file
57
+ Logger.log_format = log_format
58
+ Logger.rewrite = rewrite
59
+ Logger.stdout_level = stdout_level
60
+
61
+ Logger.logger = logging.getLogger()
62
+ Logger.logger.handlers = []
63
+ fmt = logging.Formatter(Logger.log_format)
64
+
65
+ if Logger.logfile_level is not None:
66
+ filemode = 'w'
67
+ if not Logger.rewrite:
68
+ filemode = 'a'
69
+
70
+ dir_name = os.path.dirname(os.path.abspath(Logger.log_file))
71
+ if not os.path.exists(dir_name):
72
+ os.makedirs(dir_name)
73
+
74
+ if Logger.logfile_level not in LOG_LEVEL_DICT:
75
+ print('Invalid logging level: {}'.format(Logger.logfile_level))
76
+ Logger.logfile_level = DEFAULT_LOGFILE_LEVEL
77
+
78
+ Logger.logger.setLevel(LOG_LEVEL_DICT[Logger.logfile_level])
79
+
80
+ fh = logging.FileHandler(Logger.log_file, mode=filemode)
81
+ fh.setFormatter(fmt)
82
+ fh.setLevel(LOG_LEVEL_DICT[Logger.logfile_level])
83
+
84
+ Logger.logger.addHandler(fh)
85
+
86
+ if stdout_level is not None:
87
+ if Logger.logfile_level is None:
88
+ Logger.logger.setLevel(LOG_LEVEL_DICT[Logger.stdout_level])
89
+
90
+ console = logging.StreamHandler()
91
+ if Logger.stdout_level not in LOG_LEVEL_DICT:
92
+ print('Invalid logging level: {}'.format(Logger.stdout_level))
93
+ return
94
+
95
+ console.setLevel(LOG_LEVEL_DICT[Logger.stdout_level])
96
+ console.setFormatter(fmt)
97
+ Logger.logger.addHandler(console)
98
+
99
+ @staticmethod
100
+ def set_log_file(file_path):
101
+ Logger.log_file = file_path
102
+ Logger.init(log_file=file_path)
103
+
104
+ @staticmethod
105
+ def set_logfile_level(log_level):
106
+ if log_level not in LOG_LEVEL_DICT:
107
+ print('Invalid logging level: {}'.format(log_level))
108
+ return
109
+
110
+ Logger.init(logfile_level=log_level)
111
+
112
+ @staticmethod
113
+ def clear_log_file():
114
+ Logger.rewrite = True
115
+ Logger.init(rewrite=True)
116
+
117
+ @staticmethod
118
+ def check_logger():
119
+ if Logger.logger is None:
120
+ Logger.init(logfile_level=None, stdout_level=DEFAULT_STDOUT_LEVEL)
121
+
122
+ @staticmethod
123
+ def set_stdout_level(log_level):
124
+ if log_level not in LOG_LEVEL_DICT:
125
+ print('Invalid logging level: {}'.format(log_level))
126
+ return
127
+
128
+ Logger.init(stdout_level=log_level)
129
+
130
+ @staticmethod
131
+ def debug(message):
132
+ Logger.check_logger()
133
+ filename = os.path.basename(sys._getframe().f_back.f_code.co_filename)
134
+ lineno = sys._getframe().f_back.f_lineno
135
+ prefix = '[{}, {}]'.format(filename,lineno)
136
+ Logger.logger.debug('{} {}'.format(prefix, message))
137
+
138
+ @staticmethod
139
+ def info(message):
140
+ Logger.check_logger()
141
+ filename = os.path.basename(sys._getframe().f_back.f_code.co_filename)
142
+ lineno = sys._getframe().f_back.f_lineno
143
+ prefix = '[{}, {}]'.format(filename,lineno)
144
+ Logger.logger.info('{} {}'.format(prefix, message))
145
+
146
+ @staticmethod
147
+ def info_once(message):
148
+ Logger.check_logger()
149
+ filename = os.path.basename(sys._getframe().f_back.f_code.co_filename)
150
+ lineno = sys._getframe().f_back.f_lineno
151
+ prefix = '[{}, {}]'.format(filename, lineno)
152
+
153
+ if Logger._caches.get((prefix, message)) is not None:
154
+ return
155
+
156
+ Logger.logger.info('{} {}'.format(prefix, message))
157
+ Logger._caches[(prefix, message)] = True
158
+
159
+ @staticmethod
160
+ def warn(message):
161
+ Logger.check_logger()
162
+ filename = os.path.basename(sys._getframe().f_back.f_code.co_filename)
163
+ lineno = sys._getframe().f_back.f_lineno
164
+ prefix = '[{}, {}]'.format(filename,lineno)
165
+ Logger.logger.warn('{} {}'.format(prefix, message))
166
+
167
+ @staticmethod
168
+ def error(message):
169
+ Logger.check_logger()
170
+ filename = os.path.basename(sys._getframe().f_back.f_code.co_filename)
171
+ lineno = sys._getframe().f_back.f_lineno
172
+ prefix = '[{}, {}]'.format(filename,lineno)
173
+ Logger.logger.error('{} {}'.format(prefix, message))
174
+
175
+ @staticmethod
176
+ def critical(message):
177
+ Logger.check_logger()
178
+ filename = os.path.basename(sys._getframe().f_back.f_code.co_filename)
179
+ lineno = sys._getframe().f_back.f_lineno
180
+ prefix = '[{}, {}]'.format(filename,lineno)
181
+ Logger.logger.critical('{} {}'.format(prefix, message))
182
+
183
+
184
+ if __name__ == "__main__":
185
+ parser = argparse.ArgumentParser()
186
+ parser.add_argument('--logfile_level', default="debug", type=str,
187
+ dest='logfile_level', help='To set the log level to files.')
188
+ parser.add_argument('--stdout_level', default=None, type=str,
189
+ dest='stdout_level', help='To set the level to print to screen.')
190
+ parser.add_argument('--log_file', default="./default.log", type=str,
191
+ dest='log_file', help='The path of log files.')
192
+ parser.add_argument('--log_format', default="%(asctime)s %(levelname)-7s %(message)s",
193
+ type=str, dest='log_format', help='The format of log messages.')
194
+ parser.add_argument('--rewrite', default=False, type=bool,
195
+ dest='rewrite', help='Clear the log files existed.')
196
+
197
+ args = parser.parse_args()
198
+ Logger.init(logfile_level=args.logfile_level, stdout_level=args.stdout_level,
199
+ log_file=args.log_file, log_format=args.log_format, rewrite=args.rewrite)
200
+
201
+ Logger.info("info test.")
202
+ Logger.debug("debug test.")
203
+ Logger.warn("warn test.")
204
+ Logger.error("error test.")
205
+ Logger.debug("debug test.")
isegm/model/modeling/hrformer_helper/hrt/module_helper.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding:utf-8 -*-
3
+ # Author: Donny You ([email protected])
4
+
5
+
6
+ from __future__ import absolute_import
7
+ from __future__ import division
8
+ from __future__ import print_function
9
+
10
+ import functools
11
+ import os
12
+ import pdb
13
+ import math
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+
18
+ try:
19
+ from urllib import urlretrieve
20
+ except ImportError:
21
+ from urllib.request import urlretrieve
22
+
23
+ from .logger import Logger as Log
24
+
25
+
26
+ class ModuleHelper(object):
27
+ @staticmethod
28
+ def BNReLU(num_features, bn_type=None, **kwargs):
29
+ if bn_type == "torchbn":
30
+ return nn.Sequential(nn.BatchNorm2d(num_features, **kwargs), nn.ReLU())
31
+ elif bn_type == "torchsyncbn":
32
+ return nn.Sequential(nn.SyncBatchNorm(num_features, **kwargs), nn.ReLU())
33
+ elif bn_type == "syncbn":
34
+ from lib.extensions.syncbn.module import BatchNorm2d
35
+
36
+ return nn.Sequential(BatchNorm2d(num_features, **kwargs), nn.ReLU())
37
+ elif bn_type == "sn":
38
+ from lib.extensions.switchablenorms.switchable_norm import SwitchNorm2d
39
+
40
+ return nn.Sequential(SwitchNorm2d(num_features, **kwargs), nn.ReLU())
41
+ elif bn_type == "gn":
42
+ return nn.Sequential(
43
+ nn.GroupNorm(num_groups=8, num_channels=num_features, **kwargs),
44
+ nn.ReLU(),
45
+ )
46
+ elif bn_type == "fn":
47
+ Log.error("Not support Filter-Response-Normalization: {}.".format(bn_type))
48
+ exit(1)
49
+ elif bn_type == "inplace_abn":
50
+ torch_ver = torch.__version__[:3]
51
+ # Log.info('Pytorch Version: {}'.format(torch_ver))
52
+ if torch_ver == "0.4":
53
+ from lib.extensions.inplace_abn.bn import InPlaceABNSync
54
+
55
+ return InPlaceABNSync(num_features, **kwargs)
56
+ elif torch_ver in ("1.0", "1.1"):
57
+ from lib.extensions.inplace_abn_1.bn import InPlaceABNSync
58
+
59
+ return InPlaceABNSync(num_features, **kwargs)
60
+ elif torch_ver == "1.2":
61
+ from inplace_abn import InPlaceABNSync
62
+
63
+ return InPlaceABNSync(num_features, **kwargs)
64
+
65
+ else:
66
+ Log.error("Not support BN type: {}.".format(bn_type))
67
+ exit(1)
68
+
69
+ @staticmethod
70
+ def BatchNorm2d(bn_type="torch", ret_cls=False):
71
+ if bn_type == "torchbn":
72
+ return nn.BatchNorm2d
73
+
74
+ elif bn_type == "torchsyncbn":
75
+ return nn.SyncBatchNorm
76
+
77
+ elif bn_type == "syncbn":
78
+ from lib.extensions.syncbn.module import BatchNorm2d
79
+
80
+ return BatchNorm2d
81
+
82
+ elif bn_type == "sn":
83
+ from lib.extensions.switchablenorms.switchable_norm import SwitchNorm2d
84
+
85
+ return SwitchNorm2d
86
+
87
+ elif bn_type == "gn":
88
+ return functools.partial(nn.GroupNorm, num_groups=32)
89
+
90
+ elif bn_type == "inplace_abn":
91
+ torch_ver = torch.__version__[:3]
92
+ if torch_ver == "0.4":
93
+ from lib.extensions.inplace_abn.bn import InPlaceABNSync
94
+
95
+ if ret_cls:
96
+ return InPlaceABNSync
97
+ return functools.partial(InPlaceABNSync, activation="none")
98
+
99
+ elif torch_ver in ("1.0", "1.1"):
100
+ from lib.extensions.inplace_abn_1.bn import InPlaceABNSync
101
+
102
+ if ret_cls:
103
+ return InPlaceABNSync
104
+ return functools.partial(InPlaceABNSync, activation="none")
105
+
106
+ elif torch_ver == "1.2":
107
+ from inplace_abn import InPlaceABNSync
108
+
109
+ if ret_cls:
110
+ return InPlaceABNSync
111
+ return functools.partial(InPlaceABNSync, activation="identity")
112
+
113
+ else:
114
+ Log.error("Not support BN type: {}.".format(bn_type))
115
+ exit(1)
116
+
117
+ @staticmethod
118
+ def load_model(model, pretrained=None, all_match=True, network="resnet101"):
119
+ if pretrained is None:
120
+ return model
121
+
122
+ if all_match:
123
+ Log.info("Loading pretrained model:{}".format(pretrained))
124
+ pretrained_dict = torch.load(pretrained)
125
+ model_dict = model.state_dict()
126
+ load_dict = dict()
127
+ for k, v in pretrained_dict.items():
128
+ if "resinit.{}".format(k) in model_dict:
129
+ load_dict["resinit.{}".format(k)] = v
130
+ else:
131
+ load_dict[k] = v
132
+ model.load_state_dict(load_dict)
133
+
134
+ else:
135
+ Log.info("Loading pretrained model:{}".format(pretrained))
136
+ pretrained_dict = torch.load(pretrained)
137
+
138
+ # settings for "wide_resnet38" or network == "resnet152"
139
+ if network == "wide_resnet":
140
+ pretrained_dict = pretrained_dict["state_dict"]
141
+
142
+ model_dict = model.state_dict()
143
+
144
+ if network == "hrnet_plus":
145
+ # pretrained_dict['conv1_full_res.weight'] = pretrained_dict['conv1.weight']
146
+ # pretrained_dict['conv2_full_res.weight'] = pretrained_dict['conv2.weight']
147
+ load_dict = {
148
+ k: v for k, v in pretrained_dict.items() if k in model_dict.keys()
149
+ }
150
+
151
+ elif network == "hrt_window":
152
+ pretrained_dict = pretrained_dict["model"]
153
+ for name, m in model.named_parameters():
154
+ if "relative_position_bias_table" in name and "embed" not in name:
155
+ target_size = int(math.sqrt(m.shape[0]))
156
+ head_num = m.shape[-1]
157
+ ckpt_size = int(math.sqrt(pretrained_dict[name].shape[0]))
158
+ if target_size != ckpt_size:
159
+ Log.info(
160
+ f"Interpolate from size {pretrained_dict[name ].shape} to {m.shape}."
161
+ )
162
+ reshape_ckpt = (
163
+ pretrained_dict[name]
164
+ .permute(1, 0)
165
+ .reshape(1, head_num, ckpt_size, ckpt_size)
166
+ )
167
+ inter_ckpt = (
168
+ torch.nn.functional.interpolate(
169
+ reshape_ckpt,
170
+ size=(target_size, target_size),
171
+ mode="bilinear",
172
+ )
173
+ .reshape(head_num, -1)
174
+ .permute(1, 0)
175
+ )
176
+ scale = 1
177
+ inter_ckpt *= scale
178
+ pretrained_dict[name] = inter_ckpt
179
+ for name, m in list(pretrained_dict.items()):
180
+ if "relative_position_index" in name:
181
+ Log.info(f"Remove {name}.")
182
+ pretrained_dict.pop(name)
183
+ load_dict = {
184
+ k: v for k, v in pretrained_dict.items() if k in model_dict.keys()
185
+ }
186
+ Log.info(
187
+ "Missing keys: {}".format(list(set(model_dict) - set(load_dict)))
188
+ )
189
+
190
+ elif network == "hrt":
191
+ pretrained_dict = pretrained_dict["model"]
192
+ load_dict = {
193
+ k: v for k, v in pretrained_dict.items() if k in model_dict.keys()
194
+ }
195
+ Log.info(
196
+ "Missing keys: {}".format(list(set(model_dict) - set(load_dict)))
197
+ )
198
+
199
+ elif network == "swin":
200
+ pretrained_dict = pretrained_dict["model"]
201
+ # TODO fix the mis-match between the dict keys and the checkpoint keys.
202
+ pretrained_dict = {
203
+ k.replace(".attn.", ".attn.attn."): v
204
+ for k, v in pretrained_dict.items()
205
+ }
206
+ load_dict = {
207
+ k: v for k, v in pretrained_dict.items() if k in model_dict.keys()
208
+ }
209
+ Log.info(
210
+ "Missing keys: {}".format(list(set(model_dict) - set(load_dict)))
211
+ )
212
+
213
+ elif network == "hrnet" or network == "xception" or network == "resnest":
214
+ load_dict = {
215
+ k: v for k, v in pretrained_dict.items() if k in model_dict.keys()
216
+ }
217
+ Log.info(
218
+ "Missing keys: {}".format(list(set(model_dict) - set(load_dict)))
219
+ )
220
+
221
+ elif network == "dcnet" or network == "resnext":
222
+ load_dict = dict()
223
+ for k, v in pretrained_dict.items():
224
+ if "resinit.{}".format(k) in model_dict:
225
+ load_dict["resinit.{}".format(k)] = v
226
+ else:
227
+ if k in model_dict:
228
+ load_dict[k] = v
229
+ else:
230
+ pass
231
+
232
+ elif network == "wide_resnet":
233
+ load_dict = {
234
+ ".".join(k.split(".")[1:]): v
235
+ for k, v in pretrained_dict.items()
236
+ if ".".join(k.split(".")[1:]) in model_dict
237
+ }
238
+ else:
239
+ load_dict = {
240
+ ".".join(k.split(".")[1:]): v
241
+ for k, v in pretrained_dict.items()
242
+ if ".".join(k.split(".")[1:]) in model_dict
243
+ }
244
+
245
+ # used to debug
246
+ if int(os.environ.get("debug_load_model", 0)):
247
+ Log.info("Matched Keys List:")
248
+ for key in load_dict.keys():
249
+ Log.info("{}".format(key))
250
+ model_dict.update(load_dict)
251
+ model.load_state_dict(model_dict)
252
+
253
+ return model
254
+
255
+ @staticmethod
256
+ def load_url(url, map_location=None):
257
+ model_dir = os.path.join("~", ".PyTorchCV", "models")
258
+ if not os.path.exists(model_dir):
259
+ os.makedirs(model_dir)
260
+
261
+ filename = url.split("/")[-1]
262
+ cached_file = os.path.join(model_dir, filename)
263
+ if not os.path.exists(cached_file):
264
+ Log.info('Downloading: "{}" to {}\n'.format(url, cached_file))
265
+ urlretrieve(url, cached_file)
266
+
267
+ Log.info("Loading pretrained model:{}".format(cached_file))
268
+ return torch.load(cached_file, map_location=map_location)
269
+
270
+ @staticmethod
271
+ def constant_init(module, val, bias=0):
272
+ nn.init.constant_(module.weight, val)
273
+ if hasattr(module, "bias") and module.bias is not None:
274
+ nn.init.constant_(module.bias, bias)
275
+
276
+ @staticmethod
277
+ def xavier_init(module, gain=1, bias=0, distribution="normal"):
278
+ assert distribution in ["uniform", "normal"]
279
+ if distribution == "uniform":
280
+ nn.init.xavier_uniform_(module.weight, gain=gain)
281
+ else:
282
+ nn.init.xavier_normal_(module.weight, gain=gain)
283
+ if hasattr(module, "bias") and module.bias is not None:
284
+ nn.init.constant_(module.bias, bias)
285
+
286
+ @staticmethod
287
+ def normal_init(module, mean=0, std=1, bias=0):
288
+ nn.init.normal_(module.weight, mean, std)
289
+ if hasattr(module, "bias") and module.bias is not None:
290
+ nn.init.constant_(module.bias, bias)
291
+
292
+ @staticmethod
293
+ def uniform_init(module, a=0, b=1, bias=0):
294
+ nn.init.uniform_(module.weight, a, b)
295
+ if hasattr(module, "bias") and module.bias is not None:
296
+ nn.init.constant_(module.bias, bias)
297
+
298
+ @staticmethod
299
+ def kaiming_init(
300
+ module, mode="fan_in", nonlinearity="leaky_relu", bias=0, distribution="normal"
301
+ ):
302
+ assert distribution in ["uniform", "normal"]
303
+ if distribution == "uniform":
304
+ nn.init.kaiming_uniform_(
305
+ module.weight, mode=mode, nonlinearity=nonlinearity
306
+ )
307
+ else:
308
+ nn.init.kaiming_normal_(module.weight, mode=mode, nonlinearity=nonlinearity)
309
+ if hasattr(module, "bias") and module.bias is not None:
310
+ nn.init.constant_(module.bias, bias)
isegm/model/modeling/hrformer_helper/hrt/modules/__init__.py ADDED
File without changes
isegm/model/modeling/hrformer_helper/hrt/modules/bottleneck_block.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pdb
3
+ import logging
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ # from torchvision.models.utils import load_state_dict_from_url
7
+ # from timm.models.registry import register_model
8
+ from functools import partial
9
+
10
+ BN_MOMENTUM = 0.1
11
+
12
+
13
+ class Bottleneck(nn.Module):
14
+ expansion = 4
15
+
16
+ def __init__(
17
+ self,
18
+ inplanes,
19
+ planes,
20
+ stride=1,
21
+ downsample=None,
22
+ mhsa_flag=False,
23
+ num_heads=1,
24
+ num_halo_block=1,
25
+ num_mlp_ratio=4,
26
+ num_sr_ratio=1,
27
+ num_resolution=None,
28
+ with_rpe=False,
29
+ with_ffn=True,
30
+ ):
31
+ super(Bottleneck, self).__init__()
32
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
33
+ self.bn1 = nn.SyncBatchNorm(planes)
34
+ self.conv2 = nn.Conv2d(
35
+ planes, planes, kernel_size=3, stride=stride, padding=1, bias=False
36
+ )
37
+ self.bn2 = nn.SyncBatchNorm(planes)
38
+ self.conv3 = nn.Conv2d(
39
+ planes, planes * self.expansion, kernel_size=1, bias=False
40
+ )
41
+ self.bn3 = nn.SyncBatchNorm(planes * self.expansion)
42
+ self.relu = nn.ReLU(inplace=True)
43
+ self.downsample = downsample
44
+ self.stride = stride
45
+
46
+ def forward(self, x):
47
+ residual = x
48
+
49
+ out = self.conv1(x)
50
+ out = self.bn1(out)
51
+ out = self.relu(out)
52
+
53
+ out = self.conv2(out)
54
+ out = self.bn2(out)
55
+ out = self.relu(out)
56
+
57
+ out = self.conv3(out)
58
+ out = self.bn3(out)
59
+
60
+ if self.downsample is not None:
61
+ residual = self.downsample(x)
62
+
63
+ out += residual
64
+ out = self.relu(out)
65
+
66
+ return out
67
+
68
+
69
+ class BottleneckDWP(nn.Module):
70
+ expansion = 4
71
+
72
+ def __init__(
73
+ self,
74
+ inplanes,
75
+ planes,
76
+ stride=1,
77
+ downsample=None,
78
+ mhsa_flag=False,
79
+ num_heads=1,
80
+ num_halo_block=1,
81
+ num_mlp_ratio=4,
82
+ num_sr_ratio=1,
83
+ num_resolution=None,
84
+ with_rpe=False,
85
+ with_ffn=True,
86
+ ):
87
+ super(BottleneckDWP, self).__init__()
88
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
89
+ self.bn1 = nn.SyncBatchNorm(planes, momentum=BN_MOMENTUM)
90
+ self.conv2 = nn.Conv2d(
91
+ planes,
92
+ planes,
93
+ kernel_size=3,
94
+ stride=stride,
95
+ padding=1,
96
+ bias=False,
97
+ groups=planes,
98
+ )
99
+ self.bn2 = nn.SyncBatchNorm(planes, momentum=BN_MOMENTUM)
100
+ self.conv3 = nn.Conv2d(
101
+ planes, planes * self.expansion, kernel_size=1, bias=False
102
+ )
103
+ self.bn3 = nn.SyncBatchNorm(planes * self.expansion, momentum=BN_MOMENTUM)
104
+ self.relu = nn.ReLU(inplace=True)
105
+ self.downsample = downsample
106
+ self.stride = stride
107
+
108
+ def forward(self, x):
109
+ residual = x
110
+
111
+ out = self.conv1(x)
112
+ out = self.bn1(out)
113
+ out = self.relu(out)
114
+
115
+ out = self.conv2(out)
116
+ out = self.bn2(out)
117
+ out = self.relu(out)
118
+
119
+ out = self.conv3(out)
120
+ out = self.bn3(out)
121
+
122
+ if self.downsample is not None:
123
+ residual = self.downsample(x)
124
+
125
+ out += residual
126
+ out = self.relu(out)
127
+
128
+ return out
isegm/model/modeling/hrformer_helper/hrt/modules/ffn_block.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pdb
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+
6
+ class Mlp(nn.Module):
7
+ def __init__(
8
+ self,
9
+ in_features,
10
+ hidden_features=None,
11
+ out_features=None,
12
+ act_layer=nn.GELU,
13
+ drop=0.0,
14
+ ):
15
+ super().__init__()
16
+ out_features = out_features or in_features
17
+ hidden_features = hidden_features or in_features
18
+ self.fc1 = nn.Linear(in_features, hidden_features)
19
+ self.act = act_layer()
20
+ self.fc2 = nn.Linear(hidden_features, out_features)
21
+ self.drop = nn.Dropout(drop)
22
+
23
+ def forward(self, x, H, W):
24
+ x = self.fc1(x)
25
+ x = self.act(x)
26
+ x = self.drop(x)
27
+ x = self.fc2(x)
28
+ x = self.drop(x)
29
+ return x
30
+
31
+
32
+ class MlpLight(nn.Module):
33
+ def __init__(
34
+ self,
35
+ in_features,
36
+ hidden_features=None,
37
+ out_features=None,
38
+ act_layer=nn.GELU,
39
+ drop=0.0,
40
+ ):
41
+ super().__init__()
42
+ self.fc1 = nn.Linear(in_features, in_features)
43
+ self.act = act_layer()
44
+ self.drop = nn.Dropout(drop)
45
+
46
+ def forward(self, x):
47
+ x = self.fc1(x)
48
+ x = self.act(x)
49
+ x = self.drop(x)
50
+ return x
51
+
52
+
53
+ class MlpDW(nn.Module):
54
+ def __init__(
55
+ self,
56
+ in_features,
57
+ hidden_features=None,
58
+ out_features=None,
59
+ act_layer=nn.GELU,
60
+ dw_act_layer=nn.GELU,
61
+ drop=0.0,
62
+ ):
63
+ super().__init__()
64
+ out_features = out_features or in_features
65
+ hidden_features = hidden_features or in_features
66
+ self.fc1 = nn.Conv2d(in_features, hidden_features, kernel_size=1)
67
+ self.act1 = act_layer()
68
+ self.dw3x3 = nn.Conv2d(
69
+ hidden_features,
70
+ hidden_features,
71
+ kernel_size=3,
72
+ stride=1,
73
+ groups=hidden_features,
74
+ padding=1,
75
+ )
76
+ self.act2 = dw_act_layer()
77
+ self.fc2 = nn.Conv2d(hidden_features, out_features, kernel_size=1)
78
+ self.drop = nn.Dropout(drop)
79
+
80
+ def forward(self, x, H, W):
81
+ B, N, C = x.shape
82
+
83
+ if N == (H * W + 1):
84
+ cls_tokens = x[:, 0, :]
85
+ x_ = x[:, 1:, :].permute(0, 2, 1).reshape(B, C, H, W)
86
+ else:
87
+ x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
88
+
89
+ x_ = self.fc1(x_)
90
+ x_ = self.act1(x_)
91
+ x_ = self.dw3x3(x_)
92
+ x_ = self.act2(x_)
93
+ x_ = self.drop(x_)
94
+ x_ = self.fc2(x_)
95
+ x_ = self.drop(x_)
96
+ x_ = x_.reshape(B, C, -1).permute(0, 2, 1)
97
+
98
+ if N == (H * W + 1):
99
+ x = torch.cat((cls_tokens.unsqueeze(1), x_), dim=1)
100
+ else:
101
+ x = x_
102
+
103
+ return x
104
+
105
+
106
+ class MlpDWBN(nn.Module):
107
+ def __init__(
108
+ self,
109
+ in_features,
110
+ hidden_features=None,
111
+ out_features=None,
112
+ act_layer=nn.GELU,
113
+ dw_act_layer=nn.GELU,
114
+ drop=0.0,
115
+ ):
116
+ super().__init__()
117
+ out_features = out_features or in_features
118
+ hidden_features = hidden_features or in_features
119
+ self.fc1 = nn.Conv2d(in_features, hidden_features, kernel_size=1)
120
+ self.act1 = act_layer()
121
+ self.norm1 = nn.SyncBatchNorm(hidden_features)
122
+ self.dw3x3 = nn.Conv2d(
123
+ hidden_features,
124
+ hidden_features,
125
+ kernel_size=3,
126
+ stride=1,
127
+ groups=hidden_features,
128
+ padding=1,
129
+ )
130
+ self.act2 = dw_act_layer()
131
+ self.norm2 = nn.SyncBatchNorm(hidden_features)
132
+ self.fc2 = nn.Conv2d(hidden_features, out_features, kernel_size=1)
133
+ self.act3 = act_layer()
134
+ self.norm3 = nn.SyncBatchNorm(out_features)
135
+ # self.drop = nn.Dropout(drop, inplace=True)
136
+
137
+ def forward(self, x, H, W):
138
+ if len(x.shape) == 3:
139
+ B, N, C = x.shape
140
+ if N == (H * W + 1):
141
+ cls_tokens = x[:, 0, :]
142
+ x_ = x[:, 1:, :].permute(0, 2, 1).reshape(B, C, H, W)
143
+ else:
144
+ x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
145
+
146
+ x_ = self.fc1(x_)
147
+ x_ = self.norm1(x_)
148
+ x_ = self.act1(x_)
149
+ x_ = self.dw3x3(x_)
150
+ x_ = self.norm2(x_)
151
+ x_ = self.act2(x_)
152
+ # x_ = self.drop(x_)
153
+ x_ = self.fc2(x_)
154
+ x_ = self.norm3(x_)
155
+ x_ = self.act3(x_)
156
+ # x_ = self.drop(x_)
157
+ x_ = x_.reshape(B, C, -1).permute(0, 2, 1)
158
+ if N == (H * W + 1):
159
+ x = torch.cat((cls_tokens.unsqueeze(1), x_), dim=1)
160
+ else:
161
+ x = x_
162
+ return x
163
+
164
+ elif len(x.shape) == 4:
165
+ x = self.fc1(x)
166
+ x = self.norm1(x)
167
+ x = self.act1(x)
168
+ x = self.dw3x3(x)
169
+ x = self.norm2(x)
170
+ x = self.act2(x)
171
+ # x = self.drop(x)
172
+ x = self.fc2(x)
173
+ x = self.norm3(x)
174
+ x = self.act3(x)
175
+ # x = self.drop(x)
176
+ return x
177
+
178
+ else:
179
+ raise RuntimeError("Unsupported input shape: {}".format(x.shape))
180
+
181
+
182
+ class MlpConvBN(nn.Module):
183
+ def __init__(
184
+ self,
185
+ in_features,
186
+ hidden_features=None,
187
+ out_features=None,
188
+ act_layer=nn.GELU,
189
+ drop=0.0,
190
+ ):
191
+ super().__init__()
192
+ out_features = out_features or in_features
193
+ hidden_features = hidden_features or in_features
194
+ self.fc1 = nn.Sequential(
195
+ nn.Conv1d(
196
+ in_channels=in_features,
197
+ out_channels=hidden_features,
198
+ kernel_size=1,
199
+ stride=1,
200
+ padding=0,
201
+ ),
202
+ nn.BatchNorm1d(hidden_features),
203
+ )
204
+ self.act = act_layer()
205
+ self.fc2 = nn.Sequential(
206
+ nn.Conv1d(
207
+ in_channels=hidden_features,
208
+ out_channels=out_features,
209
+ kernel_size=1,
210
+ stride=1,
211
+ padding=0,
212
+ ),
213
+ nn.BatchNorm1d(out_features),
214
+ )
215
+ self.drop = nn.Dropout(drop)
216
+
217
+ def forward(self, x):
218
+ x = x.transpose(1, 2)
219
+ x = self.fc1(x)
220
+ x = self.act(x)
221
+ x = self.drop(x)
222
+ x = self.fc2(x)
223
+ x = x.transpose(1, 2)
224
+ x = self.drop(x)
225
+ return x
226
+
227
+
228
+ class MlpWODWBN(nn.Module):
229
+ def __init__(
230
+ self,
231
+ in_features,
232
+ hidden_features=None,
233
+ out_features=None,
234
+ act_layer=nn.GELU,
235
+ dw_act_layer=nn.GELU,
236
+ drop=0.0,
237
+ ):
238
+ super().__init__()
239
+ out_features = out_features or in_features
240
+ hidden_features = hidden_features or in_features
241
+ self.fc1 = nn.Conv2d(in_features, hidden_features, kernel_size=1)
242
+ self.act1 = act_layer()
243
+ self.norm1 = nn.SyncBatchNorm(hidden_features)
244
+ self.fc2 = nn.Conv2d(hidden_features, out_features, kernel_size=1)
245
+ self.act3 = act_layer()
246
+ self.norm3 = nn.SyncBatchNorm(out_features)
247
+ self.drop = nn.Dropout(drop)
248
+
249
+ def forward(self, x, H, W):
250
+ if len(x.shape) == 3:
251
+ B, N, C = x.shape
252
+ if N == (H * W + 1):
253
+ cls_tokens = x[:, 0, :]
254
+ x_ = x[:, 1:, :].permute(0, 2, 1).reshape(B, C, H, W)
255
+ else:
256
+ x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
257
+
258
+ x_ = self.fc1(x_)
259
+ x_ = self.norm1(x_)
260
+ x_ = self.act1(x_)
261
+ x_ = self.fc2(x_)
262
+ x_ = self.norm3(x_)
263
+ x_ = self.act3(x_)
264
+ x_ = self.drop(x_)
265
+ x_ = x_.reshape(B, C, -1).permute(0, 2, 1)
266
+ if N == (H * W + 1):
267
+ x = torch.cat((cls_tokens.unsqueeze(1), x_), dim=1)
268
+ else:
269
+ x = x_
270
+ return x
271
+
272
+ elif len(x.shape) == 4:
273
+ x = self.fc1(x)
274
+ x = self.norm1(x)
275
+ x = self.act1(x)
276
+ x = self.dw3x3(x)
277
+ x = self.norm2(x)
278
+ x = self.act2(x)
279
+ x = self.drop(x)
280
+ x = self.fc2(x)
281
+ x = self.norm3(x)
282
+ x = self.act3(x)
283
+ x = self.drop(x)
284
+ return x
285
+
286
+ else:
287
+ raise RuntimeError("Unsupported input shape: {}".format(x.shape))
isegm/model/modeling/hrformer_helper/hrt/modules/multihead_attention.py ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import warnings
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from torch import nn, Tensor
7
+ from torch.nn.modules.module import Module
8
+ from torch._jit_internal import Optional, Tuple
9
+ from torch.nn.functional import linear, pad, softmax, dropout
10
+ from torch.overrides import has_torch_function, handle_torch_function
11
+
12
+
13
+
14
+ class MultiheadAttention(Module):
15
+ bias_k: Optional[torch.Tensor]
16
+ bias_v: Optional[torch.Tensor]
17
+
18
+ def __init__(
19
+ self,
20
+ embed_dim,
21
+ num_heads,
22
+ dropout=0.0,
23
+ bias=True,
24
+ add_bias_kv=False,
25
+ add_zero_attn=False,
26
+ kdim=None,
27
+ vdim=None,
28
+ ):
29
+ super(MultiheadAttention, self).__init__()
30
+ self.embed_dim = embed_dim
31
+ self.kdim = kdim if kdim is not None else embed_dim
32
+ self.vdim = vdim if vdim is not None else embed_dim
33
+ self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
34
+
35
+ self.num_heads = num_heads
36
+ self.dropout = dropout
37
+ self.head_dim = embed_dim // num_heads
38
+ assert (
39
+ self.head_dim * num_heads == self.embed_dim
40
+ ), "embed_dim must be divisible by num_heads"
41
+
42
+ self.k_proj = nn.Linear(self.kdim, embed_dim, bias=bias)
43
+ self.v_proj = nn.Linear(self.vdim, embed_dim, bias=bias)
44
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
45
+ self.out_proj = nn.Linear(embed_dim, embed_dim)
46
+
47
+ self.in_proj_bias = None
48
+ self.in_proj_weight = None
49
+ self.bias_k = self.bias_v = None
50
+ self.q_proj_weight = None
51
+ self.k_proj_weight = None
52
+ self.v_proj_weight = None
53
+ self.add_zero_attn = add_zero_attn
54
+
55
+ def __setstate__(self, state):
56
+ # Support loading old MultiheadAttention checkpoints generated by v1.1.0
57
+ if "_qkv_same_embed_dim" not in state:
58
+ state["_qkv_same_embed_dim"] = True
59
+
60
+ super(MultiheadAttention, self).__setstate__(state)
61
+
62
+ def forward(
63
+ self,
64
+ query,
65
+ key,
66
+ value,
67
+ key_padding_mask=None,
68
+ need_weights=False,
69
+ attn_mask=None,
70
+ residual_attn=None,
71
+ ):
72
+ if not self._qkv_same_embed_dim:
73
+ return self.multi_head_attention_forward(
74
+ query,
75
+ key,
76
+ value,
77
+ self.embed_dim,
78
+ self.num_heads,
79
+ self.in_proj_weight,
80
+ self.in_proj_bias,
81
+ self.bias_k,
82
+ self.bias_v,
83
+ self.add_zero_attn,
84
+ self.dropout,
85
+ self.out_proj.weight,
86
+ self.out_proj.bias,
87
+ training=self.training,
88
+ key_padding_mask=key_padding_mask,
89
+ need_weights=need_weights,
90
+ attn_mask=attn_mask,
91
+ use_separate_proj_weight=True,
92
+ q_proj_weight=self.q_proj_weight,
93
+ k_proj_weight=self.k_proj_weight,
94
+ v_proj_weight=self.v_proj_weight,
95
+ out_dim=self.vdim,
96
+ residual_attn=residual_attn,
97
+ )
98
+ else:
99
+ return self.multi_head_attention_forward(
100
+ query,
101
+ key,
102
+ value,
103
+ self.embed_dim,
104
+ self.num_heads,
105
+ self.in_proj_weight,
106
+ self.in_proj_bias,
107
+ self.bias_k,
108
+ self.bias_v,
109
+ self.add_zero_attn,
110
+ self.dropout,
111
+ self.out_proj.weight,
112
+ self.out_proj.bias,
113
+ training=self.training,
114
+ key_padding_mask=key_padding_mask,
115
+ need_weights=need_weights,
116
+ attn_mask=attn_mask,
117
+ out_dim=self.vdim,
118
+ residual_attn=residual_attn,
119
+ )
120
+
121
+ def multi_head_attention_forward(
122
+ self,
123
+ query: Tensor,
124
+ key: Tensor,
125
+ value: Tensor,
126
+ embed_dim_to_check: int,
127
+ num_heads: int,
128
+ in_proj_weight: Tensor,
129
+ in_proj_bias: Tensor,
130
+ bias_k: Optional[Tensor],
131
+ bias_v: Optional[Tensor],
132
+ add_zero_attn: bool,
133
+ dropout_p: float,
134
+ out_proj_weight: Tensor,
135
+ out_proj_bias: Tensor,
136
+ training: bool = True,
137
+ key_padding_mask: Optional[Tensor] = None,
138
+ need_weights: bool = False,
139
+ attn_mask: Optional[Tensor] = None,
140
+ use_separate_proj_weight: bool = False,
141
+ q_proj_weight: Optional[Tensor] = None,
142
+ k_proj_weight: Optional[Tensor] = None,
143
+ v_proj_weight: Optional[Tensor] = None,
144
+ static_k: Optional[Tensor] = None,
145
+ static_v: Optional[Tensor] = None,
146
+ out_dim: Optional[Tensor] = None,
147
+ residual_attn: Optional[Tensor] = None,
148
+ ) -> Tuple[Tensor, Optional[Tensor]]:
149
+ if not torch.jit.is_scripting():
150
+ tens_ops = (
151
+ query,
152
+ key,
153
+ value,
154
+ in_proj_weight,
155
+ in_proj_bias,
156
+ bias_k,
157
+ bias_v,
158
+ out_proj_weight,
159
+ out_proj_bias,
160
+ )
161
+ if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(
162
+ tens_ops
163
+ ):
164
+ return handle_torch_function(
165
+ multi_head_attention_forward,
166
+ tens_ops,
167
+ query,
168
+ key,
169
+ value,
170
+ embed_dim_to_check,
171
+ num_heads,
172
+ in_proj_weight,
173
+ in_proj_bias,
174
+ bias_k,
175
+ bias_v,
176
+ add_zero_attn,
177
+ dropout_p,
178
+ out_proj_weight,
179
+ out_proj_bias,
180
+ training=training,
181
+ key_padding_mask=key_padding_mask,
182
+ need_weights=need_weights,
183
+ attn_mask=attn_mask,
184
+ use_separate_proj_weight=use_separate_proj_weight,
185
+ q_proj_weight=q_proj_weight,
186
+ k_proj_weight=k_proj_weight,
187
+ v_proj_weight=v_proj_weight,
188
+ static_k=static_k,
189
+ static_v=static_v,
190
+ )
191
+ tgt_len, bsz, embed_dim = query.size()
192
+ key = query if key is None else key
193
+ value = query if value is None else value
194
+
195
+ assert embed_dim == embed_dim_to_check
196
+ # allow MHA to have different sizes for the feature dimension
197
+ assert key.size(0) == value.size(0) and key.size(1) == value.size(1)
198
+
199
+ head_dim = embed_dim // num_heads
200
+ v_head_dim = out_dim // num_heads
201
+ assert (
202
+ head_dim * num_heads == embed_dim
203
+ ), "embed_dim must be divisible by num_heads"
204
+ scaling = float(head_dim) ** -0.5
205
+
206
+ q = self.q_proj(query) * scaling
207
+ k = self.k_proj(key)
208
+ v = self.v_proj(value)
209
+
210
+ if attn_mask is not None:
211
+ assert (
212
+ attn_mask.dtype == torch.float32
213
+ or attn_mask.dtype == torch.float64
214
+ or attn_mask.dtype == torch.float16
215
+ or attn_mask.dtype == torch.uint8
216
+ or attn_mask.dtype == torch.bool
217
+ ), "Only float, byte, and bool types are supported for attn_mask, not {}".format(
218
+ attn_mask.dtype
219
+ )
220
+ if attn_mask.dtype == torch.uint8:
221
+ warnings.warn(
222
+ "Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead."
223
+ )
224
+ attn_mask = attn_mask.to(torch.bool)
225
+
226
+ if attn_mask.dim() == 2:
227
+ attn_mask = attn_mask.unsqueeze(0)
228
+ if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
229
+ raise RuntimeError("The size of the 2D attn_mask is not correct.")
230
+ elif attn_mask.dim() == 3:
231
+ if list(attn_mask.size()) != [
232
+ bsz * num_heads,
233
+ query.size(0),
234
+ key.size(0),
235
+ ]:
236
+ raise RuntimeError("The size of the 3D attn_mask is not correct.")
237
+ else:
238
+ raise RuntimeError(
239
+ "attn_mask's dimension {} is not supported".format(attn_mask.dim())
240
+ )
241
+
242
+ # convert ByteTensor key_padding_mask to bool
243
+ if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
244
+ warnings.warn(
245
+ "Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead."
246
+ )
247
+ key_padding_mask = key_padding_mask.to(torch.bool)
248
+
249
+ q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
250
+ if k is not None:
251
+ k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
252
+ if v is not None:
253
+ v = v.contiguous().view(-1, bsz * num_heads, v_head_dim).transpose(0, 1)
254
+
255
+ src_len = k.size(1)
256
+
257
+ if key_padding_mask is not None:
258
+ assert key_padding_mask.size(0) == bsz
259
+ assert key_padding_mask.size(1) == src_len
260
+
261
+ if add_zero_attn:
262
+ src_len += 1
263
+ k = torch.cat(
264
+ [
265
+ k,
266
+ torch.zeros(
267
+ (k.size(0), 1) + k.size()[2:], dtype=k.dtype, device=k.device
268
+ ),
269
+ ],
270
+ dim=1,
271
+ )
272
+ v = torch.cat(
273
+ [
274
+ v,
275
+ torch.zeros(
276
+ (v.size(0), 1) + v.size()[2:], dtype=v.dtype, device=v.device
277
+ ),
278
+ ],
279
+ dim=1,
280
+ )
281
+ if attn_mask is not None:
282
+ attn_mask = pad(attn_mask, (0, 1))
283
+ if key_padding_mask is not None:
284
+ key_padding_mask = pad(key_padding_mask, (0, 1))
285
+
286
+ attn_output_weights = torch.bmm(q, k.transpose(1, 2))
287
+ assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len]
288
+
289
+ """
290
+ Attention weight for the invalid region is -inf
291
+ """
292
+ if attn_mask is not None:
293
+ if attn_mask.dtype == torch.bool:
294
+ attn_output_weights.masked_fill_(attn_mask, float("-inf"))
295
+ else:
296
+ attn_output_weights += attn_mask
297
+
298
+ if key_padding_mask is not None:
299
+ attn_output_weights = attn_output_weights.view(
300
+ bsz, num_heads, tgt_len, src_len
301
+ )
302
+ attn_output_weights = attn_output_weights.masked_fill(
303
+ key_padding_mask.unsqueeze(1).unsqueeze(2),
304
+ float("-inf"),
305
+ )
306
+ attn_output_weights = attn_output_weights.view(
307
+ bsz * num_heads, tgt_len, src_len
308
+ )
309
+
310
+ if residual_attn is not None:
311
+ attn_output_weights = attn_output_weights.view(
312
+ bsz, num_heads, tgt_len, src_len
313
+ )
314
+ attn_output_weights += residual_attn.unsqueeze(0)
315
+ attn_output_weights = attn_output_weights.view(
316
+ bsz * num_heads, tgt_len, src_len
317
+ )
318
+
319
+ """
320
+ Reweight the attention map before softmax().
321
+ attn_output_weights: (b*n_head, n, hw)
322
+ """
323
+ attn_output_weights = softmax(attn_output_weights, dim=-1)
324
+ attn_output_weights = dropout(
325
+ attn_output_weights, p=dropout_p, training=training
326
+ )
327
+
328
+ attn_output = torch.bmm(attn_output_weights, v)
329
+ assert list(attn_output.size()) == [bsz * num_heads, tgt_len, v_head_dim]
330
+ attn_output = (
331
+ attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, out_dim)
332
+ )
333
+ attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
334
+
335
+ if need_weights:
336
+ # average attention weights over heads
337
+ attn_output_weights = attn_output_weights.view(
338
+ bsz, num_heads, tgt_len, src_len
339
+ )
340
+ return attn_output, attn_output_weights.sum(dim=1) / num_heads
341
+ else:
342
+ return attn_output