File size: 6,066 Bytes
150d962
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
diff --git a/IS-Net/Inference.py b/IS-Net/Inference.py
index 0b2907d..ca8484b 100644
--- a/IS-Net/Inference.py
+++ b/IS-Net/Inference.py
@@ -40,7 +40,7 @@ if __name__ == "__main__":
             im_tensor = torch.tensor(im, dtype=torch.float32).permute(2,0,1)
             im_tensor = F.upsample(torch.unsqueeze(im_tensor,0), input_size, mode="bilinear").type(torch.uint8)
             image = torch.divide(im_tensor,255.0)
-            image = normalize(image,[0.5,0.5,0.5],[1.0,1.0,1.0])
+            #image = normalize(image,[0.5,0.5,0.5],[1.0,1.0,1.0])
 
             if torch.cuda.is_available():
                 image=image.cuda()
diff --git a/IS-Net/train_valid_inference_main.py b/IS-Net/train_valid_inference_main.py
index 375bb26..ad9043c 100644
--- a/IS-Net/train_valid_inference_main.py
+++ b/IS-Net/train_valid_inference_main.py
@@ -536,10 +536,10 @@ def main(train_datasets,
                                                              cache_size = hypar["cache_size"],
                                                              cache_boost = hypar["cache_boost_train"],
                                                              my_transforms = [
-                                                                             GOSRandomHFlip(), ## this line can be uncommented for horizontal flip augmetation
+                                                                             #GOSRandomHFlip(), ## this line can be uncommented for horizontal flip augmetation
                                                                              # GOSResize(hypar["input_size"]),
                                                                              # GOSRandomCrop(hypar["crop_size"]), ## this line can be uncommented for randomcrop augmentation
-                                                                              GOSNormalize([0.5,0.5,0.5],[1.0,1.0,1.0]),
+                                                                              #GOSNormalize([0.5,0.5,0.5],[1.0,1.0,1.0]),
                                                                               ],
                                                              batch_size = hypar["batch_size_train"],
                                                              shuffle = True)
@@ -547,7 +547,7 @@ def main(train_datasets,
                                                              cache_size = hypar["cache_size"],
                                                              cache_boost = hypar["cache_boost_train"],
                                                              my_transforms = [
-                                                                              GOSNormalize([0.5,0.5,0.5],[1.0,1.0,1.0]),
+                                                                              #GOSNormalize([0.5,0.5,0.5],[1.0,1.0,1.0]),
                                                                               ],
                                                              batch_size = hypar["batch_size_valid"],
                                                              shuffle = False)
@@ -561,7 +561,7 @@ def main(train_datasets,
                                                           cache_size = hypar["cache_size"],
                                                           cache_boost = hypar["cache_boost_valid"],
                                                           my_transforms = [
-                                                                           GOSNormalize([0.5,0.5,0.5],[1.0,1.0,1.0]),
+                                                                           #GOSNormalize([0.5,0.5,0.5],[1.0,1.0,1.0]),
                                                                            # GOSResize(hypar["input_size"])
                                                                            ],
                                                           batch_size=hypar["batch_size_valid"],
@@ -618,19 +618,19 @@ if __name__ == "__main__":
     train_datasets, valid_datasets = [], []
     dataset_1, dataset_1 = {}, {}
 
-    dataset_tr = {"name": "DIS5K-TR",
-                 "im_dir": "../DIS5K/DIS-TR/im",
-                 "gt_dir": "../DIS5K/DIS-TR/gt",
-                 "im_ext": ".jpg",
+    dataset_tr = {"name": "training",
+                 "im_dir": "../training/im",
+                 "gt_dir": "../training/gt",
+                 "im_ext": ".png",
                  "gt_ext": ".png",
-                 "cache_dir":"../DIS5K-Cache/DIS-TR"}
+                 "cache_dir":"../cache/training"}
 
-    dataset_vd = {"name": "DIS5K-VD",
-                 "im_dir": "../DIS5K/DIS-VD/im",
-                 "gt_dir": "../DIS5K/DIS-VD/gt",
-                 "im_ext": ".jpg",
+    dataset_vd = {"name": "validation",
+                 "im_dir": "../validation/im",
+                 "gt_dir": "../validation/gt",
+                 "im_ext": ".png",
                  "gt_ext": ".png",
-                 "cache_dir":"../DIS5K-Cache/DIS-VD"}
+                 "cache_dir":"../cache/validation"}
 
     dataset_te1 = {"name": "DIS5K-TE1",
                  "im_dir": "../DIS5K/DIS-TE1/im",
@@ -685,7 +685,7 @@ if __name__ == "__main__":
     if hypar["mode"] == "train":
         hypar["valid_out_dir"] = "" ## for "train" model leave it as "", for "valid"("inference") mode: set it according to your local directory
         hypar["model_path"] ="../saved_models/IS-Net-test" ## model weights saving (or restoring) path
-        hypar["restore_model"] = "" ## name of the segmentation model weights .pth for resume training process from last stop or for the inferencing
+        hypar["restore_model"] = "isnet-base-model.pth" ## name of the segmentation model weights .pth for resume training process from last stop or for the inferencing
         hypar["start_ite"] = 0 ## start iteration for the training, can be changed to match the restored training process
         hypar["gt_encoder_model"] = ""
     else: ## configure the segmentation output path and the to-be-used model weights path