hyliu commited on
Commit
d6ec83b
1 Parent(s): 8cdfa55

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. SR/README.md +78 -0
  2. SR/code/.vscode/settings.json +3 -0
  3. SR/code/__init__.py +0 -0
  4. SR/code/__pycache__/option.cpython-35.pyc +0 -0
  5. SR/code/__pycache__/option.cpython-36.pyc +0 -0
  6. SR/code/__pycache__/option.cpython-37.pyc +0 -0
  7. SR/code/__pycache__/template.cpython-35.pyc +0 -0
  8. SR/code/__pycache__/template.cpython-36.pyc +0 -0
  9. SR/code/__pycache__/template.cpython-37.pyc +0 -0
  10. SR/code/__pycache__/trainer.cpython-35.pyc +0 -0
  11. SR/code/__pycache__/trainer.cpython-36.pyc +0 -0
  12. SR/code/__pycache__/trainer.cpython-37.pyc +0 -0
  13. SR/code/__pycache__/utility.cpython-35.pyc +0 -0
  14. SR/code/__pycache__/utility.cpython-36.pyc +0 -0
  15. SR/code/__pycache__/utility.cpython-37.pyc +0 -0
  16. SR/code/bootstrap.sh +5 -0
  17. SR/code/data/__init__.py +52 -0
  18. SR/code/data/__pycache__/__init__.cpython-35.pyc +0 -0
  19. SR/code/data/__pycache__/__init__.cpython-36.pyc +0 -0
  20. SR/code/data/__pycache__/__init__.cpython-37.pyc +0 -0
  21. SR/code/data/__pycache__/benchmark.cpython-35.pyc +0 -0
  22. SR/code/data/__pycache__/benchmark.cpython-36.pyc +0 -0
  23. SR/code/data/__pycache__/common.cpython-35.pyc +0 -0
  24. SR/code/data/__pycache__/common.cpython-36.pyc +0 -0
  25. SR/code/data/__pycache__/common.cpython-37.pyc +0 -0
  26. SR/code/data/__pycache__/demo.cpython-36.pyc +0 -0
  27. SR/code/data/__pycache__/div2k.cpython-35.pyc +0 -0
  28. SR/code/data/__pycache__/div2k.cpython-36.pyc +0 -0
  29. SR/code/data/__pycache__/div2k.cpython-37.pyc +0 -0
  30. SR/code/data/__pycache__/srdata.cpython-35.pyc +0 -0
  31. SR/code/data/__pycache__/srdata.cpython-36.pyc +0 -0
  32. SR/code/data/__pycache__/srdata.cpython-37.pyc +0 -0
  33. SR/code/data/benchmark.py +25 -0
  34. SR/code/data/common.py +72 -0
  35. SR/code/data/demo.py +39 -0
  36. SR/code/data/div2k.py +32 -0
  37. SR/code/data/div2kjpeg.py +20 -0
  38. SR/code/data/sr291.py +6 -0
  39. SR/code/data/srdata.py +157 -0
  40. SR/code/data/video.py +44 -0
  41. SR/code/dataloader.py +158 -0
  42. SR/code/demo.sb +5 -0
  43. SR/code/lambda_networks/__init__.py +4 -0
  44. SR/code/lambda_networks/__pycache__/__init__.cpython-37.pyc +0 -0
  45. SR/code/lambda_networks/__pycache__/lambda_networks.cpython-37.pyc +0 -0
  46. SR/code/lambda_networks/__pycache__/rlambda_networks.cpython-37.pyc +0 -0
  47. SR/code/lambda_networks/lambda_networks.py +137 -0
  48. SR/code/lambda_networks/rlambda_networks.py +93 -0
  49. SR/code/loss/__init__.py +173 -0
  50. SR/code/loss/__loss__.py +0 -0
SR/README.md ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Pyramid Attention for Image Restoration
2
+ This repository is for PANet and PA-EDSR introduced in the following paper
3
+
4
+ [Yiqun Mei](http://yiqunm2.web.illinois.edu/), [Yuchen Fan](https://scholar.google.com/citations?user=BlfdYL0AAAAJ&hl=en), [Yulun Zhang](http://yulunzhang.com/), [Jiahui Yu](https://jiahuiyu.com/), [Yuqian Zhou](https://yzhouas.github.io/), [Ding Liu](https://scholar.google.com/citations?user=PGtHUI0AAAAJ&hl=en), [Yun Fu](http://www1.ece.neu.edu/~yunfu/), [Thomas S. Huang](http://ifp-uiuc.github.io/) and [Honghui Shi](https://www.humphreyshi.com/) "Pyramid Attention for Image Restoration", [[Arxiv]](https://arxiv.org/abs/2004.13824)
5
+
6
+ The code is built on [EDSR (PyTorch)](https://github.com/thstkdgus35/EDSR-PyTorch) & [RNAN](https://github.com/yulunzhang/RNAN) and tested on Ubuntu 18.04 environment (Python3.6, PyTorch_1.1) with Titan X/1080Ti/V100 GPUs.
7
+
8
+ ## Contents
9
+ 1. [Train](#train)
10
+ 2. [Test](#test)
11
+ 3. [Results](#results)
12
+ 4. [Citation](#citation)
13
+ 5. [Acknowledgements](#acknowledgements)
14
+
15
+ ## Train
16
+ ### Prepare training data
17
+
18
+ 1. Download DIV2K training data (800 training + 100 validtion images) from [DIV2K dataset](https://data.vision.ee.ethz.ch/cvl/DIV2K/) or [SNU_CVLab](https://cv.snu.ac.kr/research/EDSR/DIV2K.tar).
19
+
20
+ 2. Specify '--dir_data' based on the HR and LR images path.
21
+
22
+ For more informaiton, please refer to [EDSR(PyTorch)](https://github.com/thstkdgus35/EDSR-PyTorch).
23
+
24
+ ### Begin to train
25
+
26
+ 1. (optional) All the pretrained models and visual results can be downloaded from [Google Drive](https://drive.google.com/open?id=1q9iUzqYX0fVRzDu4J6fvSPRosgOZoJJE).
27
+
28
+ 2. Cd to 'PANet-PyTorch/[Task]/code', run the following scripts to train models.
29
+
30
+ **You can use scripts in file 'demo.sb' to train and test models for our paper.**
31
+
32
+ ```bash
33
+ # Example Usage:
34
+ python main.py --n_GPUs 4 --rgb_range 1 --reset --save_models --lr 1e-4 --decay 200-400-600-800 --chop --save_results --n_resblocks 32 --n_feats 256 --res_scale 0.1 --batch_size 16 --model PAEDSR --scale 2 --patch_size 96 --save EDSR_PA_x2 --data_train DIV2K
35
+ ```
36
+ ## Test
37
+ ### Quick start
38
+
39
+ 1. Cd to 'PANet-PyTorch/[Task]/code', run the following scripts.
40
+
41
+ **You can use scripts in file 'demo.sb' to produce results for our paper.**
42
+
43
+ ```bash
44
+ # No self-ensemble, use different testsets to reproduce the results in the paper.
45
+ # Example Usage:
46
+ python main.py --model PAEDSR --data_test Set5+Set14+B100+Urban100 --save_results --rgb_range 1 --data_range 801-900 --scale 2 --n_feats 256 --n_resblocks 32 --res_scale 0.1 --pre_train ../model_x2.pt --test_only --chop
47
+
48
+ ```
49
+
50
+ ### The whole test pipeline
51
+ 1. Prepare benchmark datasets [from SNU_CVLab](https://cv.snu.ac.kr/research/EDSR/benchmark.tar)
52
+
53
+ 2. Conduct image SR.
54
+
55
+ See **Quick start**
56
+ 3. Evaluate the results.
57
+
58
+ Run 'Evaluate_PSNR_SSIM.m' to obtain PSNR/SSIM values for paper.
59
+
60
+ ## Citation
61
+ If you find the code helpful in your resarch or work, please cite the following papers.
62
+ ```
63
+ @article{mei2020pyramid,
64
+ title={Pyramid Attention Networks for Image Restoration},
65
+ author={Mei, Yiqun and Fan, Yuchen and Zhang, Yulun and Yu, Jiahui and Zhou, Yuqian and Liu, Ding and Fu, Yun and Huang, Thomas S and Shi, Honghui},
66
+ journal={arXiv preprint arXiv:2004.13824},
67
+ year={2020}
68
+ }
69
+ @InProceedings{Lim_2017_CVPR_Workshops,
70
+ author = {Lim, Bee and Son, Sanghyun and Kim, Heewon and Nah, Seungjun and Lee, Kyoung Mu},
71
+ title = {Enhanced Deep Residual Networks for Single Image Super-Resolution},
72
+ booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) Workshops},
73
+ month = {July},
74
+ year = {2017}
75
+ }
76
+ ```
77
+ ## Acknowledgements
78
+ This code is built on [EDSR (PyTorch)](https://github.com/thstkdgus35/EDSR-PyTorch), [RNAN](https://github.com/yulunzhang/RNAN) and [generative-inpainting-pytorch](https://github.com/daa233/generative-inpainting-pytorch). We thank the authors for sharing their codes.
SR/code/.vscode/settings.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "python.pythonPath": "D:\\Environments\\Anaconda3\\envs\\torch1.6\\python.exe"
3
+ }
SR/code/__init__.py ADDED
File without changes
SR/code/__pycache__/option.cpython-35.pyc ADDED
Binary file (5.42 kB). View file
 
SR/code/__pycache__/option.cpython-36.pyc ADDED
Binary file (4.76 kB). View file
 
SR/code/__pycache__/option.cpython-37.pyc ADDED
Binary file (4.9 kB). View file
 
SR/code/__pycache__/template.cpython-35.pyc ADDED
Binary file (1.13 kB). View file
 
SR/code/__pycache__/template.cpython-36.pyc ADDED
Binary file (1 kB). View file
 
SR/code/__pycache__/template.cpython-37.pyc ADDED
Binary file (993 Bytes). View file
 
SR/code/__pycache__/trainer.cpython-35.pyc ADDED
Binary file (4.46 kB). View file
 
SR/code/__pycache__/trainer.cpython-36.pyc ADDED
Binary file (3.98 kB). View file
 
SR/code/__pycache__/trainer.cpython-37.pyc ADDED
Binary file (4.84 kB). View file
 
SR/code/__pycache__/utility.cpython-35.pyc ADDED
Binary file (9.88 kB). View file
 
SR/code/__pycache__/utility.cpython-36.pyc ADDED
Binary file (9.04 kB). View file
 
SR/code/__pycache__/utility.cpython-37.pyc ADDED
Binary file (9.1 kB). View file
 
SR/code/bootstrap.sh ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
4
+
5
+ bash ~/miniconda.sh -b -p $HOME/miniconda
SR/code/data/__init__.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from importlib import import_module
2
+ #from dataloader import MSDataLoader
3
+ from torch.utils.data import dataloader
4
+ from torch.utils.data import ConcatDataset
5
+
6
+ # This is a simple wrapper function for ConcatDataset
7
+ class MyConcatDataset(ConcatDataset):
8
+ def __init__(self, datasets):
9
+ super(MyConcatDataset, self).__init__(datasets)
10
+ self.train = datasets[0].train
11
+
12
+ def set_scale(self, idx_scale):
13
+ for d in self.datasets:
14
+ if hasattr(d, 'set_scale'): d.set_scale(idx_scale)
15
+
16
+ class Data:
17
+ def __init__(self, args):
18
+ self.loader_train = None
19
+ if not args.test_only:
20
+ datasets = []
21
+ for d in args.data_train:
22
+ module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG'
23
+ m = import_module('data.' + module_name.lower())
24
+ datasets.append(getattr(m, module_name)(args, name=d))
25
+
26
+ self.loader_train = dataloader.DataLoader(
27
+ MyConcatDataset(datasets),
28
+ batch_size=args.batch_size,
29
+ shuffle=True,
30
+ pin_memory=not args.cpu,
31
+ num_workers=args.n_threads,
32
+ )
33
+
34
+ self.loader_test = []
35
+ for d in args.data_test:
36
+ if d in ['Set5', 'Set14', 'B100', 'Urban100']:
37
+ m = import_module('data.benchmark')
38
+ testset = getattr(m, 'Benchmark')(args, train=False, name=d)
39
+ else:
40
+ module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG'
41
+ m = import_module('data.' + module_name.lower())
42
+ testset = getattr(m, module_name)(args, train=False, name=d)
43
+
44
+ self.loader_test.append(
45
+ dataloader.DataLoader(
46
+ testset,
47
+ batch_size=1,
48
+ shuffle=False,
49
+ pin_memory=not args.cpu,
50
+ num_workers=args.n_threads,
51
+ )
52
+ )
SR/code/data/__pycache__/__init__.cpython-35.pyc ADDED
Binary file (1.94 kB). View file
 
SR/code/data/__pycache__/__init__.cpython-36.pyc ADDED
Binary file (1.79 kB). View file
 
SR/code/data/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (1.76 kB). View file
 
SR/code/data/__pycache__/benchmark.cpython-35.pyc ADDED
Binary file (1.13 kB). View file
 
SR/code/data/__pycache__/benchmark.cpython-36.pyc ADDED
Binary file (1.07 kB). View file
 
SR/code/data/__pycache__/common.cpython-35.pyc ADDED
Binary file (3.05 kB). View file
 
SR/code/data/__pycache__/common.cpython-36.pyc ADDED
Binary file (2.76 kB). View file
 
SR/code/data/__pycache__/common.cpython-37.pyc ADDED
Binary file (2.73 kB). View file
 
SR/code/data/__pycache__/demo.cpython-36.pyc ADDED
Binary file (1.5 kB). View file
 
SR/code/data/__pycache__/div2k.cpython-35.pyc ADDED
Binary file (1.87 kB). View file
 
SR/code/data/__pycache__/div2k.cpython-36.pyc ADDED
Binary file (1.74 kB). View file
 
SR/code/data/__pycache__/div2k.cpython-37.pyc ADDED
Binary file (1.75 kB). View file
 
SR/code/data/__pycache__/srdata.cpython-35.pyc ADDED
Binary file (5.43 kB). View file
 
SR/code/data/__pycache__/srdata.cpython-36.pyc ADDED
Binary file (4.83 kB). View file
 
SR/code/data/__pycache__/srdata.cpython-37.pyc ADDED
Binary file (4.82 kB). View file
 
SR/code/data/benchmark.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from data import common
4
+ from data import srdata
5
+
6
+ import numpy as np
7
+
8
+ import torch
9
+ import torch.utils.data as data
10
+
11
+ class Benchmark(srdata.SRData):
12
+ def __init__(self, args, name='', train=True, benchmark=True):
13
+ super(Benchmark, self).__init__(
14
+ args, name=name, train=train, benchmark=True
15
+ )
16
+
17
+ def _set_filesystem(self, dir_data):
18
+ self.apath = os.path.join(dir_data, 'benchmark', self.name)
19
+ self.dir_hr = os.path.join(self.apath, 'HR')
20
+ if self.input_large:
21
+ self.dir_lr = os.path.join(self.apath, 'LR_bicubicL')
22
+ else:
23
+ self.dir_lr = os.path.join(self.apath, 'LR_bicubic')
24
+ self.ext = ('', '.png')
25
+
SR/code/data/common.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ import numpy as np
4
+ import skimage.color as sc
5
+
6
+ import torch
7
+
8
+ def get_patch(*args, patch_size=96, scale=2, multi=False, input_large=False):
9
+ ih, iw = args[0].shape[:2]
10
+
11
+ if not input_large:
12
+ p = scale if multi else 1
13
+ tp = p * patch_size
14
+ ip = tp // scale
15
+ else:
16
+ tp = patch_size
17
+ ip = patch_size
18
+
19
+ ix = random.randrange(0, iw - ip + 1)
20
+ iy = random.randrange(0, ih - ip + 1)
21
+
22
+ if not input_large:
23
+ tx, ty = scale * ix, scale * iy
24
+ else:
25
+ tx, ty = ix, iy
26
+
27
+ ret = [
28
+ args[0][iy:iy + ip, ix:ix + ip, :],
29
+ *[a[ty:ty + tp, tx:tx + tp, :] for a in args[1:]]
30
+ ]
31
+
32
+ return ret
33
+
34
+ def set_channel(*args, n_channels=3):
35
+ def _set_channel(img):
36
+ if img.ndim == 2:
37
+ img = np.expand_dims(img, axis=2)
38
+
39
+ c = img.shape[2]
40
+ if n_channels == 1 and c == 3:
41
+ img = np.expand_dims(sc.rgb2ycbcr(img)[:, :, 0], 2)
42
+ elif n_channels == 3 and c == 1:
43
+ img = np.concatenate([img] * n_channels, 2)
44
+
45
+ return img
46
+
47
+ return [_set_channel(a) for a in args]
48
+
49
+ def np2Tensor(*args, rgb_range=255):
50
+ def _np2Tensor(img):
51
+ np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1)))
52
+ tensor = torch.from_numpy(np_transpose).float()
53
+ tensor.mul_(rgb_range / 255)
54
+
55
+ return tensor
56
+
57
+ return [_np2Tensor(a) for a in args]
58
+
59
+ def augment(*args, hflip=True, rot=True):
60
+ hflip = hflip and random.random() < 0.5
61
+ vflip = rot and random.random() < 0.5
62
+ rot90 = rot and random.random() < 0.5
63
+
64
+ def _augment(img):
65
+ if hflip: img = img[:, ::-1, :]
66
+ if vflip: img = img[::-1, :, :]
67
+ if rot90: img = img.transpose(1, 0, 2)
68
+
69
+ return img
70
+
71
+ return [_augment(a) for a in args]
72
+
SR/code/data/demo.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from data import common
4
+
5
+ import numpy as np
6
+ import imageio
7
+
8
+ import torch
9
+ import torch.utils.data as data
10
+
11
+ class Demo(data.Dataset):
12
+ def __init__(self, args, name='Demo', train=False, benchmark=False):
13
+ self.args = args
14
+ self.name = name
15
+ self.scale = args.scale
16
+ self.idx_scale = 0
17
+ self.train = False
18
+ self.benchmark = benchmark
19
+
20
+ self.filelist = []
21
+ for f in os.listdir(args.dir_demo):
22
+ if f.find('.png') >= 0 or f.find('.jp') >= 0:
23
+ self.filelist.append(os.path.join(args.dir_demo, f))
24
+ self.filelist.sort()
25
+
26
+ def __getitem__(self, idx):
27
+ filename = os.path.splitext(os.path.basename(self.filelist[idx]))[0]
28
+ lr = imageio.imread(self.filelist[idx])
29
+ lr, = common.set_channel(lr, n_channels=self.args.n_colors)
30
+ lr_t, = common.np2Tensor(lr, rgb_range=self.args.rgb_range)
31
+
32
+ return lr_t, -1, filename
33
+
34
+ def __len__(self):
35
+ return len(self.filelist)
36
+
37
+ def set_scale(self, idx_scale):
38
+ self.idx_scale = idx_scale
39
+
SR/code/data/div2k.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from data import srdata
3
+
4
+ class DIV2K(srdata.SRData):
5
+ def __init__(self, args, name='DIV2K', train=True, benchmark=False):
6
+ data_range = [r.split('-') for r in args.data_range.split('/')]
7
+ if train:
8
+ data_range = data_range[0]
9
+ else:
10
+ if args.test_only and len(data_range) == 1:
11
+ data_range = data_range[0]
12
+ else:
13
+ data_range = data_range[1]
14
+
15
+ self.begin, self.end = list(map(lambda x: int(x), data_range))
16
+ super(DIV2K, self).__init__(
17
+ args, name=name, train=train, benchmark=benchmark
18
+ )
19
+
20
+ def _scan(self):
21
+ names_hr, names_lr = super(DIV2K, self)._scan()
22
+ names_hr = names_hr[self.begin - 1:self.end]
23
+ names_lr = [n[self.begin - 1:self.end] for n in names_lr]
24
+
25
+ return names_hr, names_lr
26
+
27
+ def _set_filesystem(self, dir_data):
28
+ super(DIV2K, self)._set_filesystem(dir_data)
29
+ self.dir_hr = os.path.join(self.apath, 'DIV2K_train_HR')
30
+ self.dir_lr = os.path.join(self.apath, 'DIV2K_train_LR_bicubic')
31
+ if self.input_large: self.dir_lr += 'L'
32
+
SR/code/data/div2kjpeg.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from data import srdata
3
+ from data import div2k
4
+
5
+ class DIV2KJPEG(div2k.DIV2K):
6
+ def __init__(self, args, name='', train=True, benchmark=False):
7
+ self.q_factor = int(name.replace('DIV2K-Q', ''))
8
+ super(DIV2KJPEG, self).__init__(
9
+ args, name=name, train=train, benchmark=benchmark
10
+ )
11
+
12
+ def _set_filesystem(self, dir_data):
13
+ self.apath = os.path.join(dir_data, 'DIV2K')
14
+ self.dir_hr = os.path.join(self.apath, 'DIV2K_train_HR')
15
+ self.dir_lr = os.path.join(
16
+ self.apath, 'DIV2K_Q{}'.format(self.q_factor)
17
+ )
18
+ if self.input_large: self.dir_lr += 'L'
19
+ self.ext = ('.png', '.jpg')
20
+
SR/code/data/sr291.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from data import srdata
2
+
3
+ class SR291(srdata.SRData):
4
+ def __init__(self, args, name='SR291', train=True, benchmark=False):
5
+ super(SR291, self).__init__(args, name=name)
6
+
SR/code/data/srdata.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import random
4
+ import pickle
5
+
6
+ from data import common
7
+
8
+ import numpy as np
9
+ import imageio
10
+ import torch
11
+ import torch.utils.data as data
12
+
13
+ class SRData(data.Dataset):
14
+ def __init__(self, args, name='', train=True, benchmark=False):
15
+ self.args = args
16
+ self.name = name
17
+ self.train = train
18
+ self.split = 'train' if train else 'test'
19
+ self.do_eval = True
20
+ self.benchmark = benchmark
21
+ self.input_large = (args.model == 'VDSR')
22
+ self.scale = args.scale
23
+ self.idx_scale = 0
24
+
25
+ self._set_filesystem(args.dir_data)
26
+ if args.ext.find('img') < 0:
27
+ path_bin = os.path.join(self.apath, 'bin')
28
+ os.makedirs(path_bin, exist_ok=True)
29
+
30
+ list_hr, list_lr = self._scan()
31
+ if args.ext.find('img') >= 0 or benchmark:
32
+ self.images_hr, self.images_lr = list_hr, list_lr
33
+ elif args.ext.find('sep') >= 0:
34
+ os.makedirs(
35
+ self.dir_hr.replace(self.apath, path_bin),
36
+ exist_ok=True
37
+ )
38
+ for s in self.scale:
39
+ os.makedirs(
40
+ os.path.join(
41
+ self.dir_lr.replace(self.apath, path_bin),
42
+ 'X{}'.format(s)
43
+ ),
44
+ exist_ok=True
45
+ )
46
+
47
+ self.images_hr, self.images_lr = [], [[] for _ in self.scale]
48
+ for h in list_hr:
49
+ b = h.replace(self.apath, path_bin)
50
+ b = b.replace(self.ext[0], '.pt')
51
+ self.images_hr.append(b)
52
+ self._check_and_load(args.ext, h, b, verbose=True)
53
+ for i, ll in enumerate(list_lr):
54
+ for l in ll:
55
+ b = l.replace(self.apath, path_bin)
56
+ b = b.replace(self.ext[1], '.pt')
57
+ self.images_lr[i].append(b)
58
+ self._check_and_load(args.ext, l, b, verbose=True)
59
+ if train:
60
+ n_patches = args.batch_size * args.test_every
61
+ n_images = len(args.data_train) * len(self.images_hr)
62
+ if n_images == 0:
63
+ self.repeat = 0
64
+ else:
65
+ self.repeat = max(n_patches // n_images, 1)
66
+
67
+ # Below functions as used to prepare images
68
+ def _scan(self):
69
+ names_hr = sorted(
70
+ glob.glob(os.path.join(self.dir_hr, '*' + self.ext[0]))
71
+ )
72
+ names_lr = [[] for _ in self.scale]
73
+ for f in names_hr:
74
+ filename, _ = os.path.splitext(os.path.basename(f))
75
+ for si, s in enumerate(self.scale):
76
+ names_lr[si].append(os.path.join(
77
+ self.dir_lr, 'X{}/{}x{}{}'.format(
78
+ s, filename, s, self.ext[1]
79
+ )
80
+ ))
81
+
82
+ return names_hr, names_lr
83
+
84
+ def _set_filesystem(self, dir_data):
85
+ self.apath = os.path.join(dir_data, self.name)
86
+ self.dir_hr = os.path.join(self.apath, 'HR')
87
+ self.dir_lr = os.path.join(self.apath, 'LR_bicubic')
88
+ if self.input_large: self.dir_lr += 'L'
89
+ self.ext = ('.png', '.png')
90
+
91
+ def _check_and_load(self, ext, img, f, verbose=True):
92
+ if not os.path.isfile(f) or ext.find('reset') >= 0:
93
+ if verbose:
94
+ print('Making a binary: {}'.format(f))
95
+ with open(f, 'wb') as _f:
96
+ pickle.dump(imageio.imread(img), _f)
97
+
98
+ def __getitem__(self, idx):
99
+ lr, hr, filename = self._load_file(idx)
100
+ pair = self.get_patch(lr, hr)
101
+ pair = common.set_channel(*pair, n_channels=self.args.n_colors)
102
+ pair_t = common.np2Tensor(*pair, rgb_range=self.args.rgb_range)
103
+
104
+ return pair_t[0], pair_t[1], filename
105
+
106
+ def __len__(self):
107
+ if self.train:
108
+ return len(self.images_hr) * self.repeat
109
+ else:
110
+ return len(self.images_hr)
111
+
112
+ def _get_index(self, idx):
113
+ if self.train:
114
+ return idx % len(self.images_hr)
115
+ else:
116
+ return idx
117
+
118
+ def _load_file(self, idx):
119
+ idx = self._get_index(idx)
120
+ f_hr = self.images_hr[idx]
121
+ f_lr = self.images_lr[self.idx_scale][idx]
122
+
123
+ filename, _ = os.path.splitext(os.path.basename(f_hr))
124
+ if self.args.ext == 'img' or self.benchmark:
125
+ hr = imageio.imread(f_hr)
126
+ lr = imageio.imread(f_lr)
127
+ elif self.args.ext.find('sep') >= 0:
128
+ with open(f_hr, 'rb') as _f:
129
+ hr = pickle.load(_f)
130
+ with open(f_lr, 'rb') as _f:
131
+ lr = pickle.load(_f)
132
+
133
+ return lr, hr, filename
134
+
135
+ def get_patch(self, lr, hr):
136
+ scale = self.scale[self.idx_scale]
137
+ if self.train:
138
+ lr, hr = common.get_patch(
139
+ lr, hr,
140
+ patch_size=self.args.patch_size,
141
+ scale=scale,
142
+ multi=(len(self.scale) > 1),
143
+ input_large=self.input_large
144
+ )
145
+ if not self.args.no_augment: lr, hr = common.augment(lr, hr)
146
+ else:
147
+ ih, iw = lr.shape[:2]
148
+ hr = hr[0:ih * scale, 0:iw * scale]
149
+
150
+ return lr, hr
151
+
152
+ def set_scale(self, idx_scale):
153
+ if not self.input_large:
154
+ self.idx_scale = idx_scale
155
+ else:
156
+ self.idx_scale = random.randint(0, len(self.scale) - 1)
157
+
SR/code/data/video.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from data import common
4
+
5
+ import cv2
6
+ import numpy as np
7
+ import imageio
8
+
9
+ import torch
10
+ import torch.utils.data as data
11
+
12
+ class Video(data.Dataset):
13
+ def __init__(self, args, name='Video', train=False, benchmark=False):
14
+ self.args = args
15
+ self.name = name
16
+ self.scale = args.scale
17
+ self.idx_scale = 0
18
+ self.train = False
19
+ self.do_eval = False
20
+ self.benchmark = benchmark
21
+
22
+ self.filename, _ = os.path.splitext(os.path.basename(args.dir_demo))
23
+ self.vidcap = cv2.VideoCapture(args.dir_demo)
24
+ self.n_frames = 0
25
+ self.total_frames = int(self.vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
26
+
27
+ def __getitem__(self, idx):
28
+ success, lr = self.vidcap.read()
29
+ if success:
30
+ self.n_frames += 1
31
+ lr, = common.set_channel(lr, n_channels=self.args.n_colors)
32
+ lr_t, = common.np2Tensor(lr, rgb_range=self.args.rgb_range)
33
+
34
+ return lr_t, -1, '{}_{:0>5}'.format(self.filename, self.n_frames)
35
+ else:
36
+ vidcap.release()
37
+ return None
38
+
39
+ def __len__(self):
40
+ return self.total_frames
41
+
42
+ def set_scale(self, idx_scale):
43
+ self.idx_scale = idx_scale
44
+
SR/code/dataloader.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import threading
2
+ import random
3
+
4
+ import torch
5
+ import torch.multiprocessing as multiprocessing
6
+ from torch.utils.data import DataLoader
7
+ from torch.utils.data import SequentialSampler
8
+ from torch.utils.data import RandomSampler
9
+ from torch.utils.data import BatchSampler
10
+ from torch.utils.data import _utils
11
+ from torch.utils.data.dataloader import _DataLoaderIter
12
+
13
+ from torch.utils.data._utils import collate
14
+ from torch.utils.data._utils import signal_handling
15
+ from torch.utils.data._utils import MP_STATUS_CHECK_INTERVAL
16
+ from torch.utils.data._utils import ExceptionWrapper
17
+ from torch.utils.data._utils import IS_WINDOWS
18
+ from torch.utils.data._utils.worker import ManagerWatchdog
19
+
20
+ from torch._six import queue
21
+
22
+ def _ms_loop(dataset, index_queue, data_queue, done_event, collate_fn, scale, seed, init_fn, worker_id):
23
+ try:
24
+ collate._use_shared_memory = True
25
+ signal_handling._set_worker_signal_handlers()
26
+
27
+ torch.set_num_threads(1)
28
+ random.seed(seed)
29
+ torch.manual_seed(seed)
30
+
31
+ data_queue.cancel_join_thread()
32
+
33
+ if init_fn is not None:
34
+ init_fn(worker_id)
35
+
36
+ watchdog = ManagerWatchdog()
37
+
38
+ while watchdog.is_alive():
39
+ try:
40
+ r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
41
+ except queue.Empty:
42
+ continue
43
+
44
+ if r is None:
45
+ assert done_event.is_set()
46
+ return
47
+ elif done_event.is_set():
48
+ continue
49
+
50
+ idx, batch_indices = r
51
+ try:
52
+ idx_scale = 0
53
+ if len(scale) > 1 and dataset.train:
54
+ idx_scale = random.randrange(0, len(scale))
55
+ dataset.set_scale(idx_scale)
56
+
57
+ samples = collate_fn([dataset[i] for i in batch_indices])
58
+ samples.append(idx_scale)
59
+ except Exception:
60
+ data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
61
+ else:
62
+ data_queue.put((idx, samples))
63
+ del samples
64
+
65
+ except KeyboardInterrupt:
66
+ pass
67
+
68
+ class _MSDataLoaderIter(_DataLoaderIter):
69
+
70
+ def __init__(self, loader):
71
+ self.dataset = loader.dataset
72
+ self.scale = loader.scale
73
+ self.collate_fn = loader.collate_fn
74
+ self.batch_sampler = loader.batch_sampler
75
+ self.num_workers = loader.num_workers
76
+ self.pin_memory = loader.pin_memory and torch.cuda.is_available()
77
+ self.timeout = loader.timeout
78
+
79
+ self.sample_iter = iter(self.batch_sampler)
80
+
81
+ base_seed = torch.LongTensor(1).random_().item()
82
+
83
+ if self.num_workers > 0:
84
+ self.worker_init_fn = loader.worker_init_fn
85
+ self.worker_queue_idx = 0
86
+ self.worker_result_queue = multiprocessing.Queue()
87
+ self.batches_outstanding = 0
88
+ self.worker_pids_set = False
89
+ self.shutdown = False
90
+ self.send_idx = 0
91
+ self.rcvd_idx = 0
92
+ self.reorder_dict = {}
93
+ self.done_event = multiprocessing.Event()
94
+
95
+ base_seed = torch.LongTensor(1).random_()[0]
96
+
97
+ self.index_queues = []
98
+ self.workers = []
99
+ for i in range(self.num_workers):
100
+ index_queue = multiprocessing.Queue()
101
+ index_queue.cancel_join_thread()
102
+ w = multiprocessing.Process(
103
+ target=_ms_loop,
104
+ args=(
105
+ self.dataset,
106
+ index_queue,
107
+ self.worker_result_queue,
108
+ self.done_event,
109
+ self.collate_fn,
110
+ self.scale,
111
+ base_seed + i,
112
+ self.worker_init_fn,
113
+ i
114
+ )
115
+ )
116
+ w.daemon = True
117
+ w.start()
118
+ self.index_queues.append(index_queue)
119
+ self.workers.append(w)
120
+
121
+ if self.pin_memory:
122
+ self.data_queue = queue.Queue()
123
+ pin_memory_thread = threading.Thread(
124
+ target=_utils.pin_memory._pin_memory_loop,
125
+ args=(
126
+ self.worker_result_queue,
127
+ self.data_queue,
128
+ torch.cuda.current_device(),
129
+ self.done_event
130
+ )
131
+ )
132
+ pin_memory_thread.daemon = True
133
+ pin_memory_thread.start()
134
+ self.pin_memory_thread = pin_memory_thread
135
+ else:
136
+ self.data_queue = self.worker_result_queue
137
+
138
+ _utils.signal_handling._set_worker_pids(
139
+ id(self), tuple(w.pid for w in self.workers)
140
+ )
141
+ _utils.signal_handling._set_SIGCHLD_handler()
142
+ self.worker_pids_set = True
143
+
144
+ for _ in range(2 * self.num_workers):
145
+ self._put_indices()
146
+
147
+
148
+ class MSDataLoader(DataLoader):
149
+
150
+ def __init__(self, cfg, *args, **kwargs):
151
+ super(MSDataLoader, self).__init__(
152
+ *args, **kwargs, num_workers=cfg.n_threads
153
+ )
154
+ self.scale = cfg.scale
155
+
156
+ def __iter__(self):
157
+ return _MSDataLoaderIter(self)
158
+
SR/code/demo.sb ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #Train x2
3
+ python main.py --n_GPUs 4 --rgb_range 1 --reset --save_models --lr 1e-4 --decay 200-400-600-800 --chop --save_results --n_resblocks 32 --n_feats 256 --res_scale 0.1 --batch_size 16 --model PAEDSR --scale 2 --patch_size 96 --save EDSR_PA_x2 --data_train DIV2K
4
+ #Test
5
+ python main.py --model PAEDSR --data_test Set5+Set14+B100+Urban100 --save_results --rgb_range 1 --data_range 801-900 --scale 2 --n_feats 256 --n_resblocks 32 --res_scale 0.1 --pre_train ../model_x2.pt --test_only --chop
SR/code/lambda_networks/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from lambda_networks.lambda_networks import LambdaLayer
2
+ from lambda_networks.lambda_networks import Recursion
3
+ from lambda_networks.rlambda_networks import RLambdaLayer
4
+ λLayer = LambdaLayer
SR/code/lambda_networks/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (337 Bytes). View file
 
SR/code/lambda_networks/__pycache__/lambda_networks.cpython-37.pyc ADDED
Binary file (4.78 kB). View file
 
SR/code/lambda_networks/__pycache__/rlambda_networks.cpython-37.pyc ADDED
Binary file (2.94 kB). View file
 
SR/code/lambda_networks/lambda_networks.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn, einsum
3
+ import torch.nn.functional as F
4
+ from einops import rearrange
5
+
6
+ # helpers functions
7
+
8
+ def exists(val):
9
+ return val is not None
10
+
11
+ def default(val, d):
12
+ return val if exists(val) else d
13
+
14
+ # lambda layer
15
+
16
+ class LambdaLayer(nn.Module):
17
+ def __init__(
18
+ self,
19
+ dim,
20
+ *,
21
+ dim_k,
22
+ n = None,
23
+ r = None,
24
+ heads = 4,
25
+ dim_out = None,
26
+ dim_u = 1):
27
+ super().__init__()
28
+ dim_out = default(dim_out, dim)
29
+ self.u = dim_u # intra-depth dimension
30
+ self.heads = heads
31
+
32
+ assert (dim_out % heads) == 0, 'values dimension must be divisible by number of heads for multi-head query'
33
+ dim_v = dim_out // heads
34
+
35
+ self.to_q = nn.Conv2d(dim, dim_k * heads, 1, bias = False)
36
+ self.to_k = nn.Conv2d(dim, dim_k * dim_u, 1, bias = False)
37
+ self.to_v = nn.Conv2d(dim, dim_v * dim_u, 1, bias = False)
38
+
39
+ self.norm_q = nn.BatchNorm2d(dim_k * heads)
40
+ self.norm_v = nn.BatchNorm2d(dim_v * dim_u)
41
+
42
+ self.local_contexts = exists(r)
43
+ if exists(r):
44
+ assert (r % 2) == 1, 'Receptive kernel size should be odd'
45
+ self.pos_conv = nn.Conv3d(dim_u, dim_k, (1, r, r), padding = (0, r // 2, r // 2))
46
+ else:
47
+ assert exists(n), 'You must specify the total sequence length (h x w)'
48
+ self.pos_emb = nn.Parameter(torch.randn(n, n, dim_k, dim_u))
49
+
50
+
51
+ def forward(self, x):
52
+ b, c, hh, ww, u, h = *x.shape, self.u, self.heads
53
+
54
+ q = self.to_q(x)
55
+ k = self.to_k(x)
56
+ v = self.to_v(x)
57
+
58
+ q = self.norm_q(q)
59
+ v = self.norm_v(v)
60
+
61
+ q = rearrange(q, 'b (h k) hh ww -> b h k (hh ww)', h = h)
62
+ k = rearrange(k, 'b (u k) hh ww -> b u k (hh ww)', u = u)
63
+ v = rearrange(v, 'b (u v) hh ww -> b u v (hh ww)', u = u)
64
+
65
+ k = k.softmax(dim=-1)
66
+
67
+ λc = einsum('b u k m, b u v m -> b k v', k, v)
68
+ Yc = einsum('b h k n, b k v -> b h v n', q, λc)
69
+
70
+ if self.local_contexts:
71
+ v = rearrange(v, 'b u v (hh ww) -> b u v hh ww', hh = hh, ww = ww)
72
+ λp = self.pos_conv(v)
73
+ Yp = einsum('b h k n, b k v n -> b h v n', q, λp.flatten(3))
74
+ else:
75
+ λp = einsum('n m k u, b u v m -> b n k v', self.pos_emb, v)
76
+ Yp = einsum('b h k n, b n k v -> b h v n', q, λp)
77
+
78
+ Y = Yc + Yp
79
+ out = rearrange(Y, 'b h v (hh ww) -> b (h v) hh ww', hh = hh, ww = ww)
80
+ return out
81
+
82
+
83
+ # i'm not sure whether this will work or not
84
+ class Recursion(nn.Module):
85
+ def __init__(self, N: int, hidden_dim:int=64):
86
+ super(Recursion,self).__init__()
87
+ self.N = N
88
+ self.lambdaNxN_identity = LambdaLayer(dim=hidden_dim, dim_out=hidden_dim, n=N * N, dim_k=16, heads=2, dim_u=1)
89
+ # merge upstream information here
90
+ self.lambdaNxN_merge = LambdaLayer(dim=2*hidden_dim, dim_out=hidden_dim, n=N * N, dim_k=16, heads=2, dim_u=1)
91
+ self.downscale_conv = nn.Conv2d(hidden_dim, hidden_dim, kernel_size=N, stride=N)
92
+ self.upscale_conv = nn.Conv2d(hidden_dim, hidden_dim * N * N, kernel_size=3,padding=1)
93
+ self.pixel_shuffle = nn.PixelShuffle(N)
94
+
95
+ def forward(self, x: torch.Tensor):
96
+ N = self.N
97
+
98
+ def to_patch(blocks:torch.Tensor)->torch.Tensor:
99
+ shape = blocks.shape
100
+ blocks_patch = F.unfold(blocks, kernel_size=N, stride=N)
101
+ blocks_patch = blocks_patch.view(shape[0], shape[1], N, N, -1)
102
+ num_patch = blocks_patch.shape[-1]
103
+ blocks_patch = blocks_patch.permute(0, 4, 1, 2, 3).reshape(-1, shape[1], N, N).contiguous()
104
+ return blocks_patch, num_patch
105
+
106
+ def combine_patch(processed_patch,shape,num_patch):
107
+ processed_patch = processed_patch.reshape(shape[0], num_patch, shape[1], N, N)
108
+ processed_patch=processed_patch.permute(0, 2, 3, 4, 1).reshape(shape[0],shape[1] * N * N,num_patch).contiguous()
109
+ processed=F.fold(processed_patch,output_size=(shape[-2],shape[-1]),kernel_size=N,stride=N)
110
+ return processed
111
+
112
+ def process(blocks:torch.Tensor)->torch.Tensor:
113
+ shape = blocks.shape
114
+ if blocks.shape[-1] == N:
115
+ processed = self.lambdaNxN_identity(blocks)
116
+ return processed
117
+ # to NxN patchs
118
+ blocks_patch,num_patch=to_patch(blocks)
119
+ # pass through identity
120
+ processed_patch = self.lambdaNxN_identity(blocks_patch)
121
+ # back to HxW
122
+ processed=combine_patch(processed_patch,shape,num_patch)
123
+ # get feedback
124
+ feedback = process(self.downscale_conv(processed))
125
+ # upscale feedback
126
+ upscale_feedback = self.upscale_conv(feedback)
127
+ upscale_feedback=self.pixel_shuffle(upscale_feedback)
128
+ # combine results
129
+ combined = torch.cat([processed, upscale_feedback], dim=1)
130
+ combined_shape=combined.shape
131
+ combined_patch,num_patch=to_patch(combined)
132
+ combined_patch_reduced = self.lambdaNxN_merge(combined_patch)
133
+ ret_shape=(combined_shape[0],combined_shape[1]//2,combined_shape[2],combined_shape[3])
134
+ ret=combine_patch(combined_patch_reduced,ret_shape,num_patch)
135
+ return ret
136
+
137
+ return process(x)
SR/code/lambda_networks/rlambda_networks.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn, einsum
3
+ import torch.nn.functional as F
4
+ from einops import rearrange
5
+
6
+
7
+ # helpers functions
8
+
9
+ def exists(val):
10
+ return val is not None
11
+
12
+
13
+ def default(val, d):
14
+ return val if exists(val) else d
15
+
16
+
17
+ # lambda layer
18
+
19
+ class RLambdaLayer(nn.Module):
20
+ def __init__(
21
+ self,
22
+ dim,
23
+ *,
24
+ dim_k,
25
+ n=None,
26
+ r=None,
27
+ heads=4,
28
+ dim_out=None,
29
+ dim_u=1,
30
+ recurrence=None
31
+ ):
32
+ super().__init__()
33
+ dim_out = default(dim_out, dim)
34
+ self.u = dim_u # intra-depth dimension
35
+ self.heads = heads
36
+
37
+ assert (dim_out % heads) == 0, 'values dimension must be divisible by number of heads for multi-head query'
38
+ dim_v = dim_out // heads
39
+
40
+ self.to_q = nn.Conv2d(dim, dim_k * heads, 1, bias=False)
41
+ self.to_k = nn.Conv2d(dim, dim_k * dim_u, 1, bias=False)
42
+ self.to_v = nn.Conv2d(dim, dim_v * dim_u, 1, bias=False)
43
+
44
+ self.norm_q = nn.BatchNorm2d(dim_k * heads)
45
+ self.norm_v = nn.BatchNorm2d(dim_v * dim_u)
46
+
47
+ self.local_contexts = exists(r)
48
+ self.recurrence = recurrence
49
+ if exists(r):
50
+ assert (r % 2) == 1, 'Receptive kernel size should be odd'
51
+ self.pos_conv = nn.Conv3d(dim_u, dim_k, (1, r, r), padding=(0, r // 2, r // 2))
52
+ else:
53
+ assert exists(n), 'You must specify the total sequence length (h x w)'
54
+ self.pos_emb = nn.Parameter(torch.randn(n, n, dim_k, dim_u))
55
+
56
+ def apply_lambda(self, lambda_c, lambda_p, x):
57
+ b, c, hh, ww, u, h = *x.shape, self.u, self.heads
58
+ q = self.to_q(x)
59
+ q = self.norm_q(q)
60
+ q = rearrange(q, 'b (h k) hh ww -> b h k (hh ww)', h=h)
61
+ Yc = einsum('b h k n, b k v -> b h v n', q, lambda_c)
62
+ if self.local_contexts:
63
+ Yp = einsum('b h k n, b k v n -> b h v n', q, lambda_p.flatten(3))
64
+ else:
65
+ Yp = einsum('b h k n, b n k v -> b h v n', q, lambda_p)
66
+ Y = Yc + Yp
67
+ out = rearrange(Y, 'b h v (hh ww) -> b (h v) hh ww', hh=hh, ww=ww)
68
+ return out
69
+
70
+ def forward(self, x):
71
+ b, c, hh, ww, u, h = *x.shape, self.u, self.heads
72
+
73
+ k = self.to_k(x)
74
+ v = self.to_v(x)
75
+
76
+ v = self.norm_v(v)
77
+
78
+ k = rearrange(k, 'b (u k) hh ww -> b u k (hh ww)', u=u)
79
+ v = rearrange(v, 'b (u v) hh ww -> b u v (hh ww)', u=u)
80
+
81
+ k = k.softmax(dim=-1)
82
+
83
+ λc = einsum('b u k m, b u v m -> b k v', k, v)
84
+
85
+ if self.local_contexts:
86
+ v = rearrange(v, 'b u v (hh ww) -> b u v hh ww', hh=hh, ww=ww)
87
+ λp = self.pos_conv(v)
88
+ else:
89
+ λp = einsum('n m k u, b u v m -> b n k v', self.pos_emb, v)
90
+ out = x
91
+ for i in range(self.recurrence):
92
+ out = self.apply_lambda(λc, λp, out)
93
+ return out
SR/code/loss/__init__.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from importlib import import_module
3
+
4
+ import matplotlib
5
+ matplotlib.use('Agg')
6
+ import matplotlib.pyplot as plt
7
+
8
+ import numpy as np
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+
14
+ def sequence_loss(sr, hr, loss_func, gamma=0.8, max_val=None):
15
+ """ Loss function defined over sequence of flow predictions """
16
+
17
+ n_recurrence = len(sr)
18
+ total_loss = 0.0
19
+ buffer=[0.0]*n_recurrence
20
+ # exlude invalid pixels and extremely large diplacements
21
+ for i in range(n_recurrence):
22
+ i_weight = gamma**(n_recurrence - i - 1)
23
+ i_loss = loss_func(sr[i],hr)
24
+ buffer[i]=i_loss.item()
25
+ # total_loss += i_weight * (valid[:, None] * i_loss).mean()
26
+ total_loss += i_weight * (i_loss)
27
+ return total_loss,buffer
28
+
29
+ class Loss(nn.modules.loss._Loss):
30
+ def __init__(self, args, ckp):
31
+ super(Loss, self).__init__()
32
+ print('Preparing loss function:')
33
+ self.buffer=[0.0]*args.recurrence
34
+ self.n_GPUs = args.n_GPUs
35
+ self.loss = []
36
+ self.loss_module = nn.ModuleList()
37
+ for loss in args.loss.split('+'):
38
+ weight, loss_type = loss.split('*')
39
+ if loss_type == 'MSE':
40
+ loss_function = nn.MSELoss()
41
+ elif loss_type == 'L1':
42
+ loss_function = nn.L1Loss()
43
+ elif loss_type.find('VGG') >= 0:
44
+ module = import_module('loss.vgg')
45
+ loss_function = getattr(module, 'VGG')(
46
+ loss_type[3:],
47
+ rgb_range=args.rgb_range
48
+ )
49
+ elif loss_type.find('GAN') >= 0:
50
+ module = import_module('loss.adversarial')
51
+ loss_function = getattr(module, 'Adversarial')(
52
+ args,
53
+ loss_type
54
+ )
55
+
56
+ self.loss.append({
57
+ 'type': loss_type,
58
+ 'weight': float(weight),
59
+ 'function': loss_function}
60
+ )
61
+ if loss_type.find('GAN') >= 0:
62
+ self.loss.append({'type': 'DIS', 'weight': 1, 'function': None})
63
+
64
+ if len(self.loss) > 1:
65
+ self.loss.append({'type': 'Total', 'weight': 0, 'function': None})
66
+
67
+ for l in self.loss:
68
+ if l['function'] is not None:
69
+ print('{:.3f} * {}'.format(l['weight'], l['type']))
70
+ # self.loss_module.append(l['function'])
71
+
72
+ self.log = torch.Tensor()
73
+
74
+ device = torch.device('cpu' if args.cpu else 'cuda')
75
+ self.loss_module.to(device)
76
+ if args.precision == 'half': self.loss_module.half()
77
+ if not args.cpu and args.n_GPUs > 1:
78
+ self.loss_module = nn.DataParallel(
79
+ self.loss_module, range(args.n_GPUs)
80
+ )
81
+
82
+ if args.load != '': self.load(ckp.dir, cpu=args.cpu)
83
+
84
+ def forward(self, sr, hr):
85
+ losses = []
86
+ for i, l in enumerate(self.loss):
87
+ if l['function'] is not None:
88
+ if isinstance(sr,list):
89
+ # weights=[0.32,0.08,0.02,0.01,0.005]
90
+ # weights=weights[::-1]
91
+ # weights=[0.01,0.02,0.08,0.32]
92
+ # self.buffer=[]
93
+ effective_loss,buffer_lst=sequence_loss(sr,hr,l['function'])
94
+ # for k in range(len(sr)):
95
+ # loss=l['function'](sr[k], hr)
96
+ # self.buffer.append(loss.item())
97
+ # effective_loss=loss*weights[k]*l['weight']
98
+ losses.append(effective_loss)
99
+ self.buffer=buffer_lst
100
+ self.log[-1, i] += effective_loss.item()
101
+ else:
102
+ loss = l['function'](sr, hr)
103
+ effective_loss = l['weight'] * loss
104
+ losses.append(effective_loss)
105
+ self.buffer[0]=effective_loss.item()
106
+ self.log[-1, i] += effective_loss.item()
107
+ elif l['type'] == 'DIS':
108
+ self.log[-1, i] += self.loss[i - 1]['function'].loss
109
+
110
+ loss_sum = sum(losses)
111
+ if len(self.loss) > 1:
112
+ self.log[-1, -1] += loss_sum.item()
113
+
114
+ return loss_sum
115
+
116
+ def step(self):
117
+ for l in self.get_loss_module():
118
+ if hasattr(l, 'scheduler'):
119
+ l.scheduler.step()
120
+
121
+ def start_log(self):
122
+ self.log = torch.cat((self.log, torch.zeros(1, len(self.loss))))
123
+
124
+ def end_log(self, n_batches):
125
+ self.log[-1].div_(n_batches)
126
+
127
+ def display_loss(self, batch):
128
+ n_samples = batch + 1
129
+ log = []
130
+ for l, c in zip(self.loss, self.log[-1]):
131
+ log.append('[{}: {:.4f}]'.format(l['type'], c / n_samples))
132
+
133
+ return ''.join(log)
134
+
135
+ def plot_loss(self, apath, epoch):
136
+ axis = np.linspace(1, epoch, epoch)
137
+ for i, l in enumerate(self.loss):
138
+ label = '{} Loss'.format(l['type'])
139
+ fig = plt.figure()
140
+ plt.title(label)
141
+ plt.plot(axis, self.log[:, i].numpy(), label=label)
142
+ plt.legend()
143
+ plt.xlabel('Epochs')
144
+ plt.ylabel('Loss')
145
+ plt.grid(True)
146
+ plt.savefig(os.path.join(apath, 'loss_{}.pdf'.format(l['type'])))
147
+ plt.close(fig)
148
+
149
+ def get_loss_module(self):
150
+ if self.n_GPUs == 1:
151
+ return self.loss_module
152
+ else:
153
+ return self.loss_module.module
154
+
155
+ def save(self, apath):
156
+ torch.save(self.state_dict(), os.path.join(apath, 'loss.pt'))
157
+ torch.save(self.log, os.path.join(apath, 'loss_log.pt'))
158
+
159
+ def load(self, apath, cpu=False):
160
+ if cpu:
161
+ kwargs = {'map_location': lambda storage, loc: storage}
162
+ else:
163
+ kwargs = {}
164
+
165
+ self.load_state_dict(torch.load(
166
+ os.path.join(apath, 'loss.pt'),
167
+ **kwargs
168
+ ))
169
+ self.log = torch.load(os.path.join(apath, 'loss_log.pt'))
170
+ for l in self.get_loss_module():
171
+ if hasattr(l, 'scheduler'):
172
+ for _ in range(len(self.log)): l.scheduler.step()
173
+
SR/code/loss/__loss__.py ADDED
File without changes