XiaRho commited on
Commit
8b4c6c7
1 Parent(s): 816ed23
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. app.py +124 -0
  2. configs/SEMat_HQ-SAM.py +48 -0
  3. configs/SEMat_SAM.py +51 -0
  4. configs/SEMat_SAM2.py +57 -0
  5. configs/common/optimizer.py +26 -0
  6. configs/common/scheduler.py +13 -0
  7. configs/common/train.py +17 -0
  8. configs/semantic_enhanced_matting/dataloader.py +62 -0
  9. configs/semantic_enhanced_matting/model.py +35 -0
  10. data/__init__.py +1 -0
  11. data/__pycache__/__init__.cpython-38.pyc +0 -0
  12. data/__pycache__/dim_dataset.cpython-38.pyc +0 -0
  13. data/__pycache__/evaluate.cpython-38.pyc +0 -0
  14. data/__pycache__/rand_augment.cpython-38.pyc +0 -0
  15. data/coconut_dataset.py +377 -0
  16. data/dim_dataset.py +1476 -0
  17. data/evaluate.py +102 -0
  18. data/p3m10k_dataset.py +325 -0
  19. data/rand_augment.py +196 -0
  20. data/refmatte_dataset.py +418 -0
  21. engine/__init__.py +1 -0
  22. engine/hooks.py +52 -0
  23. engine/mattingtrainer.py +171 -0
  24. modeling/__init__.py +5 -0
  25. modeling/__pycache__/__init__.cpython-38.pyc +0 -0
  26. modeling/backbone/__init__.py +2 -0
  27. modeling/backbone/__pycache__/__init__.cpython-38.pyc +0 -0
  28. modeling/backbone/__pycache__/backbone.cpython-38.pyc +0 -0
  29. modeling/backbone/__pycache__/utils.cpython-38.pyc +0 -0
  30. modeling/backbone/__pycache__/vit.cpython-38.pyc +0 -0
  31. modeling/backbone/backbone.py +74 -0
  32. modeling/backbone/utils.py +186 -0
  33. modeling/backbone/vit.py +404 -0
  34. modeling/criterion/__init__.py +1 -0
  35. modeling/criterion/__pycache__/__init__.cpython-38.pyc +0 -0
  36. modeling/criterion/__pycache__/matting_criterion.cpython-38.pyc +0 -0
  37. modeling/criterion/matting_criterion.py +271 -0
  38. modeling/decoder/__init__.py +1 -0
  39. modeling/decoder/__pycache__/__init__.cpython-38.pyc +0 -0
  40. modeling/decoder/__pycache__/detail_capture.cpython-38.pyc +0 -0
  41. modeling/decoder/__pycache__/unet_detail_capture.cpython-38.pyc +0 -0
  42. modeling/decoder/detail_capture.py +185 -0
  43. modeling/decoder/unet_detail_capture.py +429 -0
  44. modeling/meta_arch/__init__.py +1 -0
  45. modeling/meta_arch/__pycache__/__init__.cpython-38.pyc +0 -0
  46. modeling/meta_arch/__pycache__/sam_hq_matting.cpython-38.pyc +0 -0
  47. modeling/meta_arch/sam_hq_matting.py +671 -0
  48. modeling/semantic_enhanced_matting/__init__.py +17 -0
  49. modeling/semantic_enhanced_matting/__pycache__/__init__.cpython-38.pyc +0 -0
  50. modeling/semantic_enhanced_matting/__pycache__/automatic_mask_generator.cpython-38.pyc +0 -0
app.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from gradio_image_prompter import ImagePrompter
3
+ from detectron2.config import LazyConfig, instantiate
4
+ from detectron2.checkpoint import DetectionCheckpointer
5
+ import cv2
6
+ import numpy as np
7
+ import torch
8
+ from huggingface_hub import hf_hub_download
9
+
10
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
11
+ model_choice = {
12
+ 'SAM': None,
13
+ 'HQ-SAM': None,
14
+ 'SAM2': None
15
+ }
16
+
17
+ for model_type in model_choice.keys():
18
+ model_choice[model_type] = hf_hub_download(repo_id="XiaRho/SEMat", filename=f"SEMat_{model_type}.pth", repo_type="model")
19
+
20
+ def load_model(model_type='SAM2'):
21
+ assert model_type in model_choice.keys()
22
+ config_path = './configs/SEMat_{}.py'.format(model_type)
23
+ cfg = LazyConfig.load(config_path)
24
+
25
+ if hasattr(cfg.model.sam_model, 'ckpt_path'):
26
+ cfg.model.sam_model.ckpt_path = None
27
+ else:
28
+ cfg.model.sam_model.checkpoint = None
29
+ model = instantiate(cfg.model)
30
+ if model.lora_rank is not None:
31
+ model.init_lora()
32
+ model.to(DEVICE)
33
+ DetectionCheckpointer(model).load(model_choice[model_type])
34
+ model.eval()
35
+ return model, model_type
36
+
37
+ def transform_image_bbox(prompts):
38
+ if len(prompts["points"]) != 1:
39
+ raise gr.Error("Please input only one BBox.", duration=5)
40
+ [[x1, y1, idx_3, x2, y2, idx_6]] = prompts["points"]
41
+ if idx_3 != 2 or idx_6 != 3:
42
+ raise gr.Error("Please input BBox instead of point.", duration=5)
43
+ x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
44
+
45
+ img = prompts["image"]
46
+ ori_H, ori_W, _ = img.shape
47
+
48
+ scale = 1024 * 1.0 / max(ori_H, ori_W)
49
+ new_H, new_W = ori_H * scale, ori_W * scale
50
+ new_W = int(new_W + 0.5)
51
+ new_H = int(new_H + 0.5)
52
+
53
+ img = cv2.resize(img, (new_W, new_H), interpolation=cv2.INTER_LINEAR)
54
+ padding = np.zeros([1024, 1024, 3], dtype=img.dtype)
55
+ padding[: new_H, : new_W, :] = img
56
+ img = padding
57
+ # img = img[:, :, ::-1].transpose((2, 0, 1)).astype(np.float32) / 255.0
58
+ img = img.transpose((2, 0, 1)).astype(np.float32) / 255.0
59
+
60
+ [[x1, y1, _, x2, y2, _]] = prompts["points"]
61
+ x1, y1, x2, y2 = int(x1 * scale + 0.5), int(y1 * scale + 0.5), int(x2 * scale + 0.5), int(y2 * scale + 0.5)
62
+ bbox = np.clip(np.array([[x1, y1, x2, y2]]) * 1.0, 0, 1023.0)
63
+
64
+ return img, bbox, (ori_H, ori_W), (new_H, new_W)
65
+
66
+ if __name__ == '__main__':
67
+
68
+ model, model_type = load_model()
69
+
70
+ def inference_image(prompts, input_model_type):
71
+
72
+ global model_type
73
+ global model
74
+
75
+ if input_model_type != model_type:
76
+ gr.Info('Loading SEMat of {} version.'.format(input_model_type), duration=5)
77
+ _model, _ = load_model(input_model_type)
78
+ model_type = input_model_type
79
+ model = _model
80
+
81
+ image, bbox, ori_H_W, pad_H_W = transform_image_bbox(prompts)
82
+ input_data = {
83
+ 'image': torch.from_numpy(image)[None].to(model.device),
84
+ 'bbox': torch.from_numpy(bbox)[None].to(model.device),
85
+ }
86
+
87
+ with torch.no_grad():
88
+ inputs = model.preprocess_inputs(input_data)
89
+ images, bbox, gt_alpha, trimap, condition = inputs['images'], inputs['bbox'], inputs['alpha'], inputs['trimap'], inputs['condition']
90
+
91
+ if model.backbone_condition:
92
+ condition_proj = model.condition_embedding(condition)
93
+ elif model.backbone_bbox_prompt is not None or model.bbox_prompt_all_block is not None:
94
+ condition_proj = bbox
95
+ else:
96
+ condition_proj = None
97
+
98
+ low_res_masks, pred_alphas, pred_trimap, sam_hq_matting_token = model.forward_samhq_and_matting_decoder(images, bbox, condition_proj)
99
+
100
+
101
+ output_alpha = np.uint8(pred_alphas[0, 0][:pad_H_W[0], :pad_H_W[1], None].repeat(1, 1, 3).cpu().numpy() * 255)
102
+
103
+ return output_alpha
104
+
105
+ with gr.Blocks() as demo:
106
+
107
+ with gr.Row():
108
+ with gr.Column(scale=45):
109
+ img_in = ImagePrompter(type='numpy', show_label=False, label="query image")
110
+
111
+ with gr.Column(scale=45):
112
+ img_out = gr.Image(type='pil', label="output")
113
+
114
+ with gr.Row():
115
+ with gr.Column(scale=45):
116
+ input_model_type = gr.Dropdown(list(model_choice.keys()), value='SAM2', label="Trained SEMat Version")
117
+
118
+ with gr.Column(scale=45):
119
+ bt = gr.Button()
120
+
121
+ bt.click(inference_image, inputs=[img_in, input_model_type], outputs=[img_out])
122
+
123
+ demo.launch()
124
+
configs/SEMat_HQ-SAM.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .common.train import train
2
+ from .semantic_enhanced_matting.model import model
3
+ from .common.optimizer import optimizer
4
+ from .common.scheduler import lr_multiplier
5
+ from .semantic_enhanced_matting.dataloader import dataloader
6
+ from modeling.decoder.unet_detail_capture import MattingDetailDecoder
7
+ from detectron2.config import LazyCall as L
8
+
9
+ model.sam_model.model_type = 'vit_l'
10
+ model.sam_model.checkpoint = None
11
+ model.vis_period = 200
12
+ model.output_dir = '?'
13
+
14
+ train.max_iter = 60000
15
+ train.eval_period = int(train.max_iter * 1 / 10)
16
+ train.checkpointer.period = int(train.max_iter * 1 / 10)
17
+ train.checkpointer.max_to_keep = 1
18
+
19
+ optimizer.lr = 5e-5
20
+
21
+ lr_multiplier.scheduler.values = [1.0, 0.5, 0.2]
22
+ lr_multiplier.scheduler.milestones = [0.5, 0.75]
23
+ lr_multiplier.scheduler.num_updates = train.max_iter
24
+ lr_multiplier.warmup_length = 250 / train.max_iter
25
+
26
+ train.output_dir = './work_dirs/SEMat_HQ-SAM'
27
+
28
+ model.lora_rank = 16
29
+ model.lora_alpha = 16
30
+ model.matting_decoder = L(MattingDetailDecoder)(
31
+ vit_intern_feat_in = 1024,
32
+ vit_intern_feat_index = [0, 1, 2, 3],
33
+ norm_type = 'SyncBN',
34
+ block_num = 2,
35
+ img_feat_in = 6,
36
+ norm_mask_logits = 6.5
37
+ )
38
+ model.backbone_bbox_prompt = 'bbox'
39
+ model.backbone_bbox_prompt_loc = [2, 3]
40
+ model.backbone_bbox_prompt_loss_weight = 1.0
41
+ model.matting_token = True
42
+ model.sam_model.matting_token = 3
43
+ model.sam_model.frozen_decoder = True
44
+ model.sam_hq_token_reg = 0.2
45
+ model.reg_w_bce_loss = True
46
+ model.matting_token_sup = 'trimap'
47
+ model.matting_token_sup_loss_weight = 0.05
48
+ model.trimap_loss_type = 'NGHM'
configs/SEMat_SAM.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .common.train import train
2
+ from .semantic_enhanced_matting.model import model
3
+ from .common.optimizer import optimizer
4
+ from .common.scheduler import lr_multiplier
5
+ from .semantic_enhanced_matting.dataloader import dataloader
6
+ from modeling.decoder.unet_detail_capture import MattingDetailDecoder
7
+ from detectron2.config import LazyCall as L
8
+
9
+ model.sam_model.model_type = 'vit_l'
10
+ model.sam_model.checkpoint = None
11
+ model.vis_period = 200
12
+ model.output_dir = '?'
13
+
14
+ train.max_iter = 60000
15
+ train.eval_period = int(train.max_iter * 1 / 10)
16
+ train.checkpointer.period = int(train.max_iter * 1 / 10)
17
+ train.checkpointer.max_to_keep = 1
18
+
19
+ optimizer.lr = 5e-5
20
+
21
+ lr_multiplier.scheduler.values = [1.0, 0.5, 0.2]
22
+ lr_multiplier.scheduler.milestones = [0.5, 0.75]
23
+ lr_multiplier.scheduler.num_updates = train.max_iter
24
+ lr_multiplier.warmup_length = 250 / train.max_iter
25
+
26
+ train.output_dir = './work_dirs/SEMat_SAM'
27
+
28
+ model.lora_rank = 16
29
+ model.lora_alpha = 16
30
+ model.matting_decoder = L(MattingDetailDecoder)(
31
+ vit_intern_feat_in = 1024,
32
+ vit_intern_feat_index = [0, 1, 2, 3],
33
+ norm_type = 'SyncBN',
34
+ block_num = 2,
35
+ img_feat_in = 6,
36
+ norm_mask_logits = 6.5
37
+ )
38
+ model.backbone_bbox_prompt = 'bbox'
39
+ model.backbone_bbox_prompt_loc = [2, 3]
40
+ model.backbone_bbox_prompt_loss_weight = 1.0
41
+ model.matting_token = True
42
+ model.sam_model.matting_token = 3
43
+ model.sam_model.frozen_decoder = True
44
+ model.sam_hq_token_reg = 0.2
45
+ model.reg_on_sam_logits = True
46
+ model.reg_w_bce_loss = True
47
+ model.matting_token_sup = 'trimap'
48
+ model.matting_token_sup_loss_weight = 0.05
49
+ model.trimap_loss_type = 'NGHM'
50
+ model.sam_model.wo_hq = True
51
+ model.sam_model.mask_matting_res_add = False
configs/SEMat_SAM2.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .common.train import train
2
+ from .semantic_enhanced_matting.model import model
3
+ from .common.optimizer import optimizer
4
+ from .common.scheduler import lr_multiplier
5
+ from .semantic_enhanced_matting.dataloader import dataloader
6
+ from modeling.decoder.unet_detail_capture import MattingDetailDecoder
7
+ from detectron2.config import LazyCall as L
8
+ from sam2.build_sam import build_sam2
9
+
10
+ model.sam_model.model_type = 'vit_l'
11
+ model.sam_model.checkpoint = None
12
+ model.vis_period = 200
13
+ model.output_dir = '?'
14
+
15
+ train.max_iter = 60000
16
+ train.eval_period = int(train.max_iter * 1 / 10)
17
+ train.checkpointer.period = int(train.max_iter * 1 / 10)
18
+ train.checkpointer.max_to_keep = 1
19
+
20
+ optimizer.lr = 5e-5
21
+
22
+ lr_multiplier.scheduler.values = [1.0, 0.5, 0.2]
23
+ lr_multiplier.scheduler.milestones = [0.5, 0.75]
24
+ lr_multiplier.scheduler.num_updates = train.max_iter
25
+ lr_multiplier.warmup_length = 250 / train.max_iter
26
+
27
+ train.output_dir = './work_dirs/SEMat_SAM2'
28
+
29
+ model.sam2 = True
30
+ model.sam_model = L(build_sam2)(
31
+ config_file = 'sam2_hiera_l.yaml',
32
+ ckpt_path = None,
33
+ device = "cuda",
34
+ bbox_mask_matting_token = True,
35
+ mode="train",
36
+ upscaled_embedding_res_add = False
37
+ )
38
+ model.lora_rank = 16
39
+ model.lora_alpha = 16
40
+ model.matting_decoder = L(MattingDetailDecoder)(
41
+ vit_intern_feat_in = 1024,
42
+ vit_intern_feat_index = [0, 1, 2, 3],
43
+ norm_type = 'SyncBN',
44
+ block_num = 2,
45
+ img_feat_in = 6,
46
+ norm_mask_logits = 6.5,
47
+ sam2_multi_scale_feates = True
48
+ )
49
+ model.backbone_bbox_prompt = 'bbox'
50
+ model.backbone_bbox_prompt_loc = [2, 3]
51
+ model.backbone_bbox_prompt_loss_weight = 1.0
52
+ model.matting_token = True
53
+ model.sam_hq_token_reg = 0.2
54
+ model.reg_w_bce_loss = True
55
+ model.matting_token_sup = 'trimap'
56
+ model.matting_token_sup_loss_weight = 0.05
57
+ model.trimap_loss_type = 'NGHM'
configs/common/optimizer.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from detectron2 import model_zoo
2
+ from functools import partial
3
+
4
+ def get_vit_lr_decay_rate(name, lr_decay_rate=1.0, num_layers=12):
5
+ """
6
+ Calculate lr decay rate for different ViT blocks.
7
+ Args:
8
+ name (string): parameter name.
9
+ lr_decay_rate (float): base lr decay rate.
10
+ num_layers (int): number of ViT blocks.
11
+
12
+ Returns:
13
+ lr decay rate for the given parameter.
14
+ """
15
+ layer_id = num_layers + 1
16
+ if name.startswith("backbone"):
17
+ if ".pos_embed" in name or ".patch_embed" in name:
18
+ layer_id = 0
19
+ elif ".blocks." in name and ".residual." not in name:
20
+ layer_id = int(name[name.find(".blocks.") :].split(".")[2]) + 1
21
+ return lr_decay_rate ** (num_layers + 1 - layer_id)
22
+
23
+ # Optimizer
24
+ optimizer = model_zoo.get_config("common/optim.py").AdamW
25
+ optimizer.params.lr_factor_func = partial(get_vit_lr_decay_rate, num_layers=12, lr_decay_rate=0.65)
26
+ optimizer.params.overrides = {"pos_embed": {"weight_decay": 0.0}}
configs/common/scheduler.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from detectron2.config import LazyCall as L
2
+ from detectron2.solver import WarmupParamScheduler
3
+ from fvcore.common.param_scheduler import MultiStepParamScheduler
4
+
5
+ lr_multiplier = L(WarmupParamScheduler)(
6
+ scheduler=L(MultiStepParamScheduler)(
7
+ values=[1.0, 0.1, 0.01],
8
+ milestones=[96778, 103579],
9
+ num_updates=100,
10
+ ),
11
+ warmup_length=250 / 100,
12
+ warmup_factor=0.001,
13
+ )
configs/common/train.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ train = dict(
2
+ output_dir="./output",
3
+ init_checkpoint="",
4
+ max_iter=90000,
5
+ amp=dict(enabled=False), # options for Automatic Mixed Precision
6
+ ddp=dict( # options for DistributedDataParallel
7
+ broadcast_buffers=True,
8
+ find_unused_parameters=False,
9
+ fp16_compression=True,
10
+ ),
11
+ checkpointer=dict(period=5000, max_to_keep=100), # options for PeriodicCheckpointer
12
+ eval_period=5000,
13
+ log_period=20,
14
+ device="cuda",
15
+ seed=42
16
+ # ...
17
+ )
configs/semantic_enhanced_matting/dataloader.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from omegaconf import OmegaConf
2
+ from torch.utils.data import ConcatDataset
3
+ from detectron2.config import LazyCall as L
4
+
5
+ from data.dim_dataset import build_d2_test_dataloader, AdobeCompositionEvaluator, adobe_composition_collate_fn, RW100Test, AIM500Test, AM2KTest, P3M500Test, RWP636Test, SIMTest
6
+
7
+ AIM500_PATH = '/path/to/datasets/AIM-500'
8
+ RW100_PATH = '/path/to/datasets/RefMatte_RW_100'
9
+ AM2K_PATH = '/path/to/datasets/AM-2K'
10
+ P3M500_PATH = '/path/to/datasets/P3M-10k/validation/P3M-500-NP'
11
+ RWP636_PATH = '/path/to/datasets/RealWorldPortrait-636'
12
+ SIM_PATH = '/path/to/datasets/SIMD/generated_testset'
13
+
14
+ dataloader = OmegaConf.create()
15
+ test_dataset = L(ConcatDataset)(
16
+ datasets = [
17
+ L(AIM500Test)(
18
+ data_dir = AIM500_PATH,
19
+ target_size = 1024,
20
+ multi_fg = True,
21
+ ),
22
+ L(RW100Test)(
23
+ data_dir = RW100_PATH,
24
+ target_size = 1024,
25
+ multi_fg = True,
26
+ ),
27
+ L(AM2KTest)(
28
+ data_dir = AM2K_PATH,
29
+ target_size = 1024,
30
+ multi_fg = True,
31
+ ),
32
+ L(P3M500Test)(
33
+ data_dir = P3M500_PATH,
34
+ target_size = 1024,
35
+ multi_fg = True,
36
+ ),
37
+ L(RWP636Test)(
38
+ data_dir = RWP636_PATH,
39
+ target_size = 1024,
40
+ multi_fg = True
41
+ ),
42
+ L(SIMTest)(
43
+ data_dir = SIM_PATH,
44
+ target_size = 1024,
45
+ multi_fg = True
46
+ )
47
+ ]
48
+ )
49
+
50
+ dataloader.test = L(build_d2_test_dataloader)(
51
+ dataset = test_dataset,
52
+ local_batch_size = 1,
53
+ num_workers = 4,
54
+ collate_fn = adobe_composition_collate_fn
55
+ )
56
+
57
+ dataloader.evaluator = L(AdobeCompositionEvaluator)(
58
+ save_eval_results_step = 10,
59
+ output_dir = None, # modify in EvalHook (do_test)
60
+ eval_dataset_type = ['RW100', 'AIM500', 'AM2K', 'P3M500', 'RWP636', 'SIM'],
61
+ distributed = True,
62
+ ),
configs/semantic_enhanced_matting/model.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from detectron2.config import LazyCall as L
2
+
3
+ from modeling import Detail_Capture, MattingCriterion
4
+ from modeling.meta_arch import SamHqMatte
5
+ from modeling.semantic_enhanced_matting.build_sam import sam_model_registry_def
6
+ # from modeling.sam_hq_matting.predictor import SamPredictor
7
+ from modeling.semantic_enhanced_matting import MaskDecoderMatting
8
+
9
+ mask_token_only = False
10
+
11
+ model = L(SamHqMatte)(
12
+
13
+ # original sam_hq
14
+ sam_model = L(sam_model_registry_def)(
15
+ model_type = 'vit_b',
16
+ checkpoint = None,
17
+ ),
18
+ hq_token_only = True,
19
+ hq_features_type = 'Final',
20
+ multimask_output = True,
21
+
22
+ # loss function
23
+ criterion=L(MattingCriterion)(
24
+ losses = ['unknown_l1_loss', 'known_l1_loss', 'loss_pha_laplacian', 'loss_gradient_penalty']
25
+ ),
26
+
27
+ # other params.
28
+ pixel_mean = [123.675 / 255., 116.280 / 255., 103.530 / 255.],
29
+ pixel_std = [58.395 / 255., 57.120 / 255., 57.375 / 255.],
30
+
31
+ lora_rank = None,
32
+ lora_alpha = None,
33
+ w_dora = False,
34
+ w_rslora = False,
35
+ )
data/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .dim_dataset import *
data/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (173 Bytes). View file
 
data/__pycache__/dim_dataset.cpython-38.pyc ADDED
Binary file (42.1 kB). View file
 
data/__pycache__/evaluate.cpython-38.pyc ADDED
Binary file (3.17 kB). View file
 
data/__pycache__/rand_augment.cpython-38.pyc ADDED
Binary file (4.75 kB). View file
 
data/coconut_dataset.py ADDED
@@ -0,0 +1,377 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import json
4
+ import torch
5
+ import numpy as np
6
+ import cv2
7
+ from torch.utils.data import Dataset, DistributedSampler, Sampler
8
+ from torchvision import transforms
9
+ from detectron2.utils.logger import setup_logger
10
+ from typing import Optional
11
+ from operator import itemgetter
12
+ from collections import defaultdict
13
+
14
+ from data.dim_dataset import GenBBox
15
+
16
+
17
+ def random_interp():
18
+ return np.random.choice([cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_LANCZOS4])
19
+
20
+
21
+ class SplitConcatImage(object):
22
+
23
+ def __init__(self, concat_num=4, wo_mask_to_mattes=False):
24
+ self.concat_num = concat_num
25
+ self.wo_mask_to_mattes = wo_mask_to_mattes
26
+ if self.wo_mask_to_mattes:
27
+ assert self.concat_num == 5
28
+
29
+ def __call__(self, concat_image):
30
+ if isinstance(concat_image, list):
31
+ concat_image, image_path = concat_image[0], concat_image[1]
32
+ else:
33
+ image_path = None
34
+ H, W, _ = concat_image.shape
35
+
36
+ concat_num = self.concat_num
37
+ if image_path is not None:
38
+ if '06-14' in image_path:
39
+ concat_num = 4
40
+ elif 'ori_mask' in image_path or 'SEMat' in image_path:
41
+ concat_num = 3
42
+ else:
43
+ concat_num = 5
44
+
45
+ assert W % concat_num == 0
46
+ W = W // concat_num
47
+
48
+ image = concat_image[:H, :W]
49
+ if self.concat_num != 3:
50
+ trimap = concat_image[:H, (concat_num - 2) * W: (concat_num - 1) * W]
51
+ if self.wo_mask_to_mattes:
52
+ alpha = concat_image[:H, 2 * W: 3 * W]
53
+ else:
54
+ alpha = concat_image[:H, (concat_num - 1) * W: concat_num * W]
55
+ else:
56
+ trimap = concat_image[:H, (concat_num - 1) * W: concat_num * W]
57
+ alpha = concat_image[:H, (concat_num - 2) * W: (concat_num - 1) * W]
58
+
59
+ return {'image': image, 'trimap': trimap, 'alpha': alpha}
60
+
61
+
62
+ class RandomHorizontalFlip(object):
63
+
64
+ def __init__(self, prob=0.5):
65
+ self.prob = prob
66
+
67
+ def __call__(self, sample):
68
+ if np.random.uniform(0, 1) < self.prob:
69
+ for key in sample.keys():
70
+ sample[key] = cv2.flip(sample[key], 1)
71
+ return sample
72
+
73
+ class EmptyAug(object):
74
+ def __call__(self, sample):
75
+ return sample
76
+
77
+ class RandomReszieCrop(object):
78
+
79
+ def __init__(self, output_size=1024, aug_scale_min=0.5, aug_scale_max=1.5):
80
+ self.desired_size = output_size
81
+ self.aug_scale_min = aug_scale_min
82
+ self.aug_scale_max = aug_scale_max
83
+
84
+ def __call__(self, sample):
85
+ H, W, _ = sample['image'].shape
86
+
87
+ if self.aug_scale_min == 1.0 and self.aug_scale_max == 1.0:
88
+ crop_H, crop_W = H, W
89
+ crop_y1, crop_y2 = 0, crop_H
90
+ crop_x1, crop_x2 = 0, crop_W
91
+ scale_W, scaled_H = W, H
92
+ elif self.aug_scale_min == -1.0 and self.aug_scale_max == -1.0:
93
+ scale = min(self.desired_size / H, self.desired_size / W)
94
+ scaled_H, scale_W = round(H * scale), round(W * scale)
95
+ crop_H, crop_W = scaled_H, scale_W
96
+ crop_y1, crop_y2 = 0, crop_H
97
+ crop_x1, crop_x2 = 0, crop_W
98
+ else:
99
+ # random size
100
+ random_scale = np.random.uniform(0, 1) * (self.aug_scale_max - self.aug_scale_min) + self.aug_scale_min # random_val: 0.5 ~ 1.5
101
+ scaled_size = round(random_scale * self.desired_size)
102
+
103
+ scale = min(scaled_size / H, scaled_size / W)
104
+ scaled_H, scale_W = round(H * scale), round(W * scale)
105
+
106
+ # random crop
107
+ crop_H, crop_W = min(self.desired_size, scaled_H), min(self.desired_size, scale_W) # crop_size
108
+ margin_H, margin_W = max(scaled_H - crop_H, 0), max(scale_W - crop_W, 0)
109
+ offset_H, offset_W = np.random.randint(0, margin_H + 1), np.random.randint(0, margin_W + 1)
110
+ crop_y1, crop_y2 = offset_H, offset_H + crop_H
111
+ crop_x1, crop_x2 = offset_W, offset_W + crop_W
112
+
113
+ for key in sample.keys():
114
+ sample[key] = cv2.resize(sample[key], (scale_W, scaled_H), interpolation=random_interp())[crop_y1: crop_y2, crop_x1: crop_x2, :] # resize and crop
115
+ padding = np.zeros(shape=(self.desired_size, self.desired_size, 3), dtype=sample[key].dtype) # pad to desired_size
116
+ padding[: crop_H, : crop_W, :] = sample[key]
117
+ sample[key] = padding
118
+
119
+ return sample
120
+
121
+
122
+ class RandomJitter(object):
123
+ """
124
+ Random change the hue of the image
125
+ """
126
+
127
+ def __call__(self, sample):
128
+
129
+ image = sample['image']
130
+
131
+ # convert to HSV space, convert to float32 image to keep precision during space conversion.
132
+ image = cv2.cvtColor(image.astype(np.float32)/255.0, cv2.COLOR_BGR2HSV)
133
+ # Hue noise
134
+ hue_jitter = np.random.randint(-40, 40)
135
+ image[:, :, 0] = np.remainder(image[:, :, 0].astype(np.float32) + hue_jitter, 360)
136
+ # Saturation noise
137
+ sat_bar = image[:, :, 1].mean()
138
+
139
+ sat_jitter = np.random.rand()*(1.1 - sat_bar)/5 - (1.1 - sat_bar) / 10
140
+ sat = image[:, :, 1]
141
+ sat = np.abs(sat + sat_jitter)
142
+ sat[sat>1] = 2 - sat[sat>1]
143
+ image[:, :, 1] = sat
144
+ # Value noise
145
+ val_bar = image[:, :, 2].mean()
146
+
147
+ val_jitter = np.random.rand()*(1.1 - val_bar)/5-(1.1 - val_bar) / 10
148
+ val = image[:, :, 2]
149
+ val = np.abs(val + val_jitter)
150
+ val[val>1] = 2 - val[val>1]
151
+ image[:, :, 2] = val
152
+ # convert back to BGR space
153
+ image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR)
154
+ sample['image'] = image * 255
155
+
156
+ return sample
157
+
158
+
159
+ class ToTensor(object):
160
+
161
+ def __call__(self, sample):
162
+ image, alpha, trimap = sample['image'][:, :, ::-1], sample['alpha'], sample['trimap']
163
+
164
+ # image
165
+ image = image.transpose((2, 0, 1)) / 255.
166
+ sample['image'] = torch.from_numpy(image).float()
167
+
168
+ # alpha
169
+ alpha = alpha.transpose((2, 0, 1))[0: 1] / 255.
170
+ alpha[alpha < 0 ] = 0
171
+ alpha[alpha > 1] = 1
172
+ sample['alpha'] = torch.from_numpy(alpha).float()
173
+
174
+ # trimap
175
+ trimap = trimap.transpose((2, 0, 1))[0: 1] / 1.
176
+ sample['trimap'] = torch.from_numpy(trimap).float()
177
+ sample['trimap'][sample['trimap'] < 85] = 0
178
+ sample['trimap'][sample['trimap'] >= 170] = 1
179
+ sample['trimap'][sample['trimap'] >= 85] = 0.5
180
+
181
+ return sample
182
+
183
+
184
+ class COCONutData(Dataset):
185
+ def __init__(
186
+ self,
187
+ json_path,
188
+ data_root_path,
189
+ output_size = 512,
190
+ aug_scale_min = 0.5,
191
+ aug_scale_max = 1.5,
192
+ with_bbox = False,
193
+ bbox_offset_factor = None,
194
+ phase = "train",
195
+ min_miou = 95,
196
+ miou_json = '',
197
+ remove_coco_transparent = False,
198
+ coconut_num_ratio = None,
199
+ return_multi_fg_info = False,
200
+ wo_accessory_fusion = False,
201
+ wo_mask_to_mattes = False,
202
+ return_image_name = False,
203
+ ):
204
+
205
+ self.data_root_path = data_root_path
206
+ self.output_size = output_size
207
+ self.aug_scale_min = aug_scale_min
208
+ self.aug_scale_max = aug_scale_max
209
+ self.with_bbox = with_bbox
210
+ self.bbox_offset_factor = bbox_offset_factor
211
+ self.phase = phase
212
+ self.min_miou = min_miou
213
+ self.miou_json = miou_json
214
+ self.remove_coco_transparent = remove_coco_transparent
215
+ self.coconut_num_ratio = coconut_num_ratio
216
+ self.return_multi_fg_info = return_multi_fg_info
217
+ self.wo_accessory_fusion = wo_accessory_fusion # TODO
218
+ self.wo_mask_to_mattes = wo_mask_to_mattes
219
+ self.return_image_name = return_image_name
220
+ assert self.wo_accessory_fusion + self.wo_mask_to_mattes <= 1
221
+ assert self.phase == 'train'
222
+
223
+ self.data_path = []
224
+ with open(json_path, "r") as file:
225
+ coconut_matting_info = json.load(file)
226
+
227
+ if self.miou_json != '':
228
+ name_2_miou_dict = defaultdict(int)
229
+ with open(self.miou_json, "r") as file:
230
+ coconut_matting_miou = json.load(file)
231
+ for miou, name in coconut_matting_miou:
232
+ name_2_miou_dict[name] = miou
233
+ for i in coconut_matting_info:
234
+ if 'accessory' in i['save_path']:
235
+ self.data_path.append(i['save_path'])
236
+ elif name_2_miou_dict[i['save_path'].split('/')[-1]] >= self.min_miou:
237
+ if not (self.remove_coco_transparent and 'glass' in i['save_path']):
238
+ self.data_path.append(i['save_path'])
239
+ else:
240
+ for i in coconut_matting_info:
241
+ self.data_path.append(i['save_path'])
242
+
243
+ if 'accessory' in json_path:
244
+ concat_num = 5
245
+ elif 'ori_mask' in json_path:
246
+ concat_num = 3
247
+ else:
248
+ concat_num = 4
249
+
250
+ train_trans = [
251
+ SplitConcatImage(concat_num, wo_mask_to_mattes = self.wo_mask_to_mattes),
252
+ RandomHorizontalFlip(prob=0 if hasattr(self, 'return_image_name') and self.return_image_name else 0.5),
253
+ RandomReszieCrop(self.output_size, self.aug_scale_min, self.aug_scale_max),
254
+ EmptyAug() if hasattr(self, 'return_image_name') and self.return_image_name else RandomJitter(),
255
+ ToTensor(),
256
+ GenBBox(bbox_offset_factor=self.bbox_offset_factor)
257
+ ]
258
+ self.transform = transforms.Compose(train_trans)
259
+ print('coconut num: ', len(self.data_path) * self.coconut_num_ratio if self.coconut_num_ratio is not None else len(self.data_path))
260
+
261
+ def __getitem__(self, idx):
262
+ if self.coconut_num_ratio is not None:
263
+ if self.coconut_num_ratio < 1.0 or idx >= len(self.data_path):
264
+ idx = np.random.randint(0, len(self.data_path))
265
+ concat_image = cv2.imread(os.path.join(self.data_root_path, self.data_path[idx]))
266
+ sample = self.transform([concat_image, self.data_path[idx]])
267
+ sample['dataset_name'] = 'COCONut'
268
+ if self.return_multi_fg_info:
269
+ sample['multi_fg'] = False
270
+ if hasattr(self, 'return_image_name') and self.return_image_name:
271
+ sample['image_name'] = self.data_path[idx]
272
+ return sample
273
+
274
+ def __len__(self):
275
+ if self.coconut_num_ratio is not None:
276
+ return int(len(self.data_path) * self.coconut_num_ratio)
277
+ else:
278
+ return len(self.data_path)
279
+
280
+
281
+ class DatasetFromSampler(Dataset):
282
+ """Dataset to create indexes from `Sampler`.
283
+
284
+ Args:
285
+ sampler: PyTorch sampler
286
+ """
287
+
288
+ def __init__(self, sampler: Sampler):
289
+ """Initialisation for DatasetFromSampler."""
290
+ self.sampler = sampler
291
+ self.sampler_list = None
292
+
293
+ def __getitem__(self, index: int):
294
+ """Gets element of the dataset.
295
+
296
+ Args:
297
+ index: index of the element in the dataset
298
+
299
+ Returns:
300
+ Single element by index
301
+ """
302
+ if self.sampler_list is None:
303
+ self.sampler_list = list(self.sampler)
304
+ return self.sampler_list[index]
305
+
306
+ def __len__(self) -> int:
307
+ """
308
+ Returns:
309
+ int: length of the dataset
310
+ """
311
+ return len(self.sampler)
312
+
313
+
314
+ class DistributedSamplerWrapper(DistributedSampler):
315
+ """
316
+ Wrapper over `Sampler` for distributed training.
317
+ Allows you to use any sampler in distributed mode.
318
+ It is especially useful in conjunction with
319
+ `torch.nn.parallel.DistributedDataParallel`. In such case, each
320
+ process can pass a DistributedSamplerWrapper instance as a DataLoader
321
+ sampler, and load a subset of subsampled data of the original dataset
322
+ that is exclusive to it.
323
+ .. note::
324
+ Sampler is assumed to be of constant size.
325
+ """
326
+
327
+ def __init__(
328
+ self,
329
+ sampler,
330
+ num_replicas: Optional[int] = None,
331
+ rank: Optional[int] = None,
332
+ shuffle: bool = True,
333
+ ):
334
+ """
335
+ Args:
336
+ sampler: Sampler used for subsampling
337
+ num_replicas (int, optional): Number of processes participating in
338
+ distributed training
339
+ rank (int, optional): Rank of the current process
340
+ within ``num_replicas``
341
+ shuffle (bool, optional): If true (default),
342
+ sampler will shuffle the indices
343
+ """
344
+ super(DistributedSamplerWrapper, self).__init__(
345
+ DatasetFromSampler(sampler),
346
+ num_replicas=num_replicas,
347
+ rank=rank,
348
+ shuffle=shuffle,
349
+ )
350
+ self.sampler = sampler
351
+
352
+ def __iter__(self):
353
+ """@TODO: Docs. Contribution is welcome."""
354
+ self.dataset = DatasetFromSampler(self.sampler)
355
+ indexes_of_indexes = super().__iter__()
356
+ subsampler_indexes = self.dataset
357
+ return iter(itemgetter(*indexes_of_indexes)(subsampler_indexes))
358
+
359
+
360
+ if __name__ == '__main__':
361
+
362
+
363
+
364
+ dataset = COCONutData(
365
+ json_path = '/root/data/my_path/Matting/DiffMatte-main/24-06-14_coco-nut_matting.json',
366
+ data_root_path = '/root/data/my_path/Matting/DiffMatte-main',
367
+ output_size = 1024,
368
+ aug_scale_min = 0.5,
369
+ aug_scale_max = 1.5,
370
+ with_bbox = True,
371
+ bbox_offset_factor = 0.1,
372
+ phase = "train"
373
+ )
374
+ data = dataset[0]
375
+
376
+ for key, val in data.items():
377
+ print(key, val.shape, torch.min(val), torch.max(val))
data/dim_dataset.py ADDED
@@ -0,0 +1,1476 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Dataloader to process Adobe Image Matting Dataset.
3
+
4
+ From GCA_Matting(https://github.com/Yaoyi-Li/GCA-Matting/tree/master/dataloader)
5
+ '''
6
+ import os
7
+ import glob
8
+ import logging
9
+ import os.path as osp
10
+ import functools
11
+ import numpy as np
12
+ import torch
13
+ import cv2
14
+ import math
15
+ import numbers
16
+ import random
17
+ import pickle
18
+ from torch.utils.data import Dataset, DataLoader
19
+ from torch.nn import functional as F
20
+ from torchvision import transforms
21
+ from easydict import EasyDict
22
+ from detectron2.utils.logger import setup_logger
23
+ from detectron2.utils import comm
24
+ from detectron2.data import build_detection_test_loader
25
+ import torchvision.transforms.functional
26
+
27
+ import json
28
+ from PIL import Image
29
+ from detectron2.evaluation.evaluator import DatasetEvaluator
30
+ from collections import defaultdict
31
+
32
+ from data.evaluate import compute_sad_loss, compute_mse_loss, compute_mad_loss, compute_gradient_loss, compute_connectivity_error
33
+
34
+ # Base default config
35
+ CONFIG = EasyDict({})
36
+
37
+ # Model config
38
+ CONFIG.model = EasyDict({})
39
+ # one-hot or class, choice: [3, 1]
40
+ CONFIG.model.trimap_channel = 1
41
+
42
+ # Dataloader config
43
+ CONFIG.data = EasyDict({})
44
+ # feed forward image size (untested)
45
+ CONFIG.data.crop_size = 512
46
+ # composition of two foregrounds, affine transform, crop and HSV jitter
47
+ CONFIG.data.cutmask_prob = 0.25
48
+ CONFIG.data.augmentation = True
49
+ CONFIG.data.random_interp = True
50
+
51
+ class Prefetcher():
52
+ """
53
+ Modified from the data_prefetcher in https://github.com/NVIDIA/apex/blob/master/examples/imagenet/main_amp.py
54
+ """
55
+ def __init__(self, loader):
56
+ self.orig_loader = loader
57
+ self.stream = torch.cuda.Stream()
58
+ self.next_sample = None
59
+
60
+ def preload(self):
61
+ try:
62
+ self.next_sample = next(self.loader)
63
+ except StopIteration:
64
+ self.next_sample = None
65
+ return
66
+
67
+ with torch.cuda.stream(self.stream):
68
+ for key, value in self.next_sample.items():
69
+ if isinstance(value, torch.Tensor):
70
+ self.next_sample[key] = value.cuda(non_blocking=True)
71
+
72
+ def __next__(self):
73
+ torch.cuda.current_stream().wait_stream(self.stream)
74
+ sample = self.next_sample
75
+ if sample is not None:
76
+ for key, value in sample.items():
77
+ if isinstance(value, torch.Tensor):
78
+ sample[key].record_stream(torch.cuda.current_stream())
79
+ self.preload()
80
+ else:
81
+ # throw stop exception if there is no more data to perform as a default dataloader
82
+ raise StopIteration("No samples in loader. example: `iterator = iter(Prefetcher(loader)); "
83
+ "data = next(iterator)`")
84
+ return sample
85
+
86
+ def __iter__(self):
87
+ self.loader = iter(self.orig_loader)
88
+ self.preload()
89
+ return self
90
+
91
+
92
+ class ImageFile(object):
93
+ def __init__(self, phase='train'):
94
+ self.phase = phase
95
+ self.rng = np.random.RandomState(0)
96
+
97
+ def _get_valid_names(self, *dirs, shuffle=True):
98
+ name_sets = [self._get_name_set(d) for d in dirs]
99
+
100
+ def _join_and(a, b):
101
+ return a & b
102
+
103
+ valid_names = list(functools.reduce(_join_and, name_sets))
104
+ if shuffle:
105
+ self.rng.shuffle(valid_names)
106
+
107
+ return valid_names
108
+
109
+ @staticmethod
110
+ def _get_name_set(dir_name):
111
+ path_list = glob.glob(os.path.join(dir_name, '*'))
112
+ name_set = set()
113
+ for path in path_list:
114
+ name = os.path.basename(path)
115
+ name = os.path.splitext(name)[0]
116
+ name_set.add(name)
117
+ return name_set
118
+
119
+ @staticmethod
120
+ def _list_abspath(data_dir, ext, data_list):
121
+ return [os.path.join(data_dir, name + ext)
122
+ for name in data_list]
123
+
124
+ class ImageFileTrain(ImageFile):
125
+ def __init__(
126
+ self,
127
+ alpha_dir="train_alpha",
128
+ fg_dir="train_fg",
129
+ bg_dir="train_bg",
130
+ alpha_ext=".jpg",
131
+ fg_ext=".jpg",
132
+ bg_ext=".jpg",
133
+ fg_have_bg_num=None,
134
+ alpha_ratio_json = None,
135
+ alpha_min_ratio = None,
136
+ key_sample_ratio = None,
137
+ ):
138
+ super(ImageFileTrain, self).__init__(phase="train")
139
+
140
+ self.alpha_dir = alpha_dir
141
+ self.fg_dir = fg_dir
142
+ self.bg_dir = bg_dir
143
+ self.alpha_ext = alpha_ext
144
+ self.fg_ext = fg_ext
145
+ self.bg_ext = bg_ext
146
+ logger = setup_logger(name=__name__)
147
+
148
+ if not isinstance(self.alpha_dir, str):
149
+ assert len(self.alpha_dir) == len(self.fg_dir) == len(alpha_ext) == len(fg_ext)
150
+ self.valid_fg_list = []
151
+ self.alpha = []
152
+ self.fg = []
153
+ self.key_alpha = []
154
+ self.key_fg = []
155
+ for i in range(len(self.alpha_dir)):
156
+ valid_fg_list = self._get_valid_names(self.fg_dir[i], self.alpha_dir[i])
157
+ valid_fg_list.sort()
158
+ alpha = self._list_abspath(self.alpha_dir[i], self.alpha_ext[i], valid_fg_list)
159
+ fg = self._list_abspath(self.fg_dir[i], self.fg_ext[i], valid_fg_list)
160
+ self.valid_fg_list += valid_fg_list
161
+
162
+ self.alpha += alpha * fg_have_bg_num[i]
163
+ self.fg += fg * fg_have_bg_num[i]
164
+
165
+ if alpha_ratio_json[i] is not None:
166
+ tmp_key_alpha = []
167
+ tmp_key_fg = []
168
+ name_to_alpha_path = dict()
169
+ for name in alpha:
170
+ name_to_alpha_path[name.split('/')[-1].split('.')[0]] = name
171
+ name_to_fg_path = dict()
172
+ for name in fg:
173
+ name_to_fg_path[name.split('/')[-1].split('.')[0]] = name
174
+
175
+ with open(alpha_ratio_json[i], 'r') as file:
176
+ alpha_ratio_list = json.load(file)
177
+ for ratio, name in alpha_ratio_list:
178
+ if ratio < alpha_min_ratio[i]:
179
+ break
180
+ tmp_key_alpha.append(name_to_alpha_path[name.split('.')[0]])
181
+ tmp_key_fg.append(name_to_fg_path[name.split('.')[0]])
182
+
183
+ self.key_alpha.extend(tmp_key_alpha * fg_have_bg_num[i])
184
+ self.key_fg.extend(tmp_key_fg * fg_have_bg_num[i])
185
+
186
+ if len(self.key_alpha) != 0 and key_sample_ratio > 0:
187
+ repeat_num = key_sample_ratio * (len(self.alpha) - len(self.key_alpha)) / len(self.key_alpha) / (1 - key_sample_ratio) - 1
188
+ print('key sample num:', len(self.key_alpha), ', repeat num: ', repeat_num)
189
+ for i in range(math.ceil(repeat_num)):
190
+ self.alpha += self.key_alpha
191
+ self.fg += self.key_fg
192
+
193
+ else:
194
+ self.valid_fg_list = self._get_valid_names(self.fg_dir, self.alpha_dir)
195
+ self.valid_fg_list.sort()
196
+ self.alpha = self._list_abspath(self.alpha_dir, self.alpha_ext, self.valid_fg_list)
197
+ self.fg = self._list_abspath(self.fg_dir, self.fg_ext, self.valid_fg_list)
198
+
199
+ self.valid_bg_list = [os.path.splitext(name)[0] for name in os.listdir(self.bg_dir)]
200
+ self.valid_bg_list.sort()
201
+
202
+ if fg_have_bg_num is not None:
203
+ # assert fg_have_bg_num * len(self.valid_fg_list) <= len(self.valid_bg_list)
204
+ # self.valid_bg_list = self.valid_bg_list[: fg_have_bg_num * len(self.valid_fg_list)]
205
+ assert len(self.alpha) <= len(self.valid_bg_list)
206
+ self.valid_bg_list = self.valid_bg_list[: len(self.alpha)]
207
+
208
+ self.bg = self._list_abspath(self.bg_dir, self.bg_ext, self.valid_bg_list)
209
+
210
+ def __len__(self):
211
+ return len(self.alpha)
212
+
213
+ class ImageFileTest(ImageFile):
214
+ def __init__(self,
215
+ alpha_dir="test_alpha",
216
+ merged_dir="test_merged",
217
+ trimap_dir="test_trimap",
218
+ alpha_ext=".png",
219
+ merged_ext=".png",
220
+ trimap_ext=".png"):
221
+ super(ImageFileTest, self).__init__(phase="test")
222
+
223
+ self.alpha_dir = alpha_dir
224
+ self.merged_dir = merged_dir
225
+ self.trimap_dir = trimap_dir
226
+ self.alpha_ext = alpha_ext
227
+ self.merged_ext = merged_ext
228
+ self.trimap_ext = trimap_ext
229
+
230
+ self.valid_image_list = self._get_valid_names(self.alpha_dir, self.merged_dir, self.trimap_dir, shuffle=False)
231
+
232
+ self.alpha = self._list_abspath(self.alpha_dir, self.alpha_ext, self.valid_image_list)
233
+ self.merged = self._list_abspath(self.merged_dir, self.merged_ext, self.valid_image_list)
234
+ self.trimap = self._list_abspath(self.trimap_dir, self.trimap_ext, self.valid_image_list)
235
+
236
+ def __len__(self):
237
+ return len(self.alpha)
238
+
239
+ interp_list = [cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_LANCZOS4]
240
+
241
+
242
+ def maybe_random_interp(cv2_interp):
243
+ if CONFIG.data.random_interp:
244
+ return np.random.choice(interp_list)
245
+ else:
246
+ return cv2_interp
247
+
248
+
249
+ class ToTensor(object):
250
+ """
251
+ Convert ndarrays in sample to Tensors with normalization.
252
+ """
253
+ def __init__(self, phase="test"):
254
+ self.mean = torch.tensor([0.485, 0.456, 0.406]).view(3,1,1)
255
+ self.std = torch.tensor([0.229, 0.224, 0.225]).view(3,1,1)
256
+ self.phase = phase
257
+
258
+ def __call__(self, sample):
259
+ image, alpha, trimap, mask = sample['image'][:,:,::-1], sample['alpha'], sample['trimap'], sample['mask']
260
+
261
+ alpha[alpha < 0 ] = 0
262
+ alpha[alpha > 1] = 1
263
+
264
+ image = image.transpose((2, 0, 1)).astype(np.float32)
265
+ alpha = np.expand_dims(alpha.astype(np.float32), axis=0)
266
+
267
+ mask = np.expand_dims(mask.astype(np.float32), axis=0)
268
+
269
+ image /= 255.
270
+
271
+ if self.phase == "train":
272
+ fg = sample['fg'][:,:,::-1].transpose((2, 0, 1)).astype(np.float32) / 255.
273
+ sample['fg'] = torch.from_numpy(fg)
274
+ bg = sample['bg'][:,:,::-1].transpose((2, 0, 1)).astype(np.float32) / 255.
275
+ sample['bg'] = torch.from_numpy(bg)
276
+
277
+ sample['image'], sample['alpha'], sample['trimap'] = \
278
+ torch.from_numpy(image), torch.from_numpy(alpha), torch.from_numpy(trimap).to(torch.long)
279
+ sample['image'] = sample['image']
280
+
281
+ if CONFIG.model.trimap_channel == 3:
282
+ sample['trimap'] = F.one_hot(sample['trimap'], num_classes=3).permute(2,0,1).float()
283
+ elif CONFIG.model.trimap_channel == 1:
284
+ sample['trimap'] = sample['trimap'][None,...].float()
285
+ else:
286
+ raise NotImplementedError("CONFIG.model.trimap_channel can only be 3 or 1")
287
+ sample['trimap'][sample['trimap'] < 85] = 0
288
+ sample['trimap'][sample['trimap'] >= 170] = 1
289
+ sample['trimap'][sample['trimap'] >= 85] = 0.5
290
+
291
+ sample['mask'] = torch.from_numpy(mask).float()
292
+
293
+ return sample
294
+
295
+
296
+ class RandomAffine(object):
297
+ """
298
+ Random affine translation
299
+ """
300
+ def __init__(self, degrees, translate=None, scale=None, shear=None, flip=None, resample=False, fillcolor=0):
301
+ if isinstance(degrees, numbers.Number):
302
+ if degrees < 0:
303
+ raise ValueError("If degrees is a single number, it must be positive.")
304
+ self.degrees = (-degrees, degrees)
305
+ else:
306
+ assert isinstance(degrees, (tuple, list)) and len(degrees) == 2, \
307
+ "degrees should be a list or tuple and it must be of length 2."
308
+ self.degrees = degrees
309
+
310
+ if translate is not None:
311
+ assert isinstance(translate, (tuple, list)) and len(translate) == 2, \
312
+ "translate should be a list or tuple and it must be of length 2."
313
+ for t in translate:
314
+ if not (0.0 <= t <= 1.0):
315
+ raise ValueError("translation values should be between 0 and 1")
316
+ self.translate = translate
317
+
318
+ if scale is not None:
319
+ assert isinstance(scale, (tuple, list)) and len(scale) == 2, \
320
+ "scale should be a list or tuple and it must be of length 2."
321
+ for s in scale:
322
+ if s <= 0:
323
+ raise ValueError("scale values should be positive")
324
+ self.scale = scale
325
+
326
+ if shear is not None:
327
+ if isinstance(shear, numbers.Number):
328
+ if shear < 0:
329
+ raise ValueError("If shear is a single number, it must be positive.")
330
+ self.shear = (-shear, shear)
331
+ else:
332
+ assert isinstance(shear, (tuple, list)) and len(shear) == 2, \
333
+ "shear should be a list or tuple and it must be of length 2."
334
+ self.shear = shear
335
+ else:
336
+ self.shear = shear
337
+
338
+ self.resample = resample
339
+ self.fillcolor = fillcolor
340
+ self.flip = flip
341
+
342
+ @staticmethod
343
+ def get_params(degrees, translate, scale_ranges, shears, flip, img_size):
344
+ """Get parameters for affine transformation
345
+
346
+ Returns:
347
+ sequence: params to be passed to the affine transformation
348
+ """
349
+ angle = random.uniform(degrees[0], degrees[1])
350
+ if translate is not None:
351
+ max_dx = translate[0] * img_size[0]
352
+ max_dy = translate[1] * img_size[1]
353
+ translations = (np.round(random.uniform(-max_dx, max_dx)),
354
+ np.round(random.uniform(-max_dy, max_dy)))
355
+ else:
356
+ translations = (0, 0)
357
+
358
+ if scale_ranges is not None:
359
+ scale = (random.uniform(scale_ranges[0], scale_ranges[1]),
360
+ random.uniform(scale_ranges[0], scale_ranges[1]))
361
+ else:
362
+ scale = (1.0, 1.0)
363
+
364
+ if shears is not None:
365
+ shear = random.uniform(shears[0], shears[1])
366
+ else:
367
+ shear = 0.0
368
+
369
+ if flip is not None:
370
+ flip = (np.random.rand(2) < flip).astype(np.int32) * 2 - 1
371
+
372
+ return angle, translations, scale, shear, flip
373
+
374
+ def __call__(self, sample):
375
+ fg, alpha = sample['fg'], sample['alpha']
376
+ rows, cols, ch = fg.shape
377
+ if np.maximum(rows, cols) < 1024:
378
+ params = self.get_params((0, 0), self.translate, self.scale, self.shear, self.flip, fg.size)
379
+ else:
380
+ params = self.get_params(self.degrees, self.translate, self.scale, self.shear, self.flip, fg.size)
381
+
382
+ center = (cols * 0.5 + 0.5, rows * 0.5 + 0.5)
383
+ M = self._get_inverse_affine_matrix(center, *params)
384
+ M = np.array(M).reshape((2, 3))
385
+
386
+ fg = cv2.warpAffine(fg, M, (cols, rows),
387
+ flags=maybe_random_interp(cv2.INTER_NEAREST) + cv2.WARP_INVERSE_MAP)
388
+ alpha = cv2.warpAffine(alpha, M, (cols, rows),
389
+ flags=maybe_random_interp(cv2.INTER_NEAREST) + cv2.WARP_INVERSE_MAP)
390
+
391
+ sample['fg'], sample['alpha'] = fg, alpha
392
+
393
+ return sample
394
+
395
+
396
+ @ staticmethod
397
+ def _get_inverse_affine_matrix(center, angle, translate, scale, shear, flip):
398
+
399
+ angle = math.radians(angle)
400
+ shear = math.radians(shear)
401
+ scale_x = 1.0 / scale[0] * flip[0]
402
+ scale_y = 1.0 / scale[1] * flip[1]
403
+
404
+ # Inverted rotation matrix with scale and shear
405
+ d = math.cos(angle + shear) * math.cos(angle) + math.sin(angle + shear) * math.sin(angle)
406
+ matrix = [
407
+ math.cos(angle) * scale_x, math.sin(angle + shear) * scale_x, 0,
408
+ -math.sin(angle) * scale_y, math.cos(angle + shear) * scale_y, 0
409
+ ]
410
+ matrix = [m / d for m in matrix]
411
+
412
+ # Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1
413
+ matrix[2] += matrix[0] * (-center[0] - translate[0]) + matrix[1] * (-center[1] - translate[1])
414
+ matrix[5] += matrix[3] * (-center[0] - translate[0]) + matrix[4] * (-center[1] - translate[1])
415
+
416
+ # Apply center translation: C * RSS^-1 * C^-1 * T^-1
417
+ matrix[2] += center[0]
418
+ matrix[5] += center[1]
419
+
420
+ return matrix
421
+
422
+
423
+ class RandomJitter(object):
424
+ """
425
+ Random change the hue of the image
426
+ """
427
+
428
+ def __call__(self, sample):
429
+ sample_ori = sample.copy()
430
+ fg, alpha = sample['fg'], sample['alpha']
431
+ # if alpha is all 0 skip
432
+ if np.all(alpha==0):
433
+ return sample_ori
434
+ # convert to HSV space, convert to float32 image to keep precision during space conversion.
435
+ fg = cv2.cvtColor(fg.astype(np.float32)/255.0, cv2.COLOR_BGR2HSV)
436
+ # Hue noise
437
+ hue_jitter = np.random.randint(-40, 40)
438
+ fg[:, :, 0] = np.remainder(fg[:, :, 0].astype(np.float32) + hue_jitter, 360)
439
+ # Saturation noise
440
+ sat_bar = fg[:, :, 1][alpha > 0].mean()
441
+ if np.isnan(sat_bar):
442
+ return sample_ori
443
+ sat_jitter = np.random.rand()*(1.1 - sat_bar)/5 - (1.1 - sat_bar) / 10
444
+ sat = fg[:, :, 1]
445
+ sat = np.abs(sat + sat_jitter)
446
+ sat[sat>1] = 2 - sat[sat>1]
447
+ fg[:, :, 1] = sat
448
+ # Value noise
449
+ val_bar = fg[:, :, 2][alpha > 0].mean()
450
+ if np.isnan(val_bar):
451
+ return sample_ori
452
+ val_jitter = np.random.rand()*(1.1 - val_bar)/5-(1.1 - val_bar) / 10
453
+ val = fg[:, :, 2]
454
+ val = np.abs(val + val_jitter)
455
+ val[val>1] = 2 - val[val>1]
456
+ fg[:, :, 2] = val
457
+ # convert back to BGR space
458
+ fg = cv2.cvtColor(fg, cv2.COLOR_HSV2BGR)
459
+ sample['fg'] = fg*255
460
+
461
+ return sample
462
+
463
+
464
+ class RandomHorizontalFlip(object):
465
+ """
466
+ Random flip image and label horizontally
467
+ """
468
+ def __init__(self, prob=0.5):
469
+ self.prob = prob
470
+ def __call__(self, sample):
471
+ fg, alpha = sample['fg'], sample['alpha']
472
+ if np.random.uniform(0, 1) < self.prob:
473
+ fg = cv2.flip(fg, 1)
474
+ alpha = cv2.flip(alpha, 1)
475
+ sample['fg'], sample['alpha'] = fg, alpha
476
+
477
+ return sample
478
+
479
+
480
+ class RandomCrop(object):
481
+ """
482
+ Crop randomly the image in a sample, retain the center 1/4 images, and resize to 'output_size'
483
+
484
+ :param output_size (tuple or int): Desired output size. If int, square crop
485
+ is made.
486
+ """
487
+
488
+ def __init__(self, output_size=( CONFIG.data.crop_size, CONFIG.data.crop_size)):
489
+ assert isinstance(output_size, (int, tuple))
490
+ if isinstance(output_size, int):
491
+ self.output_size = (output_size, output_size)
492
+ else:
493
+ assert len(output_size) == 2
494
+ self.output_size = output_size
495
+ self.margin = output_size[0] // 2
496
+ self.logger = logging.getLogger("Logger")
497
+
498
+ def __call__(self, sample):
499
+ fg, alpha, trimap, mask, name = sample['fg'], sample['alpha'], sample['trimap'], sample['mask'], sample['image_name']
500
+ bg = sample['bg']
501
+ h, w = trimap.shape
502
+ bg = cv2.resize(bg, (w, h), interpolation=maybe_random_interp(cv2.INTER_CUBIC))
503
+ if w < self.output_size[0]+1 or h < self.output_size[1]+1:
504
+ ratio = 1.1*self.output_size[0]/h if h < w else 1.1*self.output_size[1]/w
505
+ # self.logger.warning("Size of {} is {}.".format(name, (h, w)))
506
+ while h < self.output_size[0]+1 or w < self.output_size[1]+1:
507
+ fg = cv2.resize(fg, (int(w*ratio), int(h*ratio)), interpolation=maybe_random_interp(cv2.INTER_NEAREST))
508
+ alpha = cv2.resize(alpha, (int(w*ratio), int(h*ratio)),
509
+ interpolation=maybe_random_interp(cv2.INTER_NEAREST))
510
+ trimap = cv2.resize(trimap, (int(w*ratio), int(h*ratio)), interpolation=cv2.INTER_NEAREST)
511
+ bg = cv2.resize(bg, (int(w*ratio), int(h*ratio)), interpolation=maybe_random_interp(cv2.INTER_CUBIC))
512
+ mask = cv2.resize(mask, (int(w*ratio), int(h*ratio)), interpolation=cv2.INTER_NEAREST)
513
+ h, w = trimap.shape
514
+ small_trimap = cv2.resize(trimap, (w//4, h//4), interpolation=cv2.INTER_NEAREST)
515
+ unknown_list = list(zip(*np.where(small_trimap[self.margin//4:(h-self.margin)//4,
516
+ self.margin//4:(w-self.margin)//4] == 128)))
517
+ unknown_num = len(unknown_list)
518
+ if len(unknown_list) < 10:
519
+ left_top = (np.random.randint(0, h-self.output_size[0]+1), np.random.randint(0, w-self.output_size[1]+1))
520
+ else:
521
+ idx = np.random.randint(unknown_num)
522
+ left_top = (unknown_list[idx][0]*4, unknown_list[idx][1]*4)
523
+
524
+ fg_crop = fg[left_top[0]:left_top[0]+self.output_size[0], left_top[1]:left_top[1]+self.output_size[1],:]
525
+ alpha_crop = alpha[left_top[0]:left_top[0]+self.output_size[0], left_top[1]:left_top[1]+self.output_size[1]]
526
+ bg_crop = bg[left_top[0]:left_top[0]+self.output_size[0], left_top[1]:left_top[1]+self.output_size[1],:]
527
+ trimap_crop = trimap[left_top[0]:left_top[0]+self.output_size[0], left_top[1]:left_top[1]+self.output_size[1]]
528
+ mask_crop = mask[left_top[0]:left_top[0]+self.output_size[0], left_top[1]:left_top[1]+self.output_size[1]]
529
+
530
+ if len(np.where(trimap==128)[0]) == 0:
531
+ self.logger.error("{} does not have enough unknown area for crop. Resized to target size."
532
+ "left_top: {}".format(name, left_top))
533
+ fg_crop = cv2.resize(fg, self.output_size[::-1], interpolation=maybe_random_interp(cv2.INTER_NEAREST))
534
+ alpha_crop = cv2.resize(alpha, self.output_size[::-1], interpolation=maybe_random_interp(cv2.INTER_NEAREST))
535
+ trimap_crop = cv2.resize(trimap, self.output_size[::-1], interpolation=cv2.INTER_NEAREST)
536
+ bg_crop = cv2.resize(bg, self.output_size[::-1], interpolation=maybe_random_interp(cv2.INTER_CUBIC))
537
+ mask_crop = cv2.resize(mask, self.output_size[::-1], interpolation=cv2.INTER_NEAREST)
538
+
539
+ sample.update({'fg': fg_crop, 'alpha': alpha_crop, 'trimap': trimap_crop, 'mask': mask_crop, 'bg': bg_crop})
540
+ return sample
541
+
542
+
543
+ class OriginScale(object):
544
+ def __call__(self, sample):
545
+ h, w = sample["alpha_shape"]
546
+
547
+ if h % 32 == 0 and w % 32 == 0:
548
+ return sample
549
+
550
+ target_h = 32 * ((h - 1) // 32 + 1)
551
+ target_w = 32 * ((w - 1) // 32 + 1)
552
+ pad_h = target_h - h
553
+ pad_w = target_w - w
554
+
555
+ padded_image = np.pad(sample['image'], ((0,pad_h), (0, pad_w), (0,0)), mode="reflect")
556
+ padded_trimap = np.pad(sample['trimap'], ((0,pad_h), (0, pad_w)), mode="reflect")
557
+ padded_mask = np.pad(sample['mask'], ((0,pad_h), (0, pad_w)), mode="reflect")
558
+
559
+ sample['image'] = padded_image
560
+ sample['trimap'] = padded_trimap
561
+ sample['mask'] = padded_mask
562
+
563
+ return sample
564
+
565
+
566
+ class GenMask(object):
567
+ def __init__(self):
568
+ self.erosion_kernels = [None] + [cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (size, size)) for size in range(1,30)]
569
+
570
+ def __call__(self, sample):
571
+ alpha_ori = sample['alpha']
572
+ h, w = alpha_ori.shape
573
+
574
+ max_kernel_size = 30
575
+ alpha = cv2.resize(alpha_ori, (640,640), interpolation=maybe_random_interp(cv2.INTER_NEAREST))
576
+
577
+ ### generate trimap
578
+ fg_mask = (alpha + 1e-5).astype(np.int32).astype(np.uint8)
579
+ bg_mask = (1 - alpha + 1e-5).astype(np.int32).astype(np.uint8)
580
+ fg_mask = cv2.erode(fg_mask, self.erosion_kernels[np.random.randint(1, max_kernel_size)])
581
+ bg_mask = cv2.erode(bg_mask, self.erosion_kernels[np.random.randint(1, max_kernel_size)])
582
+
583
+ fg_width = np.random.randint(1, 30)
584
+ bg_width = np.random.randint(1, 30)
585
+ fg_mask = (alpha + 1e-5).astype(np.int32).astype(np.uint8)
586
+ bg_mask = (1 - alpha + 1e-5).astype(np.int32).astype(np.uint8)
587
+ fg_mask = cv2.erode(fg_mask, self.erosion_kernels[fg_width])
588
+ bg_mask = cv2.erode(bg_mask, self.erosion_kernels[bg_width])
589
+
590
+ trimap = np.ones_like(alpha) * 128
591
+ trimap[fg_mask == 1] = 255
592
+ trimap[bg_mask == 1] = 0
593
+
594
+ trimap = cv2.resize(trimap, (w,h), interpolation=cv2.INTER_NEAREST)
595
+ sample['trimap'] = trimap
596
+
597
+ ### generate mask
598
+ low = 0.01
599
+ high = 1.0
600
+ thres = random.random() * (high - low) + low
601
+ seg_mask = (alpha >= thres).astype(np.int32).astype(np.uint8)
602
+ random_num = random.randint(0,3)
603
+ if random_num == 0:
604
+ seg_mask = cv2.erode(seg_mask, self.erosion_kernels[np.random.randint(1, max_kernel_size)])
605
+ elif random_num == 1:
606
+ seg_mask = cv2.dilate(seg_mask, self.erosion_kernels[np.random.randint(1, max_kernel_size)])
607
+ elif random_num == 2:
608
+ seg_mask = cv2.erode(seg_mask, self.erosion_kernels[np.random.randint(1, max_kernel_size)])
609
+ seg_mask = cv2.dilate(seg_mask, self.erosion_kernels[np.random.randint(1, max_kernel_size)])
610
+ elif random_num == 3:
611
+ seg_mask = cv2.dilate(seg_mask, self.erosion_kernels[np.random.randint(1, max_kernel_size)])
612
+ seg_mask = cv2.erode(seg_mask, self.erosion_kernels[np.random.randint(1, max_kernel_size)])
613
+
614
+ seg_mask = cv2.resize(seg_mask, (w,h), interpolation=cv2.INTER_NEAREST)
615
+ sample['mask'] = seg_mask
616
+
617
+ return sample
618
+
619
+
620
+ class Composite(object):
621
+ def __call__(self, sample):
622
+ fg, bg, alpha = sample['fg'], sample['bg'], sample['alpha']
623
+ alpha[alpha < 0 ] = 0
624
+ alpha[alpha > 1] = 1
625
+ fg[fg < 0 ] = 0
626
+ fg[fg > 255] = 255
627
+ bg[bg < 0 ] = 0
628
+ bg[bg > 255] = 255
629
+
630
+ image = fg * alpha[:, :, None] + bg * (1 - alpha[:, :, None])
631
+ sample['image'] = image
632
+ return sample
633
+
634
+
635
+ class CutMask(object):
636
+ def __init__(self, perturb_prob = 0):
637
+ self.perturb_prob = perturb_prob
638
+
639
+ def __call__(self, sample):
640
+ if np.random.rand() < self.perturb_prob:
641
+ return sample
642
+
643
+ mask = sample['mask'] # H x W, trimap 0--255, segmask 0--1, alpha 0--1
644
+ h, w = mask.shape
645
+ perturb_size_h, perturb_size_w = random.randint(h // 4, h // 2), random.randint(w // 4, w // 2)
646
+ x = random.randint(0, h - perturb_size_h)
647
+ y = random.randint(0, w - perturb_size_w)
648
+ x1 = random.randint(0, h - perturb_size_h)
649
+ y1 = random.randint(0, w - perturb_size_w)
650
+
651
+ mask[x:x+perturb_size_h, y:y+perturb_size_w] = mask[x1:x1+perturb_size_h, y1:y1+perturb_size_w].copy()
652
+
653
+ sample['mask'] = mask
654
+ return sample
655
+
656
+
657
+ class ScaleFg(object):
658
+ def __init__(self, min_scale_fg_scale=0.5, max_scale_fg_scale=1.0):
659
+ self.min_scale_fg_scale = min_scale_fg_scale
660
+ self.max_scale_fg_scale = max_scale_fg_scale
661
+
662
+ def __call__(self, sample):
663
+ scale_factor = np.random.uniform(low=self.min_scale_fg_scale, high=self.max_scale_fg_scale)
664
+
665
+ fg, alpha = sample['fg'], sample['alpha'] # np.array(): [H, W, 3] 0 ~ 255 , [H, W] 0.0 ~ 1.0
666
+ h, w = alpha.shape
667
+ scale_h, scale_w = int(h * scale_factor), int(w * scale_factor)
668
+
669
+ new_fg, new_alpha = np.zeros_like(fg), np.zeros_like(alpha)
670
+ fg = cv2.resize(fg, (scale_w, scale_h), interpolation=cv2.INTER_LINEAR)
671
+ alpha = cv2.resize(alpha, (scale_w, scale_h), interpolation=cv2.INTER_LINEAR)
672
+
673
+ if scale_factor <= 1:
674
+ offset_h, offset_w = np.random.randint(h - scale_h + 1), np.random.randint(w - scale_w + 1)
675
+ new_fg[offset_h: offset_h + scale_h, offset_w: offset_w + scale_w, :] = fg
676
+ new_alpha[offset_h: offset_h + scale_h, offset_w: offset_w + scale_w] = alpha
677
+ else:
678
+ offset_h, offset_w = np.random.randint(scale_h - h + 1), np.random.randint(scale_w - w + 1)
679
+ new_fg = fg[offset_h: offset_h + scale_h, offset_w: offset_w + scale_w, :]
680
+ new_alpha = alpha[offset_h: offset_h + scale_h, offset_w: offset_w + scale_w]
681
+
682
+ sample['fg'], sample['alpha'] = new_fg, new_alpha
683
+ return sample
684
+
685
+ class GenBBox(object):
686
+ def __init__(self, bbox_offset_factor = 0.1, random_crop_bbox = None, train_or_test = 'train', dataset_type = None, random_auto_matting=None):
687
+ self.bbox_offset_factor = bbox_offset_factor
688
+ self.random_crop_bbox = random_crop_bbox
689
+ self.train_or_test = train_or_test
690
+ self.dataset_type = dataset_type
691
+ self.random_auto_matting = random_auto_matting
692
+
693
+ def __call__(self, sample):
694
+
695
+ alpha = sample['alpha'] # [1, H, W] 0.0 ~ 1.0
696
+ indices = torch.nonzero(alpha[0], as_tuple=True)
697
+
698
+ if len(indices[0]) > 0:
699
+
700
+ min_x, min_y = torch.min(indices[1]), torch.min(indices[0])
701
+ max_x, max_y = torch.max(indices[1]), torch.max(indices[0])
702
+
703
+ if self.random_crop_bbox is not None and np.random.uniform(0, 1) < self.random_crop_bbox:
704
+ ori_h_w = (sample['alpha'].shape[-2], sample['alpha'].shape[-1])
705
+ sample['alpha'] = F.interpolate(sample['alpha'][None, :, min_y: max_y + 1, min_x: max_x + 1], size=ori_h_w, mode='bilinear', align_corners=False)[0]
706
+ sample['image'] = F.interpolate(sample['image'][None, :, min_y: max_y + 1, min_x: max_x + 1], size=ori_h_w, mode='bilinear', align_corners=False)[0]
707
+ sample['trimap'] = F.interpolate(sample['trimap'][None, :, min_y: max_y + 1, min_x: max_x + 1], size=ori_h_w, mode='nearest')[0]
708
+ bbox = torch.tensor([[0, 0, ori_h_w[1] - 1, ori_h_w[0] - 1]])
709
+
710
+ elif self.bbox_offset_factor != 0:
711
+ bbox_w = max(1, max_x - min_x)
712
+ bbox_h = max(1, max_y - min_y)
713
+ offset_w = math.ceil(self.bbox_offset_factor * bbox_w)
714
+ offset_h = math.ceil(self.bbox_offset_factor * bbox_h)
715
+
716
+ min_x = max(0, min_x + np.random.randint(-offset_w, offset_w))
717
+ max_x = min(alpha.shape[2] - 1, max_x + np.random.randint(-offset_w, offset_w))
718
+ min_y = max(0, min_y + np.random.randint(-offset_h, offset_h))
719
+ max_y = min(alpha.shape[1] - 1, max_y + np.random.randint(-offset_h, offset_h))
720
+ bbox = torch.tensor([[min_x, min_y, max_x, max_y]])
721
+ else:
722
+ bbox = torch.tensor([[min_x, min_y, max_x, max_y]])
723
+
724
+ if self.random_auto_matting is not None and np.random.uniform(0, 1) < self.random_auto_matting:
725
+ bbox = torch.tensor([[0, 0, alpha.shape[2] - 1, alpha.shape[1] - 1]])
726
+
727
+ else:
728
+ bbox = torch.zeros(1, 4)
729
+
730
+ sample['bbox'] = bbox.float()
731
+ return sample
732
+
733
+ class DataGenerator(Dataset):
734
+ def __init__(
735
+ self,
736
+ data,
737
+ phase="train",
738
+ crop_size=512,
739
+ remove_multi_fg=False,
740
+ min_scale_fg_scale=None,
741
+ max_scale_fg_scale=None,
742
+ with_bbox = False,
743
+ bbox_offset_factor = None,
744
+ return_keys = None,
745
+ random_crop_bbox = None,
746
+ dataset_name = None,
747
+ random_auto_matting = None,
748
+ ):
749
+ self.phase = phase
750
+ # self.crop_size = CONFIG.data.crop_size
751
+ self.crop_size = crop_size
752
+ self.remove_multi_fg = remove_multi_fg
753
+ self.with_bbox = with_bbox
754
+ self.bbox_offset_factor = bbox_offset_factor
755
+ self.alpha = data.alpha
756
+ self.return_keys = return_keys
757
+ self.random_crop_bbox = random_crop_bbox
758
+ self.dataset_name = dataset_name
759
+ self.random_auto_matting = random_auto_matting
760
+
761
+ if self.phase == "train":
762
+ self.fg = data.fg
763
+ self.bg = data.bg
764
+ self.merged = []
765
+ self.trimap = []
766
+ else:
767
+ self.fg = []
768
+ self.bg = []
769
+ self.merged = data.merged
770
+ self.trimap = data.trimap
771
+
772
+ train_trans = [
773
+ RandomAffine(degrees=30, scale=[0.8, 1.25], shear=10, flip=0.5),
774
+ GenMask(),
775
+ CutMask(perturb_prob=CONFIG.data.cutmask_prob),
776
+ RandomCrop((self.crop_size, self.crop_size)),
777
+ RandomJitter(),
778
+ Composite(),
779
+ ToTensor(phase="train")
780
+ ]
781
+ if min_scale_fg_scale is not None:
782
+ train_trans.insert(0, ScaleFg(min_scale_fg_scale, max_scale_fg_scale))
783
+ if self.with_bbox:
784
+ train_trans.append(GenBBox(bbox_offset_factor=self.bbox_offset_factor, random_crop_bbox=self.random_crop_bbox, random_auto_matting=self.random_auto_matting))
785
+
786
+ test_trans = [ OriginScale(), ToTensor() ]
787
+
788
+ self.transform = {
789
+ 'train':
790
+ transforms.Compose(train_trans),
791
+ 'val':
792
+ transforms.Compose([
793
+ OriginScale(),
794
+ ToTensor()
795
+ ]),
796
+ 'test':
797
+ transforms.Compose(test_trans)
798
+ }[phase]
799
+
800
+ self.fg_num = len(self.fg)
801
+
802
+ def select_keys(self, sample):
803
+ new_sample = {}
804
+ for key, val in sample.items():
805
+ if key in self.return_keys:
806
+ new_sample[key] = val
807
+ return new_sample
808
+
809
+ def __getitem__(self, idx):
810
+ if self.phase == "train":
811
+ fg = cv2.imread(self.fg[idx % self.fg_num])
812
+ alpha = cv2.imread(self.alpha[idx % self.fg_num], 0).astype(np.float32)/255
813
+ bg = cv2.imread(self.bg[idx], 1)
814
+
815
+ if not self.remove_multi_fg:
816
+ fg, alpha, multi_fg = self._composite_fg(fg, alpha, idx)
817
+ else:
818
+ multi_fg = False
819
+ image_name = os.path.split(self.fg[idx % self.fg_num])[-1]
820
+ sample = {'fg': fg, 'alpha': alpha, 'bg': bg, 'image_name': image_name, 'multi_fg': multi_fg}
821
+
822
+ else:
823
+ image = cv2.imread(self.merged[idx])
824
+ alpha = cv2.imread(self.alpha[idx], 0)/255.
825
+ trimap = cv2.imread(self.trimap[idx], 0)
826
+ mask = (trimap >= 170).astype(np.float32)
827
+ image_name = os.path.split(self.merged[idx])[-1]
828
+
829
+ sample = {'image': image, 'alpha': alpha, 'trimap': trimap, 'mask': mask, 'image_name': image_name, 'alpha_shape': alpha.shape}
830
+
831
+ sample = self.transform(sample)
832
+
833
+ if self.return_keys is not None:
834
+ sample = self.select_keys(sample)
835
+ if self.dataset_name is not None:
836
+ sample['dataset_name'] = self.dataset_name
837
+ return sample
838
+
839
+ def _composite_fg(self, fg, alpha, idx):
840
+
841
+ multi_fg = False
842
+ if np.random.rand() < 0.5:
843
+ idx2 = np.random.randint(self.fg_num) + idx
844
+ fg2 = cv2.imread(self.fg[idx2 % self.fg_num])
845
+ alpha2 = cv2.imread(self.alpha[idx2 % self.fg_num], 0).astype(np.float32)/255.
846
+ h, w = alpha.shape
847
+ fg2 = cv2.resize(fg2, (w, h), interpolation=maybe_random_interp(cv2.INTER_NEAREST))
848
+ alpha2 = cv2.resize(alpha2, (w, h), interpolation=maybe_random_interp(cv2.INTER_NEAREST))
849
+
850
+ alpha_tmp = 1 - (1 - alpha) * (1 - alpha2)
851
+ if np.any(alpha_tmp < 1):
852
+ fg = fg.astype(np.float32) * alpha[:,:,None] + fg2.astype(np.float32) * (1 - alpha[:,:,None])
853
+ # The overlap of two 50% transparency should be 25%
854
+ alpha = alpha_tmp
855
+ fg = fg.astype(np.uint8)
856
+ multi_fg = True
857
+
858
+ if np.random.rand() < 0.25:
859
+ # fg = cv2.resize(fg, (640, 640), interpolation=maybe_random_interp(cv2.INTER_NEAREST))
860
+ # alpha = cv2.resize(alpha, (640, 640), interpolation=maybe_random_interp(cv2.INTER_NEAREST))
861
+ fg = cv2.resize(fg, (1280, 1280), interpolation=maybe_random_interp(cv2.INTER_NEAREST))
862
+ alpha = cv2.resize(alpha, (1280, 1280), interpolation=maybe_random_interp(cv2.INTER_NEAREST))
863
+
864
+ return fg, alpha, multi_fg
865
+
866
+ def __len__(self):
867
+ if self.phase == "train":
868
+ return len(self.bg)
869
+ else:
870
+ return len(self.alpha)
871
+
872
+
873
+ class ResziePad(object):
874
+
875
+ def __init__(self, target_size=1024):
876
+ self.target_size = target_size
877
+
878
+ def __call__(self, sample):
879
+ _, H, W = sample['image'].shape
880
+
881
+ scale = self.target_size * 1.0 / max(H, W)
882
+ new_H, new_W = H * scale, W * scale
883
+ new_W = int(new_W + 0.5)
884
+ new_H = int(new_H + 0.5)
885
+
886
+ choice = {'image', 'trimap', 'alpha'} if 'trimap' in sample.keys() else {'image', 'alpha'}
887
+ for key in choice:
888
+ if key in {'image', 'trimap'}:
889
+ sample[key] = F.interpolate(sample[key][None], size=(new_H, new_W), mode='bilinear', align_corners=False)[0]
890
+ else:
891
+ # sample[key] = F.interpolate(sample[key][None], size=(new_H, new_W), mode='nearest')[0]
892
+ sample[key] = F.interpolate(sample[key][None], size=(new_H, new_W), mode='bilinear', align_corners=False)[0]
893
+ padding = torch.zeros([sample[key].shape[0], self.target_size, self.target_size], dtype=sample[key].dtype, device=sample[key].device)
894
+ padding[:, : new_H, : new_W] = sample[key]
895
+ sample[key] = padding
896
+
897
+ return sample
898
+
899
+
900
+ class Cv2ResziePad(object):
901
+
902
+ def __init__(self, target_size=1024):
903
+ self.target_size = target_size
904
+
905
+ def __call__(self, sample):
906
+ H, W, _ = sample['image'].shape
907
+
908
+ scale = self.target_size * 1.0 / max(H, W)
909
+ new_H, new_W = H * scale, W * scale
910
+ new_W = int(new_W + 0.5)
911
+ new_H = int(new_H + 0.5)
912
+
913
+ choice = {'image', 'trimap', 'alpha'} if 'trimap' in sample.keys() and sample['trimap'] is not None else {'image', 'alpha'}
914
+ for key in choice:
915
+ sample[key] = cv2.resize(sample[key], (new_W, new_H), interpolation=cv2.INTER_LINEAR) # cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC
916
+
917
+ if key == 'image':
918
+ padding = np.zeros([self.target_size, self.target_size, sample[key].shape[-1]], dtype=sample[key].dtype)
919
+ padding[: new_H, : new_W, :] = sample[key]
920
+ sample[key] = padding
921
+ sample[key] = sample[key][:, :, ::-1].transpose((2, 0, 1)).astype(np.float32) #/ 255.0
922
+ else:
923
+ padding = np.zeros([self.target_size, self.target_size], dtype=sample[key].dtype)
924
+ padding[: new_H, : new_W] = sample[key]
925
+ sample[key] = padding
926
+ sample[key] = sample[key][None].astype(np.float32)
927
+ sample[key] = torch.from_numpy(sample[key])
928
+
929
+ return sample
930
+
931
+
932
+ class AdobeCompositionTest(Dataset):
933
+ def __init__(self, data_dir, target_size=1024, multi_fg=None):
934
+ self.data_dir = data_dir
935
+ self.file_names = sorted(os.listdir(os.path.join(self.data_dir, 'merged')))
936
+
937
+ test_trans = [
938
+ ResziePad(target_size=target_size),
939
+ GenBBox(bbox_offset_factor=0)
940
+ ]
941
+ self.transform = transforms.Compose(test_trans)
942
+ self.multi_fg = multi_fg
943
+
944
+ def __len__(self): # 1000
945
+ return len(self.file_names)
946
+
947
+ def __getitem__(self, idx):
948
+ phas = Image.open(os.path.join(self.data_dir, 'alpha_copy', self.file_names[idx])).convert('L')
949
+ tris = Image.open(os.path.join(self.data_dir, 'trimaps', self.file_names[idx]))
950
+ imgs = Image.open(os.path.join(self.data_dir, 'merged', self.file_names[idx]))
951
+ sample = {
952
+ 'ori_h_w': (imgs.size[1], imgs.size[0]),
953
+ 'data_type': 'Adobe'
954
+ }
955
+
956
+ sample['alpha'] = torchvision.transforms.functional.to_tensor(phas) # [1, H, W] 0.0 ~ 1.0
957
+ sample['trimap'] = torchvision.transforms.functional.to_tensor(tris) * 255.0
958
+ sample['image'] = torchvision.transforms.functional.to_tensor(imgs)
959
+ sample['image_name'] = 'Adobe_' + self.file_names[idx]
960
+
961
+ sample = self.transform(sample)
962
+ sample['trimap'][sample['trimap'] < 85] = 0
963
+ sample['trimap'][sample['trimap'] >= 170] = 1
964
+ sample['trimap'][sample['trimap'] >= 85] = 0.5
965
+
966
+ if self.multi_fg is not None:
967
+ sample['multi_fg'] = torch.tensor(self.multi_fg)
968
+
969
+ return sample
970
+
971
+
972
+ class SIMTest(Dataset):
973
+ def __init__(self, data_dir, target_size=1024, multi_fg=None):
974
+ self.data_dir = data_dir
975
+ self.file_names = sorted(glob.glob(os.path.join(*[data_dir, '*', 'alpha', '*']))) # [: 10]
976
+ test_trans = [
977
+ ResziePad(target_size=target_size),
978
+ GenBBox(bbox_offset_factor=0)
979
+ ]
980
+ self.transform = transforms.Compose(test_trans)
981
+ self.multi_fg = multi_fg
982
+
983
+ def __len__(self): # 1000
984
+ return len(self.file_names)
985
+
986
+ def __getitem__(self, idx):
987
+ phas = Image.open(self.file_names[idx]).convert('L')
988
+ # tris = Image.open(self.file_names[idx].replace('alpha', 'trimap'))
989
+ imgs = Image.open(self.file_names[idx].replace('alpha', 'merged'))
990
+ sample = {
991
+ 'ori_h_w': (imgs.size[1], imgs.size[0]),
992
+ 'data_type': 'SIM'
993
+ }
994
+
995
+ sample['alpha'] = torchvision.transforms.functional.to_tensor(phas) # [1, H, W] 0.0 ~ 1.0
996
+ # sample['trimap'] = torchvision.transforms.functional.to_tensor(tris) * 255.0
997
+ sample['image'] = torchvision.transforms.functional.to_tensor(imgs)
998
+ sample['image_name'] = 'SIM_{}_{}'.format(self.file_names[idx].split('/')[-3], self.file_names[idx].split('/')[-1])
999
+
1000
+ sample = self.transform(sample)
1001
+ # sample['trimap'][sample['trimap'] < 85] = 0
1002
+ # sample['trimap'][sample['trimap'] >= 170] = 1
1003
+ # sample['trimap'][sample['trimap'] >= 85] = 0.5
1004
+
1005
+ if self.multi_fg is not None:
1006
+ sample['multi_fg'] = torch.tensor(self.multi_fg)
1007
+
1008
+ return sample
1009
+
1010
+
1011
+ class RW100Test(Dataset):
1012
+ def __init__(self, data_dir, target_size=1024, multi_fg=None):
1013
+ self.data_dir = data_dir
1014
+ self.file_names = sorted(glob.glob(os.path.join(*[data_dir, 'mask', '*'])))
1015
+
1016
+ self.name_to_idx = dict()
1017
+ for idx, file_name in enumerate(self.file_names):
1018
+ self.name_to_idx[file_name.split('/')[-1].split('.')[0]] = idx
1019
+
1020
+ test_trans = [
1021
+ ResziePad(target_size=target_size),
1022
+ GenBBox(bbox_offset_factor=0, train_or_test='test', dataset_type='RW100')
1023
+ ]
1024
+ self.transform = transforms.Compose(test_trans)
1025
+ self.multi_fg = multi_fg
1026
+
1027
+ def __len__(self): # 1000
1028
+ return len(self.file_names)
1029
+
1030
+ def __getitem__(self, idx):
1031
+ phas = Image.open(self.file_names[idx]).convert('L')
1032
+ imgs = Image.open(self.file_names[idx].replace('mask', 'image')[:-6] + '.jpg')
1033
+ sample = {
1034
+ 'ori_h_w': (imgs.size[1], imgs.size[0]),
1035
+ 'data_type': 'RW100'
1036
+ }
1037
+
1038
+ sample['alpha'] = torchvision.transforms.functional.to_tensor(phas) # [1, H, W] 0.0 ~ 1.0
1039
+ sample['image'] = torchvision.transforms.functional.to_tensor(imgs)
1040
+ sample['image_name'] = 'RW100_' + self.file_names[idx].split('/')[-1]
1041
+
1042
+ sample = self.transform(sample)
1043
+
1044
+ if self.multi_fg is not None:
1045
+ sample['multi_fg'] = torch.tensor(self.multi_fg)
1046
+
1047
+ return sample
1048
+
1049
+
1050
+ class AIM500Test(Dataset):
1051
+ def __init__(self, data_dir, target_size=1024, multi_fg=None):
1052
+ self.data_dir = data_dir
1053
+ self.file_names = sorted(glob.glob(os.path.join(*[data_dir, 'original', '*'])))
1054
+
1055
+ self.name_to_idx = dict()
1056
+ for idx, file_name in enumerate(self.file_names):
1057
+ self.name_to_idx[file_name.split('/')[-1].split('.')[0]] = idx
1058
+
1059
+ test_trans = [
1060
+ ResziePad(target_size=target_size),
1061
+ GenBBox(bbox_offset_factor=0)
1062
+ ]
1063
+ self.transform = transforms.Compose(test_trans)
1064
+ self.multi_fg = multi_fg
1065
+
1066
+ def __len__(self): # 1000
1067
+ return len(self.file_names)
1068
+
1069
+ def __getitem__(self, idx):
1070
+ phas = Image.open(self.file_names[idx].replace('original', 'mask').replace('jpg', 'png')).convert('L')
1071
+ # tris = Image.open(self.file_names[idx].replace('original', 'trimap').replace('jpg', 'png')).convert('L')
1072
+ imgs = Image.open(self.file_names[idx])
1073
+ sample = {
1074
+ 'ori_h_w': (imgs.size[1], imgs.size[0]),
1075
+ 'data_type': 'AIM500'
1076
+ }
1077
+
1078
+ sample['alpha'] = torchvision.transforms.functional.to_tensor(phas) # [1, H, W] 0.0 ~ 1.0
1079
+ # sample['trimap'] = torchvision.transforms.functional.to_tensor(tris) * 255.0
1080
+ sample['image'] = torchvision.transforms.functional.to_tensor(imgs)
1081
+ sample['image_name'] = 'AIM500_' + self.file_names[idx].split('/')[-1]
1082
+
1083
+ sample = self.transform(sample)
1084
+ # sample['trimap'][sample['trimap'] < 85] = 0
1085
+ # sample['trimap'][sample['trimap'] >= 170] = 1
1086
+ # sample['trimap'][sample['trimap'] >= 85] = 0.5
1087
+
1088
+ if self.multi_fg is not None:
1089
+ sample['multi_fg'] = torch.tensor(self.multi_fg)
1090
+
1091
+ return sample
1092
+
1093
+
1094
+ class RWP636Test(Dataset):
1095
+ def __init__(self, data_dir, target_size=1024, multi_fg=None):
1096
+ self.data_dir = data_dir
1097
+ self.file_names = sorted(glob.glob(os.path.join(*[data_dir, 'image', '*'])))
1098
+
1099
+ self.name_to_idx = dict()
1100
+ for idx, file_name in enumerate(self.file_names):
1101
+ self.name_to_idx[file_name.split('/')[-1].split('.')[0]] = idx
1102
+
1103
+ test_trans = [
1104
+ ResziePad(target_size=target_size),
1105
+ GenBBox(bbox_offset_factor=0)
1106
+ ]
1107
+ self.transform = transforms.Compose(test_trans)
1108
+ self.multi_fg = multi_fg
1109
+
1110
+ def __len__(self): # 1000
1111
+ return len(self.file_names)
1112
+
1113
+ def __getitem__(self, idx):
1114
+ phas = Image.open(self.file_names[idx].replace('image', 'alpha').replace('jpg', 'png')).convert('L')
1115
+ imgs = Image.open(self.file_names[idx])
1116
+ sample = {
1117
+ 'ori_h_w': (imgs.size[1], imgs.size[0]),
1118
+ 'data_type': 'RWP636'
1119
+ }
1120
+
1121
+ sample['alpha'] = torchvision.transforms.functional.to_tensor(phas) # [1, H, W] 0.0 ~ 1.0
1122
+ sample['image'] = torchvision.transforms.functional.to_tensor(imgs)
1123
+ sample['image_name'] = 'RWP636_' + self.file_names[idx].split('/')[-1]
1124
+
1125
+ sample = self.transform(sample)
1126
+
1127
+ if self.multi_fg is not None:
1128
+ sample['multi_fg'] = torch.tensor(self.multi_fg)
1129
+
1130
+ return sample
1131
+
1132
+
1133
+ class AM2KTest(Dataset):
1134
+ def __init__(self, data_dir, target_size=1024, multi_fg=None):
1135
+ self.data_dir = data_dir
1136
+ self.file_names = sorted(glob.glob(os.path.join(*[data_dir, 'validation/original', '*'])))
1137
+ test_trans = [
1138
+ ResziePad(target_size=target_size),
1139
+ GenBBox(bbox_offset_factor=0)
1140
+ ]
1141
+ self.transform = transforms.Compose(test_trans)
1142
+ self.multi_fg = multi_fg
1143
+
1144
+ def __len__(self): # 1000
1145
+ return len(self.file_names)
1146
+
1147
+ def __getitem__(self, idx):
1148
+ phas = Image.open(self.file_names[idx].replace('original', 'mask').replace('jpg', 'png')).convert('L')
1149
+ # tris = Image.open(self.file_names[idx].replace('original', 'trimap').replace('jpg', 'png')).convert('L')
1150
+ imgs = Image.open(self.file_names[idx])
1151
+ sample = {
1152
+ 'ori_h_w': (imgs.size[1], imgs.size[0]),
1153
+ 'data_type': 'AM2K'
1154
+ }
1155
+
1156
+ sample['alpha'] = torchvision.transforms.functional.to_tensor(phas) # [1, H, W] 0.0 ~ 1.0
1157
+ # sample['trimap'] = torchvision.transforms.functional.to_tensor(tris) * 255.0
1158
+ sample['image'] = torchvision.transforms.functional.to_tensor(imgs)
1159
+ sample['image_name'] = 'AM2K_' + self.file_names[idx].split('/')[-1]
1160
+
1161
+ sample = self.transform(sample)
1162
+ # sample['trimap'][sample['trimap'] < 85] = 0
1163
+ # sample['trimap'][sample['trimap'] >= 170] = 1
1164
+ # sample['trimap'][sample['trimap'] >= 85] = 0.5
1165
+
1166
+ if self.multi_fg is not None:
1167
+ sample['multi_fg'] = torch.tensor(self.multi_fg)
1168
+
1169
+ return sample
1170
+
1171
+
1172
+ class P3M500Test(Dataset):
1173
+ def __init__(self, data_dir, target_size=1024, multi_fg=None):
1174
+ self.data_dir = data_dir
1175
+ self.file_names = sorted(glob.glob(os.path.join(*[data_dir, 'original_image', '*'])))
1176
+
1177
+ self.name_to_idx = dict()
1178
+ for idx, file_name in enumerate(self.file_names):
1179
+ self.name_to_idx[file_name.split('/')[-1].split('.')[0]] = idx
1180
+
1181
+ test_trans = [
1182
+ ResziePad(target_size=target_size),
1183
+ GenBBox(bbox_offset_factor=0)
1184
+ ]
1185
+ self.transform = transforms.Compose(test_trans)
1186
+ self.multi_fg = multi_fg
1187
+
1188
+ def __len__(self): # 1000
1189
+ return len(self.file_names)
1190
+
1191
+ def __getitem__(self, idx):
1192
+ phas = Image.open(self.file_names[idx].replace('original_image', 'mask').replace('jpg', 'png')).convert('L')
1193
+ # tris = Image.open(self.file_names[idx].replace('original_image', 'trimap').replace('jpg', 'png')).convert('L')
1194
+ imgs = Image.open(self.file_names[idx])
1195
+ sample = {
1196
+ 'ori_h_w': (imgs.size[1], imgs.size[0]),
1197
+ 'data_type': 'P3M500'
1198
+ }
1199
+
1200
+ sample['alpha'] = torchvision.transforms.functional.to_tensor(phas) # [1, H, W] 0.0 ~ 1.0
1201
+ # sample['trimap'] = torchvision.transforms.functional.to_tensor(tris) * 255.0
1202
+ sample['image'] = torchvision.transforms.functional.to_tensor(imgs)
1203
+ sample['image_name'] = 'P3M500_' + self.file_names[idx].split('/')[-1]
1204
+
1205
+ sample = self.transform(sample)
1206
+ # sample['trimap'][sample['trimap'] < 85] = 0
1207
+ # sample['trimap'][sample['trimap'] >= 170] = 1
1208
+ # sample['trimap'][sample['trimap'] >= 85] = 0.5
1209
+
1210
+ if self.multi_fg is not None:
1211
+ sample['multi_fg'] = torch.tensor(self.multi_fg)
1212
+
1213
+ return sample
1214
+
1215
+
1216
+ class MattingTest(Dataset):
1217
+ def __init__(
1218
+ self,
1219
+ data_type,
1220
+ data_dir,
1221
+ image_sub_path,
1222
+ alpha_sub_path,
1223
+ trimpa_sub_path=None,
1224
+ target_size=1024,
1225
+ multi_fg=None,
1226
+ ):
1227
+ self.data_type = data_type
1228
+ self.data_dir = data_dir
1229
+
1230
+ self.image_paths = sorted(glob.glob(os.path.join(*[data_dir, image_sub_path])))
1231
+ self.alpha_paths = sorted(glob.glob(os.path.join(*[data_dir, alpha_sub_path])))
1232
+ self.trimpa_paths = sorted(glob.glob(os.path.join(*[data_dir, trimpa_sub_path]))) if trimpa_sub_path is not None else None
1233
+
1234
+ self.name_to_idx = dict()
1235
+ for idx, file_name in enumerate(self.image_paths):
1236
+ self.name_to_idx[file_name.split('/')[-1].split('.')[0]] = idx
1237
+
1238
+ test_trans = [
1239
+ Cv2ResziePad(target_size=target_size),
1240
+ GenBBox(bbox_offset_factor=0)
1241
+ ]
1242
+ self.transform = transforms.Compose(test_trans)
1243
+ self.multi_fg = multi_fg
1244
+
1245
+ def __len__(self): # 1000
1246
+ return len(self.image_paths)
1247
+
1248
+ def __getitem__(self, idx):
1249
+
1250
+ img = cv2.imread(self.image_paths[idx])
1251
+ sample = {
1252
+ 'image': img.astype(np.float32) / 255,
1253
+ 'alpha': cv2.imread(self.alpha_paths[idx], 0).astype(np.float32) / 255,
1254
+ 'trimap': cv2.imread(self.trimpa_paths[idx], 0) if self.trimpa_paths is not None else None,
1255
+ 'ori_h_w': (img.shape[0], img.shape[1]),
1256
+ 'data_type': self.data_type,
1257
+ 'image_name': self.data_type + '_' + self.image_paths[idx].split('/')[-1]
1258
+ }
1259
+
1260
+ sample = self.transform(sample)
1261
+ if self.trimpa_paths is not None:
1262
+ sample['trimap'][sample['trimap'] < 85] = 0
1263
+ sample['trimap'][sample['trimap'] >= 170] = 1
1264
+ sample['trimap'][sample['trimap'] >= 85] = 0.5
1265
+ else:
1266
+ del sample['trimap']
1267
+
1268
+ if self.multi_fg is not None:
1269
+ sample['multi_fg'] = torch.tensor(self.multi_fg)
1270
+
1271
+ return sample
1272
+
1273
+
1274
+ def adobe_composition_collate_fn(batch):
1275
+ new_batch = defaultdict(list)
1276
+ for sub_batch in batch:
1277
+ for key in sub_batch.keys():
1278
+ new_batch[key].append(sub_batch[key])
1279
+ for key in new_batch:
1280
+ if isinstance(new_batch[key][0], torch.Tensor):
1281
+ new_batch[key] = torch.stack(new_batch[key])
1282
+ return dict(new_batch)
1283
+
1284
+
1285
+ def build_d2_test_dataloader(
1286
+ dataset,
1287
+ mapper=None,
1288
+ total_batch_size=None,
1289
+ local_batch_size=None,
1290
+ num_workers=0,
1291
+ collate_fn=None
1292
+ ):
1293
+
1294
+ assert (total_batch_size is None) != (
1295
+ local_batch_size is None
1296
+ ), "Either total_batch_size or local_batch_size must be specified"
1297
+
1298
+ world_size = comm.get_world_size()
1299
+
1300
+ if total_batch_size is not None:
1301
+ assert (
1302
+ total_batch_size > 0 and total_batch_size % world_size == 0
1303
+ ), "Total batch size ({}) must be divisible by the number of gpus ({}).".format(
1304
+ total_batch_size, world_size
1305
+ )
1306
+ batch_size = total_batch_size // world_size
1307
+
1308
+ if local_batch_size is not None:
1309
+ batch_size = local_batch_size
1310
+
1311
+ logger = logging.getLogger(__name__)
1312
+ if batch_size != 1:
1313
+ logger.warning(
1314
+ "When testing, batch size is set to 1. "
1315
+ "This is the only mode that is supported for d2."
1316
+ )
1317
+
1318
+ return build_detection_test_loader(
1319
+ dataset=dataset,
1320
+ mapper=mapper,
1321
+ sampler=None,
1322
+ num_workers=num_workers,
1323
+ collate_fn=collate_fn,
1324
+ )
1325
+
1326
+
1327
+ class AdobeCompositionEvaluator(DatasetEvaluator):
1328
+
1329
+ def __init__(
1330
+ self,
1331
+ save_eval_results_step=-1,
1332
+ output_dir=None,
1333
+ eval_dataset_type=['Adobe'],
1334
+ distributed=True,
1335
+ eval_w_sam_hq_mask = False,
1336
+ ):
1337
+
1338
+ self.save_eval_results_step = save_eval_results_step
1339
+ self.output_dir = output_dir
1340
+ self.eval_index = 0
1341
+ self.eval_dataset_type = eval_dataset_type
1342
+ self.eval_w_sam_hq_mask = eval_w_sam_hq_mask
1343
+
1344
+ self._distributed = distributed
1345
+ self._logger = logging.getLogger(__name__)
1346
+
1347
+ def reset(self):
1348
+ self.eval_metric = dict()
1349
+ for i in self.eval_dataset_type:
1350
+ self.eval_metric[i + '_MSE'] = []
1351
+ self.eval_metric[i + '_SAD'] = []
1352
+ self.eval_metric[i + '_MAD'] = []
1353
+ self.eval_metric[i + '_Grad'] = []
1354
+ self.eval_metric[i + '_Conn'] = []
1355
+
1356
+ os.makedirs(self.output_dir, exist_ok=True) if self.output_dir is not None else None
1357
+
1358
+ def process(self, inputs, outputs):
1359
+ """
1360
+ Args:
1361
+ inputs: {'alpha', 'trimap', 'image', 'bbox', 'image_name'}
1362
+ outputs: [1, 1, H, W] 0. ~ 1.
1363
+ """
1364
+
1365
+ # crop the black pad area
1366
+ assert inputs['image'].shape[-1] == inputs['image'].shape[-2] == 1024 and len(inputs['ori_h_w']) == 1
1367
+ inputs['ori_h_w'] = inputs['ori_h_w'][0]
1368
+ before_pad_h, before_pad_w = int(1024 / max(inputs['ori_h_w']) * inputs['ori_h_w'][0] + 0.5), int(1024 / max(inputs['ori_h_w']) * inputs['ori_h_w'][1] + 0.5)
1369
+ inputs['image'] = inputs['image'][:, :, :before_pad_h, :before_pad_w]
1370
+ inputs['alpha'] = inputs['alpha'][:, :, :before_pad_h, :before_pad_w]
1371
+
1372
+ if self.eval_w_sam_hq_mask:
1373
+ outputs, samhq_low_res_masks = outputs[0][:, :, :before_pad_h, :before_pad_w], outputs[1][:, :, :before_pad_h, :before_pad_w]
1374
+ pred_alpha, label_alpha, samhq_low_res_masks = outputs.cpu().numpy(), inputs['alpha'].numpy(), (samhq_low_res_masks > 0).float().cpu()
1375
+ else:
1376
+ outputs = outputs[:, :, :before_pad_h, :before_pad_w]
1377
+ pred_alpha, label_alpha = outputs.cpu().numpy(), inputs['alpha'].numpy()
1378
+
1379
+ # if 'trimap' in inputs.keys():
1380
+ # inputs['trimap'] = inputs['trimap'][:, :, :before_pad_h, :before_pad_w]
1381
+ # trimap = inputs['trimap'].numpy()
1382
+ # assert np.max(trimap) <= 1 and len(np.unique(trimap)) <= 3
1383
+ # sad_loss_unknown = compute_sad_loss(pred_alpha, label_alpha, trimap, area='unknown')
1384
+ # mse_loss_unknown = compute_mse_loss(pred_alpha, label_alpha, trimap, area='unknown')
1385
+
1386
+ # self.eval_metric[inputs['data_type'][0] + '_unknown_mse (1e-3)'].append(mse_loss_unknown)
1387
+ # self.eval_metric[inputs['data_type'][0] + '_unknown_sad (1e3)'].append(sad_loss_unknown)
1388
+
1389
+ # calculate loss
1390
+ assert np.max(pred_alpha) <= 1 and np.max(label_alpha) <= 1
1391
+ eval_pred = np.uint8(pred_alpha[0, 0] * 255.0 + 0.5) * 1.0
1392
+ eval_gt = label_alpha[0, 0] * 255.0
1393
+
1394
+ detailmap = np.zeros_like(eval_gt) + 128
1395
+ mse_loss_ = compute_mse_loss(eval_pred, eval_gt, detailmap)
1396
+ sad_loss_ = compute_sad_loss(eval_pred, eval_gt, detailmap)[0]
1397
+ mad_loss_ = compute_mad_loss(eval_pred, eval_gt, detailmap)
1398
+ grad_loss_ = compute_gradient_loss(eval_pred, eval_gt, detailmap)
1399
+ conn_loss_ = compute_connectivity_error(eval_pred, eval_gt, detailmap)
1400
+
1401
+ self.eval_metric[inputs['data_type'][0] + '_MSE'].append(mse_loss_)
1402
+ self.eval_metric[inputs['data_type'][0] + '_SAD'].append(sad_loss_)
1403
+ self.eval_metric[inputs['data_type'][0] + '_MAD'].append(mad_loss_)
1404
+ self.eval_metric[inputs['data_type'][0] + '_Grad'].append(grad_loss_)
1405
+ self.eval_metric[inputs['data_type'][0] + '_Conn'].append(conn_loss_)
1406
+
1407
+ # vis results
1408
+ if self.save_eval_results_step != -1 and self.eval_index % self.save_eval_results_step == 0:
1409
+ if self.eval_w_sam_hq_mask:
1410
+ self.save_vis_results(inputs, pred_alpha, samhq_low_res_masks)
1411
+ else:
1412
+ self.save_vis_results(inputs, pred_alpha)
1413
+ self.eval_index += 1
1414
+
1415
+ def save_vis_results(self, inputs, pred_alpha, samhq_low_res_masks=None):
1416
+
1417
+ # image
1418
+ image = inputs['image'][0].permute(1, 2, 0) * 255.0
1419
+ l, u, r, d = int(inputs['bbox'][0, 0, 0].item()), int(inputs['bbox'][0, 0, 1].item()), int(inputs['bbox'][0, 0, 2].item()), int(inputs['bbox'][0, 0, 3].item())
1420
+ red_line = torch.tensor([[255., 0., 0.]], device=image.device, dtype=image.dtype)
1421
+ image[u: d, l, :] = red_line
1422
+ image[u: d, r, :] = red_line
1423
+ image[u, l: r, :] = red_line
1424
+ image[d, l: r, :] = red_line
1425
+ image = np.uint8(image.numpy())
1426
+
1427
+ # trimap, pred_alpha, label_alpha
1428
+ save_results = [image]
1429
+
1430
+ choice = [inputs['trimap'], torch.from_numpy(pred_alpha), inputs['alpha']] if 'trimap' in inputs.keys() else [torch.from_numpy(pred_alpha), inputs['alpha']]
1431
+ for val in choice:
1432
+ val = val[0].permute(1, 2, 0).repeat(1, 1, 3) * 255.0 + 0.5 # +0.5 and int() = round()
1433
+ val = np.uint8(val.numpy())
1434
+ save_results.append(val)
1435
+
1436
+ if samhq_low_res_masks is not None:
1437
+ save_results.append(np.uint8(samhq_low_res_masks[0].permute(1, 2, 0).repeat(1, 1, 3).numpy() * 255.0))
1438
+
1439
+ save_results = np.concatenate(save_results, axis=1)
1440
+ save_name = os.path.join(self.output_dir, inputs['image_name'][0])
1441
+ Image.fromarray(save_results).save(save_name.replace('.jpg', '.png'))
1442
+
1443
+ def evaluate(self):
1444
+
1445
+ if self._distributed:
1446
+ comm.synchronize()
1447
+ eval_metric = comm.gather(self.eval_metric, dst=0)
1448
+
1449
+ if not comm.is_main_process():
1450
+ return {}
1451
+
1452
+ merges_eval_metric = defaultdict(list)
1453
+ for sub_eval_metric in eval_metric:
1454
+ for key, val in sub_eval_metric.items():
1455
+ merges_eval_metric[key] += val
1456
+ eval_metric = merges_eval_metric
1457
+
1458
+ else:
1459
+ eval_metric = self.eval_metric
1460
+
1461
+ eval_results = {}
1462
+
1463
+ for key, val in eval_metric.items():
1464
+ if len(val) != 0:
1465
+ # if 'mse' in key:
1466
+ # eval_results[key] = np.array(val).mean() * 1e3
1467
+ # else:
1468
+ # assert 'sad' in key
1469
+ # eval_results[key] = np.array(val).mean() / 1e3
1470
+ eval_results[key] = np.array(val).mean()
1471
+
1472
+ return eval_results
1473
+
1474
+
1475
+ if __name__ == '__main__':
1476
+ pass
data/evaluate.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import scipy.ndimage
2
+ import numpy as np
3
+ from skimage.measure import label
4
+ import scipy.ndimage.morphology
5
+
6
+
7
+ def gauss(x, sigma):
8
+ y = np.exp(-x ** 2 / (2 * sigma ** 2)) / (sigma * np.sqrt(2 * np.pi))
9
+ return y
10
+
11
+
12
+ def dgauss(x, sigma):
13
+ y = -x * gauss(x, sigma) / (sigma ** 2)
14
+ return y
15
+
16
+
17
+ def gaussgradient(im, sigma):
18
+ epsilon = 1e-2
19
+ halfsize = np.ceil(sigma * np.sqrt(-2 * np.log(np.sqrt(2 * np.pi) * sigma * epsilon))).astype(np.int32)
20
+ size = 2 * halfsize + 1
21
+ hx = np.zeros((size, size))
22
+ for i in range(0, size):
23
+ for j in range(0, size):
24
+ u = [i - halfsize, j - halfsize]
25
+ hx[i, j] = gauss(u[0], sigma) * dgauss(u[1], sigma)
26
+
27
+ hx = hx / np.sqrt(np.sum(np.abs(hx) * np.abs(hx)))
28
+ hy = hx.transpose()
29
+
30
+ gx = scipy.ndimage.convolve(im, hx, mode='nearest')
31
+ gy = scipy.ndimage.convolve(im, hy, mode='nearest')
32
+
33
+ return gx, gy
34
+
35
+
36
+ def compute_gradient_loss(pred, target, trimap):
37
+
38
+ pred = pred / 255.0
39
+ target = target / 255.0
40
+
41
+ pred_x, pred_y = gaussgradient(pred, 1.4)
42
+ target_x, target_y = gaussgradient(target, 1.4)
43
+
44
+ pred_amp = np.sqrt(pred_x ** 2 + pred_y ** 2)
45
+ target_amp = np.sqrt(target_x ** 2 + target_y ** 2)
46
+
47
+ error_map = (pred_amp - target_amp) ** 2
48
+ loss = np.sum(error_map[trimap == 128])
49
+
50
+ return loss / 1000.
51
+
52
+
53
+ def getLargestCC(segmentation):
54
+ labels = label(segmentation, connectivity=1)
55
+ largestCC = labels == np.argmax(np.bincount(labels.flat))
56
+ return largestCC
57
+
58
+
59
+ def compute_connectivity_error(pred, target, trimap, step=0.1):
60
+ pred = pred / 255.0
61
+ target = target / 255.0
62
+ h, w = pred.shape
63
+
64
+ thresh_steps = list(np.arange(0, 1 + step, step))
65
+ l_map = np.ones_like(pred, dtype=np.float32) * -1
66
+ for i in range(1, len(thresh_steps)):
67
+ pred_alpha_thresh = (pred >= thresh_steps[i]).astype(np.int32)
68
+ target_alpha_thresh = (target >= thresh_steps[i]).astype(np.int32)
69
+
70
+ omega = getLargestCC(pred_alpha_thresh * target_alpha_thresh).astype(np.int32)
71
+ flag = ((l_map == -1) & (omega == 0)).astype(np.int32)
72
+ l_map[flag == 1] = thresh_steps[i - 1]
73
+
74
+ l_map[l_map == -1] = 1
75
+
76
+ pred_d = pred - l_map
77
+ target_d = target - l_map
78
+ pred_phi = 1 - pred_d * (pred_d >= 0.15).astype(np.int32)
79
+ target_phi = 1 - target_d * (target_d >= 0.15).astype(np.int32)
80
+ loss = np.sum(np.abs(pred_phi - target_phi)[trimap == 128])
81
+
82
+ return loss / 1000.
83
+
84
+
85
+ def compute_mse_loss(pred, target, trimap):
86
+ error_map = (pred - target) / 255.0
87
+ loss = np.sum((error_map ** 2) * (trimap == 128)) / (np.sum(trimap == 128) + 1e-8)
88
+
89
+ return loss
90
+
91
+
92
+ def compute_sad_loss(pred, target, trimap):
93
+ error_map = np.abs((pred - target) / 255.0)
94
+ loss = np.sum(error_map * (trimap == 128))
95
+
96
+ return loss / 1000, np.sum(trimap == 128) / 1000
97
+
98
+ def compute_mad_loss(pred, target, trimap):
99
+ error_map = np.abs((pred - target) / 255.0)
100
+ loss = np.sum(error_map * (trimap == 128)) / (np.sum(trimap == 128) + 1e-8)
101
+
102
+ return loss
data/p3m10k_dataset.py ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+ import cv2
5
+ from torch.utils.data import Dataset
6
+ from torchvision import transforms
7
+ import math
8
+ import torch.nn.functional as F
9
+
10
+
11
+ class GenBBox(object):
12
+ def __init__(self, bbox_offset_factor = 0.1, random_crop_bbox = None, train_or_test = 'train', dataset_type = None, random_auto_matting=None):
13
+ self.bbox_offset_factor = bbox_offset_factor
14
+ self.random_crop_bbox = random_crop_bbox
15
+ self.train_or_test = train_or_test
16
+ self.dataset_type = dataset_type
17
+ self.random_auto_matting = random_auto_matting
18
+
19
+ def __call__(self, sample):
20
+
21
+ alpha = sample['alpha'] # [1, H, W] 0.0 ~ 1.0
22
+ indices = torch.nonzero(alpha[0], as_tuple=True)
23
+
24
+ if len(indices[0]) > 0:
25
+
26
+ min_x, min_y = torch.min(indices[1]), torch.min(indices[0])
27
+ max_x, max_y = torch.max(indices[1]), torch.max(indices[0])
28
+
29
+ if self.random_crop_bbox is not None and np.random.uniform(0, 1) < self.random_crop_bbox:
30
+ ori_h_w = (sample['alpha'].shape[-2], sample['alpha'].shape[-1])
31
+ sample['alpha'] = F.interpolate(sample['alpha'][None, :, min_y: max_y + 1, min_x: max_x + 1], size=ori_h_w, mode='bilinear', align_corners=False)[0]
32
+ sample['image'] = F.interpolate(sample['image'][None, :, min_y: max_y + 1, min_x: max_x + 1], size=ori_h_w, mode='bilinear', align_corners=False)[0]
33
+ sample['trimap'] = F.interpolate(sample['trimap'][None, :, min_y: max_y + 1, min_x: max_x + 1], size=ori_h_w, mode='nearest')[0]
34
+ bbox = torch.tensor([[0, 0, ori_h_w[1] - 1, ori_h_w[0] - 1]])
35
+
36
+ elif self.bbox_offset_factor != 0:
37
+ bbox_w = max(1, max_x - min_x)
38
+ bbox_h = max(1, max_y - min_y)
39
+ offset_w = math.ceil(self.bbox_offset_factor * bbox_w)
40
+ offset_h = math.ceil(self.bbox_offset_factor * bbox_h)
41
+
42
+ min_x = max(0, min_x + np.random.randint(-offset_w, offset_w))
43
+ max_x = min(alpha.shape[2] - 1, max_x + np.random.randint(-offset_w, offset_w))
44
+ min_y = max(0, min_y + np.random.randint(-offset_h, offset_h))
45
+ max_y = min(alpha.shape[1] - 1, max_y + np.random.randint(-offset_h, offset_h))
46
+ bbox = torch.tensor([[min_x, min_y, max_x, max_y]])
47
+ else:
48
+ bbox = torch.tensor([[min_x, min_y, max_x, max_y]])
49
+
50
+ if self.random_auto_matting is not None and np.random.uniform(0, 1) < self.random_auto_matting:
51
+ bbox = torch.tensor([[0, 0, alpha.shape[2] - 1, alpha.shape[1] - 1]])
52
+
53
+ else:
54
+ bbox = torch.zeros(1, 4)
55
+
56
+ sample['bbox'] = bbox.float()
57
+ return sample
58
+
59
+ def random_interp():
60
+ return np.random.choice([cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_LANCZOS4])
61
+
62
+
63
+ class SplitConcatImage(object):
64
+
65
+ def __init__(self, concat_num=4, wo_mask_to_mattes=False):
66
+ self.concat_num = concat_num
67
+ self.wo_mask_to_mattes = wo_mask_to_mattes
68
+ if self.wo_mask_to_mattes:
69
+ assert self.concat_num == 5
70
+
71
+ def __call__(self, concat_image):
72
+ if isinstance(concat_image, list):
73
+ concat_image, image_path = concat_image[0], concat_image[1]
74
+ else:
75
+ image_path = None
76
+ H, W, _ = concat_image.shape
77
+
78
+ concat_num = self.concat_num
79
+ if image_path is not None:
80
+ if '06-14' in image_path:
81
+ concat_num = 4
82
+ elif 'ori_mask' in image_path or 'SEMat' in image_path:
83
+ concat_num = 3
84
+ else:
85
+ concat_num = 5
86
+
87
+ assert W % concat_num == 0
88
+ W = W // concat_num
89
+
90
+ image = concat_image[:H, :W]
91
+ if self.concat_num != 3:
92
+ trimap = concat_image[:H, (concat_num - 2) * W: (concat_num - 1) * W]
93
+ if self.wo_mask_to_mattes:
94
+ alpha = concat_image[:H, 2 * W: 3 * W]
95
+ else:
96
+ alpha = concat_image[:H, (concat_num - 1) * W: concat_num * W]
97
+ else:
98
+ trimap = concat_image[:H, (concat_num - 1) * W: concat_num * W]
99
+ alpha = concat_image[:H, (concat_num - 2) * W: (concat_num - 1) * W]
100
+
101
+ return {'image': image, 'trimap': trimap, 'alpha': alpha}
102
+
103
+
104
+ class RandomHorizontalFlip(object):
105
+
106
+ def __init__(self, prob=0.5):
107
+ self.prob = prob
108
+
109
+ def __call__(self, sample):
110
+ if np.random.uniform(0, 1) < self.prob:
111
+ for key in sample.keys():
112
+ sample[key] = cv2.flip(sample[key], 1)
113
+ return sample
114
+
115
+ class EmptyAug(object):
116
+ def __call__(self, sample):
117
+ return sample
118
+
119
+ class RandomReszieCrop(object):
120
+
121
+ def __init__(self, output_size=1024, aug_scale_min=0.5, aug_scale_max=1.5):
122
+ self.desired_size = output_size
123
+ self.aug_scale_min = aug_scale_min
124
+ self.aug_scale_max = aug_scale_max
125
+
126
+ def __call__(self, sample):
127
+ H, W, _ = sample['image'].shape
128
+ sample['trimap'] = sample['trimap'][:, :, None].repeat(3, axis=-1)
129
+ sample['alpha'] = sample['alpha'][:, :, None].repeat(3, axis=-1)
130
+
131
+ if self.aug_scale_min == 1.0 and self.aug_scale_max == 1.0:
132
+ crop_H, crop_W = H, W
133
+ crop_y1, crop_y2 = 0, crop_H
134
+ crop_x1, crop_x2 = 0, crop_W
135
+ scale_W, scaled_H = W, H
136
+ elif self.aug_scale_min == -1.0 and self.aug_scale_max == -1.0:
137
+ scale = min(self.desired_size / H, self.desired_size / W)
138
+ scaled_H, scale_W = round(H * scale), round(W * scale)
139
+ crop_H, crop_W = scaled_H, scale_W
140
+ crop_y1, crop_y2 = 0, crop_H
141
+ crop_x1, crop_x2 = 0, crop_W
142
+ else:
143
+ # random size
144
+ random_scale = np.random.uniform(0, 1) * (self.aug_scale_max - self.aug_scale_min) + self.aug_scale_min # random_val: 0.5 ~ 1.5
145
+ scaled_size = round(random_scale * self.desired_size)
146
+
147
+ scale = min(scaled_size / H, scaled_size / W)
148
+ scaled_H, scale_W = round(H * scale), round(W * scale)
149
+
150
+ # random crop
151
+ crop_H, crop_W = min(self.desired_size, scaled_H), min(self.desired_size, scale_W) # crop_size
152
+ margin_H, margin_W = max(scaled_H - crop_H, 0), max(scale_W - crop_W, 0)
153
+ offset_H, offset_W = np.random.randint(0, margin_H + 1), np.random.randint(0, margin_W + 1)
154
+ crop_y1, crop_y2 = offset_H, offset_H + crop_H
155
+ crop_x1, crop_x2 = offset_W, offset_W + crop_W
156
+
157
+ for key in sample.keys():
158
+ sample[key] = cv2.resize(sample[key], (scale_W, scaled_H), interpolation=random_interp())[crop_y1: crop_y2, crop_x1: crop_x2, :] # resize and crop
159
+ padding = np.zeros(shape=(self.desired_size, self.desired_size, 3), dtype=sample[key].dtype) # pad to desired_size
160
+ padding[: crop_H, : crop_W, :] = sample[key]
161
+ sample[key] = padding
162
+
163
+ return sample
164
+
165
+
166
+ class RandomJitter(object):
167
+ """
168
+ Random change the hue of the image
169
+ """
170
+
171
+ def __call__(self, sample):
172
+
173
+ image = sample['image']
174
+
175
+ # convert to HSV space, convert to float32 image to keep precision during space conversion.
176
+ image = cv2.cvtColor(image.astype(np.float32)/255.0, cv2.COLOR_BGR2HSV)
177
+ # Hue noise
178
+ hue_jitter = np.random.randint(-40, 40)
179
+ image[:, :, 0] = np.remainder(image[:, :, 0].astype(np.float32) + hue_jitter, 360)
180
+ # Saturation noise
181
+ sat_bar = image[:, :, 1].mean()
182
+
183
+ sat_jitter = np.random.rand()*(1.1 - sat_bar)/5 - (1.1 - sat_bar) / 10
184
+ sat = image[:, :, 1]
185
+ sat = np.abs(sat + sat_jitter)
186
+ sat[sat>1] = 2 - sat[sat>1]
187
+ image[:, :, 1] = sat
188
+ # Value noise
189
+ val_bar = image[:, :, 2].mean()
190
+
191
+ val_jitter = np.random.rand()*(1.1 - val_bar)/5-(1.1 - val_bar) / 10
192
+ val = image[:, :, 2]
193
+ val = np.abs(val + val_jitter)
194
+ val[val>1] = 2 - val[val>1]
195
+ image[:, :, 2] = val
196
+ # convert back to BGR space
197
+ image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR)
198
+ sample['image'] = image * 255
199
+
200
+ return sample
201
+
202
+
203
+ class ToTensor(object):
204
+
205
+ def __call__(self, sample):
206
+ image, alpha, trimap = sample['image'][:, :, ::-1], sample['alpha'], sample['trimap']
207
+
208
+ # image
209
+ image = image.transpose((2, 0, 1)) / 255.
210
+ sample['image'] = torch.from_numpy(image).float()
211
+
212
+ # alpha
213
+ alpha = alpha.transpose((2, 0, 1))[0: 1] / 255.
214
+ alpha[alpha < 0 ] = 0
215
+ alpha[alpha > 1] = 1
216
+ sample['alpha'] = torch.from_numpy(alpha).float()
217
+
218
+ # trimap
219
+ trimap = trimap.transpose((2, 0, 1))[0: 1] / 1.
220
+ sample['trimap'] = torch.from_numpy(trimap).float()
221
+ sample['trimap'][sample['trimap'] < 85] = 0
222
+ sample['trimap'][sample['trimap'] >= 170] = 1
223
+ sample['trimap'][sample['trimap'] >= 85] = 0.5
224
+
225
+ return sample
226
+
227
+
228
+ class GenTrimap(object):
229
+ def __init__(self):
230
+ self.erosion_kernels = [None] + [cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (size, size)) for size in range(1,100)]
231
+
232
+ def __call__(self, sample):
233
+ alpha = sample['alpha']
234
+ h, w = alpha.shape
235
+
236
+ max_kernel_size = max(30, int((min(h,w) / 2048) * 30))
237
+
238
+ ### generate trimap
239
+ fg_mask = (alpha / 255.0 + 1e-5).astype(np.int32).astype(np.uint8)
240
+ bg_mask = (1 - alpha / 255.0 + 1e-5).astype(np.int32).astype(np.uint8)
241
+ fg_mask = cv2.erode(fg_mask, self.erosion_kernels[np.random.randint(1, max_kernel_size)])
242
+ bg_mask = cv2.erode(bg_mask, self.erosion_kernels[np.random.randint(1, max_kernel_size)])
243
+
244
+ trimap = np.ones_like(alpha) * 128
245
+ trimap[fg_mask == 1] = 255
246
+ trimap[bg_mask == 1] = 0
247
+
248
+ trimap = cv2.resize(trimap, (w,h), interpolation=cv2.INTER_NEAREST)
249
+ sample['trimap'] = trimap
250
+
251
+ return sample
252
+
253
+
254
+ class P3MData(Dataset):
255
+ def __init__(
256
+ self,
257
+ data_root_path = '/root/data/my_path_b/public_data/data/matting/P3M-10k/train/blurred_image/',
258
+ output_size = 1024,
259
+ aug_scale_min = 0.8,
260
+ aug_scale_max = 1.5,
261
+ with_bbox = True,
262
+ bbox_offset_factor = 0.05,
263
+ num_ratio = 4.06, # 9421 * 4.06 = 38249.26 (38251)
264
+ ):
265
+
266
+ self.data_root_path = data_root_path
267
+ self.output_size = output_size
268
+ self.aug_scale_min = aug_scale_min
269
+ self.aug_scale_max = aug_scale_max
270
+ self.with_bbox = with_bbox
271
+ self.bbox_offset_factor = bbox_offset_factor
272
+ self.num_ratio = num_ratio
273
+
274
+ self.image_names = os.listdir(self.data_root_path)
275
+ self.image_names = [i for i in self.image_names if 'jpg' in i]
276
+ self.image_names.sort()
277
+
278
+ train_trans = [
279
+ RandomHorizontalFlip(prob=0 if hasattr(self, 'return_image_name') and self.return_image_name else 0.5),
280
+ GenTrimap(),
281
+ RandomReszieCrop(self.output_size, self.aug_scale_min, self.aug_scale_max),
282
+ RandomJitter(),
283
+ ToTensor(),
284
+ GenBBox(bbox_offset_factor=self.bbox_offset_factor)
285
+ ]
286
+ self.transform = transforms.Compose(train_trans)
287
+
288
+ def __getitem__(self, idx):
289
+
290
+ if self.num_ratio is not None:
291
+ if self.num_ratio < 1.0:
292
+ idx = np.random.randint(0, len(self.image_names))
293
+ else:
294
+ idx = idx % len(self.image_names)
295
+
296
+ image_path = os.path.join(self.data_root_path, self.image_names[idx])
297
+ alpha_path = image_path.replace('jpg', 'png').replace('blurred_image', 'mask')
298
+
299
+ sample = self.transform({
300
+ 'image': cv2.imread(image_path),
301
+ 'alpha': cv2.imread(alpha_path, 0),
302
+ })
303
+
304
+ sample['dataset_name'] = 'P3M'
305
+ sample['multi_fg'] = False
306
+
307
+ return sample
308
+
309
+ def __len__(self):
310
+ if self.num_ratio is not None:
311
+ return int(len(self.image_names) * self.num_ratio)
312
+ else:
313
+ return len(self.image_names)
314
+
315
+
316
+ if __name__ == '__main__':
317
+
318
+ dataset = P3MData()
319
+ data = dataset[0]
320
+ print(len(dataset))
321
+ for key, val in data.items():
322
+ if isinstance(val, torch.Tensor):
323
+ print(key, val.shape, torch.min(val), torch.max(val), torch.unique(val))
324
+ else:
325
+ print(key, val)
data/rand_augment.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # copyright: https://github.com/ildoonet/pytorch-randaugment
2
+ # code in this file is adpated from rpmcruz/autoaugment
3
+ # https://github.com/rpmcruz/autoaugment/blob/master/transformations.py
4
+ # This code is modified version of one of ildoonet, for randaugmentation of fixmatch.
5
+
6
+ import random
7
+
8
+ import PIL, PIL.ImageOps, PIL.ImageEnhance, PIL.ImageDraw
9
+ import numpy as np
10
+ import torch
11
+ import torch.nn.functional as F
12
+ from PIL import Image
13
+
14
+
15
+ def AutoContrast(img, _):
16
+ return PIL.ImageOps.autocontrast(img)
17
+
18
+
19
+ def Brightness(img, v):
20
+ assert v >= 0.0
21
+ return PIL.ImageEnhance.Brightness(img).enhance(v)
22
+
23
+
24
+ def Color(img, v):
25
+ assert v >= 0.0
26
+ return PIL.ImageEnhance.Color(img).enhance(v)
27
+
28
+
29
+ def Contrast(img, v):
30
+ assert v >= 0.0
31
+ return PIL.ImageEnhance.Contrast(img).enhance(v)
32
+
33
+
34
+ def Equalize(img, _):
35
+ return PIL.ImageOps.equalize(img)
36
+
37
+
38
+ def Invert(img, _):
39
+ return PIL.ImageOps.invert(img)
40
+
41
+
42
+ def Identity(img, v):
43
+ return img
44
+
45
+
46
+ def Posterize(img, v): # [4, 8]
47
+ v = int(v)
48
+ v = max(1, v)
49
+ return PIL.ImageOps.posterize(img, v)
50
+
51
+
52
+ def Rotate(img, v): # [-30, 30]
53
+ #assert -30 <= v <= 30
54
+ #if random.random() > 0.5:
55
+ # v = -v
56
+ return img.rotate(v)
57
+
58
+
59
+
60
+ def Sharpness(img, v): # [0.1,1.9]
61
+ assert v >= 0.0
62
+ return PIL.ImageEnhance.Sharpness(img).enhance(v)
63
+
64
+
65
+ def ShearX(img, v): # [-0.3, 0.3]
66
+ #assert -0.3 <= v <= 0.3
67
+ #if random.random() > 0.5:
68
+ # v = -v
69
+ return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0))
70
+
71
+
72
+ def ShearY(img, v): # [-0.3, 0.3]
73
+ #assert -0.3 <= v <= 0.3
74
+ #if random.random() > 0.5:
75
+ # v = -v
76
+ return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0))
77
+
78
+
79
+ def TranslateX(img, v): # [-150, 150] => percentage: [-0.45, 0.45]
80
+ #assert -0.3 <= v <= 0.3
81
+ #if random.random() > 0.5:
82
+ # v = -v
83
+ v = v * img.size[0]
84
+ return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0))
85
+
86
+
87
+ def TranslateXabs(img, v): # [-150, 150] => percentage: [-0.45, 0.45]
88
+ #assert v >= 0.0
89
+ #if random.random() > 0.5:
90
+ # v = -v
91
+ return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0))
92
+
93
+
94
+ def TranslateY(img, v): # [-150, 150] => percentage: [-0.45, 0.45]
95
+ #assert -0.3 <= v <= 0.3
96
+ #if random.random() > 0.5:
97
+ # v = -v
98
+ v = v * img.size[1]
99
+ return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v))
100
+
101
+
102
+ def TranslateYabs(img, v): # [-150, 150] => percentage: [-0.45, 0.45]
103
+ #assert 0 <= v
104
+ #if random.random() > 0.5:
105
+ # v = -v
106
+ return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v))
107
+
108
+
109
+ def Solarize(img, v): # [0, 256]
110
+ assert 0 <= v <= 256
111
+ return PIL.ImageOps.solarize(img, v)
112
+
113
+
114
+ def Cutout(img, v): #[0, 60] => percentage: [0, 0.2] => change to [0, 0.5]
115
+ assert 0.0 <= v <= 0.5
116
+ if v <= 0.:
117
+ return img
118
+
119
+ v = v * img.size[0]
120
+ return CutoutAbs(img, v)
121
+
122
+
123
+ def CutoutAbs(img, v): # [0, 60] => percentage: [0, 0.2]
124
+ # assert 0 <= v <= 20
125
+ if v < 0:
126
+ return img
127
+ w, h = img.size
128
+ x0 = np.random.uniform(w)
129
+ y0 = np.random.uniform(h)
130
+
131
+ x0 = int(max(0, x0 - v / 2.))
132
+ y0 = int(max(0, y0 - v / 2.))
133
+ x1 = min(w, x0 + v)
134
+ y1 = min(h, y0 + v)
135
+
136
+ xy = (x0, y0, x1, y1)
137
+ color = (125, 123, 114)
138
+ # color = (0, 0, 0)
139
+ img = img.copy()
140
+ PIL.ImageDraw.Draw(img).rectangle(xy, color)
141
+ return img
142
+
143
+
144
+ def augment_list():
145
+ l = [
146
+ (AutoContrast, 0, 1),
147
+ (Brightness, 0.05, 0.95),
148
+ (Color, 0.05, 0.95),
149
+ (Contrast, 0.05, 0.95),
150
+ (Equalize, 0, 1),
151
+ (Identity, 0, 1),
152
+ (Posterize, 4, 8),
153
+ # (Rotate, -30, 30),
154
+ (Sharpness, 0.05, 0.95),
155
+ # (ShearX, -0.3, 0.3),
156
+ # (ShearY, -0.3, 0.3),
157
+ (Solarize, 0, 256),
158
+ # (TranslateX, -0.3, 0.3),
159
+ # (TranslateY, -0.3, 0.3)
160
+ ]
161
+ return l
162
+
163
+
164
+ class RandAugment:
165
+ def __init__(self, n, m):
166
+ self.n = n
167
+ self.m = m # [0, 30] in fixmatch, deprecated.
168
+ self.augment_list = augment_list()
169
+
170
+
171
+ def __call__(self, img, cutout=True):
172
+ ops = random.choices(self.augment_list, k=self.n)
173
+ for op, min_val, max_val in ops:
174
+ val = min_val + float(max_val - min_val)*random.random()
175
+ img = op(img, val)
176
+ if cutout:
177
+ cutout_val = random.random() * 0.5
178
+ img = Cutout(img, cutout_val) #for fixmatch
179
+ return img
180
+
181
+
182
+ if __name__ == '__main__':
183
+ # randaug = RandAugment(3,5)
184
+ # print(randaug)
185
+ # for item in randaug.augment_list:
186
+ # print(item)
187
+ import os
188
+
189
+ os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
190
+ img = PIL.Image.open('./u.jpg')
191
+ randaug = RandAugment(3,6)
192
+ img = randaug(img)
193
+ import matplotlib
194
+ from matplotlib import pyplot as plt
195
+ plt.imshow(img)
196
+ plt.show()
data/refmatte_dataset.py ADDED
@@ -0,0 +1,418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+ import cv2
5
+ from torch.utils.data import Dataset
6
+ from torchvision import transforms
7
+ import random
8
+ import imgaug.augmenters as iaa
9
+ import numbers
10
+ import math
11
+
12
+
13
+ def random_interp():
14
+ return np.random.choice([cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_LANCZOS4])
15
+
16
+ class RandomAffine(object):
17
+ """
18
+ Random affine translation
19
+ """
20
+ def __init__(self, degrees, translate=None, scale=None, shear=None, flip=None, resample=False, fillcolor=0):
21
+ if isinstance(degrees, numbers.Number):
22
+ if degrees < 0:
23
+ raise ValueError("If degrees is a single number, it must be positive.")
24
+ self.degrees = (-degrees, degrees)
25
+ else:
26
+ assert isinstance(degrees, (tuple, list)) and len(degrees) == 2, \
27
+ "degrees should be a list or tuple and it must be of length 2."
28
+ self.degrees = degrees
29
+
30
+ if translate is not None:
31
+ assert isinstance(translate, (tuple, list)) and len(translate) == 2, \
32
+ "translate should be a list or tuple and it must be of length 2."
33
+ for t in translate:
34
+ if not (0.0 <= t <= 1.0):
35
+ raise ValueError("translation values should be between 0 and 1")
36
+ self.translate = translate
37
+
38
+ if scale is not None:
39
+ assert isinstance(scale, (tuple, list)) and len(scale) == 2, \
40
+ "scale should be a list or tuple and it must be of length 2."
41
+ for s in scale:
42
+ if s <= 0:
43
+ raise ValueError("scale values should be positive")
44
+ self.scale = scale
45
+
46
+ if shear is not None:
47
+ if isinstance(shear, numbers.Number):
48
+ if shear < 0:
49
+ raise ValueError("If shear is a single number, it must be positive.")
50
+ self.shear = (-shear, shear)
51
+ else:
52
+ assert isinstance(shear, (tuple, list)) and len(shear) == 2, \
53
+ "shear should be a list or tuple and it must be of length 2."
54
+ self.shear = shear
55
+ else:
56
+ self.shear = shear
57
+
58
+ self.resample = resample
59
+ self.fillcolor = fillcolor
60
+ self.flip = flip
61
+
62
+ @staticmethod
63
+ def get_params(degrees, translate, scale_ranges, shears, flip, img_size):
64
+ """Get parameters for affine transformation
65
+
66
+ Returns:
67
+ sequence: params to be passed to the affine transformation
68
+ """
69
+ angle = random.uniform(degrees[0], degrees[1])
70
+ if translate is not None:
71
+ max_dx = translate[0] * img_size[0]
72
+ max_dy = translate[1] * img_size[1]
73
+ translations = (np.round(random.uniform(-max_dx, max_dx)),
74
+ np.round(random.uniform(-max_dy, max_dy)))
75
+ else:
76
+ translations = (0, 0)
77
+
78
+ if scale_ranges is not None:
79
+ scale = (random.uniform(scale_ranges[0], scale_ranges[1]),
80
+ random.uniform(scale_ranges[0], scale_ranges[1]))
81
+ else:
82
+ scale = (1.0, 1.0)
83
+
84
+ if shears is not None:
85
+ shear = random.uniform(shears[0], shears[1])
86
+ else:
87
+ shear = 0.0
88
+
89
+ if flip is not None:
90
+ flip = (np.random.rand(2) < flip).astype(np.int32) * 2 - 1
91
+
92
+ return angle, translations, scale, shear, flip
93
+
94
+ def __call__(self, sample):
95
+ fg, alpha = sample['fg'], sample['alpha']
96
+ rows, cols, ch = fg.shape
97
+ if np.maximum(rows, cols) < 1024:
98
+ params = self.get_params((0, 0), self.translate, self.scale, self.shear, self.flip, fg.size)
99
+ else:
100
+ params = self.get_params(self.degrees, self.translate, self.scale, self.shear, self.flip, fg.size)
101
+
102
+ center = (cols * 0.5 + 0.5, rows * 0.5 + 0.5)
103
+ M = self._get_inverse_affine_matrix(center, *params)
104
+ M = np.array(M).reshape((2, 3))
105
+
106
+ fg = cv2.warpAffine(fg, M, (cols, rows), flags=random_interp() + cv2.WARP_INVERSE_MAP)
107
+ alpha = cv2.warpAffine(alpha, M, (cols, rows), flags=random_interp() + cv2.WARP_INVERSE_MAP)
108
+
109
+ sample['fg'], sample['alpha'] = fg, alpha
110
+
111
+ return sample
112
+
113
+ @ staticmethod
114
+ def _get_inverse_affine_matrix(center, angle, translate, scale, shear, flip):
115
+ # Helper method to compute inverse matrix for affine transformation
116
+
117
+ # As it is explained in PIL.Image.rotate
118
+ # We need compute INVERSE of affine transformation matrix: M = T * C * RSS * C^-1
119
+ # where T is translation matrix: [1, 0, tx | 0, 1, ty | 0, 0, 1]
120
+ # C is translation matrix to keep center: [1, 0, cx | 0, 1, cy | 0, 0, 1]
121
+ # RSS is rotation with scale and shear matrix
122
+ # It is different from the original function in torchvision
123
+ # The order are changed to flip -> scale -> rotation -> shear
124
+ # x and y have different scale factors
125
+ # RSS(shear, a, scale, f) = [ cos(a + shear)*scale_x*f -sin(a + shear)*scale_y 0]
126
+ # [ sin(a)*scale_x*f cos(a)*scale_y 0]
127
+ # [ 0 0 1]
128
+ # Thus, the inverse is M^-1 = C * RSS^-1 * C^-1 * T^-1
129
+
130
+ angle = math.radians(angle)
131
+ shear = math.radians(shear)
132
+ scale_x = 1.0 / scale[0] * flip[0]
133
+ scale_y = 1.0 / scale[1] * flip[1]
134
+
135
+ # Inverted rotation matrix with scale and shear
136
+ d = math.cos(angle + shear) * math.cos(angle) + math.sin(angle + shear) * math.sin(angle)
137
+ matrix = [
138
+ math.cos(angle) * scale_x, math.sin(angle + shear) * scale_x, 0,
139
+ -math.sin(angle) * scale_y, math.cos(angle + shear) * scale_y, 0
140
+ ]
141
+ matrix = [m / d for m in matrix]
142
+
143
+ # Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1
144
+ matrix[2] += matrix[0] * (-center[0] - translate[0]) + matrix[1] * (-center[1] - translate[1])
145
+ matrix[5] += matrix[3] * (-center[0] - translate[0]) + matrix[4] * (-center[1] - translate[1])
146
+
147
+ # Apply center translation: C * RSS^-1 * C^-1 * T^-1
148
+ matrix[2] += center[0]
149
+ matrix[5] += center[1]
150
+
151
+ return matrix
152
+
153
+
154
+ class GenTrimap(object):
155
+ def __init__(self):
156
+ self.erosion_kernels = [None] + [cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (size, size)) for size in range(1,100)]
157
+
158
+ def __call__(self, sample):
159
+ alpha = sample['alpha']
160
+ h, w = alpha.shape
161
+
162
+ max_kernel_size = max(30, int((min(h,w) / 2048) * 30))
163
+
164
+ ### generate trimap
165
+ fg_mask = (alpha + 1e-5).astype(np.int32).astype(np.uint8)
166
+ bg_mask = (1 - alpha + 1e-5).astype(np.int32).astype(np.uint8)
167
+ fg_mask = cv2.erode(fg_mask, self.erosion_kernels[np.random.randint(1, max_kernel_size)])
168
+ bg_mask = cv2.erode(bg_mask, self.erosion_kernels[np.random.randint(1, max_kernel_size)])
169
+
170
+ trimap = np.ones_like(alpha) * 128
171
+ trimap[fg_mask == 1] = 255
172
+ trimap[bg_mask == 1] = 0
173
+
174
+ trimap = cv2.resize(trimap, (w,h), interpolation=cv2.INTER_NEAREST)
175
+ sample['trimap'] = trimap
176
+
177
+ return sample
178
+
179
+
180
+ class RandomCrop(object):
181
+ """
182
+ Crop randomly the image in a sample, retain the center 1/4 images, and resize to 'output_size'
183
+
184
+ :param output_size (tuple or int): Desired output size. If int, square crop
185
+ is made.
186
+ """
187
+
188
+ def __init__(self, output_size=(1024, 1024)):
189
+ assert isinstance(output_size, (int, tuple))
190
+ if isinstance(output_size, int):
191
+ self.output_size = (output_size, output_size)
192
+ else:
193
+ assert len(output_size) == 2
194
+ self.output_size = output_size
195
+ self.margin = output_size[0] // 2
196
+
197
+ def __call__(self, sample):
198
+ fg, alpha, trimap, name = sample['fg'], sample['alpha'], sample['trimap'], sample['image_name']
199
+ bg = sample['bg']
200
+ h, w = trimap.shape
201
+ bg = cv2.resize(bg, (w, h), interpolation=random_interp())
202
+ if w < self.output_size[0]+1 or h < self.output_size[1]+1:
203
+ ratio = 1.1*self.output_size[0]/h if h < w else 1.1*self.output_size[1]/w
204
+ # self.logger.warning("Size of {} is {}.".format(name, (h, w)))
205
+ while h < self.output_size[0]+1 or w < self.output_size[1]+1:
206
+ fg = cv2.resize(fg, (int(w*ratio), int(h*ratio)), interpolation=random_interp())
207
+ alpha = cv2.resize(alpha, (int(w*ratio), int(h*ratio)),
208
+ interpolation=random_interp())
209
+ trimap = cv2.resize(trimap, (int(w*ratio), int(h*ratio)), interpolation=cv2.INTER_NEAREST)
210
+ bg = cv2.resize(bg, (int(w*ratio), int(h*ratio)), interpolation=random_interp())
211
+ h, w = trimap.shape
212
+ small_trimap = cv2.resize(trimap, (w//4, h//4), interpolation=cv2.INTER_NEAREST)
213
+ unknown_list = list(zip(*np.where(small_trimap[self.margin//4:(h-self.margin)//4,
214
+ self.margin//4:(w-self.margin)//4] == 128)))
215
+ unknown_num = len(unknown_list)
216
+ if len(unknown_list) < 10:
217
+ left_top = (np.random.randint(0, h-self.output_size[0]+1), np.random.randint(0, w-self.output_size[1]+1))
218
+ else:
219
+ idx = np.random.randint(unknown_num)
220
+ left_top = (unknown_list[idx][0]*4, unknown_list[idx][1]*4)
221
+
222
+ fg_crop = fg[left_top[0]:left_top[0]+self.output_size[0], left_top[1]:left_top[1]+self.output_size[1],:]
223
+ alpha_crop = alpha[left_top[0]:left_top[0]+self.output_size[0], left_top[1]:left_top[1]+self.output_size[1]]
224
+ bg_crop = bg[left_top[0]:left_top[0]+self.output_size[0], left_top[1]:left_top[1]+self.output_size[1],:]
225
+ trimap_crop = trimap[left_top[0]:left_top[0]+self.output_size[0], left_top[1]:left_top[1]+self.output_size[1]]
226
+
227
+ if len(np.where(trimap==128)[0]) == 0:
228
+ fg_crop = cv2.resize(fg, self.output_size[::-1], interpolation=random_interp())
229
+ alpha_crop = cv2.resize(alpha, self.output_size[::-1], interpolation=random_interp())
230
+ trimap_crop = cv2.resize(trimap, self.output_size[::-1], interpolation=cv2.INTER_NEAREST)
231
+ bg_crop = cv2.resize(bg, self.output_size[::-1], interpolation=random_interp())
232
+
233
+ sample.update({'fg': fg_crop, 'alpha': alpha_crop, 'trimap': trimap_crop, 'bg': bg_crop})
234
+ return sample
235
+
236
+
237
+ class Composite_Seg(object):
238
+ def __call__(self, sample):
239
+ fg, bg, alpha = sample['fg'], sample['bg'], sample['alpha']
240
+ fg[fg < 0 ] = 0
241
+ fg[fg > 255] = 255
242
+ image = fg
243
+ sample['image'] = image
244
+ return sample
245
+
246
+
247
+ class ToTensor(object):
248
+ """
249
+ Convert ndarrays in sample to Tensors with normalization.
250
+ """
251
+ def __init__(self, phase="test", real_world_aug = False):
252
+ # self.mean = torch.tensor([0.485, 0.456, 0.406]).view(3,1,1)
253
+ # self.std = torch.tensor([0.229, 0.224, 0.225]).view(3,1,1)
254
+ self.mean = torch.tensor([0.0, 0.0, 0.0]).view(3,1,1)
255
+ self.std = torch.tensor([1.0, 1.0, 1.0]).view(3,1,1)
256
+ self.phase = phase
257
+ if real_world_aug:
258
+ self.RWA = iaa.SomeOf((1, None), [
259
+ iaa.LinearContrast((0.6, 1.4)),
260
+ iaa.JpegCompression(compression=(0, 60)),
261
+ iaa.GaussianBlur(sigma=(0.0, 3.0)),
262
+ iaa.AdditiveGaussianNoise(scale=(0, 0.1*255))
263
+ ], random_order=True)
264
+ else:
265
+ self.RWA = None
266
+
267
+ def get_box_from_alpha(self, alpha_final):
268
+ bi_mask = np.zeros_like(alpha_final)
269
+ bi_mask[alpha_final>0.5] = 1
270
+ #bi_mask[alpha_final<=0.5] = 0
271
+ fg_set = np.where(bi_mask != 0)
272
+ if len(fg_set[1]) == 0 or len(fg_set[0]) == 0:
273
+ x_min = random.randint(1, 511)
274
+ x_max = random.randint(1, 511) + x_min
275
+ y_min = random.randint(1, 511)
276
+ y_max = random.randint(1, 511) + y_min
277
+ else:
278
+ x_min = np.min(fg_set[1])
279
+ x_max = np.max(fg_set[1])
280
+ y_min = np.min(fg_set[0])
281
+ y_max = np.max(fg_set[0])
282
+ bbox = np.array([x_min, y_min, x_max, y_max])
283
+ #cv2.rectangle(image, (x_min, y_min), (x_max, y_max), (0,255,0), 2)
284
+ #cv2.imwrite('../outputs/test.jpg', image)
285
+ #cv2.imwrite('../outputs/test_gt.jpg', alpha_single)
286
+ return bbox
287
+
288
+ def __call__(self, sample):
289
+ # convert GBR images to RGB
290
+ image, alpha, trimap = sample['image'][:,:,::-1], sample['alpha'], sample['trimap']
291
+
292
+ alpha[alpha < 0 ] = 0
293
+ alpha[alpha > 1] = 1
294
+
295
+ bbox = self.get_box_from_alpha(alpha)
296
+
297
+ if self.phase == 'train' and self.RWA is not None and np.random.rand() < 0.5:
298
+ image[image > 255] = 255
299
+ image[image < 0] = 0
300
+ image = np.round(image).astype(np.uint8)
301
+ image = np.expand_dims(image, axis=0)
302
+ image = self.RWA(images=image)
303
+ image = image[0, ...]
304
+
305
+ # swap color axis because
306
+ # numpy image: H x W x C
307
+ # torch image: C X H X W
308
+ image = image.transpose((2, 0, 1)).astype(np.float32)
309
+ alpha = np.expand_dims(alpha.astype(np.float32), axis=0)
310
+ trimap[trimap < 85] = 0
311
+ trimap[trimap >= 170] = 2
312
+ trimap[trimap >= 85] = 1
313
+ #image = cv2.rectangle(image, (bbox[0], bbox[1]), (bbox[2], bbox[3]), (255,0,0), 3)
314
+ #cv2.imwrite(os.path.join('outputs', 'img_bbox.png'), image.astype('uint8'))
315
+ # normalize image
316
+ image /= 255.
317
+
318
+ if self.phase == "train":
319
+ # convert GBR images to RGB
320
+ fg = sample['fg'][:,:,::-1].transpose((2, 0, 1)).astype(np.float32) / 255.
321
+ sample['fg'] = torch.from_numpy(fg).sub_(self.mean).div_(self.std)
322
+ bg = sample['bg'][:,:,::-1].transpose((2, 0, 1)).astype(np.float32) / 255.
323
+ sample['bg'] = torch.from_numpy(bg).sub_(self.mean).div_(self.std)
324
+ del sample['image_name']
325
+
326
+ sample['boxes'] = torch.from_numpy(bbox).to(torch.float)[None,...]
327
+
328
+ sample['image'], sample['alpha'], sample['trimap'] = \
329
+ torch.from_numpy(image), torch.from_numpy(alpha), torch.from_numpy(trimap).to(torch.long)
330
+ sample['image'] = sample['image'].sub_(self.mean).div_(self.std)
331
+ sample['trimap'] = sample['trimap'][None,...].float()
332
+
333
+ return sample
334
+
335
+
336
+ class RefMatteData(Dataset):
337
+ def __init__(
338
+ self,
339
+ data_root_path,
340
+ num_ratio = 0.34,
341
+ ):
342
+ self.data_root_path = data_root_path
343
+ self.num_ratio = num_ratio
344
+
345
+ self.rim_img = [os.path.join(data_root_path, name) for name in sorted(os.listdir(data_root_path))]
346
+ self.rim_pha = [os.path.join(data_root_path.replace('img', 'mask'), name) for name in sorted(os.listdir(data_root_path.replace('img', 'mask')))]
347
+ self.rim_num = len(self.rim_pha)
348
+
349
+ self.transform_spd = transforms.Compose([
350
+ RandomAffine(degrees=30, scale=[0.8, 1.5], shear=10, flip=0.5),
351
+ GenTrimap(),
352
+ RandomCrop((1024, 1024)),
353
+ Composite_Seg(),
354
+ ToTensor(phase="train", real_world_aug=False)
355
+ ])
356
+
357
+ def __getitem__(self, idx):
358
+ if self.num_ratio is not None:
359
+ if self.num_ratio < 1.0 or idx >= self.rim_num:
360
+ idx = np.random.randint(0, self.rim_num)
361
+ alpha = cv2.imread(self.rim_pha[idx % self.rim_num], 0).astype(np.float32)/255
362
+ alpha_img_name = self.rim_pha[idx % self.rim_num].split('/')[-1]
363
+ fg_img_name = alpha_img_name[:-6] + '.jpg'
364
+
365
+ fg = cv2.imread(os.path.join(self.data_root_path, fg_img_name))
366
+
367
+ if np.random.rand() < 0.25:
368
+ fg = cv2.resize(fg, (1280, 1280), interpolation=random_interp())
369
+ alpha = cv2.resize(alpha, (1280, 1280), interpolation=random_interp())
370
+
371
+ image_name = alpha_img_name # os.path.split(self.rim_img[idx % self.rim_num])[-1]
372
+ sample = {'fg': fg, 'alpha': alpha, 'bg': fg, 'image_name': image_name}
373
+ sample = self.transform_spd(sample)
374
+
375
+ converted_sample = {
376
+ 'image': sample['image'],
377
+ 'trimap': sample['trimap'] / 2.0,
378
+ 'alpha': sample['alpha'],
379
+ 'bbox': sample['boxes'],
380
+ 'dataset_name': 'RefMatte',
381
+ 'multi_fg': False,
382
+ }
383
+ return converted_sample
384
+
385
+ def __len__(self):
386
+ if self.num_ratio is not None:
387
+ return int(self.rim_num * self.num_ratio) # 112506 * 0.34 = 38252 (COCONut_num-38251 + 1)
388
+ else:
389
+ return self.rim_num # 112506
390
+
391
+
392
+
393
+ if __name__ == '__main__':
394
+ dataset = RefMatteData(
395
+ data_root_path = '/data/my_path_b/public_data/data/matting/RefMatte/RefMatte/train/img',
396
+ num_ratio=0.34,
397
+ )
398
+ data = dataset[0]
399
+ '''
400
+ fg torch.Size([3, 1024, 1024]) tensor(-2.1179) tensor(2.6400)
401
+ alpha torch.Size([1, 1024, 1024]) tensor(0.) tensor(1.)
402
+ bg torch.Size([3, 1024, 1024]) tensor(-2.1179) tensor(2.6400)
403
+ trimap torch.Size([1, 1024, 1024]) 0.0 or 1.0 or 2.0
404
+ image torch.Size([3, 1024, 1024]) tensor(-2.1179) tensor(2.6400)
405
+ boxes torch.Size([1, 4]) tensor(72.) tensor(676.) 0.0~1024.0
406
+
407
+ COCONut:
408
+ image torch.Size([3, 1024, 1024]) tensor(0.0006) tensor(0.9991)
409
+ trimap torch.Size([1, 1024, 1024]) 0.0 or 0.5 or 1.0
410
+ alpha torch.Size([1, 1024, 1024]) tensor(0.) tensor(1.)
411
+ bbox torch.Size([1, 4]) tensor(0.) tensor(590.)
412
+ dataset_name: 'COCONut'
413
+ '''
414
+ for key, val in data.items():
415
+ if isinstance(val, torch.Tensor):
416
+ print(key, val.shape, torch.min(val), torch.max(val))
417
+ else:
418
+ print(key, val.shape)
engine/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .mattingtrainer import MattingTrainer
engine/hooks.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import detectron2.utils.comm as comm
3
+ from detectron2.engine import EvalHook as _EvalHook
4
+ from detectron2.evaluation.testing import flatten_results_dict
5
+
6
+
7
+ class EvalHook(_EvalHook):
8
+ def __init__(self, eval_period, eval_function):
9
+ super().__init__(eval_period, eval_function)
10
+ func_args = inspect.getfullargspec(eval_function).args
11
+ assert {"final_iter", "next_iter"}.issubset(set(func_args)), (
12
+ f"Eval function must have either 'final_iter' or 'next_iter' as an argument."
13
+ f"Got {func_args} instead."
14
+ )
15
+
16
+ def _do_eval(self, final_iter=False, next_iter=0):
17
+ results = self._func(final_iter=final_iter, next_iter=next_iter)
18
+
19
+ if results:
20
+ assert isinstance(
21
+ results, dict
22
+ ), "Eval function must return a dict. Got {} instead.".format(results)
23
+
24
+ flattened_results = flatten_results_dict(results)
25
+ for k, v in flattened_results.items():
26
+ try:
27
+ v = float(v)
28
+ except Exception as e:
29
+ raise ValueError(
30
+ "[EvalHook] eval_function should return a nested dict of float. "
31
+ "Got '{}: {}' instead.".format(k, v)
32
+ ) from e
33
+ self.trainer.storage.put_scalars(**flattened_results, smoothing_hint=False)
34
+
35
+ # Evaluation may take different time among workers.
36
+ # A barrier make them start the next iteration together.
37
+ comm.synchronize()
38
+
39
+ def after_step(self):
40
+ next_iter = self.trainer.iter + 1
41
+ if self._period > 0 and next_iter % self._period == 0:
42
+ # do the last eval in after_train
43
+ if next_iter != self.trainer.max_iter:
44
+ self._do_eval(next_iter=next_iter)
45
+
46
+ def after_train(self):
47
+ # This condition is to prevent the eval from running after a failed training
48
+ if self.trainer.iter + 1 >= self.trainer.max_iter:
49
+ self._do_eval(final_iter=True)
50
+ # func is likely a closure that holds reference to the trainer
51
+ # therefore we clean it to avoid circular reference in the end
52
+ del self._func
engine/mattingtrainer.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from detectron2.engine import AMPTrainer
2
+ import torch
3
+ import time
4
+ import logging
5
+
6
+ logger = logging.getLogger("detectron2")
7
+
8
+ import typing
9
+ from collections import defaultdict
10
+ import tabulate
11
+ from torch import nn
12
+
13
+
14
+ def parameter_count(model: nn.Module, trainable_only: bool = False) -> typing.DefaultDict[str, int]:
15
+ """
16
+ Count parameters of a model and its submodules.
17
+
18
+ Args:
19
+ model: a torch module
20
+
21
+ Returns:
22
+ dict (str-> int): the key is either a parameter name or a module name.
23
+ The value is the number of elements in the parameter, or in all
24
+ parameters of the module. The key "" corresponds to the total
25
+ number of parameters of the model.
26
+ """
27
+ r = defaultdict(int)
28
+ for name, prm in model.named_parameters():
29
+ if trainable_only:
30
+ if not prm.requires_grad:
31
+ continue
32
+ size = prm.numel()
33
+ name = name.split(".")
34
+ for k in range(0, len(name) + 1):
35
+ prefix = ".".join(name[:k])
36
+ r[prefix] += size
37
+ return r
38
+
39
+
40
+ def parameter_count_table(
41
+ model: nn.Module, max_depth: int = 3, trainable_only: bool = False
42
+ ) -> str:
43
+ """
44
+ Format the parameter count of the model (and its submodules or parameters)
45
+ in a nice table. It looks like this:
46
+
47
+ ::
48
+
49
+ | name | #elements or shape |
50
+ |:--------------------------------|:---------------------|
51
+ | model | 37.9M |
52
+ | backbone | 31.5M |
53
+ | backbone.fpn_lateral3 | 0.1M |
54
+ | backbone.fpn_lateral3.weight | (256, 512, 1, 1) |
55
+ | backbone.fpn_lateral3.bias | (256,) |
56
+ | backbone.fpn_output3 | 0.6M |
57
+ | backbone.fpn_output3.weight | (256, 256, 3, 3) |
58
+ | backbone.fpn_output3.bias | (256,) |
59
+ | backbone.fpn_lateral4 | 0.3M |
60
+ | backbone.fpn_lateral4.weight | (256, 1024, 1, 1) |
61
+ | backbone.fpn_lateral4.bias | (256,) |
62
+ | backbone.fpn_output4 | 0.6M |
63
+ | backbone.fpn_output4.weight | (256, 256, 3, 3) |
64
+ | backbone.fpn_output4.bias | (256,) |
65
+ | backbone.fpn_lateral5 | 0.5M |
66
+ | backbone.fpn_lateral5.weight | (256, 2048, 1, 1) |
67
+ | backbone.fpn_lateral5.bias | (256,) |
68
+ | backbone.fpn_output5 | 0.6M |
69
+ | backbone.fpn_output5.weight | (256, 256, 3, 3) |
70
+ | backbone.fpn_output5.bias | (256,) |
71
+ | backbone.top_block | 5.3M |
72
+ | backbone.top_block.p6 | 4.7M |
73
+ | backbone.top_block.p7 | 0.6M |
74
+ | backbone.bottom_up | 23.5M |
75
+ | backbone.bottom_up.stem | 9.4K |
76
+ | backbone.bottom_up.res2 | 0.2M |
77
+ | backbone.bottom_up.res3 | 1.2M |
78
+ | backbone.bottom_up.res4 | 7.1M |
79
+ | backbone.bottom_up.res5 | 14.9M |
80
+ | ...... | ..... |
81
+
82
+ Args:
83
+ model: a torch module
84
+ max_depth (int): maximum depth to recursively print submodules or
85
+ parameters
86
+
87
+ Returns:
88
+ str: the table to be printed
89
+ """
90
+ count: typing.DefaultDict[str, int] = parameter_count(model, trainable_only)
91
+ # pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter.
92
+ param_shape: typing.Dict[str, typing.Tuple] = {
93
+ k: tuple(v.shape) for k, v in model.named_parameters()
94
+ }
95
+
96
+ # pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter.
97
+ table: typing.List[typing.Tuple] = []
98
+
99
+ def format_size(x: int) -> str:
100
+ if x > 1e8:
101
+ return "{:.1f}G".format(x / 1e9)
102
+ if x > 1e5:
103
+ return "{:.1f}M".format(x / 1e6)
104
+ if x > 1e2:
105
+ return "{:.1f}K".format(x / 1e3)
106
+ return str(x)
107
+
108
+ def fill(lvl: int, prefix: str) -> None:
109
+ if lvl >= max_depth:
110
+ return
111
+ for name, v in count.items():
112
+ if name.count(".") == lvl and name.startswith(prefix):
113
+ indent = " " * (lvl + 1)
114
+ if name in param_shape:
115
+ table.append((indent + name, indent + str(param_shape[name])))
116
+ else:
117
+ table.append((indent + name, indent + format_size(v)))
118
+ fill(lvl + 1, name + ".")
119
+
120
+ table.append(("model", format_size(count.pop(""))))
121
+ fill(0, "")
122
+
123
+ old_ws = tabulate.PRESERVE_WHITESPACE
124
+ tabulate.PRESERVE_WHITESPACE = True
125
+ tab = tabulate.tabulate(table, headers=["name", "#elements or shape"], tablefmt="pipe")
126
+ tabulate.PRESERVE_WHITESPACE = old_ws
127
+ return tab
128
+
129
+
130
+ def cycle(iterable):
131
+ while True:
132
+ for x in iterable:
133
+ yield x
134
+
135
+ class MattingTrainer(AMPTrainer):
136
+ def __init__(self, model, data_loader, optimizer, grad_scaler=None):
137
+ super().__init__(model, data_loader, optimizer, grad_scaler=None)
138
+ self.data_loader_iter = iter(cycle(self.data_loader))
139
+
140
+ # print model parameters
141
+ logger.info("All parameters: \n" + parameter_count_table(model))
142
+ logger.info("Trainable parameters: \n" + parameter_count_table(model, trainable_only=True, max_depth=8))
143
+
144
+ def run_step(self):
145
+ """
146
+ Implement the AMP training logic.
147
+ """
148
+ assert self.model.training, "[AMPTrainer] model was changed to eval mode!"
149
+ assert torch.cuda.is_available(), "[AMPTrainer] CUDA is required for AMP training!"
150
+ from torch.cuda.amp import autocast
151
+
152
+ #matting pass
153
+ start = time.perf_counter()
154
+ data = next(self.data_loader_iter)
155
+ data_time = time.perf_counter() - start
156
+
157
+ with autocast():
158
+ loss_dict = self.model(data)
159
+ if isinstance(loss_dict, torch.Tensor):
160
+ losses = loss_dict
161
+ loss_dict = {"total_loss": loss_dict}
162
+ else:
163
+ losses = sum(loss_dict.values())
164
+
165
+ self.optimizer.zero_grad()
166
+ self.grad_scaler.scale(losses).backward()
167
+
168
+ self._write_metrics(loss_dict, data_time)
169
+
170
+ self.grad_scaler.step(self.optimizer)
171
+ self.grad_scaler.update()
modeling/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .backbone import *
2
+ from .criterion import *
3
+ from .decoder import *
4
+ from .meta_arch import *
5
+ from .semantic_enhanced_matting import *
modeling/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (272 Bytes). View file
 
modeling/backbone/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .backbone import *
2
+ from .vit import *
modeling/backbone/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (198 Bytes). View file
 
modeling/backbone/__pycache__/backbone.cpython-38.pyc ADDED
Binary file (3.23 kB). View file
 
modeling/backbone/__pycache__/utils.cpython-38.pyc ADDED
Binary file (6.11 kB). View file
 
modeling/backbone/__pycache__/vit.cpython-38.pyc ADDED
Binary file (12.3 kB). View file
 
modeling/backbone/backbone.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ from abc import ABCMeta, abstractmethod
3
+ from typing import Dict
4
+ import torch.nn as nn
5
+
6
+ from detectron2.layers import ShapeSpec
7
+
8
+ __all__ = ["Backbone"]
9
+
10
+
11
+ class Backbone(nn.Module, metaclass=ABCMeta):
12
+ """
13
+ Abstract base class for network backbones.
14
+ """
15
+
16
+ def __init__(self):
17
+ """
18
+ The `__init__` method of any subclass can specify its own set of arguments.
19
+ """
20
+ super().__init__()
21
+
22
+ @abstractmethod
23
+ def forward(self):
24
+ """
25
+ Subclasses must override this method, but adhere to the same return type.
26
+
27
+ Returns:
28
+ dict[str->Tensor]: mapping from feature name (e.g., "res2") to tensor
29
+ """
30
+ pass
31
+
32
+ @property
33
+ def size_divisibility(self) -> int:
34
+ """
35
+ Some backbones require the input height and width to be divisible by a
36
+ specific integer. This is typically true for encoder / decoder type networks
37
+ with lateral connection (e.g., FPN) for which feature maps need to match
38
+ dimension in the "bottom up" and "top down" paths. Set to 0 if no specific
39
+ input size divisibility is required.
40
+ """
41
+ return 0
42
+
43
+ @property
44
+ def padding_constraints(self) -> Dict[str, int]:
45
+ """
46
+ This property is a generalization of size_divisibility. Some backbones and training
47
+ recipes require specific padding constraints, such as enforcing divisibility by a specific
48
+ integer (e.g., FPN) or padding to a square (e.g., ViTDet with large-scale jitter
49
+ in :paper:vitdet). `padding_constraints` contains these optional items like:
50
+ {
51
+ "size_divisibility": int,
52
+ "square_size": int,
53
+ # Future options are possible
54
+ }
55
+ `size_divisibility` will read from here if presented and `square_size` indicates the
56
+ square padding size if `square_size` > 0.
57
+
58
+ TODO: use type of Dict[str, int] to avoid torchscipt issues. The type of padding_constraints
59
+ could be generalized as TypedDict (Python 3.8+) to support more types in the future.
60
+ """
61
+ return {}
62
+
63
+ def output_shape(self):
64
+ """
65
+ Returns:
66
+ dict[str->ShapeSpec]
67
+ """
68
+ # this is a backward-compatible default
69
+ return {
70
+ name: ShapeSpec(
71
+ channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]
72
+ )
73
+ for name in self._out_features
74
+ }
modeling/backbone/utils.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ import math
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ __all__ = [
8
+ "window_partition",
9
+ "window_unpartition",
10
+ "add_decomposed_rel_pos",
11
+ "get_abs_pos",
12
+ "PatchEmbed",
13
+ ]
14
+
15
+
16
+ def window_partition(x, window_size):
17
+ """
18
+ Partition into non-overlapping windows with padding if needed.
19
+ Args:
20
+ x (tensor): input tokens with [B, H, W, C].
21
+ window_size (int): window size.
22
+
23
+ Returns:
24
+ windows: windows after partition with [B * num_windows, window_size, window_size, C].
25
+ (Hp, Wp): padded height and width before partition
26
+ """
27
+ B, H, W, C = x.shape
28
+
29
+ pad_h = (window_size - H % window_size) % window_size
30
+ pad_w = (window_size - W % window_size) % window_size
31
+ if pad_h > 0 or pad_w > 0:
32
+ x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
33
+ Hp, Wp = H + pad_h, W + pad_w
34
+
35
+ x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
36
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
37
+ return windows, (Hp, Wp)
38
+
39
+
40
+ def window_unpartition(windows, window_size, pad_hw, hw):
41
+ """
42
+ Window unpartition into original sequences and removing padding.
43
+ Args:
44
+ x (tensor): input tokens with [B * num_windows, window_size, window_size, C].
45
+ window_size (int): window size.
46
+ pad_hw (Tuple): padded height and width (Hp, Wp).
47
+ hw (Tuple): original height and width (H, W) before padding.
48
+
49
+ Returns:
50
+ x: unpartitioned sequences with [B, H, W, C].
51
+ """
52
+ Hp, Wp = pad_hw
53
+ H, W = hw
54
+ B = windows.shape[0] // (Hp * Wp // window_size // window_size)
55
+ x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
56
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
57
+
58
+ if Hp > H or Wp > W:
59
+ x = x[:, :H, :W, :].contiguous()
60
+ return x
61
+
62
+
63
+ def get_rel_pos(q_size, k_size, rel_pos):
64
+ """
65
+ Get relative positional embeddings according to the relative positions of
66
+ query and key sizes.
67
+ Args:
68
+ q_size (int): size of query q.
69
+ k_size (int): size of key k.
70
+ rel_pos (Tensor): relative position embeddings (L, C).
71
+
72
+ Returns:
73
+ Extracted positional embeddings according to relative positions.
74
+ """
75
+ max_rel_dist = int(2 * max(q_size, k_size) - 1)
76
+ # Interpolate rel pos if needed.
77
+ if rel_pos.shape[0] != max_rel_dist:
78
+ # Interpolate rel pos.
79
+ rel_pos_resized = F.interpolate(
80
+ rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
81
+ size=max_rel_dist,
82
+ mode="linear",
83
+ )
84
+ rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
85
+ else:
86
+ rel_pos_resized = rel_pos
87
+
88
+ # Scale the coords with short length if shapes for q and k are different.
89
+ q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
90
+ k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
91
+ relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
92
+
93
+ return rel_pos_resized[relative_coords.long()]
94
+
95
+
96
+ def add_decomposed_rel_pos(attn, q, rel_pos_h, rel_pos_w, q_size, k_size):
97
+ """
98
+ Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
99
+ https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
100
+ Args:
101
+ attn (Tensor): attention map.
102
+ q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
103
+ rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
104
+ rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
105
+ q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
106
+ k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
107
+
108
+ Returns:
109
+ attn (Tensor): attention map with added relative positional embeddings.
110
+ """
111
+ q_h, q_w = q_size
112
+ k_h, k_w = k_size
113
+ Rh = get_rel_pos(q_h, k_h, rel_pos_h)
114
+ Rw = get_rel_pos(q_w, k_w, rel_pos_w)
115
+
116
+ B, _, dim = q.shape
117
+ r_q = q.reshape(B, q_h, q_w, dim)
118
+ rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
119
+ rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
120
+
121
+ attn = (
122
+ attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
123
+ ).view(B, q_h * q_w, k_h * k_w)
124
+
125
+ return attn
126
+
127
+
128
+ def get_abs_pos(abs_pos, has_cls_token, hw):
129
+ """
130
+ Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token
131
+ dimension for the original embeddings.
132
+ Args:
133
+ abs_pos (Tensor): absolute positional embeddings with (1, num_position, C).
134
+ has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token.
135
+ hw (Tuple): size of input image tokens.
136
+
137
+ Returns:
138
+ Absolute positional embeddings after processing with shape (1, H, W, C)
139
+ """
140
+ h, w = hw
141
+ if has_cls_token:
142
+ abs_pos = abs_pos[:, 1:]
143
+ xy_num = abs_pos.shape[1]
144
+ size = int(math.sqrt(xy_num))
145
+ assert size * size == xy_num
146
+
147
+ if size != h or size != w:
148
+ new_abs_pos = F.interpolate(
149
+ abs_pos.reshape(1, size, size, -1).permute(0, 3, 1, 2),
150
+ size=(h, w),
151
+ mode="bicubic",
152
+ align_corners=False,
153
+ )
154
+
155
+ return new_abs_pos.permute(0, 2, 3, 1)
156
+ else:
157
+ return abs_pos.reshape(1, h, w, -1)
158
+
159
+
160
+ class PatchEmbed(nn.Module):
161
+ """
162
+ Image to Patch Embedding.
163
+ """
164
+
165
+ def __init__(
166
+ self, kernel_size=(16, 16), stride=(16, 16), padding=(0, 0), in_chans=3, embed_dim=768
167
+ ):
168
+ """
169
+ Args:
170
+ kernel_size (Tuple): kernel size of the projection layer.
171
+ stride (Tuple): stride of the projection layer.
172
+ padding (Tuple): padding size of the projection layer.
173
+ in_chans (int): Number of input image channels.
174
+ embed_dim (int): embed_dim (int): Patch embedding dimension.
175
+ """
176
+ super().__init__()
177
+
178
+ self.proj = nn.Conv2d(
179
+ in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
180
+ )
181
+
182
+ def forward(self, x):
183
+ x = self.proj(x)
184
+ # B C H W -> B H W C
185
+ x = x.permute(0, 2, 3, 1)
186
+ return x
modeling/backbone/vit.py ADDED
@@ -0,0 +1,404 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import math
3
+ import fvcore.nn.weight_init as weight_init
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch.nn import functional as F
7
+ from detectron2.layers import CNNBlockBase, Conv2d, get_norm
8
+ from detectron2.modeling.backbone.fpn import _assert_strides_are_log2_contiguous
9
+ from fairscale.nn.checkpoint import checkpoint_wrapper
10
+ from timm.models.layers import DropPath, Mlp, trunc_normal_
11
+ from .backbone import Backbone
12
+ from .utils import (
13
+ PatchEmbed,
14
+ add_decomposed_rel_pos,
15
+ get_abs_pos,
16
+ window_partition,
17
+ window_unpartition,
18
+ )
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ __all__ = ["ViT"]
24
+
25
+
26
+ class Attention(nn.Module):
27
+ """Multi-head Attention block with relative position embeddings."""
28
+
29
+ def __init__(
30
+ self,
31
+ dim,
32
+ num_heads=8,
33
+ qkv_bias=True,
34
+ use_rel_pos=False,
35
+ rel_pos_zero_init=True,
36
+ input_size=None,
37
+ ):
38
+ """
39
+ Args:
40
+ dim (int): Number of input channels.
41
+ num_heads (int): Number of attention heads.
42
+ qkv_bias (bool: If True, add a learnable bias to query, key, value.
43
+ rel_pos (bool): If True, add relative positional embeddings to the attention map.
44
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
45
+ input_size (int or None): Input resolution for calculating the relative positional
46
+ parameter size.
47
+ """
48
+ super().__init__()
49
+ self.num_heads = num_heads
50
+ head_dim = dim // num_heads
51
+ self.scale = head_dim**-0.5
52
+
53
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
54
+ self.proj = nn.Linear(dim, dim)
55
+
56
+ self.use_rel_pos = use_rel_pos
57
+ if self.use_rel_pos:
58
+ # initialize relative positional embeddings
59
+ self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
60
+ self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
61
+
62
+ if not rel_pos_zero_init:
63
+ trunc_normal_(self.rel_pos_h, std=0.02)
64
+ trunc_normal_(self.rel_pos_w, std=0.02)
65
+
66
+ def forward(self, x):
67
+ B, H, W, _ = x.shape
68
+ # qkv with shape (3, B, nHead, H * W, C)
69
+ qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
70
+ # q, k, v with shape (B * nHead, H * W, C)
71
+ q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
72
+
73
+ attn = (q * self.scale) @ k.transpose(-2, -1)
74
+
75
+ if self.use_rel_pos:
76
+ attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
77
+
78
+ attn = attn.softmax(dim=-1)
79
+ x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
80
+ x = self.proj(x)
81
+
82
+ return x
83
+
84
+ class LayerNorm(nn.Module):
85
+ r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
86
+ The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
87
+ shape (batch_size, height, width, channels) while channels_first corresponds to inputs
88
+ with shape (batch_size, channels, height, width).
89
+ """
90
+ def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
91
+ super().__init__()
92
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
93
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
94
+ self.eps = eps
95
+ self.data_format = data_format
96
+ if self.data_format not in ["channels_last", "channels_first"]:
97
+ raise NotImplementedError
98
+ self.normalized_shape = (normalized_shape, )
99
+
100
+ def forward(self, x):
101
+ if self.data_format == "channels_last":
102
+ return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
103
+ elif self.data_format == "channels_first":
104
+ u = x.mean(1, keepdim=True)
105
+ s = (x - u).pow(2).mean(1, keepdim=True)
106
+ x = (x - u) / torch.sqrt(s + self.eps)
107
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
108
+ return x
109
+
110
+ class ResBottleneckBlock(CNNBlockBase):
111
+ """
112
+ The standard bottleneck residual block without the last activation layer.
113
+ It contains 3 conv layers with kernels 1x1, 3x3, 1x1.
114
+ """
115
+
116
+ def __init__(
117
+ self,
118
+ in_channels,
119
+ out_channels,
120
+ bottleneck_channels,
121
+ norm="LN",
122
+ act_layer=nn.GELU,
123
+ conv_kernels=3,
124
+ conv_paddings=1,
125
+ ):
126
+ """
127
+ Args:
128
+ in_channels (int): Number of input channels.
129
+ out_channels (int): Number of output channels.
130
+ bottleneck_channels (int): number of output channels for the 3x3
131
+ "bottleneck" conv layers.
132
+ norm (str or callable): normalization for all conv layers.
133
+ See :func:`layers.get_norm` for supported format.
134
+ act_layer (callable): activation for all conv layers.
135
+ """
136
+ super().__init__(in_channels, out_channels, 1)
137
+
138
+ self.conv1 = Conv2d(in_channels, bottleneck_channels, 1, bias=False)
139
+ self.norm1 = get_norm(norm, bottleneck_channels)
140
+ self.act1 = act_layer()
141
+
142
+ self.conv2 = Conv2d(
143
+ bottleneck_channels,
144
+ bottleneck_channels,
145
+ conv_kernels,
146
+ padding=conv_paddings,
147
+ bias=False,
148
+ )
149
+ self.norm2 = get_norm(norm, bottleneck_channels)
150
+ self.act2 = act_layer()
151
+
152
+ self.conv3 = Conv2d(bottleneck_channels, out_channels, 1, bias=False)
153
+ self.norm3 = get_norm(norm, out_channels)
154
+
155
+ for layer in [self.conv1, self.conv2, self.conv3]:
156
+ weight_init.c2_msra_fill(layer)
157
+ for layer in [self.norm1, self.norm2]:
158
+ layer.weight.data.fill_(1.0)
159
+ layer.bias.data.zero_()
160
+ # zero init last norm layer.
161
+ self.norm3.weight.data.zero_()
162
+ self.norm3.bias.data.zero_()
163
+
164
+ def forward(self, x):
165
+ out = x
166
+ for layer in self.children():
167
+ out = layer(out)
168
+
169
+ out = x + out
170
+ return out
171
+
172
+
173
+ class Block(nn.Module):
174
+ """Transformer blocks with support of window attention and residual propagation blocks"""
175
+
176
+ def __init__(
177
+ self,
178
+ dim,
179
+ num_heads,
180
+ mlp_ratio=4.0,
181
+ qkv_bias=True,
182
+ drop_path=0.0,
183
+ norm_layer=nn.LayerNorm,
184
+ act_layer=nn.GELU,
185
+ use_rel_pos=False,
186
+ rel_pos_zero_init=True,
187
+ window_size=0,
188
+ use_cc_attn = False,
189
+ use_residual_block=False,
190
+ use_convnext_block=False,
191
+ input_size=None,
192
+ res_conv_kernel_size=3,
193
+ res_conv_padding=1,
194
+ ):
195
+ """
196
+ Args:
197
+ dim (int): Number of input channels.
198
+ num_heads (int): Number of attention heads in each ViT block.
199
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
200
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
201
+ drop_path (float): Stochastic depth rate.
202
+ norm_layer (nn.Module): Normalization layer.
203
+ act_layer (nn.Module): Activation layer.
204
+ use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
205
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
206
+ window_size (int): Window size for window attention blocks. If it equals 0, then not
207
+ use window attention.
208
+ use_residual_block (bool): If True, use a residual block after the MLP block.
209
+ input_size (int or None): Input resolution for calculating the relative positional
210
+ parameter size.
211
+ """
212
+ super().__init__()
213
+ self.norm1 = norm_layer(dim)
214
+ self.attn = Attention(
215
+ dim,
216
+ num_heads=num_heads,
217
+ qkv_bias=qkv_bias,
218
+ use_rel_pos=use_rel_pos,
219
+ rel_pos_zero_init=rel_pos_zero_init,
220
+ input_size=input_size if window_size == 0 else (window_size, window_size),
221
+ )
222
+
223
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
224
+ self.norm2 = norm_layer(dim)
225
+ self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer)
226
+
227
+ self.window_size = window_size
228
+
229
+ self.use_residual_block = use_residual_block
230
+ if use_residual_block:
231
+ # Use a residual block with bottleneck channel as dim // 2
232
+ self.residual = ResBottleneckBlock(
233
+ in_channels=dim,
234
+ out_channels=dim,
235
+ bottleneck_channels=dim // 2,
236
+ norm="LN",
237
+ act_layer=act_layer,
238
+ conv_kernels=res_conv_kernel_size,
239
+ conv_paddings=res_conv_padding,
240
+ )
241
+ self.use_convnext_block = use_convnext_block
242
+ if use_convnext_block:
243
+ self.convnext = ConvNextBlock(dim = dim)
244
+
245
+ if use_cc_attn:
246
+ self.attn = CrissCrossAttention(dim)
247
+
248
+
249
+ def forward(self, x):
250
+ shortcut = x
251
+ x = self.norm1(x)
252
+ # Window partition
253
+ if self.window_size > 0:
254
+ H, W = x.shape[1], x.shape[2]
255
+ x, pad_hw = window_partition(x, self.window_size)
256
+
257
+ x = self.attn(x)
258
+
259
+ # Reverse window partition
260
+ if self.window_size > 0:
261
+ x = window_unpartition(x, self.window_size, pad_hw, (H, W))
262
+
263
+ x = shortcut + self.drop_path(x)
264
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
265
+
266
+ if self.use_residual_block:
267
+ x = self.residual(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
268
+ if self.use_convnext_block:
269
+ x = self.convnext(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
270
+
271
+ return x
272
+
273
+
274
+ class ViT(Backbone):
275
+ """
276
+ This module implements Vision Transformer (ViT) backbone in :paper:`vitdet`.
277
+ "Exploring Plain Vision Transformer Backbones for Object Detection",
278
+ https://arxiv.org/abs/2203.16527
279
+ """
280
+
281
+ def __init__(
282
+ self,
283
+ img_size=1024,
284
+ patch_size=16,
285
+ in_chans=3,
286
+ embed_dim=768,
287
+ depth=12,
288
+ num_heads=12,
289
+ mlp_ratio=4.0,
290
+ qkv_bias=True,
291
+ drop_path_rate=0.0,
292
+ norm_layer=nn.LayerNorm,
293
+ act_layer=nn.GELU,
294
+ use_abs_pos=True,
295
+ use_rel_pos=False,
296
+ rel_pos_zero_init=True,
297
+ window_size=0,
298
+ window_block_indexes=(),
299
+ residual_block_indexes=(),
300
+ use_act_checkpoint=False,
301
+ pretrain_img_size=224,
302
+ pretrain_use_cls_token=True,
303
+ out_feature="last_feat",
304
+ res_conv_kernel_size=3,
305
+ res_conv_padding=1,
306
+ ):
307
+ """
308
+ Args:
309
+ img_size (int): Input image size.
310
+ patch_size (int): Patch size.
311
+ in_chans (int): Number of input image channels.
312
+ embed_dim (int): Patch embedding dimension.
313
+ depth (int): Depth of ViT.
314
+ num_heads (int): Number of attention heads in each ViT block.
315
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
316
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
317
+ drop_path_rate (float): Stochastic depth rate.
318
+ norm_layer (nn.Module): Normalization layer.
319
+ act_layer (nn.Module): Activation layer.
320
+ use_abs_pos (bool): If True, use absolute positional embeddings.
321
+ use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
322
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
323
+ window_size (int): Window size for window attention blocks.
324
+ window_block_indexes (list): Indexes for blocks using window attention.
325
+ residual_block_indexes (list): Indexes for blocks using conv propagation.
326
+ use_act_checkpoint (bool): If True, use activation checkpointing.
327
+ pretrain_img_size (int): input image size for pretraining models.
328
+ pretrain_use_cls_token (bool): If True, pretrainig models use class token.
329
+ out_feature (str): name of the feature from the last block.
330
+ """
331
+ super().__init__()
332
+ self.pretrain_use_cls_token = pretrain_use_cls_token
333
+
334
+ self.patch_embed = PatchEmbed(
335
+ kernel_size=(patch_size, patch_size),
336
+ stride=(patch_size, patch_size),
337
+ in_chans=in_chans,
338
+ embed_dim=embed_dim,
339
+ )
340
+
341
+ if use_abs_pos:
342
+ # Initialize absolute positional embedding with pretrain image size.
343
+ num_patches = (pretrain_img_size // patch_size) * (pretrain_img_size // patch_size)
344
+ num_positions = (num_patches + 1) if pretrain_use_cls_token else num_patches
345
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_positions, embed_dim))
346
+ else:
347
+ self.pos_embed = None
348
+
349
+ # stochastic depth decay rule
350
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
351
+
352
+ self.blocks = nn.ModuleList()
353
+ for i in range(depth):
354
+ block = Block(
355
+ dim=embed_dim,
356
+ num_heads=num_heads,
357
+ mlp_ratio=mlp_ratio,
358
+ qkv_bias=qkv_bias,
359
+ drop_path=dpr[i],
360
+ norm_layer=norm_layer,
361
+ act_layer=act_layer,
362
+ use_rel_pos=use_rel_pos,
363
+ rel_pos_zero_init=rel_pos_zero_init,
364
+ window_size=window_size if i in window_block_indexes else 0,
365
+ use_residual_block=i in residual_block_indexes,
366
+ input_size=(img_size // patch_size, img_size // patch_size),
367
+ res_conv_kernel_size=res_conv_kernel_size,
368
+ res_conv_padding=res_conv_padding,
369
+ )
370
+ if use_act_checkpoint:
371
+ block = checkpoint_wrapper(block)
372
+ self.blocks.append(block)
373
+
374
+ self._out_feature_channels = {out_feature: embed_dim}
375
+ self._out_feature_strides = {out_feature: patch_size}
376
+ self._out_features = [out_feature]
377
+
378
+ if self.pos_embed is not None:
379
+ trunc_normal_(self.pos_embed, std=0.02)
380
+
381
+ self.apply(self._init_weights)
382
+
383
+ def _init_weights(self, m):
384
+ if isinstance(m, nn.Linear):
385
+ trunc_normal_(m.weight, std=0.02)
386
+ if isinstance(m, nn.Linear) and m.bias is not None:
387
+ nn.init.constant_(m.bias, 0)
388
+ elif isinstance(m, nn.LayerNorm):
389
+ nn.init.constant_(m.bias, 0)
390
+ nn.init.constant_(m.weight, 1.0)
391
+
392
+ def forward(self, x):
393
+ x = self.patch_embed(x)
394
+ if self.pos_embed is not None:
395
+ x = x + get_abs_pos(
396
+ self.pos_embed, self.pretrain_use_cls_token, (x.shape[1], x.shape[2])
397
+ )
398
+
399
+ for blk in self.blocks:
400
+ x = blk(x)
401
+
402
+ outputs = {self._out_features[0]: x.permute(0, 3, 1, 2)}
403
+
404
+ return outputs['last_feat']
modeling/criterion/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .matting_criterion import MattingCriterion
modeling/criterion/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (217 Bytes). View file
 
modeling/criterion/__pycache__/matting_criterion.cpython-38.pyc ADDED
Binary file (7.89 kB). View file
 
modeling/criterion/matting_criterion.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from collections import defaultdict
5
+
6
+
7
+ class MattingCriterion(nn.Module):
8
+ def __init__(
9
+ self,
10
+ *,
11
+ losses,
12
+ image_size = 1024,
13
+ ):
14
+ super(MattingCriterion, self).__init__()
15
+ self.losses = losses
16
+ self.image_size = image_size
17
+
18
+ def loss_gradient_penalty(self, sample_map, preds, targets):
19
+
20
+ #sample_map for unknown area
21
+ if torch.sum(sample_map) == 0:
22
+ scale = 0
23
+ else:
24
+ scale = sample_map.shape[0] * (self.image_size ** 2) / torch.sum(sample_map)
25
+
26
+ #gradient in x
27
+ sobel_x_kernel = torch.tensor([[[[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]]]).type(dtype=preds.type())
28
+ delta_pred_x = F.conv2d(preds, weight=sobel_x_kernel, padding=1)
29
+ delta_gt_x = F.conv2d(targets, weight=sobel_x_kernel, padding=1)
30
+
31
+ #gradient in y
32
+ sobel_y_kernel = torch.tensor([[[[-1, -2, -1], [0, 0, 0], [1, 2, 1]]]]).type(dtype=preds.type())
33
+ delta_pred_y = F.conv2d(preds, weight=sobel_y_kernel, padding=1)
34
+ delta_gt_y = F.conv2d(targets, weight=sobel_y_kernel, padding=1)
35
+
36
+ #loss
37
+ loss = (F.l1_loss(delta_pred_x * sample_map, delta_gt_x * sample_map) * scale + \
38
+ F.l1_loss(delta_pred_y * sample_map, delta_gt_y * sample_map) * scale + \
39
+ 0.01 * torch.mean(torch.abs(delta_pred_x * sample_map)) * scale + \
40
+ 0.01 * torch.mean(torch.abs(delta_pred_y * sample_map)) * scale)
41
+
42
+ return dict(loss_gradient_penalty=loss)
43
+
44
+ def loss_pha_laplacian(self, preds, targets):
45
+ loss = laplacian_loss(preds, targets)
46
+ return dict(loss_pha_laplacian=loss)
47
+
48
+ def unknown_l1_loss(self, sample_map, preds, targets):
49
+
50
+ if torch.sum(sample_map) == 0:
51
+ scale = 0
52
+ else:
53
+ scale = sample_map.shape[0] * (self.image_size ** 2) / torch.sum(sample_map)
54
+ # scale = 1
55
+
56
+ loss = F.l1_loss(preds * sample_map, targets * sample_map) * scale
57
+
58
+ return dict(unknown_l1_loss=loss)
59
+
60
+ def known_l1_loss(self, sample_map, preds, targets):
61
+ new_sample_map = torch.zeros_like(sample_map)
62
+ new_sample_map[sample_map==0] = 1
63
+
64
+ if torch.sum(new_sample_map) == 0:
65
+ scale = 0
66
+ else:
67
+ scale = new_sample_map.shape[0] * (self.image_size ** 2) / torch.sum(new_sample_map)
68
+ # scale = 1
69
+
70
+ loss = F.l1_loss(preds * new_sample_map, targets * new_sample_map) * scale
71
+
72
+ return dict(known_l1_loss=loss)
73
+
74
+ def get_loss(self, k, sample_map, preds, targets):
75
+ if k=='unknown_l1_loss' or k=='known_l1_loss' or k=='loss_gradient_penalty':
76
+ losses = getattr(self, k)(sample_map, preds, targets)
77
+ else:
78
+ losses = getattr(self, k)(preds, targets)
79
+ assert len(list(losses.keys())) == 1
80
+ return losses[list(losses.keys())[0]]
81
+
82
+ def forward(self, sample_map, preds, targets, batch_weight=None):
83
+ losses = {i: torch.tensor(0.0, device=sample_map.device) for i in self.losses}
84
+ for k in self.losses:
85
+ if batch_weight is None:
86
+ losses[k] += self.get_loss(k, sample_map, preds, targets)
87
+ else:
88
+ for i, loss_weight in enumerate(batch_weight):
89
+ if loss_weight == -1.0 and k != 'known_l1_loss':
90
+ continue
91
+ else:
92
+ losses[k] += self.get_loss(k, sample_map[i: i + 1], preds[i: i + 1], targets[i: i + 1]) * abs(loss_weight)
93
+ return losses
94
+
95
+
96
+ #-----------------Laplacian Loss-------------------------#
97
+ def laplacian_loss(pred, true, max_levels=5):
98
+ kernel = gauss_kernel(device=pred.device, dtype=pred.dtype)
99
+ pred_pyramid = laplacian_pyramid(pred, kernel, max_levels)
100
+ true_pyramid = laplacian_pyramid(true, kernel, max_levels)
101
+ loss = 0
102
+ for level in range(max_levels):
103
+ loss += (2 ** level) * F.l1_loss(pred_pyramid[level], true_pyramid[level])
104
+ return loss / max_levels
105
+
106
+ def laplacian_pyramid(img, kernel, max_levels):
107
+ current = img
108
+ pyramid = []
109
+ for _ in range(max_levels):
110
+ current = crop_to_even_size(current)
111
+ down = downsample(current, kernel)
112
+ up = upsample(down, kernel)
113
+ diff = current - up
114
+ pyramid.append(diff)
115
+ current = down
116
+ return pyramid
117
+
118
+ def gauss_kernel(device='cpu', dtype=torch.float32):
119
+ kernel = torch.tensor([[1, 4, 6, 4, 1],
120
+ [4, 16, 24, 16, 4],
121
+ [6, 24, 36, 24, 6],
122
+ [4, 16, 24, 16, 4],
123
+ [1, 4, 6, 4, 1]], device=device, dtype=dtype)
124
+ kernel /= 256
125
+ kernel = kernel[None, None, :, :]
126
+ return kernel
127
+
128
+ def gauss_convolution(img, kernel):
129
+ B, C, H, W = img.shape
130
+ img = img.reshape(B * C, 1, H, W)
131
+ img = F.pad(img, (2, 2, 2, 2), mode='reflect')
132
+ img = F.conv2d(img, kernel)
133
+ img = img.reshape(B, C, H, W)
134
+ return img
135
+
136
+ def downsample(img, kernel):
137
+ img = gauss_convolution(img, kernel)
138
+ img = img[:, :, ::2, ::2]
139
+ return img
140
+
141
+ def upsample(img, kernel):
142
+ B, C, H, W = img.shape
143
+ out = torch.zeros((B, C, H * 2, W * 2), device=img.device, dtype=img.dtype)
144
+ out[:, :, ::2, ::2] = img * 4
145
+ out = gauss_convolution(out, kernel)
146
+ return out
147
+
148
+ def crop_to_even_size(img):
149
+ H, W = img.shape[2:]
150
+ H = H - H % 2
151
+ W = W - W % 2
152
+ return img[:, :, :H, :W]
153
+
154
+ def normalized_focal_loss(pred, gt, gamma=2, class_num=3, norm=True, beta_detach=False, beta_sum_detach=False):
155
+ pred_logits = F.softmax(pred, dim=1) # [B, 3, H, W]
156
+ gt_one_hot = F.one_hot(gt, class_num).permute(0, 3, 1, 2) # [B, 3, H, W]
157
+ p = (pred_logits * gt_one_hot).sum(dim=1) # [B, H, W]
158
+ beta = (1 - p) ** gamma # [B, H, W]
159
+ beta_sum = torch.sum(beta, dim=(-2, -1), keepdim=True) / (pred.shape[-1] * pred.shape[-2]) # [B, 1, 1]
160
+
161
+ if beta_detach:
162
+ beta = beta.detach()
163
+ if beta_sum_detach:
164
+ beta_sum = beta_sum.detach()
165
+
166
+ if norm:
167
+ loss = 1 / beta_sum * beta * (-torch.log(p))
168
+ return torch.mean(loss)
169
+ else:
170
+ loss = beta * (-torch.log(p))
171
+ return torch.mean(loss)
172
+
173
+ class GHMC(nn.Module):
174
+ def __init__(self, bins=10, momentum=0.75, loss_weight=1.0, device='cuda', norm=False):
175
+ super(GHMC, self).__init__()
176
+ self.bins = bins
177
+ self.momentum = momentum
178
+ self.edges = torch.arange(bins + 1).float().cuda() / bins
179
+ self.edges[-1] += 1e-6
180
+ if momentum > 0:
181
+ self.acc_sum = torch.zeros(bins).cuda()
182
+ self.loss_weight = loss_weight
183
+ self.device = device
184
+ self.norm = norm
185
+
186
+ def forward(self, pred, target, *args, **kwargs):
187
+ """Calculate the GHM-C loss.
188
+ Args:
189
+ pred (float tensor of size [batch_num, class_num]):
190
+ The direct prediction of classification fc layer.
191
+ target (float tensor of size [batch_num, class_num]):
192
+ Binary class target for each sample.
193
+ label_weight (float tensor of size [batch_num, class_num]):
194
+ the value is 1 if the sample is valid and 0 if ignored.
195
+ Returns:
196
+ The gradient harmonized loss.
197
+ """
198
+
199
+ # the target should be binary class label
200
+ # if pred.dim() != target.dim():
201
+ # target, label_weight = _expand_binary_labels(
202
+ # target, label_weight, pred.size(-1))
203
+ # target, label_weight = target.float(), label_weight.float()
204
+ # pdb.set_trace()
205
+
206
+ # pred: [B, C, H, W], target: [B, H, W]
207
+ pred = pred.permute(0, 2, 3, 1).reshape(-1, 3) # [B x H x W, C]
208
+ target = target.reshape(-1) # [B x H x W]
209
+ # self.acc_sum = self.acc_sum.type(pred.dtype)
210
+
211
+ edges = self.edges
212
+ mmt = self.momentum
213
+ weights = torch.zeros((target.shape),dtype=pred.dtype).to(self.device)
214
+
215
+ # gradient length
216
+ #g = 1 - torch.index_select(F.softmax(pred,dim=1).detach(), dim=0, index=target)
217
+ g = 1 - torch.gather(F.softmax(pred,dim=1).detach(),dim=1,index=target.unsqueeze(1))
218
+ #g = torch.abs(pred.softmax(2).detach() - target)
219
+
220
+ tot = 1.0
221
+ n = 0 # n valid bins
222
+ for i in range(self.bins):
223
+ inds = (g >= edges[i]) & (g < edges[i+1])
224
+ num_in_bin = inds.sum().item()
225
+ if num_in_bin > 0:
226
+ idx = torch.nonzero(inds)[:, 0]
227
+ if mmt > 0:
228
+ self.acc_sum[i] = mmt * self.acc_sum[i] \
229
+ + (1 - mmt) * num_in_bin
230
+ # pdb.set_trace()#scatter_ index_put_
231
+ #BB=torch.nonzero(inds)
232
+ _weight_idx = tot / self.acc_sum[i]
233
+ weights = weights.to(dtype=_weight_idx.dtype)
234
+ weights[idx] = _weight_idx
235
+ # weights.scatter_(0, torch.nonzero(inds)[:,0], tot / self.acc_sum[i])
236
+ # # weights.index_put_(inds, tot / self.acc_sum[i])
237
+ # weights[inds] = tot / self.acc_sum[i] # * torch.ones((len(inds)))
238
+ else:
239
+ weights[idx] = tot / num_in_bin
240
+ n += 1
241
+ if n > 0:
242
+ weights = weights / n
243
+
244
+ # pdb.set_trace()
245
+ # loss = (weights * F.cross_entropy(pred, target, reduction='none')).sum() / tot / pred.shape[0]
246
+ if self.norm:
247
+ weights = weights / torch.sum(weights).detach()
248
+
249
+ loss = - ((weights.unsqueeze(1) * torch.gather(F.log_softmax(pred, dim=1), dim=1, index=target.unsqueeze(1))).sum() ) # / pred.shape[0]
250
+
251
+ # loss3= F.cross_entropy(pred, target, reduction='mean')
252
+ # loss4 = - ((torch.gather(F.log_softmax(pred, dim=1), dim=1, index=target.unsqueeze(1))).sum() / pred.shape[0])
253
+
254
+ # pro = F.softmax(logits, dim=1)
255
+ #
256
+ # label_onehot = torch.zeros_like(logits).scatter_(1, labels.unsqueeze(1), 1)
257
+ # with torch.no_grad():
258
+ # weight_matrix = (1 - pro) ** self.gamma
259
+ # # pdb.set_trace()
260
+ # fl = - (weight_matrix * (label_onehot * (pro + self.eps).log())).sum() / pro.shape[0]
261
+
262
+ return loss
263
+
264
+ if __name__ == '__main__':
265
+ pred = torch.randn(2, 3, 1024, 1024)
266
+ gt =torch.argmax(torch.randn(2, 3, 1024, 1024), dim=1)
267
+ loss = normalized_focal_loss(pred, gt)
268
+ print(loss)
269
+
270
+
271
+
modeling/decoder/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .detail_capture import Detail_Capture, Ori_Detail_Capture
modeling/decoder/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (239 Bytes). View file
 
modeling/decoder/__pycache__/detail_capture.cpython-38.pyc ADDED
Binary file (5.37 kB). View file
 
modeling/decoder/__pycache__/unet_detail_capture.cpython-38.pyc ADDED
Binary file (10.8 kB). View file
 
modeling/decoder/detail_capture.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import functional as F
4
+
5
+ class Basic_Conv3x3(nn.Module):
6
+ """
7
+ Basic convolution layers including: Conv3x3, BatchNorm2d, ReLU layers.
8
+ """
9
+ def __init__(
10
+ self,
11
+ in_chans,
12
+ out_chans,
13
+ stride=2,
14
+ padding=1,
15
+ ):
16
+ super().__init__()
17
+ self.conv = nn.Conv2d(in_chans, out_chans, 3, stride, padding, bias=False)
18
+ self.bn = nn.BatchNorm2d(out_chans)
19
+ self.relu = nn.ReLU(True)
20
+
21
+ def forward(self, x):
22
+ x = self.conv(x)
23
+ x = self.bn(x)
24
+ x = self.relu(x)
25
+
26
+ return x
27
+
28
+ class ConvStream(nn.Module):
29
+ """
30
+ Simple ConvStream containing a series of basic conv3x3 layers to extract detail features.
31
+ """
32
+ def __init__(
33
+ self,
34
+ in_chans = 4,
35
+ out_chans = [48, 96, 192],
36
+ ):
37
+ super().__init__()
38
+ self.convs = nn.ModuleList()
39
+
40
+ self.conv_chans = out_chans.copy()
41
+ self.conv_chans.insert(0, in_chans)
42
+
43
+ for i in range(len(self.conv_chans)-1):
44
+ in_chan_ = self.conv_chans[i]
45
+ out_chan_ = self.conv_chans[i+1]
46
+ self.convs.append(
47
+ Basic_Conv3x3(in_chan_, out_chan_)
48
+ )
49
+
50
+ def forward(self, x):
51
+ out_dict = {'D0': x}
52
+ for i in range(len(self.convs)):
53
+ x = self.convs[i](x)
54
+ name_ = 'D'+str(i+1)
55
+ out_dict[name_] = x
56
+
57
+ return out_dict
58
+
59
+ class Fusion_Block(nn.Module):
60
+ """
61
+ Simple fusion block to fuse feature from ConvStream and Plain Vision Transformer.
62
+ """
63
+ def __init__(
64
+ self,
65
+ in_chans,
66
+ out_chans,
67
+ ):
68
+ super().__init__()
69
+ self.conv = Basic_Conv3x3(in_chans, out_chans, stride=1, padding=1)
70
+
71
+ def forward(self, x, D):
72
+ F_up = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
73
+ out = torch.cat([D, F_up], dim=1)
74
+ out = self.conv(out)
75
+
76
+ return out
77
+
78
+ class Matting_Head(nn.Module):
79
+ """
80
+ Simple Matting Head, containing only conv3x3 and conv1x1 layers.
81
+ """
82
+ def __init__(
83
+ self,
84
+ in_chans = 32,
85
+ mid_chans = 16,
86
+ ):
87
+ super().__init__()
88
+ self.matting_convs = nn.Sequential(
89
+ nn.Conv2d(in_chans, mid_chans, 3, 1, 1),
90
+ nn.BatchNorm2d(mid_chans),
91
+ nn.ReLU(True),
92
+ nn.Conv2d(mid_chans, 1, 1, 1, 0)
93
+ )
94
+
95
+ def forward(self, x):
96
+ x = self.matting_convs(x)
97
+
98
+ return x
99
+
100
+ class Detail_Capture(nn.Module):
101
+ """
102
+ Simple and Lightweight Detail Capture Module for ViT Matting.
103
+ """
104
+ def __init__(
105
+ self,
106
+ in_chans = [384, 1],
107
+ img_chans=4,
108
+ convstream_out = [48, 96, 192],
109
+ fusion_out = [256, 128, 64, 32],
110
+ ):
111
+ super().__init__()
112
+ assert len(fusion_out) == len(convstream_out) + 1
113
+
114
+ self.convstream = ConvStream(in_chans=img_chans, out_chans=convstream_out)
115
+ self.conv_chans = self.convstream.conv_chans # [4, 48, 96, 192]
116
+
117
+ self.fusion_blks = nn.ModuleList()
118
+ self.fus_channs = fusion_out.copy()
119
+ self.fus_channs.insert(0, in_chans[0]) # [384, 256, 128, 64, 32]
120
+ for i in range(len(self.fus_channs)-1):
121
+ in_channels = self.fus_channs[i] + self.conv_chans[-(i+1)] if i != 2 else in_chans[1] + self.conv_chans[-(i+1)] # [256 + 192 = 448, 256 + 96 = 352, 128 + 48 = 176, 64 + 4 = 68]
122
+ out_channels = self.fus_channs[i+1] # [256, 128, 64, 32]
123
+ self.fusion_blks.append(
124
+ Fusion_Block(
125
+ in_chans = in_channels,
126
+ out_chans = out_channels,
127
+ )
128
+ )
129
+
130
+ self.matting_head = Matting_Head( # 32 --> 1
131
+ in_chans = fusion_out[-1],
132
+ )
133
+
134
+ def forward(self, features, images):
135
+ detail_features = self.convstream(images) # [1, 4, 672, 992] --> D0: [1, 4, 672, 992], D1: [1, 48, 336, 496], D2: [1, 96, 168, 248], D3: [1, 192, 84, 124]
136
+ for i in range(len(self.fusion_blks)): # D3
137
+ d_name_ = 'D'+str(len(self.fusion_blks)-i-1)
138
+ features = self.fusion_blks[i](features, detail_features[d_name_])
139
+
140
+ phas = torch.sigmoid(self.matting_head(features))
141
+
142
+ return {'phas': phas}
143
+
144
+
145
+ class Ori_Detail_Capture(nn.Module):
146
+ """
147
+ Simple and Lightweight Detail Capture Module for ViT Matting.
148
+ """
149
+ def __init__(
150
+ self,
151
+ in_chans = 384,
152
+ img_chans=4,
153
+ convstream_out = [48, 96, 192],
154
+ fusion_out = [256, 128, 64, 32],
155
+ ):
156
+ super().__init__()
157
+ assert len(fusion_out) == len(convstream_out) + 1
158
+
159
+ self.convstream = ConvStream(in_chans = img_chans)
160
+ self.conv_chans = self.convstream.conv_chans
161
+
162
+ self.fusion_blks = nn.ModuleList()
163
+ self.fus_channs = fusion_out.copy()
164
+ self.fus_channs.insert(0, in_chans)
165
+ for i in range(len(self.fus_channs)-1):
166
+ self.fusion_blks.append(
167
+ Fusion_Block(
168
+ in_chans = self.fus_channs[i] + self.conv_chans[-(i+1)],
169
+ out_chans = self.fus_channs[i+1],
170
+ )
171
+ )
172
+
173
+ self.matting_head = Matting_Head(
174
+ in_chans = fusion_out[-1],
175
+ )
176
+
177
+ def forward(self, features, images):
178
+ detail_features = self.convstream(images)
179
+ for i in range(len(self.fusion_blks)):
180
+ d_name_ = 'D'+str(len(self.fusion_blks)-i-1)
181
+ features = self.fusion_blks[i](features, detail_features[d_name_])
182
+
183
+ phas = torch.sigmoid(self.matting_head(features))
184
+
185
+ return {'phas': phas}
modeling/decoder/unet_detail_capture.py ADDED
@@ -0,0 +1,429 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import functional as F
5
+ # from nnMorpho.binary_operators import erosion
6
+ from detectron2.layers.batch_norm import NaiveSyncBatchNorm
7
+
8
+
9
+ class GenTrimapTorch(object):
10
+ def __init__(self, max_kernal=200):
11
+ self.max_kernal = max_kernal
12
+ self.erosion_kernels = [None] + [torch.from_numpy(cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (size, size))).float().cuda() for size in range(1, self.max_kernal)]
13
+
14
+ def __call__(self, mask, kernel_size):
15
+
16
+ fg_width = kernel_size
17
+ bg_width = kernel_size
18
+
19
+ fg_mask = mask
20
+ bg_mask = 1 - mask
21
+
22
+ fg_mask = erosion(fg_mask, self.erosion_kernels[fg_width], border='a')
23
+ bg_mask = erosion(bg_mask, self.erosion_kernels[bg_width], border='a')
24
+
25
+ trimap = torch.ones_like(mask) * 0.5
26
+ trimap[fg_mask == 1] = 1.0
27
+ trimap[bg_mask == 1] = 0.0
28
+
29
+ return trimap
30
+
31
+
32
+ class LayerNorm2d(nn.Module):
33
+ def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
34
+ super().__init__()
35
+ self.weight = nn.Parameter(torch.ones(num_channels))
36
+ self.bias = nn.Parameter(torch.zeros(num_channels))
37
+ self.eps = eps
38
+
39
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
40
+ u = x.mean(1, keepdim=True)
41
+ s = (x - u).pow(2).mean(1, keepdim=True)
42
+ x = (x - u) / torch.sqrt(s + self.eps)
43
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
44
+ return x
45
+
46
+
47
+ class BasicDownBlock(nn.Module):
48
+ def __init__(self, in_channel, out_channel, res = True, norm=LayerNorm2d, block_num=1, kernel_size=3):
49
+ super().__init__()
50
+
51
+ self.res = res
52
+ self.basic_layer = nn.ModuleList()
53
+ for i in range(block_num):
54
+ if i == 0:
55
+ basic_layer_in_ch = in_channel
56
+ stride = 2
57
+ else:
58
+ basic_layer_in_ch = out_channel
59
+ stride = 1
60
+ self.basic_layer.append(nn.GELU())
61
+ self.basic_layer.append(nn.Sequential(
62
+ nn.Conv2d(basic_layer_in_ch, out_channel, kernel_size, stride, kernel_size // 2),
63
+ norm(out_channel),
64
+ nn.GELU(),
65
+ nn.Conv2d(out_channel, out_channel, kernel_size, 1, kernel_size // 2),
66
+ norm(out_channel),
67
+ ))
68
+ self.act = nn.GELU()
69
+
70
+ if self.res:
71
+ self.res_layer = nn.Conv2d(in_channel, out_channel, kernel_size, 2, kernel_size // 2)
72
+
73
+ def forward(self, x):
74
+
75
+ if self.res:
76
+ identity = self.res_layer(x)
77
+ else:
78
+ identity = F.interpolate(x, size=(out.shape[-2], out.shape[-1]), mode='bilinear', align_corners=False)
79
+
80
+ out = x
81
+ for layer in self.basic_layer:
82
+ out = layer(out)
83
+
84
+ out = out + identity
85
+ out = self.act(out)
86
+
87
+ return out
88
+
89
+
90
+ class BasicUpBlock(nn.Module):
91
+
92
+ def __init__( self, in_channel, out_channel, res = True, skip_connect = 'concat', norm=LayerNorm2d, block_num=1, kernel_size=3):
93
+ super().__init__()
94
+ assert skip_connect in {'sum', 'concat'}
95
+
96
+ self.res = res
97
+ self.skip_connect = skip_connect
98
+ self.basic_layer = nn.ModuleList()
99
+ for i in range(block_num):
100
+ if i == 0:
101
+ basic_layer_in_ch = in_channel
102
+ first_conv = nn.ConvTranspose2d(basic_layer_in_ch, out_channel, 2, 2)
103
+ else:
104
+ basic_layer_in_ch = out_channel
105
+ first_conv = nn.Conv2d(out_channel, out_channel, kernel_size, 1, kernel_size // 2)
106
+ self.basic_layer.append(nn.GELU())
107
+ self.basic_layer.append(nn.Sequential(
108
+ first_conv,
109
+ norm(out_channel),
110
+ nn.GELU(),
111
+ nn.Conv2d(out_channel, out_channel, kernel_size, 1, kernel_size // 2),
112
+ norm(out_channel),
113
+ ))
114
+ self.act = nn.GELU()
115
+
116
+ if self.res:
117
+ self.res_layer = nn.Conv2d(in_channel, out_channel, kernel_size, 1, kernel_size // 2)
118
+
119
+
120
+ def forward(self, x, skip_feat, concat_feat=None):
121
+
122
+ if self.skip_connect == 'sum':
123
+ x = x + skip_feat
124
+ else:
125
+ x = torch.concat((x, skip_feat), dim=1)
126
+
127
+ if concat_feat is not None:
128
+ x = torch.concat((x, concat_feat), dim=1)
129
+
130
+ out = x
131
+ for layer in self.basic_layer:
132
+ out = layer(out)
133
+ # out = self.basic_layer(x)
134
+
135
+ identity = F.interpolate(x, size=(out.shape[-2], out.shape[-1]), mode='bilinear', align_corners=False)
136
+ if self.res:
137
+ identity = self.res_layer(identity)
138
+
139
+ out = out + identity
140
+ out = self.act(out)
141
+
142
+ return out
143
+
144
+
145
+
146
+ class DetailUNet(nn.Module):
147
+ def __init__(
148
+ self,
149
+ img_feat_in = 4,
150
+ vit_early_feat_in = 768,
151
+ matting_feat_in = 5,
152
+ downsample_in_out = [(4, 32), (32, 64), (64, 128), (128, 256)],
153
+ upsample_in_out = [(256, 128), (128, 64), (64, 32), (32, 16)],
154
+ matting_head_in = 16,
155
+ skip_connect = 'sum',
156
+ norm_type = 'LN',
157
+ ):
158
+ super().__init__()
159
+
160
+ assert len(downsample_in_out) == len(upsample_in_out)
161
+ downsample_in_out[0] = (img_feat_in, downsample_in_out[0][1])
162
+
163
+ assert norm_type in {'BN', 'LN', 'SyncBN'}
164
+ if norm_type == 'BN':
165
+ self.norm = torch.nn.BatchNorm2d
166
+ elif norm_type == 'SyncBN':
167
+ self.norm = NaiveSyncBatchNorm
168
+ else:
169
+ self.norm = LayerNorm2d
170
+
171
+ self.down_blks = nn.ModuleList()
172
+ for in_ch, out_ch in downsample_in_out:
173
+ self.down_blks.append(
174
+ BasicDownBlock(in_ch, out_ch, norm=self.norm)
175
+ )
176
+
177
+ self.mid_layer = nn.Sequential(
178
+ nn.Conv2d(vit_early_feat_in, downsample_in_out[-1][1], 1, 1),
179
+ self.norm(downsample_in_out[-1][1]),
180
+ nn.GELU(),
181
+ )
182
+
183
+ self.up_blks = nn.ModuleList()
184
+ for i, (in_ch, out_ch) in enumerate(upsample_in_out):
185
+ if i == 2:
186
+ in_ch += matting_feat_in
187
+ self.up_blks.append(
188
+ BasicUpBlock(in_ch, out_ch, skip_connect=skip_connect, norm=self.norm)
189
+ )
190
+
191
+ self.matting_head = nn.Conv2d(matting_head_in, 1, 3, 1, 1)
192
+
193
+
194
+ def forward(self, x, vit_early_feat, matting_feat, return_alpha_logits=False):
195
+ details = []
196
+ dfeatures = x
197
+
198
+ for i in range(len(self.down_blks)):
199
+ dfeatures = self.down_blks[i](dfeatures)
200
+ details.append(dfeatures)
201
+
202
+ out = self.mid_layer(vit_early_feat)
203
+ for i in range(len(self.up_blks)):
204
+ if i == 2:
205
+ out = self.up_blks[i](out, details[-i - 1], matting_feat)
206
+ else:
207
+ out = self.up_blks[i](out, details[-i - 1])
208
+ alpha = self.matting_head(out)
209
+ if return_alpha_logits:
210
+ return alpha, out
211
+ else:
212
+ return alpha
213
+
214
+
215
+ class MattingDetailDecoder(nn.Module):
216
+ def __init__(
217
+ self,
218
+ img_feat_in = 4,
219
+ vit_intern_feat_in = 1024,
220
+ vit_intern_feat_index = [0, 1, 2, 3],
221
+ downsample_in_out = [(4, 32), (32, 64), (64, 128), (128, 256)],
222
+ upsample_in_out = [(256, 128), (128, 64), (64, 32), (32, 16)],
223
+ matting_head_in = 16,
224
+ skip_connect = 'sum',
225
+ norm_type = 'BN',
226
+ norm_mask_logits = 6.5,
227
+ with_trimap = False,
228
+ min_kernel_size = 20,
229
+ kernel_div = 10,
230
+ concat_gen_trimap = False,
231
+ wo_hq_features = False,
232
+ block_num = 1,
233
+ wo_big_kernel = False,
234
+ sam2_multi_scale_feates = False,
235
+ ):
236
+ super().__init__()
237
+
238
+ assert len(downsample_in_out) == len(upsample_in_out)
239
+ assert skip_connect in {'sum', 'concat'}
240
+ downsample_in_out[0] = (img_feat_in, downsample_in_out[0][1])
241
+
242
+ self.vit_intern_feat_in = vit_intern_feat_in
243
+ self.vit_intern_feat_index = vit_intern_feat_index
244
+ self.norm_mask_logits = norm_mask_logits
245
+ self.with_trimap = with_trimap
246
+ self.min_kernel_size = min_kernel_size
247
+ self.kernel_div = kernel_div
248
+ self.concat_gen_trimap = concat_gen_trimap
249
+ self.wo_hq_features = wo_hq_features
250
+ self.block_num = block_num
251
+ self.wo_big_kernel = wo_big_kernel
252
+ self.sam2_multi_scale_feates = sam2_multi_scale_feates
253
+ if self.sam2_multi_scale_feates:
254
+ assert downsample_in_out[0][0] == 6
255
+ downsample_in_out = [(4, 32), (32, 64), (64 + 32, 128), (128 + 64, 256)]
256
+ upsample_in_out = [(256, 128), (128, 64), (64, 32), (32, 16)]
257
+
258
+ if self.with_trimap and not self.concat_gen_trimap:
259
+ self.gen_trimap = GenTrimapTorch()
260
+ assert norm_type in {'BN', 'LN', 'SyncBN'}
261
+ if norm_type == 'BN':
262
+ self.norm = torch.nn.BatchNorm2d
263
+ elif norm_type == 'SyncBN':
264
+ self.norm = NaiveSyncBatchNorm
265
+ else:
266
+ self.norm = LayerNorm2d
267
+
268
+ if self.block_num >= 2 and not self.wo_big_kernel:
269
+ self.big_kernel_process = nn.Sequential(
270
+ nn.Conv2d(img_feat_in, 16, kernel_size=13, stride=1, padding=6),
271
+ self.norm(16),
272
+ nn.GELU(),
273
+ nn.Conv2d(16, 32, kernel_size=13, stride=1, padding=6),
274
+ self.norm(32),
275
+ nn.GELU(),
276
+ )
277
+ downsample_in_out[0] = (32, downsample_in_out[0][1])
278
+
279
+ if not self.sam2_multi_scale_feates:
280
+ self.vit_feat_proj = nn.ModuleDict()
281
+ for idx in self.vit_intern_feat_index:
282
+ self.vit_feat_proj[str(idx)] = nn.Conv2d(self.vit_intern_feat_in, self.vit_intern_feat_in // len(self.vit_intern_feat_index), 1, 1)
283
+ self.vit_feat_aggregation = nn.Sequential(
284
+ nn.Conv2d(self.vit_intern_feat_in // len(self.vit_intern_feat_index) * len(self.vit_intern_feat_index), downsample_in_out[-1][1], 3, 1, 1),
285
+ self.norm(downsample_in_out[-1][1]),
286
+ nn.GELU(),
287
+ )
288
+
289
+ self.down_blks = nn.ModuleList()
290
+ for in_ch, out_ch in downsample_in_out:
291
+ self.down_blks.append(
292
+ BasicDownBlock(in_ch, out_ch, norm=self.norm, block_num=self.block_num, kernel_size=5 if self.block_num >= 2 else 3)
293
+ )
294
+
295
+ if self.sam2_multi_scale_feates:
296
+ self.mid_layer = nn.ModuleList([
297
+ nn.Sequential(
298
+ nn.Conv2d(32, 32, 1, 1),
299
+ self.norm(32),
300
+ nn.GELU(),
301
+ ),
302
+ nn.Sequential(
303
+ nn.Conv2d(64, 64, 1, 1),
304
+ self.norm(64),
305
+ nn.GELU(),
306
+ ),
307
+ nn.Sequential(
308
+ nn.Conv2d(256, 256, 1, 1),
309
+ self.norm(256),
310
+ nn.GELU(),
311
+ ),
312
+ nn.Sequential(
313
+ nn.Conv2d(512, 256, 3, 1, 1),
314
+ self.norm(256),
315
+ nn.GELU(),
316
+ ),
317
+ ])
318
+ else:
319
+ self.mid_layer = nn.Sequential(
320
+ nn.Conv2d(downsample_in_out[-1][1] * 2, downsample_in_out[-1][1], 1, 1),
321
+ self.norm(downsample_in_out[-1][1]),
322
+ nn.GELU(),
323
+ )
324
+
325
+ self.up_blks = nn.ModuleList()
326
+ for _, (in_ch, out_ch) in enumerate(upsample_in_out):
327
+ if skip_connect == 'concat':
328
+ self.up_blks.append(BasicUpBlock(in_ch * 2, out_ch, skip_connect=skip_connect, norm=self.norm, block_num=self.block_num))
329
+ else:
330
+ self.up_blks.append(BasicUpBlock(in_ch, out_ch, skip_connect=skip_connect, norm=self.norm, block_num=self.block_num))
331
+
332
+ self.matting_head = nn.Conv2d(matting_head_in, 1, 3, 1, 1)
333
+
334
+ if self.norm_mask_logits == 'BN':
335
+ self.logits_norm = self.norm(1)
336
+
337
+
338
+ def preprocess_inputs(self, images, hq_features, pred_trimap):
339
+
340
+ if self.wo_hq_features:
341
+ return images
342
+
343
+ if isinstance(self.norm_mask_logits, float):
344
+ norm_hq_features = hq_features / self.norm_mask_logits
345
+ elif self.norm_mask_logits == 'BN':
346
+ norm_hq_features = self.logits_norm(hq_features)
347
+ elif self.norm_mask_logits == 'Sigmoid':
348
+ if hq_features.shape[1] == 1:
349
+ norm_hq_features = torch.sigmoid(hq_features)
350
+ else:
351
+ norm_hq_features = torch.softmax(hq_features, dim=1)
352
+ elif self.norm_mask_logits:
353
+ norm_hq_features = hq_features / torch.std(hq_features, dim=(1, 2, 3), keepdim=True)
354
+ else:
355
+ norm_hq_features = hq_features
356
+
357
+ if self.concat_gen_trimap:
358
+ pred_trimap = F.interpolate(pred_trimap, size=(images.shape[-2], images.shape[-1]), mode='bilinear', align_corners=False)
359
+ pred_trimap = torch.argmax(pred_trimap, dim=1, keepdim=True).float() / 2.0
360
+ norm_hq_features = torch.concat((norm_hq_features, pred_trimap.detach()), dim=1)
361
+ elif self.with_trimap:
362
+ mask = (norm_hq_features > 0).float()
363
+ for i_batch in range(images.shape[0]):
364
+ mask_area = torch.sum(mask[i_batch])
365
+ kernel_size = max(self.min_kernel_size, int((mask_area ** 0.5) / self.kernel_div))
366
+ kernel_size = min(kernel_size, self.gen_trimap.max_kernal - 1)
367
+ mask[i_batch, 0] = self.gen_trimap(mask[i_batch, 0], kernel_size=kernel_size)
368
+ trimaps = mask
369
+ norm_hq_features = torch.concat((norm_hq_features, trimaps), dim=1)
370
+
371
+ conditional_images = torch.concatenate((images, norm_hq_features), dim=1)
372
+ return conditional_images
373
+
374
+ def forward(self, images, hq_features, vit_intern_feat, return_alpha_logits=False, pred_trimap=None):
375
+
376
+ condition_input = self.preprocess_inputs(images, hq_features, pred_trimap)
377
+
378
+ if not self.sam2_multi_scale_feates:
379
+ # aggregate 4 vit_intern_feat
380
+ # assert len(vit_intern_feat) == self.vit_intern_feat_num
381
+ vit_feats = []
382
+ for idx in self.vit_intern_feat_index:
383
+ vit_feats.append(self.vit_feat_proj[str(idx)](vit_intern_feat[idx].permute(0, 3, 1, 2)))
384
+ vit_feats = torch.concat(vit_feats, dim=1)
385
+ vit_aggregation_feats = self.vit_feat_aggregation(vit_feats)
386
+
387
+ details = []
388
+ dfeatures = condition_input
389
+
390
+ if hasattr(self, 'big_kernel_process'):
391
+ dfeatures = self.big_kernel_process(dfeatures)
392
+
393
+ for i in range(len(self.down_blks)):
394
+ if self.sam2_multi_scale_feates:
395
+ if i == 2:
396
+ dfeatures = torch.concat((dfeatures, self.mid_layer[0](vit_intern_feat['high_res_feats'][0])), dim=1)
397
+ elif i == 3:
398
+ dfeatures = torch.concat((dfeatures, self.mid_layer[1](vit_intern_feat['high_res_feats'][1])), dim=1)
399
+ dfeatures = self.down_blks[i](dfeatures)
400
+ details.append(dfeatures)
401
+
402
+ if self.sam2_multi_scale_feates:
403
+ out = torch.concat((details[-1], self.mid_layer[2](vit_intern_feat['image_embed'])), dim=1)
404
+ out = self.mid_layer[3](out)
405
+ else:
406
+ out = self.mid_layer(torch.concat((details[-1], vit_aggregation_feats), dim=1))
407
+ for i in range(len(self.up_blks)):
408
+ out = self.up_blks[i](out, details[-i - 1])
409
+ alpha = torch.sigmoid(self.matting_head(out))
410
+ if return_alpha_logits:
411
+ return alpha, out
412
+ else:
413
+ return alpha
414
+
415
+
416
+
417
+ if __name__ == '__main__':
418
+
419
+ from engine.mattingtrainer import parameter_count_table
420
+
421
+ model = MattingDetailDecoder(img_feat_in = 5, vit_intern_feat_index=[0])
422
+ x = torch.randn((2, 5, 1024, 1024))
423
+ hq_features = torch.randn((2, 1, 1024, 1024))
424
+ vit_feat = [torch.randn((2, 64, 64, 1024)) for _ in range(4)]
425
+
426
+ out = model(x, hq_features, vit_feat)
427
+ print(out.shape)
428
+
429
+ print("Trainable parameters: \n" + parameter_count_table(model, trainable_only=True, max_depth=5))
modeling/meta_arch/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .sam_hq_matting import SamHqMatte
modeling/meta_arch/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (208 Bytes). View file
 
modeling/meta_arch/__pycache__/sam_hq_matting.cpython-38.pyc ADDED
Binary file (18.2 kB). View file
 
modeling/meta_arch/sam_hq_matting.py ADDED
@@ -0,0 +1,671 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torchvision
5
+ import os
6
+ import numpy as np
7
+ from PIL import Image
8
+ from copy import deepcopy
9
+ from collections import defaultdict
10
+
11
+ from detectron2.structures import ImageList
12
+ from detectron2.utils.comm import get_local_rank
13
+ from modeling.semantic_enhanced_matting.predictor import SamPredictor
14
+ from modeling.semantic_enhanced_matting.condition_conv import ConditionConv, ConditionEmbedding, ConditionAdd, BBoxEmbedInteract, BBoxInteract, BBoxInteractInOut
15
+ from modeling.semantic_enhanced_matting.modeling.image_encoder import PatchEmbed
16
+ from modeling.semantic_enhanced_matting.modeling.common import LayerNorm2d
17
+ from modeling.decoder.unet_detail_capture import MattingDetailDecoder
18
+ from modeling.semantic_enhanced_matting.feature_fusion import FeatureFusion
19
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
20
+
21
+ from modeling.semantic_enhanced_matting.modeling.mask_decoder_hq_matting import MaskDecoderHQMatting
22
+ from modeling.semantic_enhanced_matting.modeling import TwoWayTransformer
23
+
24
+ from peft import LoraConfig, get_peft_model
25
+ from peft.tuners.lora.layer import LoraLayer
26
+ from peft.tuners.tuners_utils import BaseTunerLayer
27
+
28
+ from data.rand_augment import RandAugment
29
+ import random
30
+ import kornia.filters as kf
31
+
32
+
33
+ class SamHqMatte(nn.Module):
34
+
35
+ target_length = 1024
36
+
37
+ def __init__(
38
+ self,
39
+ *,
40
+ sam_model,
41
+ hq_token_only,
42
+ hq_features_type,
43
+ matting_decoder,
44
+ criterion,
45
+ pixel_mean,
46
+ pixel_std,
47
+ multimask_output=False,
48
+ vis_period=None,
49
+ output_dir=None,
50
+ lora_rank = None,
51
+ lora_alpha = None,
52
+ lora_target_modules = ["qkv", "proj"],
53
+ lora_dropout = 0.1,
54
+ w_dora = False,
55
+ w_rslora = False,
56
+ lora_on_mask_decoder = False,
57
+ frozen_sam_hq_reg = None,
58
+ reg_margin = 0.85,
59
+ w_attention_mask = False,
60
+ alpha_reg_range = None,
61
+ alpha_reg_weight = 1.0,
62
+ coconut_pl = False,
63
+ coconut_pl_alpha = 1.0,
64
+ coconut_self_training = False,
65
+ eval_w_sam_hq_mask = False,
66
+ backbone_condition = False,
67
+ condition_wo_conv = False,
68
+ w_only_bbox_cond = False,
69
+ coconut_only_known_l1 = False,
70
+ backbone_bbox_prompt = None,
71
+ backbone_bbox_prompt_loc = [2, 3],
72
+ backbone_bbox_prompt_loss_weight = 1.0,
73
+ concat_gen_trimap = False,
74
+ multi_matting_decoder = None,
75
+ w_all_logits = False,
76
+ bbox_prompt_all_block = None,
77
+ matting_token = False,
78
+ test_w_hq_token = False,
79
+ sam_hq_token_reg = None,
80
+ feat_cross_attn_fusion = False,
81
+ trimap_loss_type = None,
82
+ reg_on_sam_logits = False,
83
+ reg_w_bce_loss = False,
84
+ complex_trimap_pred_layer = False,
85
+ matting_token_sup = None,
86
+ matting_token_sup_loss_weight = None,
87
+ sam2 = False,
88
+ ):
89
+ super(SamHqMatte, self).__init__()
90
+
91
+ self.sam_model = sam_model
92
+ self.sam_predictor = SamPredictor(self.sam_model) if not sam2 else SAM2ImagePredictor(self.sam_model) # already in eval mode and no_grad
93
+ self.hq_token_only = hq_token_only
94
+ self.multimask_output = multimask_output
95
+ self.hq_features_type = hq_features_type
96
+
97
+ self.matting_decoder = matting_decoder
98
+
99
+ self.criterion = criterion
100
+
101
+ self.register_buffer(
102
+ "pixel_mean", torch.tensor(pixel_mean).view(-1, 1, 1), False
103
+ )
104
+ self.register_buffer("pixel_std", torch.tensor(pixel_std).view(-1, 1, 1), False)
105
+ assert (
106
+ self.pixel_mean.shape == self.pixel_std.shape
107
+ ), f"{self.pixel_mean} and {self.pixel_std} have different shapes!"
108
+
109
+ self.vis_period = vis_period
110
+ if output_dir is not None and output_dir != '?':
111
+ self.output_dir = os.path.join(output_dir, 'vis_results')
112
+ os.makedirs(self.output_dir, exist_ok=True)
113
+ self.train_iter_index = 0
114
+
115
+ self.lora_rank = lora_rank
116
+ self.lora_alpha = lora_alpha
117
+ self.lora_target_modules = lora_target_modules
118
+ self.lora_dropout = lora_dropout
119
+ self.w_dora = w_dora
120
+ self.w_rslora = w_rslora
121
+ self.lora_on_mask_decoder = lora_on_mask_decoder
122
+ self.frozen_sam_hq_reg = frozen_sam_hq_reg
123
+ self.reg_margin = reg_margin
124
+ self.w_attention_mask = w_attention_mask
125
+ self.alpha_reg_range = alpha_reg_range
126
+ self.alpha_reg_weight = alpha_reg_weight
127
+ self.coconut_pl = coconut_pl
128
+ self.coconut_pl_alpha = coconut_pl_alpha
129
+ self.coconut_self_training = coconut_self_training
130
+ self.eval_w_sam_hq_mask = eval_w_sam_hq_mask
131
+ self.backbone_condition = backbone_condition
132
+ self.condition_wo_conv = condition_wo_conv
133
+ self.w_only_bbox_cond = w_only_bbox_cond
134
+ self.coconut_only_known_l1 = coconut_only_known_l1
135
+ self.backbone_bbox_prompt = backbone_bbox_prompt
136
+ self.backbone_bbox_prompt_loc = backbone_bbox_prompt_loc
137
+ self.backbone_bbox_prompt_loss_weight = backbone_bbox_prompt_loss_weight
138
+ self.concat_gen_trimap = concat_gen_trimap
139
+ self.multi_matting_decoder = multi_matting_decoder
140
+ self.w_all_logits = w_all_logits
141
+ self.bbox_prompt_all_block = bbox_prompt_all_block
142
+ self.matting_token = matting_token
143
+ self.test_w_hq_token = test_w_hq_token
144
+ self.sam_hq_token_reg = sam_hq_token_reg
145
+ self.feat_cross_attn_fusion = feat_cross_attn_fusion
146
+ self.trimap_loss_type = trimap_loss_type
147
+ self.reg_on_sam_logits = reg_on_sam_logits
148
+ self.reg_w_bce_loss = reg_w_bce_loss
149
+ self.complex_trimap_pred_layer = complex_trimap_pred_layer
150
+ self.matting_token_sup = matting_token_sup
151
+ self.sam2 = sam2
152
+ assert self.matting_token_sup in {'alpha', 'trimap', None}
153
+ self.matting_token_sup_loss_weight = matting_token_sup_loss_weight
154
+ if self.matting_token_sup is not None:
155
+ assert self.backbone_bbox_prompt in {'bbox', None}
156
+ if self.frozen_sam_hq_reg is not None:
157
+ assert self.lora_rank is not None
158
+ if self.w_attention_mask:
159
+ self.attention_head = deepcopy(self.matting_decoder)
160
+ if self.coconut_self_training:
161
+ self.rand_aug = RandAugment(3,6)
162
+ self.warm_iter_coconut_self_training = 5000
163
+ if self.backbone_condition:
164
+ assert self.lora_rank is not None
165
+ if self.backbone_bbox_prompt is not None:
166
+ assert self.lora_rank is not None
167
+ if self.w_all_logits:
168
+ self.sam_predictor.model.mask_decoder.w_all_logits = True
169
+ if self.bbox_prompt_all_block:
170
+ assert self.lora_rank is not None
171
+ if self.matting_token and not self.sam2:
172
+ self.sam_predictor.model.mask_decoder.hq_token_only = self.hq_token_only
173
+
174
+ @property
175
+ def device(self):
176
+ return self.pixel_mean.device
177
+
178
+ def init_lora(self, model=None):
179
+ if model is not None and self.lora_rank >= 1:
180
+ if self.lora_on_mask_decoder:
181
+ self.lora_target_modules += ["q_proj", "k_proj", "v_proj", "out_proj"]
182
+ modules_to_save = None
183
+ else:
184
+ modules_to_save = ['matting_decoder']
185
+
186
+ lora_config = LoraConfig(
187
+ r=self.lora_rank,
188
+ lora_alpha=self.lora_alpha,
189
+ use_rslora=self.w_rslora,
190
+ use_dora=self.w_dora,
191
+ init_lora_weights="gaussian",
192
+ target_modules=self.lora_target_modules,
193
+ lora_dropout=self.lora_dropout,
194
+ modules_to_save=modules_to_save
195
+ )
196
+ model = get_peft_model(model, lora_config)
197
+ if self.lora_on_mask_decoder:
198
+ for n, p in model.matting_decoder.named_parameters():
199
+ if n.split('modules_to_save.default.')[-1] in model.matting_decoder.trainable_params_str:
200
+ p.requires_grad = True
201
+ else:
202
+ for n, p in model.matting_decoder.named_parameters():
203
+ if n.split('modules_to_save.default.')[-1] in model.matting_decoder.frozen_params_str:
204
+ p.requires_grad = False
205
+ return model
206
+ elif self.lora_rank >= 1:
207
+ lora_config = LoraConfig(
208
+ r=self.lora_rank,
209
+ lora_alpha=self.lora_alpha,
210
+ use_rslora=self.w_rslora,
211
+ use_dora=self.w_dora,
212
+ init_lora_weights="gaussian",
213
+ target_modules=self.lora_target_modules,
214
+ lora_dropout=self.lora_dropout,
215
+ )
216
+ self.sam_predictor.model.image_encoder = get_peft_model(self.sam_predictor.model.image_encoder, lora_config)
217
+
218
+ if self.sam2:
219
+ for n, p in self.sam_predictor.model.image_encoder.named_parameters():
220
+ if 'bbox_mask' in n:
221
+ p.requires_grad = True
222
+
223
+ if self.backbone_condition:
224
+ if self.w_only_bbox_cond:
225
+ self.condition_embedding = ConditionEmbedding(condition_num = 4, pos_embedding_dim = 160)
226
+ else:
227
+ self.condition_embedding = ConditionEmbedding(condition_num = 5, pos_embedding_dim = 128)
228
+
229
+ if self.condition_wo_conv:
230
+ self.condition_conv = nn.ModuleList([ConditionAdd() for _ in range(4)])
231
+ else:
232
+ self.condition_conv = nn.ModuleList([ConditionConv(
233
+ in_channels = self.sam_predictor.model.image_encoder.embed_dim,
234
+ out_channels = self.sam_predictor.model.image_encoder.embed_dim,
235
+ bottleneck_channels = 512
236
+ ) for _ in range(4)])
237
+
238
+ if self.backbone_bbox_prompt is not None and not self.sam2:
239
+ self.condition_layer = nn.ModuleDict()
240
+ self.condition_layer['patch_embed'] = PatchEmbed(
241
+ kernel_size=(self.sam_predictor.model.image_encoder.patch_size, self.sam_predictor.model.image_encoder.patch_size),
242
+ stride=(self.sam_predictor.model.image_encoder.patch_size, self.sam_predictor.model.image_encoder.patch_size),
243
+ in_chans=4,
244
+ embed_dim=self.sam_predictor.model.image_encoder.embed_dim,
245
+ )
246
+ if self.multi_matting_decoder is None:
247
+ if self.backbone_bbox_prompt in {'trimap', 'alpha_trimap'}:
248
+ transformer_dim = self.sam_predictor.model.image_encoder.embed_dim
249
+ for i in self.backbone_bbox_prompt_loc:
250
+ if self.complex_trimap_pred_layer:
251
+ self.condition_layer['{}_pred_layer'.format(i)] = nn.Sequential(
252
+ nn.ConvTranspose2d(transformer_dim, transformer_dim // 2, kernel_size=2, stride=2),
253
+ LayerNorm2d(transformer_dim // 2), # 512
254
+ nn.GELU(),
255
+ nn.Conv2d(transformer_dim // 2, transformer_dim // 4, kernel_size=3, stride=1, padding=1),
256
+ LayerNorm2d(transformer_dim // 4), # 256
257
+ nn.GELU(),
258
+ nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
259
+ LayerNorm2d(transformer_dim // 8), # 128
260
+ nn.GELU(),
261
+ nn.Conv2d(transformer_dim // 8, transformer_dim // 16, kernel_size=3, stride=1, padding=1),
262
+ LayerNorm2d(transformer_dim // 16), # 64
263
+ nn.GELU(),
264
+ nn.Conv2d(transformer_dim // 16, 3, kernel_size=3, stride=1, padding=1),
265
+ )
266
+ else:
267
+ self.condition_layer['{}_pred_layer'.format(i)] = nn.Sequential(
268
+ nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
269
+ LayerNorm2d(transformer_dim // 4),
270
+ nn.GELU(),
271
+ nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
272
+ nn.GELU(),
273
+ nn.Conv2d(transformer_dim // 8, 3, kernel_size=1, stride=1),
274
+ )
275
+ elif self.backbone_bbox_prompt == 'alpha':
276
+ transformer_dim = self.sam_predictor.model.image_encoder.embed_dim
277
+ for i in self.backbone_bbox_prompt_loc:
278
+ if self.complex_trimap_pred_layer:
279
+ self.condition_layer['{}_pred_layer'.format(i)] = nn.Sequential(
280
+ nn.ConvTranspose2d(transformer_dim, transformer_dim // 2, kernel_size=2, stride=2),
281
+ LayerNorm2d(transformer_dim // 2), # 512
282
+ nn.GELU(),
283
+ nn.Conv2d(transformer_dim // 2, transformer_dim // 4, kernel_size=3, stride=1, padding=1),
284
+ LayerNorm2d(transformer_dim // 4), # 256
285
+ nn.GELU(),
286
+ nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
287
+ LayerNorm2d(transformer_dim // 8), # 128
288
+ nn.GELU(),
289
+ nn.Conv2d(transformer_dim // 8, transformer_dim // 16, kernel_size=3, stride=1, padding=1),
290
+ LayerNorm2d(transformer_dim // 16), # 64
291
+ nn.GELU(),
292
+ nn.Conv2d(transformer_dim // 16, 1, kernel_size=3, stride=1, padding=1),
293
+ nn.Sigmoid()
294
+ )
295
+ else:
296
+ self.condition_layer['{}_pred_layer'.format(i)] = nn.Sequential(
297
+ nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
298
+ LayerNorm2d(transformer_dim // 4),
299
+ nn.GELU(),
300
+ nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
301
+ nn.GELU(),
302
+ nn.Conv2d(transformer_dim // 8, 1, kernel_size=1, stride=1),
303
+ nn.Sigmoid()
304
+ )
305
+ if self.bbox_prompt_all_block is not None:
306
+ if self.bbox_prompt_all_block == 'reuse_cross-self-attn':
307
+ self.condition_layer['prompt_layer'] = BBoxInteract(
308
+ position_point_embedding = deepcopy(self.sam_predictor.model.prompt_encoder.pe_layer),
309
+ point_weight = deepcopy(self.sam_predictor.model.prompt_encoder.point_embeddings)
310
+ )
311
+ elif self.bbox_prompt_all_block == 'in-out-bbox_cross-self-attn':
312
+ self.condition_layer['prompt_layer'] = BBoxInteractInOut(downsample_rate = 2)
313
+ else:
314
+ embed_type, interact_type = self.bbox_prompt_all_block.split('_')
315
+ self.condition_layer['prompt_layer'] = BBoxEmbedInteract(embed_type, interact_type)
316
+
317
+ if self.feat_cross_attn_fusion:
318
+ self.condition_layer['feature_fusion'] = FeatureFusion(in_channels=self.sam_predictor.model.image_encoder.embed_dim, attn_compression_ratio=8)
319
+
320
+ def condition_bbox_and_instance_num(self):
321
+ self.sam_predictor.model.image_encoder.conv_necks = None
322
+
323
+ def forward_samhq_and_matting_decoder(self, images, bbox, condition_proj=None, return_hq_token=False):
324
+ # get features from SAM image encoder
325
+ if self.sam2:
326
+ interm_features, sam2_logits, matting_logits, pred_trimap = self.forward_samhq(images, bbox, condition_proj)
327
+ sam2_logits = F.interpolate(sam2_logits, size=images.shape[-2:], mode='bilinear', align_corners=False)
328
+ matting_logits = F.interpolate(matting_logits, size=images.shape[-2:], mode='bilinear', align_corners=False)
329
+ sam_hq_matting_token = {
330
+ 'masks_hq': sam2_logits,
331
+ 'masks_matting': matting_logits
332
+ }
333
+ hq_features = matting_logits
334
+ low_res_masks = matting_logits
335
+ else:
336
+ if self.matting_token:
337
+ features, image_pe, sparse_embeddings, dense_embeddings, interm_features, sam_hq_matting_token, pred_trimap = self.forward_samhq(images, bbox, condition_proj)
338
+ if return_hq_token:
339
+ return sam_hq_matting_token['masks_hq']
340
+ else:
341
+ if not self.training and self.test_w_hq_token:
342
+ low_res_masks, hq_features = sam_hq_matting_token['masks_hq'], sam_hq_matting_token['masks_hq']
343
+ else:
344
+ low_res_masks, hq_features = sam_hq_matting_token['masks_matting'], sam_hq_matting_token['masks_matting']
345
+ else:
346
+ features, image_pe, sparse_embeddings, dense_embeddings, interm_features, hq_features, sam_logits, low_res_masks, pred_trimap = self.forward_samhq(images, bbox, condition_proj)
347
+ if return_hq_token:
348
+ return hq_features
349
+ sam_hq_matting_token = {'masks_hq': hq_features, 'masks_sam': sam_logits}
350
+
351
+ # get alpha from our proposed matting_decoder
352
+ if isinstance(self.matting_decoder, MattingDetailDecoder):
353
+ pred_alpha = self.matting_decoder(
354
+ images = images,
355
+ hq_features = hq_features,
356
+ vit_intern_feat = interm_features,
357
+ return_alpha_logits = (self.alpha_reg_range is not None),
358
+ pred_trimap = pred_trimap
359
+ )
360
+ else:
361
+ pred_alpha = self.matting_decoder(
362
+ image_embeddings = features, # [B, 256, 64, 64]
363
+ image_pe = image_pe,
364
+ sparse_prompt_embeddings = sparse_embeddings,
365
+ dense_prompt_embeddings = dense_embeddings,
366
+ multimask_output = False,
367
+ interm_embeddings = interm_features, # [B, 256, 64, 64]
368
+ hq_features = hq_features,
369
+ images = images,
370
+ return_alpha_logits = (self.alpha_reg_range is not None),
371
+ pred_trimap = pred_trimap
372
+ )
373
+ return low_res_masks, pred_alpha, pred_trimap, sam_hq_matting_token
374
+
375
+ def forward(self, batched_inputs): # image: [1, 3, 643, 960]: 0.0~1.0, trimap: [1, 1, 643, 960]: 0.0~1.0
376
+
377
+ inputs = self.preprocess_inputs(batched_inputs)
378
+ images, bbox, gt_alpha, trimap, condition = inputs['images'], inputs['bbox'], inputs['alpha'], inputs['trimap'], inputs['condition']
379
+
380
+ if self.backbone_condition:
381
+ condition_proj = self.condition_embedding(condition)
382
+ elif self.backbone_bbox_prompt is not None or self.bbox_prompt_all_block is not None:
383
+ condition_proj = bbox
384
+ else:
385
+ condition_proj = None
386
+
387
+ low_res_masks, pred_alpha, pred_trimap, sam_hq_matting_token = self.forward_samhq_and_matting_decoder(images, bbox, condition_proj)
388
+
389
+ assert not self.training
390
+ if self.eval_w_sam_hq_mask:
391
+ self.sam_predictor.model.image_encoder.disable_adapter_layers()
392
+ with torch.no_grad():
393
+ ori_features, ori_interm_features = self.sam_predictor.model.image_encoder(images)
394
+ samhq_low_res_masks = self.forward_samhq_others(images, bbox, ori_features, ori_interm_features)[-1]
395
+ samhq_low_res_masks = F.interpolate(samhq_low_res_masks, size=(images.shape[-2], images.shape[-1]), mode='bilinear', align_corners=False)
396
+ self.sam_predictor.model.image_encoder.enable_adapter_layers()
397
+
398
+ return pred_alpha, samhq_low_res_masks
399
+ else:
400
+ return pred_alpha
401
+
402
+ def forward_samhq_image_encoder(self, images, condition_proj=None):
403
+ if self.sam2:
404
+ backbone_out = self.sam_predictor.model.forward_image([images, condition_proj])
405
+ _, vision_feats, _, _ = self.sam_predictor.model._prepare_backbone_features(backbone_out)
406
+ # Add no_mem_embed, which is added to the lowest rest feat. map during training on videos
407
+ if self.sam_predictor.model.directly_add_no_mem_embed:
408
+ vision_feats[-1] = vision_feats[-1] + self.sam_predictor.model.no_mem_embed
409
+ feats = [
410
+ feat.permute(1, 2, 0).view(feat.shape[1], -1, *feat_size)
411
+ for feat, feat_size in zip(vision_feats[::-1], self.sam_predictor._bb_feat_sizes[::-1])
412
+ ][::-1]
413
+ return {"image_embed": feats[-1], "high_res_feats": feats[:-1]}, None, None
414
+ else:
415
+ if self.backbone_condition:
416
+ condition_layer = self.condition_conv
417
+ elif self.backbone_bbox_prompt:
418
+ condition_layer = self.condition_layer
419
+ else:
420
+ condition_layer = None
421
+ # [B, 3, 1024, 1024]: -2. ~ 2. --> [B, 256, 64, 64], 4 x [B, 64, 64, 768]
422
+ features, interm_features, pred_trimap = self.sam_predictor.model.image_encoder(images, condition_proj, condition_layer)
423
+ return features, interm_features, pred_trimap
424
+
425
+ # @torch.no_grad()
426
+ def forward_samhq_others(self, images, bbox, features, interm_features):
427
+ if self.sam2:
428
+ sam2_logits, matting_logits = self.sam_predictor.predict_batch_boxes_and_features(bbox, features)
429
+ return features, sam2_logits, matting_logits
430
+
431
+ image_pe = self.sam_predictor.model.prompt_encoder.get_dense_pe()
432
+
433
+ cat_sparse_embeddings = []
434
+ cat_dense_prompt_embeddings = []
435
+ cat_hq_features = []
436
+ cat_sam_logits = []
437
+ cat_low_res_masks = []
438
+ cat_sam_hq_matting_token = defaultdict(list)
439
+
440
+ for idx in range(images.shape[0]):
441
+ # get hq_features from SAM_HQ mask decoder
442
+
443
+ # Embed prompts
444
+ sparse_embeddings, dense_embeddings = self.sam_predictor.model.prompt_encoder(
445
+ points=None,
446
+ # boxes=bbox[idx: idx + 1],
447
+ boxes=bbox[idx], # [N, 4]
448
+ masks=None,
449
+ ) # [B, 2, 256], [B, 256, 64, 64]
450
+
451
+ # Predict masks
452
+ if isinstance(self.sam_predictor.model.mask_decoder, MaskDecoderHQMatting):
453
+ sam_hq_matting_token = self.sam_predictor.model.mask_decoder(
454
+ image_embeddings = features[idx: idx + 1],
455
+ image_pe = image_pe,
456
+ sparse_prompt_embeddings = sparse_embeddings,
457
+ dense_prompt_embeddings = dense_embeddings,
458
+ multimask_output = self.multimask_output,
459
+ interm_embeddings = [interm_feature[idx: idx + 1] for interm_feature in interm_features],
460
+ )
461
+ for key in sam_hq_matting_token.keys():
462
+ cat_sam_hq_matting_token[key].append(sam_hq_matting_token[key])
463
+ else:
464
+ low_res_masks, masks_sam, hq_features = self.sam_predictor.model.mask_decoder(
465
+ image_embeddings = features[idx: idx + 1],
466
+ image_pe = image_pe,
467
+ sparse_prompt_embeddings = sparse_embeddings,
468
+ dense_prompt_embeddings = dense_embeddings,
469
+ multimask_output = self.multimask_output,
470
+ hq_token_only = self.hq_token_only,
471
+ interm_embeddings = [interm_feature[idx: idx + 1] for interm_feature in interm_features],
472
+ return_hq_features_type = self.hq_features_type
473
+ )
474
+ cat_hq_features.append(hq_features)
475
+ cat_sam_logits.append(masks_sam)
476
+ cat_low_res_masks.append(low_res_masks)
477
+
478
+ cat_sparse_embeddings.append(sparse_embeddings)
479
+ cat_dense_prompt_embeddings.append(dense_embeddings)
480
+
481
+ sparse_embeddings = torch.stack(cat_sparse_embeddings, dim=0) # [B, 1, 2, 256]
482
+ dense_embeddings = torch.stack(cat_dense_prompt_embeddings, dim=0) # [B, 1, 256, 64, 64]
483
+
484
+ if self.matting_token:
485
+ for key in cat_sam_hq_matting_token.keys():
486
+ cat_sam_hq_matting_token[key] = torch.cat(cat_sam_hq_matting_token[key], dim=0)
487
+ cat_sam_hq_matting_token[key] = F.interpolate(cat_sam_hq_matting_token[key], size=images.shape[-2:], mode='bilinear', align_corners=False)
488
+ sam_hq_matting_token = cat_sam_hq_matting_token
489
+ return features, image_pe, sparse_embeddings, dense_embeddings, interm_features, sam_hq_matting_token
490
+ else:
491
+ hq_features = torch.cat(cat_hq_features, dim=0) # [B, 1, 256, 256]
492
+ low_res_masks = torch.cat(cat_low_res_masks, dim=0) # [B, 1, 256, 256]
493
+ hq_features = F.interpolate(hq_features, size=images.shape[-2:], mode='bilinear', align_corners=False) # [B, 1, 256, 256] --> [B, 1, 1024, 1024]
494
+ sam_logits = torch.cat(cat_sam_logits, dim=0)
495
+ sam_logits = F.interpolate(sam_logits, size=images.shape[-2:], mode='bilinear', align_corners=False) # [B, 1, 256, 256] --> [B, 1, 1024, 1024]
496
+ return features, image_pe, sparse_embeddings, dense_embeddings, interm_features, hq_features, sam_logits, low_res_masks
497
+
498
+ def forward_samhq(self, images, bbox, condition_proj=None):
499
+ if self.lora_rank is None:
500
+ with torch.no_grad():
501
+ features, interm_features, pred_trimap = self.forward_samhq_image_encoder(images, condition_proj)
502
+ else:
503
+ features, interm_features, pred_trimap = self.forward_samhq_image_encoder(images, condition_proj)
504
+
505
+ return self.forward_samhq_others(images, bbox, features, interm_features) + (pred_trimap, )
506
+
507
+ def get_frozen_sam_logits(self, images, bbox, mask_type='hq'):
508
+
509
+ if self.sam2:
510
+ features, _, _ = self.forward_samhq_image_encoder(images)
511
+ sam2_logits = self.sam_predictor.predict_batch_boxes_and_features(bbox, features, wo_matting_token=True)
512
+ sam2_logits = F.interpolate(sam2_logits, size=images.shape[-2:], mode='bilinear', align_corners=False)
513
+ return sam2_logits
514
+
515
+ assert mask_type in {'hq', 'sam'}
516
+ features, interm_features, _ = self.forward_samhq_image_encoder(images)
517
+ image_pe = self.sam_predictor.model.prompt_encoder.get_dense_pe()
518
+
519
+ cat_logits = []
520
+ for idx in range(images.shape[0]):
521
+ sparse_embeddings, dense_embeddings = self.sam_predictor.model.prompt_encoder(points=None, boxes=bbox[idx], masks=None)
522
+
523
+ low_res_masks, masks_sam, hq_features = self.sam_predictor.model.frozen_mask_decoder(
524
+ image_embeddings = features[idx: idx + 1],
525
+ image_pe = image_pe,
526
+ sparse_prompt_embeddings = sparse_embeddings,
527
+ dense_prompt_embeddings = dense_embeddings,
528
+ multimask_output = self.multimask_output,
529
+ hq_token_only = self.hq_token_only,
530
+ interm_embeddings = [interm_feature[idx: idx + 1] for interm_feature in interm_features],
531
+ return_hq_features_type = self.hq_features_type
532
+ )
533
+ if mask_type == 'hq':
534
+ cat_logits.append(hq_features)
535
+ else:
536
+ cat_logits.append(masks_sam)
537
+
538
+ logits = torch.cat(cat_logits, dim=0) # [B, 1, 256, 256]
539
+ logits = F.interpolate(logits, size=images.shape[-2:], mode='bilinear', align_corners=False) # [B, 1, 256, 256] --> [B, 1, 1024, 1024]
540
+ return logits
541
+
542
+ def vis_training_results(self, **kwargs):
543
+ # images, bbox, trimap, low_res_masks, pred_alpha, alpha
544
+ self.train_iter_index += 1
545
+ if self.train_iter_index % self.vis_period == 0:
546
+ batch_save_results = []
547
+ save_path = os.path.join(self.output_dir, '{:06d}_rank{}.jpg'.format(self.train_iter_index, get_local_rank()))
548
+
549
+ # [('images', (4, 3, 1024, 1024), -2.117904, 2.64), ('bbox', (4, 1, 4), 0.0, 1023.0), ('trimap', (4, 1, 1024, 1024), 0.0, 1.0), ('low_res_masks', (4, 1, 256, 256), -20.38, 10.15), ('pred_alpha', (4, 1, 1024, 1024), 0.1547, 0.791), ('alpha', (4, 1, 1024, 1024), 0.0, 1.0)]
550
+ for key in kwargs.keys():
551
+ if key == 'bbox':
552
+ continue
553
+ # turn all tensor to [B, H, W, 3]: 0~255 np.int8
554
+ if key == 'images':
555
+ kwargs[key] = kwargs[key] * self.pixel_std + self.pixel_mean
556
+ kwargs[key] = kwargs[key].permute(0, 2, 3, 1) * 255.0
557
+ for i in range(kwargs['images'].shape[0]):
558
+ l, u, r, d = int(kwargs['bbox'][i, 0, 0].item()), int(kwargs['bbox'][i, 0, 1].item()), int(kwargs['bbox'][i, 0, 2].item()), int(kwargs['bbox'][i, 0, 3].item())
559
+ red_line = torch.tensor([[255., 0., 0.]], device=kwargs[key].device, dtype=kwargs[key].dtype)
560
+ kwargs[key][i, u: d, l, :] = red_line
561
+ kwargs[key][i, u: d, r, :] = red_line
562
+ kwargs[key][i, u, l: r, :] = red_line
563
+ kwargs[key][i, d, l: r, :] = red_line
564
+ elif key in {'low_res_masks', 'frozen_hq_token'}:
565
+ if torch.max(kwargs[key]) <= 1: # coconut ori alpha
566
+ kwargs[key] = kwargs[key].permute(0, 2, 3, 1).repeat(1, 1, 1, 3) * 255.0
567
+ else:
568
+ kwargs[key] = F.interpolate(kwargs[key], size=(kwargs['images'].shape[-3], kwargs['images'].shape[-2]), mode='bilinear', align_corners=False)
569
+ kwargs[key] = (kwargs[key] > self.sam_predictor.model.mask_threshold).float().permute(0, 2, 3, 1).repeat(1, 1, 1, 3) * 255.0
570
+ else:
571
+ kwargs[key] = kwargs[key].permute(0, 2, 3, 1).repeat(1, 1, 1, 3) * 255.0
572
+
573
+ kwargs[key] = np.uint8(kwargs[key].detach().cpu().numpy())
574
+
575
+ for i in range(kwargs['images'].shape[0]):
576
+ save_results = []
577
+ for key in kwargs.keys():
578
+ if key != 'bbox':
579
+ save_results.append(kwargs[key][i])
580
+ batch_save_results.append(np.concatenate(save_results, axis=1))
581
+
582
+ Image.fromarray(np.concatenate(batch_save_results, axis=0)).save(save_path)
583
+
584
+ def preprocess_inputs(self, batched_inputs):
585
+ """
586
+ Normalize, pad and batch the input images.
587
+ """
588
+ output = dict()
589
+
590
+ if "alpha" in batched_inputs:
591
+ alpha = batched_inputs["alpha"].to(self.device)
592
+ else:
593
+ alpha = None
594
+
595
+ bbox = batched_inputs["bbox"].to(self.device)
596
+
597
+ if self.training and self.coconut_self_training and sum([i == 'COCONut' for i in batched_inputs['dataset_name']]) >= 1:
598
+ output['coconut_ori_img'] = []
599
+ output['coconut_trimap'] = []
600
+ output['coconut_bbox'] = []
601
+ output['coconut_idx'] = []
602
+ for i, dataset_name in enumerate(batched_inputs['dataset_name']):
603
+ if dataset_name == 'COCONut':
604
+ # generate coconut_aug_img
605
+ img_np = np.uint8(batched_inputs["image"][i].permute(1, 2, 0).cpu().numpy() * 255.)
606
+ strong_aug_img = self.rand_aug(Image.fromarray(img_np), cutout = False)
607
+ strong_aug_img_tensor = torch.from_numpy(np.array(strong_aug_img)).to(self.device).permute(2, 0, 1)[None] / 255.
608
+ blur_kernel_sigma = 1.0 + random.random() # random from 1.0 ~ 2.0
609
+ blur_filter = kf.GaussianBlur2d((101, 101), (blur_kernel_sigma, blur_kernel_sigma))
610
+ blur_strong_aug_img_tensor = blur_filter(strong_aug_img_tensor)[0]
611
+
612
+ output['coconut_ori_img'].append(batched_inputs["image"][i])
613
+ batched_inputs["image"][i] = blur_strong_aug_img_tensor
614
+
615
+ # generate coconut_trimap
616
+ coconut_mask = (alpha[i] != 0).float()
617
+ mask_area = torch.sum(coconut_mask)
618
+ kernel_size = max(self.matting_decoder.min_kernel_size, int((mask_area ** 0.5) / 7)) # self.matting_decoder.kernel_div
619
+ kernel_size = min(kernel_size, self.matting_decoder.gen_trimap.max_kernal - 1)
620
+ output['coconut_trimap'].append(self.matting_decoder.gen_trimap(coconut_mask[0], kernel_size=kernel_size)[None])
621
+
622
+ output['coconut_bbox'].append(bbox[i])
623
+ output['coconut_idx'].append(i)
624
+
625
+ output['coconut_ori_img'] = torch.stack(output['coconut_ori_img']).to(self.device)
626
+ output['coconut_ori_img'] = (output['coconut_ori_img'] - self.pixel_mean) / self.pixel_std
627
+ output['coconut_trimap'] = torch.stack(output['coconut_trimap']).to(self.device)
628
+ output['coconut_bbox'] = torch.stack(output['coconut_bbox']).to(self.device)
629
+
630
+ images = batched_inputs["image"].to(self.device)
631
+ images = (images - self.pixel_mean) / self.pixel_std
632
+ assert images.shape[-2] == images.shape[-1] == 1024
633
+
634
+ if 'trimap' in batched_inputs.keys():
635
+ trimap = batched_inputs["trimap"].to(self.device)
636
+ assert len(torch.unique(trimap)) <= 3
637
+ else:
638
+ trimap = None
639
+
640
+ output['images'] = images
641
+ output['bbox'] = bbox
642
+ output['alpha'] = alpha
643
+ output['trimap'] = trimap
644
+
645
+ if 'hr_images' in batched_inputs.keys():
646
+ hr_images = batched_inputs["hr_images"].to(self.device)
647
+ hr_images = (hr_images - self.pixel_mean) / self.pixel_std
648
+ _, _, H, W = hr_images.shape
649
+ if hr_images.shape[-1] % 16 != 0 or hr_images.shape[-2] % 16 != 0:
650
+ new_H = (16 - hr_images.shape[-2] % 16) + H if hr_images.shape[-2] % 16 != 0 else H
651
+ new_W = (16 - hr_images.shape[-1] % 16) + W if hr_images.shape[-1] % 16 != 0 else W
652
+ new_hr_images = torch.zeros((hr_images.shape[0], hr_images.shape[1], new_H, new_W)).to(self.device)
653
+ new_hr_images[:,:,:H,:W] = hr_images[:,:,:,:]
654
+ del hr_images
655
+ hr_images = new_hr_images
656
+ output['hr_images'] = hr_images
657
+ output['hr_images_ori_h_w'] = (H, W)
658
+
659
+ if 'dataset_name' in batched_inputs.keys():
660
+ output['dataset_name'] = batched_inputs["dataset_name"]
661
+
662
+ if self.backbone_condition:
663
+ if self.w_only_bbox_cond:
664
+ output['condition'] = output['bbox'][:, 0, :]
665
+ else:
666
+ multi_fg_float = batched_inputs["multi_fg"].to(bbox.device).float()[:, None] * 512
667
+ output['condition'] = torch.concat((output['bbox'][:, 0, :], multi_fg_float), dim=-1)
668
+ else:
669
+ output['condition'] = None
670
+
671
+ return output
modeling/semantic_enhanced_matting/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from .build_sam import (
8
+ build_sam,
9
+ build_sam_vit_h,
10
+ build_sam_vit_l,
11
+ build_sam_vit_b,
12
+ sam_model_registry,
13
+ )
14
+ from .build_sam_baseline import sam_model_registry_baseline
15
+ from .predictor import SamPredictor
16
+ from .automatic_mask_generator import SamAutomaticMaskGenerator
17
+ from .mask_decoder_matting import MaskDecoderMatting
modeling/semantic_enhanced_matting/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (567 Bytes). View file
 
modeling/semantic_enhanced_matting/__pycache__/automatic_mask_generator.cpython-38.pyc ADDED
Binary file (11.5 kB). View file