ormbg / dis-repo.patch
schirrmacher's picture
Upload folder using huggingface_hub
150d962 verified
raw
history blame
6.07 kB
diff --git a/IS-Net/Inference.py b/IS-Net/Inference.py
index 0b2907d..ca8484b 100644
--- a/IS-Net/Inference.py
+++ b/IS-Net/Inference.py
@@ -40,7 +40,7 @@ if __name__ == "__main__":
im_tensor = torch.tensor(im, dtype=torch.float32).permute(2,0,1)
im_tensor = F.upsample(torch.unsqueeze(im_tensor,0), input_size, mode="bilinear").type(torch.uint8)
image = torch.divide(im_tensor,255.0)
- image = normalize(image,[0.5,0.5,0.5],[1.0,1.0,1.0])
+ #image = normalize(image,[0.5,0.5,0.5],[1.0,1.0,1.0])
if torch.cuda.is_available():
image=image.cuda()
diff --git a/IS-Net/train_valid_inference_main.py b/IS-Net/train_valid_inference_main.py
index 375bb26..ad9043c 100644
--- a/IS-Net/train_valid_inference_main.py
+++ b/IS-Net/train_valid_inference_main.py
@@ -536,10 +536,10 @@ def main(train_datasets,
cache_size = hypar["cache_size"],
cache_boost = hypar["cache_boost_train"],
my_transforms = [
- GOSRandomHFlip(), ## this line can be uncommented for horizontal flip augmetation
+ #GOSRandomHFlip(), ## this line can be uncommented for horizontal flip augmetation
# GOSResize(hypar["input_size"]),
# GOSRandomCrop(hypar["crop_size"]), ## this line can be uncommented for randomcrop augmentation
- GOSNormalize([0.5,0.5,0.5],[1.0,1.0,1.0]),
+ #GOSNormalize([0.5,0.5,0.5],[1.0,1.0,1.0]),
],
batch_size = hypar["batch_size_train"],
shuffle = True)
@@ -547,7 +547,7 @@ def main(train_datasets,
cache_size = hypar["cache_size"],
cache_boost = hypar["cache_boost_train"],
my_transforms = [
- GOSNormalize([0.5,0.5,0.5],[1.0,1.0,1.0]),
+ #GOSNormalize([0.5,0.5,0.5],[1.0,1.0,1.0]),
],
batch_size = hypar["batch_size_valid"],
shuffle = False)
@@ -561,7 +561,7 @@ def main(train_datasets,
cache_size = hypar["cache_size"],
cache_boost = hypar["cache_boost_valid"],
my_transforms = [
- GOSNormalize([0.5,0.5,0.5],[1.0,1.0,1.0]),
+ #GOSNormalize([0.5,0.5,0.5],[1.0,1.0,1.0]),
# GOSResize(hypar["input_size"])
],
batch_size=hypar["batch_size_valid"],
@@ -618,19 +618,19 @@ if __name__ == "__main__":
train_datasets, valid_datasets = [], []
dataset_1, dataset_1 = {}, {}
- dataset_tr = {"name": "DIS5K-TR",
- "im_dir": "../DIS5K/DIS-TR/im",
- "gt_dir": "../DIS5K/DIS-TR/gt",
- "im_ext": ".jpg",
+ dataset_tr = {"name": "training",
+ "im_dir": "../training/im",
+ "gt_dir": "../training/gt",
+ "im_ext": ".png",
"gt_ext": ".png",
- "cache_dir":"../DIS5K-Cache/DIS-TR"}
+ "cache_dir":"../cache/training"}
- dataset_vd = {"name": "DIS5K-VD",
- "im_dir": "../DIS5K/DIS-VD/im",
- "gt_dir": "../DIS5K/DIS-VD/gt",
- "im_ext": ".jpg",
+ dataset_vd = {"name": "validation",
+ "im_dir": "../validation/im",
+ "gt_dir": "../validation/gt",
+ "im_ext": ".png",
"gt_ext": ".png",
- "cache_dir":"../DIS5K-Cache/DIS-VD"}
+ "cache_dir":"../cache/validation"}
dataset_te1 = {"name": "DIS5K-TE1",
"im_dir": "../DIS5K/DIS-TE1/im",
@@ -685,7 +685,7 @@ if __name__ == "__main__":
if hypar["mode"] == "train":
hypar["valid_out_dir"] = "" ## for "train" model leave it as "", for "valid"("inference") mode: set it according to your local directory
hypar["model_path"] ="../saved_models/IS-Net-test" ## model weights saving (or restoring) path
- hypar["restore_model"] = "" ## name of the segmentation model weights .pth for resume training process from last stop or for the inferencing
+ hypar["restore_model"] = "isnet-base-model.pth" ## name of the segmentation model weights .pth for resume training process from last stop or for the inferencing
hypar["start_ite"] = 0 ## start iteration for the training, can be changed to match the restored training process
hypar["gt_encoder_model"] = ""
else: ## configure the segmentation output path and the to-be-used model weights path