hyliu commited on
Commit
3ef0208
1 Parent(s): 8ec10cf

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. CAR/README.md +103 -0
  2. CAR/code/LICENSE +21 -0
  3. CAR/code/__init__.py +0 -0
  4. CAR/code/__pycache__/option.cpython-36.pyc +0 -0
  5. CAR/code/__pycache__/option.cpython-37.pyc +0 -0
  6. CAR/code/__pycache__/template.cpython-36.pyc +0 -0
  7. CAR/code/__pycache__/template.cpython-37.pyc +0 -0
  8. CAR/code/__pycache__/trainer.cpython-36.pyc +0 -0
  9. CAR/code/__pycache__/trainer.cpython-37.pyc +0 -0
  10. CAR/code/__pycache__/utility.cpython-36.pyc +0 -0
  11. CAR/code/__pycache__/utility.cpython-37.pyc +0 -0
  12. CAR/code/data/__init__.py +52 -0
  13. CAR/code/data/__pycache__/__init__.cpython-36.pyc +0 -0
  14. CAR/code/data/__pycache__/__init__.cpython-37.pyc +0 -0
  15. CAR/code/data/__pycache__/benchmark.cpython-36.pyc +0 -0
  16. CAR/code/data/__pycache__/common.cpython-36.pyc +0 -0
  17. CAR/code/data/__pycache__/common.cpython-37.pyc +0 -0
  18. CAR/code/data/__pycache__/div2k.cpython-36.pyc +0 -0
  19. CAR/code/data/__pycache__/div2k.cpython-37.pyc +0 -0
  20. CAR/code/data/__pycache__/srdata.cpython-36.pyc +0 -0
  21. CAR/code/data/__pycache__/srdata.cpython-37.pyc +0 -0
  22. CAR/code/data/benchmark.py +25 -0
  23. CAR/code/data/common.py +79 -0
  24. CAR/code/data/demo.py +39 -0
  25. CAR/code/data/div2k.py +32 -0
  26. CAR/code/data/div2kjpeg.py +20 -0
  27. CAR/code/data/sr291.py +6 -0
  28. CAR/code/data/srdata.py +166 -0
  29. CAR/code/data/video.py +44 -0
  30. CAR/code/dataloader.py +158 -0
  31. CAR/code/demo.sb +6 -0
  32. CAR/code/demo.sh +56 -0
  33. CAR/code/lambda_networks/__init__.py +4 -0
  34. CAR/code/lambda_networks/__pycache__/__init__.cpython-37.pyc +0 -0
  35. CAR/code/lambda_networks/__pycache__/lambda_networks.cpython-37.pyc +0 -0
  36. CAR/code/lambda_networks/__pycache__/rlambda_networks.cpython-37.pyc +0 -0
  37. CAR/code/lambda_networks/lambda_networks.py +161 -0
  38. CAR/code/lambda_networks/rlambda_networks.py +93 -0
  39. CAR/code/loss/__init__.py +173 -0
  40. CAR/code/loss/__pycache__/__init__.cpython-36.pyc +0 -0
  41. CAR/code/loss/__pycache__/__init__.cpython-37.pyc +0 -0
  42. CAR/code/loss/adversarial.py +112 -0
  43. CAR/code/loss/discriminator.py +55 -0
  44. CAR/code/loss/vgg.py +36 -0
  45. CAR/code/main.py +35 -0
  46. CAR/code/model/LICENSE +21 -0
  47. CAR/code/model/__init__.py +190 -0
  48. CAR/code/model/__pycache__/__init__.cpython-36.pyc +0 -0
  49. CAR/code/model/__pycache__/__init__.cpython-37.pyc +0 -0
  50. CAR/code/model/__pycache__/attention.cpython-36.pyc +0 -0
CAR/README.md ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Pyramid Attention for Image Restoration
2
+ This repository is for PANet and PA-EDSR introduced in the following paper
3
+
4
+ [Yiqun Mei](http://yiqunm2.web.illinois.edu/), [Yuchen Fan](https://scholar.google.com/citations?user=BlfdYL0AAAAJ&hl=en), [Yulun Zhang](http://yulunzhang.com/), [Jiahui Yu](https://jiahuiyu.com/), [Yuqian Zhou](https://yzhouas.github.io/), [Ding Liu](https://scholar.google.com/citations?user=PGtHUI0AAAAJ&hl=en), [Yun Fu](http://www1.ece.neu.edu/~yunfu/), [Thomas S. Huang](http://ifp-uiuc.github.io/) and [Honghui Shi](https://www.humphreyshi.com/) "Pyramid Attention for Image Restoration", [[Arxiv]](https://arxiv.org/abs/2004.13824)
5
+
6
+ The code is built on [EDSR (PyTorch)](https://github.com/thstkdgus35/EDSR-PyTorch) & [RNAN](https://github.com/yulunzhang/RNAN) and tested on Ubuntu 18.04 environment (Python3.6, PyTorch_1.1) with Titan X/1080Ti/V100 GPUs.
7
+
8
+ ## Contents
9
+ 1. [Train](#train)
10
+ 2. [Test](#test)
11
+ 3. [Results](#results)
12
+ 4. [Citation](#citation)
13
+ 5. [Acknowledgements](#acknowledgements)
14
+
15
+ ## Train
16
+ ### Prepare training data
17
+
18
+ 1. Download DIV2K training data (800 training + 100 validtion images) from [DIV2K dataset](https://data.vision.ee.ethz.ch/cvl/DIV2K/) or [SNU_CVLab](https://cv.snu.ac.kr/research/EDSR/DIV2K.tar).
19
+
20
+ 2. Specify '--dir_data' based on the HR and LR images path.
21
+
22
+ 3. Organize training data like:
23
+ ```bash
24
+ DIV2K/
25
+ ├── DIV2K_train_HR
26
+ ├── DIV2K_train_LR_bicubic
27
+ │ └── X10
28
+ │ └── X20
29
+ │ └── X30
30
+ │ └── X40
31
+ ├── DIV2K_valid_HR
32
+ └── DIV2K_valid_LR_bicubic
33
+ └── X10
34
+ └── X20
35
+ └── X30
36
+ └── X40
37
+ ```
38
+ For more informaiton, please refer to [EDSR(PyTorch)](https://github.com/thstkdgus35/EDSR-PyTorch).
39
+
40
+ ### Begin to train
41
+
42
+ 1. (optional) All the pretrained models and visual results can be downloaded from [Google Drive](https://drive.google.com/open?id=1q9iUzqYX0fVRzDu4J6fvSPRosgOZoJJE).
43
+
44
+ 2. Cd to 'PANet-PyTorch/[Task]/code', run the following scripts to train models.
45
+
46
+ **You can use scripts in file 'demo.sb' to train and test models for our paper.**
47
+
48
+ ```bash
49
+ # Example Usage: Q=10
50
+ python main.py --n_GPUs 2 --batch_size 16 --lr 1e-4 --decay 200-400-600-800 --save_models --n_resblocks 80 --model PANET --scale 10 --patch_size 48 --save PANET_Q10 --n_feats 64 --data_train DIV2K --chop
51
+ ```
52
+ ## Test
53
+ ### Quick start
54
+
55
+ 1. Cd to 'PANet-PyTorch/[Task]/code', run the following scripts.
56
+
57
+ **You can use scripts in file 'demo.sb' to produce results for our paper.**
58
+
59
+ ```bash
60
+ # No self-ensemble, use different testsets (Classic5, LIVE1) to reproduce the results in the paper.
61
+ # Example Usage: Q=40
62
+ python main.py --model PANET --save_results --n_GPUs 1 --chop --data_test classic5+LIVE1 --scale 40 --n_resblocks 80 --n_feats 64 --pre_train ../Q40.pt --test_only
63
+
64
+ ```
65
+
66
+ ### The whole test pipeline
67
+ 1. Prepare test data. Organize training data like:
68
+ ```bash
69
+ benchmark/
70
+ ├── testset1
71
+ │ └── HR
72
+ │ └── LR_bicubic
73
+ │ └── X10
74
+ │ └── ..
75
+ ├── testset2
76
+ ```
77
+
78
+ 2. Conduct image CAR.
79
+
80
+ See **Quick start**
81
+ 3. Evaluate the results.
82
+
83
+ Run 'Evaluate_PSNR_SSIM.m' to obtain PSNR/SSIM values for paper.
84
+
85
+ ## Citation
86
+ If you find the code helpful in your resarch or work, please cite the following papers.
87
+ ```
88
+ @article{mei2020pyramid,
89
+ title={Pyramid Attention Networks for Image Restoration},
90
+ author={Mei, Yiqun and Fan, Yuchen and Zhang, Yulun and Yu, Jiahui and Zhou, Yuqian and Liu, Ding and Fu, Yun and Huang, Thomas S and Shi, Honghui},
91
+ journal={arXiv preprint arXiv:2004.13824},
92
+ year={2020}
93
+ }
94
+ @InProceedings{Lim_2017_CVPR_Workshops,
95
+ author = {Lim, Bee and Son, Sanghyun and Kim, Heewon and Nah, Seungjun and Lee, Kyoung Mu},
96
+ title = {Enhanced Deep Residual Networks for Single Image Super-Resolution},
97
+ booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) Workshops},
98
+ month = {July},
99
+ year = {2017}
100
+ }
101
+ ```
102
+ ## Acknowledgements
103
+ This code is built on [EDSR (PyTorch)](https://github.com/thstkdgus35/EDSR-PyTorch), [RNAN](https://github.com/yulunzhang/RNAN) and [generative-inpainting-pytorch](https://github.com/daa233/generative-inpainting-pytorch). We thank the authors for sharing their codes.
CAR/code/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2018 Sanghyun Son
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
CAR/code/__init__.py ADDED
File without changes
CAR/code/__pycache__/option.cpython-36.pyc ADDED
Binary file (4.63 kB). View file
 
CAR/code/__pycache__/option.cpython-37.pyc ADDED
Binary file (4.85 kB). View file
 
CAR/code/__pycache__/template.cpython-36.pyc ADDED
Binary file (1.01 kB). View file
 
CAR/code/__pycache__/template.cpython-37.pyc ADDED
Binary file (994 Bytes). View file
 
CAR/code/__pycache__/trainer.cpython-36.pyc ADDED
Binary file (3.99 kB). View file
 
CAR/code/__pycache__/trainer.cpython-37.pyc ADDED
Binary file (5.07 kB). View file
 
CAR/code/__pycache__/utility.cpython-36.pyc ADDED
Binary file (9.04 kB). View file
 
CAR/code/__pycache__/utility.cpython-37.pyc ADDED
Binary file (9.1 kB). View file
 
CAR/code/data/__init__.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from importlib import import_module
2
+ #from dataloader import MSDataLoader
3
+ from torch.utils.data import dataloader
4
+ from torch.utils.data import ConcatDataset
5
+
6
+ # This is a simple wrapper function for ConcatDataset
7
+ class MyConcatDataset(ConcatDataset):
8
+ def __init__(self, datasets):
9
+ super(MyConcatDataset, self).__init__(datasets)
10
+ self.train = datasets[0].train
11
+
12
+ def set_scale(self, idx_scale):
13
+ for d in self.datasets:
14
+ if hasattr(d, 'set_scale'): d.set_scale(idx_scale)
15
+
16
+ class Data:
17
+ def __init__(self, args):
18
+ self.loader_train = None
19
+ if not args.test_only:
20
+ datasets = []
21
+ for d in args.data_train:
22
+ module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG'
23
+ m = import_module('data.' + module_name.lower())
24
+ datasets.append(getattr(m, module_name)(args, name=d))
25
+
26
+ self.loader_train = dataloader.DataLoader(
27
+ MyConcatDataset(datasets),
28
+ batch_size=args.batch_size,
29
+ shuffle=True,
30
+ pin_memory=not args.cpu,
31
+ num_workers=args.n_threads,
32
+ )
33
+
34
+ self.loader_test = []
35
+ for d in args.data_test:
36
+ if d in ['CBSD68','classic5','LIVE1','Kodak24','Set5', 'Set14', 'B100', 'Urban100']:
37
+ m = import_module('data.benchmark')
38
+ testset = getattr(m, 'Benchmark')(args, train=False, name=d)
39
+ else:
40
+ module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG'
41
+ m = import_module('data.' + module_name.lower())
42
+ testset = getattr(m, module_name)(args, train=False, name=d)
43
+
44
+ self.loader_test.append(
45
+ dataloader.DataLoader(
46
+ testset,
47
+ batch_size=1,
48
+ shuffle=False,
49
+ pin_memory=not args.cpu,
50
+ num_workers=args.n_threads,
51
+ )
52
+ )
CAR/code/data/__pycache__/__init__.cpython-36.pyc ADDED
Binary file (1.84 kB). View file
 
CAR/code/data/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (1.79 kB). View file
 
CAR/code/data/__pycache__/benchmark.cpython-36.pyc ADDED
Binary file (1.08 kB). View file
 
CAR/code/data/__pycache__/common.cpython-36.pyc ADDED
Binary file (2.77 kB). View file
 
CAR/code/data/__pycache__/common.cpython-37.pyc ADDED
Binary file (2.71 kB). View file
 
CAR/code/data/__pycache__/div2k.cpython-36.pyc ADDED
Binary file (1.75 kB). View file
 
CAR/code/data/__pycache__/div2k.cpython-37.pyc ADDED
Binary file (1.75 kB). View file
 
CAR/code/data/__pycache__/srdata.cpython-36.pyc ADDED
Binary file (4.83 kB). View file
 
CAR/code/data/__pycache__/srdata.cpython-37.pyc ADDED
Binary file (4.91 kB). View file
 
CAR/code/data/benchmark.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from data import common
4
+ from data import srdata
5
+
6
+ import numpy as np
7
+
8
+ import torch
9
+ import torch.utils.data as data
10
+
11
+ class Benchmark(srdata.SRData):
12
+ def __init__(self, args, name='', train=True, benchmark=True):
13
+ super(Benchmark, self).__init__(
14
+ args, name=name, train=train, benchmark=True
15
+ )
16
+
17
+ def _set_filesystem(self, dir_data):
18
+ self.apath = os.path.join(dir_data, 'benchmark', self.name)
19
+ self.dir_hr = os.path.join(self.apath, 'HR')
20
+ if self.input_large:
21
+ self.dir_lr = os.path.join(self.apath, 'LR_bicubicL')
22
+ else:
23
+ self.dir_lr = os.path.join(self.apath, 'LR_bicubic')
24
+ self.ext = ('','.jpg')
25
+
CAR/code/data/common.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ import numpy as np
4
+ import skimage.color as sc
5
+
6
+ import torch
7
+
8
+ def get_patch(*args, patch_size=96, scale=1, multi=False, input_large=False):
9
+ ih, iw = args[0].shape[:2]
10
+ # print('heelo')
11
+ # print(args[0].shape)
12
+
13
+ if not input_large:
14
+ p = 1 if multi else 1
15
+ tp = p * patch_size
16
+ ip = tp // 1
17
+ else:
18
+ tp = patch_size
19
+ ip = patch_size
20
+
21
+ ix = random.randrange(0, iw - ip + 1)
22
+ iy = random.randrange(0, ih - ip + 1)
23
+
24
+ if not input_large:
25
+ tx, ty = 1 * ix, 1 * iy
26
+ else:
27
+ tx, ty = ix, iy
28
+
29
+ ret = [
30
+ args[0][iy:iy + ip, ix:ix + ip],
31
+ *[a[ty:ty + tp, tx:tx + tp] for a in args[1:]]
32
+ ]
33
+
34
+ return ret
35
+
36
+ def set_channel(*args, n_channels=3):
37
+ def _set_channel(img):
38
+ if img.ndim == 2:
39
+ img = np.expand_dims(img, axis=2)
40
+
41
+ c = img.shape[2]
42
+ if n_channels == 1 and c == 3:
43
+ img = np.expand_dims(sc.rgb2ycbcr(img)[:, :, 0], 2)
44
+ elif n_channels == 3 and c == 1:
45
+ img = np.concatenate([img] * n_channels, 2)
46
+
47
+ return img
48
+
49
+ return [_set_channel(a) for a in args]
50
+
51
+ def np2Tensor(*args, rgb_range=255):
52
+ def _np2Tensor(img):
53
+ np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1)))
54
+ tensor = torch.from_numpy(np_transpose).float()
55
+ tensor.mul_(rgb_range / 255)
56
+
57
+ return tensor
58
+
59
+ return [_np2Tensor(a) for a in args]
60
+
61
+ def augment(*args, hflip=True, rot=True):
62
+ hflip = hflip and random.random() < 0.5
63
+ vflip = rot and random.random() < 0.5
64
+ rot90 = rot and random.random() < 0.5
65
+ # hflip=True
66
+ # vflip=True
67
+ # rot90=True
68
+
69
+ def _augment(img):
70
+ if hflip: img = img[:, ::-1]
71
+ if vflip: img = img[::-1, :]
72
+ if rot90:
73
+ # print(img.shape)
74
+ img = img.transpose(1, 0)
75
+
76
+ return img
77
+
78
+ return [_augment(a) for a in args]
79
+
CAR/code/data/demo.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from data import common
4
+
5
+ import numpy as np
6
+ import imageio
7
+
8
+ import torch
9
+ import torch.utils.data as data
10
+
11
+ class Demo(data.Dataset):
12
+ def __init__(self, args, name='Demo', train=False, benchmark=False):
13
+ self.args = args
14
+ self.name = name
15
+ self.scale = args.scale
16
+ self.idx_scale = 0
17
+ self.train = False
18
+ self.benchmark = benchmark
19
+
20
+ self.filelist = []
21
+ for f in os.listdir(args.dir_demo):
22
+ if f.find('.png') >= 0 or f.find('.jp') >= 0:
23
+ self.filelist.append(os.path.join(args.dir_demo, f))
24
+ self.filelist.sort()
25
+
26
+ def __getitem__(self, idx):
27
+ filename = os.path.splitext(os.path.basename(self.filelist[idx]))[0]
28
+ lr = imageio.imread(self.filelist[idx])
29
+ lr, = common.set_channel(lr, n_channels=self.args.n_colors)
30
+ lr_t, = common.np2Tensor(lr, rgb_range=self.args.rgb_range)
31
+
32
+ return lr_t, -1, filename
33
+
34
+ def __len__(self):
35
+ return len(self.filelist)
36
+
37
+ def set_scale(self, idx_scale):
38
+ self.idx_scale = idx_scale
39
+
CAR/code/data/div2k.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from data import srdata
3
+
4
+ class DIV2K(srdata.SRData):
5
+ def __init__(self, args, name='DIV2K', train=True, benchmark=False):
6
+ data_range = [r.split('-') for r in args.data_range.split('/')]
7
+ if train:
8
+ data_range = data_range[0]
9
+ else:
10
+ if args.test_only and len(data_range) == 1:
11
+ data_range = data_range[0]
12
+ else:
13
+ data_range = data_range[1]
14
+
15
+ self.begin, self.end = list(map(lambda x: int(x), data_range))
16
+ super(DIV2K, self).__init__(
17
+ args, name=name, train=train, benchmark=benchmark
18
+ )
19
+
20
+ def _scan(self):
21
+ names_hr, names_lr = super(DIV2K, self)._scan()
22
+ names_hr = names_hr[self.begin - 1:self.end]
23
+ names_lr = [n[self.begin - 1:self.end] for n in names_lr]
24
+
25
+ return names_hr, names_lr
26
+
27
+ def _set_filesystem(self, dir_data):
28
+ super(DIV2K, self)._set_filesystem(dir_data)
29
+ self.dir_hr = os.path.join(self.apath, 'DIV2K_train_HR')
30
+ self.dir_lr = os.path.join(self.apath, 'DIV2K_train_LR_bicubic')
31
+ if self.input_large: self.dir_lr += 'L'
32
+
CAR/code/data/div2kjpeg.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from data import srdata
3
+ from data import div2k
4
+
5
+ class DIV2KJPEG(div2k.DIV2K):
6
+ def __init__(self, args, name='', train=True, benchmark=False):
7
+ self.q_factor = int(name.replace('DIV2K-Q', ''))
8
+ super(DIV2KJPEG, self).__init__(
9
+ args, name=name, train=train, benchmark=benchmark
10
+ )
11
+
12
+ def _set_filesystem(self, dir_data):
13
+ self.apath = os.path.join(dir_data, 'DIV2K')
14
+ self.dir_hr = os.path.join(self.apath, 'DIV2K_train_HR')
15
+ self.dir_lr = os.path.join(
16
+ self.apath, 'DIV2K_Q{}'.format(self.q_factor)
17
+ )
18
+ if self.input_large: self.dir_lr += 'L'
19
+ self.ext = ('.png', '.jpg')
20
+
CAR/code/data/sr291.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from data import srdata
2
+
3
+ class SR291(srdata.SRData):
4
+ def __init__(self, args, name='SR291', train=True, benchmark=False):
5
+ super(SR291, self).__init__(args, name=name)
6
+
CAR/code/data/srdata.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import random
4
+ import pickle
5
+
6
+ from data import common
7
+
8
+ import numpy as np
9
+ import imageio
10
+ import torch
11
+ import torch.utils.data as data
12
+
13
+ class SRData(data.Dataset):
14
+ def __init__(self, args, name='', train=True, benchmark=False):
15
+ self.args = args
16
+ self.name = name
17
+ self.train = train
18
+ self.split = 'train' if train else 'test'
19
+ self.do_eval = True
20
+ self.benchmark = benchmark
21
+ self.input_large = (args.model == 'VDSR')
22
+ self.scale = args.scale
23
+ self.idx_scale = 0
24
+
25
+ self._set_filesystem(args.dir_data)
26
+ if args.ext.find('img') < 0:
27
+ path_bin = os.path.join(self.apath, 'bin')
28
+ os.makedirs(path_bin, exist_ok=True)
29
+
30
+ list_hr, list_lr = self._scan()
31
+ if args.ext.find('img') >= 0 or benchmark:
32
+ self.images_hr, self.images_lr = list_hr, list_lr
33
+ elif args.ext.find('sep') >= 0:
34
+ os.makedirs(
35
+ self.dir_hr.replace(self.apath, path_bin),
36
+ exist_ok=True
37
+ )
38
+ for s in self.scale:
39
+ os.makedirs(
40
+ os.path.join(
41
+ self.dir_lr.replace(self.apath, path_bin),
42
+ 'X{}'.format(s)
43
+ ),
44
+ exist_ok=True
45
+ )
46
+
47
+ self.images_hr, self.images_lr = [], [[] for _ in self.scale]
48
+ for h in list_hr:
49
+ b = h.replace(self.apath, path_bin)
50
+ b = b.replace(self.ext[0], '.pt')
51
+ self.images_hr.append(b)
52
+ self._check_and_load(args.ext, h, b, verbose=True)
53
+ for i, ll in enumerate(list_lr):
54
+ for l in ll:
55
+ b = l.replace(self.apath, path_bin)
56
+ b = b.replace(self.ext[1], '.pt')
57
+ self.images_lr[i].append(b)
58
+ self._check_and_load(args.ext, l, b, verbose=True)
59
+ if train:
60
+ n_patches = args.batch_size * args.test_every
61
+ n_images = len(args.data_train) * len(self.images_hr)
62
+ if n_images == 0:
63
+ self.repeat = 0
64
+ else:
65
+ self.repeat = max(n_patches // n_images, 1)
66
+
67
+ # Below functions as used to prepare images
68
+ def _scan(self):
69
+ names_hr = sorted(
70
+ glob.glob(os.path.join(self.dir_hr, '*' + self.ext[0]))
71
+ )
72
+ names_lr = [[] for _ in self.scale]
73
+ for f in names_hr:
74
+ filename, _ = os.path.splitext(os.path.basename(f))
75
+ for si, s in enumerate(self.scale):
76
+ names_lr[si].append(os.path.join(
77
+ self.dir_lr, 'X{}/{}{}'.format(
78
+ s, filename, self.ext[1]
79
+ )
80
+ ))
81
+
82
+ return names_hr, names_lr
83
+
84
+ def _set_filesystem(self, dir_data):
85
+ self.apath = os.path.join(dir_data, self.name)
86
+ self.dir_hr = os.path.join(self.apath, 'HR')
87
+ self.dir_lr = os.path.join(self.apath, 'LR_bicubic')
88
+ if self.input_large: self.dir_lr += 'L'
89
+ self.ext = ('.png', '.jpg')
90
+
91
+ def _check_and_load(self, ext, img, f, verbose=True):
92
+ if not os.path.isfile(f) or ext.find('reset') >= 0:
93
+ if verbose:
94
+ print('Making a binary: {}'.format(f))
95
+ with open(f, 'wb') as _f:
96
+ pickle.dump(imageio.imread(img), _f)
97
+
98
+ def __getitem__(self, idx):
99
+ lr, hr, filename = self._load_file(idx)
100
+ pair = self.get_patch(lr, hr)
101
+ pair = common.set_channel(*pair, n_channels=self.args.n_colors)
102
+ pair_t = common.np2Tensor(*pair, rgb_range=self.args.rgb_range)
103
+
104
+ return pair_t[0], pair_t[1], filename
105
+
106
+ def __len__(self):
107
+ if self.train:
108
+ return len(self.images_hr) * self.repeat
109
+ else:
110
+ return len(self.images_hr)
111
+
112
+ def _get_index(self, idx):
113
+ if self.train:
114
+ return idx % len(self.images_hr)
115
+ else:
116
+ return idx
117
+
118
+ def _load_file(self, idx):
119
+ idx = self._get_index(idx)
120
+ f_hr = self.images_hr[idx]
121
+ f_lr = self.images_lr[self.idx_scale][idx]
122
+
123
+ filename, _ = os.path.splitext(os.path.basename(f_hr))
124
+ if self.args.ext == 'img' or self.benchmark:
125
+ hr = imageio.imread(f_hr)
126
+ lr = imageio.imread(f_lr)
127
+ elif self.args.ext.find('sep') >= 0:
128
+ with open(f_hr, 'rb') as _f:
129
+ hr = pickle.load(_f)
130
+ with open(f_lr, 'rb') as _f:
131
+ lr = pickle.load(_f)
132
+
133
+ return lr, hr, filename
134
+
135
+ def get_patch(self, lr, hr):
136
+ scale = self.scale[self.idx_scale]
137
+ if self.train:
138
+ lr, hr = common.get_patch(
139
+ lr, hr,
140
+ patch_size=self.args.patch_size,
141
+ scale=scale,
142
+ multi=(len(self.scale) > 1),
143
+ input_large=self.input_large
144
+ )
145
+ if not self.args.no_augment: lr, hr = common.augment(lr, hr)
146
+ else:
147
+ if self.args.model=="RECURSIONNET":
148
+ lr, hr = common.get_patch(
149
+ lr, hr,
150
+ patch_size=self.args.patch_size,
151
+ scale=scale,
152
+ multi=(len(self.scale) > 1),
153
+ input_large=self.input_large
154
+ )
155
+ else:
156
+ ih, iw = lr.shape[:2]
157
+ hr = hr[0:ih * scale, 0:iw * scale]
158
+
159
+ return lr, hr
160
+
161
+ def set_scale(self, idx_scale):
162
+ if not self.input_large:
163
+ self.idx_scale = idx_scale
164
+ else:
165
+ self.idx_scale = random.randint(0, len(self.scale) - 1)
166
+
CAR/code/data/video.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from data import common
4
+
5
+ import cv2
6
+ import numpy as np
7
+ import imageio
8
+
9
+ import torch
10
+ import torch.utils.data as data
11
+
12
+ class Video(data.Dataset):
13
+ def __init__(self, args, name='Video', train=False, benchmark=False):
14
+ self.args = args
15
+ self.name = name
16
+ self.scale = args.scale
17
+ self.idx_scale = 0
18
+ self.train = False
19
+ self.do_eval = False
20
+ self.benchmark = benchmark
21
+
22
+ self.filename, _ = os.path.splitext(os.path.basename(args.dir_demo))
23
+ self.vidcap = cv2.VideoCapture(args.dir_demo)
24
+ self.n_frames = 0
25
+ self.total_frames = int(self.vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
26
+
27
+ def __getitem__(self, idx):
28
+ success, lr = self.vidcap.read()
29
+ if success:
30
+ self.n_frames += 1
31
+ lr, = common.set_channel(lr, n_channels=self.args.n_colors)
32
+ lr_t, = common.np2Tensor(lr, rgb_range=self.args.rgb_range)
33
+
34
+ return lr_t, -1, '{}_{:0>5}'.format(self.filename, self.n_frames)
35
+ else:
36
+ vidcap.release()
37
+ return None
38
+
39
+ def __len__(self):
40
+ return self.total_frames
41
+
42
+ def set_scale(self, idx_scale):
43
+ self.idx_scale = idx_scale
44
+
CAR/code/dataloader.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import threading
2
+ import random
3
+
4
+ import torch
5
+ import torch.multiprocessing as multiprocessing
6
+ from torch.utils.data import DataLoader
7
+ from torch.utils.data import SequentialSampler
8
+ from torch.utils.data import RandomSampler
9
+ from torch.utils.data import BatchSampler
10
+ from torch.utils.data import _utils
11
+ from torch.utils.data.dataloader import _DataLoaderIter
12
+
13
+ from torch.utils.data._utils import collate
14
+ from torch.utils.data._utils import signal_handling
15
+ from torch.utils.data._utils import MP_STATUS_CHECK_INTERVAL
16
+ from torch.utils.data._utils import ExceptionWrapper
17
+ from torch.utils.data._utils import IS_WINDOWS
18
+ from torch.utils.data._utils.worker import ManagerWatchdog
19
+
20
+ from torch._six import queue
21
+
22
+ def _ms_loop(dataset, index_queue, data_queue, done_event, collate_fn, scale, seed, init_fn, worker_id):
23
+ try:
24
+ collate._use_shared_memory = True
25
+ signal_handling._set_worker_signal_handlers()
26
+
27
+ torch.set_num_threads(1)
28
+ random.seed(seed)
29
+ torch.manual_seed(seed)
30
+
31
+ data_queue.cancel_join_thread()
32
+
33
+ if init_fn is not None:
34
+ init_fn(worker_id)
35
+
36
+ watchdog = ManagerWatchdog()
37
+
38
+ while watchdog.is_alive():
39
+ try:
40
+ r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
41
+ except queue.Empty:
42
+ continue
43
+
44
+ if r is None:
45
+ assert done_event.is_set()
46
+ return
47
+ elif done_event.is_set():
48
+ continue
49
+
50
+ idx, batch_indices = r
51
+ try:
52
+ idx_scale = 0
53
+ if len(scale) > 1 and dataset.train:
54
+ idx_scale = random.randrange(0, len(scale))
55
+ dataset.set_scale(idx_scale)
56
+
57
+ samples = collate_fn([dataset[i] for i in batch_indices])
58
+ samples.append(idx_scale)
59
+ except Exception:
60
+ data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
61
+ else:
62
+ data_queue.put((idx, samples))
63
+ del samples
64
+
65
+ except KeyboardInterrupt:
66
+ pass
67
+
68
+ class _MSDataLoaderIter(_DataLoaderIter):
69
+
70
+ def __init__(self, loader):
71
+ self.dataset = loader.dataset
72
+ self.scale = loader.scale
73
+ self.collate_fn = loader.collate_fn
74
+ self.batch_sampler = loader.batch_sampler
75
+ self.num_workers = loader.num_workers
76
+ self.pin_memory = loader.pin_memory and torch.cuda.is_available()
77
+ self.timeout = loader.timeout
78
+
79
+ self.sample_iter = iter(self.batch_sampler)
80
+
81
+ base_seed = torch.LongTensor(1).random_().item()
82
+
83
+ if self.num_workers > 0:
84
+ self.worker_init_fn = loader.worker_init_fn
85
+ self.worker_queue_idx = 0
86
+ self.worker_result_queue = multiprocessing.Queue()
87
+ self.batches_outstanding = 0
88
+ self.worker_pids_set = False
89
+ self.shutdown = False
90
+ self.send_idx = 0
91
+ self.rcvd_idx = 0
92
+ self.reorder_dict = {}
93
+ self.done_event = multiprocessing.Event()
94
+
95
+ base_seed = torch.LongTensor(1).random_()[0]
96
+
97
+ self.index_queues = []
98
+ self.workers = []
99
+ for i in range(self.num_workers):
100
+ index_queue = multiprocessing.Queue()
101
+ index_queue.cancel_join_thread()
102
+ w = multiprocessing.Process(
103
+ target=_ms_loop,
104
+ args=(
105
+ self.dataset,
106
+ index_queue,
107
+ self.worker_result_queue,
108
+ self.done_event,
109
+ self.collate_fn,
110
+ self.scale,
111
+ base_seed + i,
112
+ self.worker_init_fn,
113
+ i
114
+ )
115
+ )
116
+ w.daemon = True
117
+ w.start()
118
+ self.index_queues.append(index_queue)
119
+ self.workers.append(w)
120
+
121
+ if self.pin_memory:
122
+ self.data_queue = queue.Queue()
123
+ pin_memory_thread = threading.Thread(
124
+ target=_utils.pin_memory._pin_memory_loop,
125
+ args=(
126
+ self.worker_result_queue,
127
+ self.data_queue,
128
+ torch.cuda.current_device(),
129
+ self.done_event
130
+ )
131
+ )
132
+ pin_memory_thread.daemon = True
133
+ pin_memory_thread.start()
134
+ self.pin_memory_thread = pin_memory_thread
135
+ else:
136
+ self.data_queue = self.worker_result_queue
137
+
138
+ _utils.signal_handling._set_worker_pids(
139
+ id(self), tuple(w.pid for w in self.workers)
140
+ )
141
+ _utils.signal_handling._set_SIGCHLD_handler()
142
+ self.worker_pids_set = True
143
+
144
+ for _ in range(2 * self.num_workers):
145
+ self._put_indices()
146
+
147
+
148
+ class MSDataLoader(DataLoader):
149
+
150
+ def __init__(self, cfg, *args, **kwargs):
151
+ super(MSDataLoader, self).__init__(
152
+ *args, **kwargs, num_workers=cfg.n_threads
153
+ )
154
+ self.scale = cfg.scale
155
+
156
+ def __iter__(self):
157
+ return _MSDataLoaderIter(self)
158
+
CAR/code/demo.sb ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # PANet Train
3
+ #python main.py --n_GPUs 2 --batch_size 16 --lr 1e-4 --decay 200-400-600-800 ---save_models --model PANET --scale 10 --patch_size 48 --save PANET_Q10 --n_feats 64 --data_train DIV2K --chop
4
+
5
+ # Test
6
+ python main.py --model PANET --save_results --n_GPUs 1 --chop --data_test classic5+LIVE1 --scale 40 --n_resblocks 80 --n_feats 64 --pre_train ../Q40.pt --test_only
CAR/code/demo.sh ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # EDSR baseline model (x2) + JPEG augmentation
2
+ python main.py --n_GPUs 2 --batch_size 16 --reset --save_models --model EDSRL1T --scale 1 --patch_size 48 --save EDSR_R16F64P48_CSAL1TK20 --n_feats 64 --data_train DIV2K --chop
3
+ #python main.py --model EDSR --scale 2 --patch_size 96 --save edsr_baseline_x2 --reset --data_train DIV2K+DIV2K-Q75 --data_test DIV2K+DIV2K-Q75
4
+
5
+ # EDSR baseline model (x3) - from EDSR baseline model (x2)
6
+ #python main.py --model EDSR --scale 3 --patch_size 144 --save edsr_baseline_x3 --reset --pre_train [pre-trained EDSR_baseline_x2 model dir]
7
+
8
+ # EDSR baseline model (x4) - from EDSR baseline model (x2)
9
+ #python main.py --model EDSR --scale 4 --save edsr_baseline_x4 --reset --pre_train pre-trained EDSR_baseline_x2 model dir
10
+
11
+ # EDSR in the paper (x2)
12
+ #python main.py --model EDSR --scale 2 --save edsr_x2 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --reset
13
+
14
+ # EDSR in the paper (x3) - from EDSR (x2)
15
+ #python main.py --model EDSR --scale 3 --save edsr_x3 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --reset --pre_train [pre-trained EDSR model dir]
16
+
17
+ # EDSR in the paper (x4) - from EDSR (x2)
18
+ #python main.py --model EDSR --scale 4 --save edsr_x4 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --reset --pre_train [pre-trained EDSR_x2 model dir]
19
+
20
+ # MDSR baseline model
21
+ #python main.py --template MDSR --model MDSR --scale 2+3+4 --save MDSR_baseline --reset --save_models
22
+
23
+ # MDSR in the paper
24
+ #python main.py --template MDSR --model MDSR --scale 2+3+4 --n_resblocks 80 --save MDSR --reset --save_models
25
+
26
+ # Standard benchmarks (Ex. EDSR_baseline_x4)
27
+ #python main.py --model EDSRL1 --save_results --n_GPUs 2 --chop --data_test Kodak24+CBSD68+Urban100 --scale 1 --pre_train ../experiment/EDSR_R16F64P48_CSALK20/model/model_best.pt --test_only
28
+
29
+ #python main.py --data_test Set5+Set14+B100+Urban100+DIV2K --data_range 801-900 --scale 4 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --pre_train download --test_only --self_ensemble
30
+
31
+ # Test your own images
32
+ #python main.py --data_test Demo --scale 4 --pre_train download --test_only --save_results
33
+
34
+ # Advanced - Test with JPEG images
35
+ #python main.py --model MDSR --data_test Demo --scale 2+3+4 --pre_train download --test_only --save_results
36
+
37
+ # Advanced - Training with adversarial loss
38
+ #python main.py --template GAN --scale 4 --save edsr_gan --reset --patch_size 96 --loss 5*VGG54+0.15*GAN --pre_train download
39
+
40
+ # RDN BI model (x2)
41
+ #python3.6 main.py --scale 2 --save RDN_D16C8G64_BIx2 --model RDN --epochs 200 --batch_size 16 --data_range 801-805 --patch_size 64 --reset
42
+ # RDN BI model (x3)
43
+ #python3.6 main.py --scale 3 --save RDN_D16C8G64_BIx3 --model RDN --epochs 200 --batch_size 16 --data_range 801-805 --patch_size 96 --reset
44
+ # RDN BI model (x4)
45
+ #python3.6 main.py --scale 4 --save RDN_D16C8G64_BIx4 --model RDN --epochs 200 --batch_size 16 --data_range 801-805 --patch_size 128 --reset
46
+
47
+ # RCAN_BIX2_G10R20P48, input=48x48, output=96x96
48
+ # pretrained model can be downloaded from https://www.dropbox.com/s/mjbcqkd4nwhr6nu/models_ECCV2018RCAN.zip?dl=0
49
+ #python main.py --template RCAN --save RCAN_BIX2_G10R20P48 --scale 2 --reset --save_results --patch_size 96
50
+ # RCAN_BIX3_G10R20P48, input=48x48, output=144x144
51
+ #python main.py --template RCAN --save RCAN_BIX3_G10R20P48 --scale 3 --reset --save_results --patch_size 144 --pre_train ../experiment/model/RCAN_BIX2.pt
52
+ # RCAN_BIX4_G10R20P48, input=48x48, output=192x192
53
+ #python main.py --template RCAN --save RCAN_BIX4_G10R20P48 --scale 4 --reset --save_results --patch_size 192 --pre_train ../experiment/model/RCAN_BIX2.pt
54
+ # RCAN_BIX8_G10R20P48, input=48x48, output=384x384
55
+ #python main.py --template RCAN --save RCAN_BIX8_G10R20P48 --scale 8 --reset --save_results --patch_size 384 --pre_train ../experiment/model/RCAN_BIX2.pt
56
+
CAR/code/lambda_networks/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from lambda_networks.lambda_networks import LambdaLayer
2
+ from lambda_networks.lambda_networks import Recursion
3
+ from lambda_networks.rlambda_networks import RLambdaLayer
4
+ λLayer = LambdaLayer
CAR/code/lambda_networks/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (338 Bytes). View file
 
CAR/code/lambda_networks/__pycache__/lambda_networks.cpython-37.pyc ADDED
Binary file (5.75 kB). View file
 
CAR/code/lambda_networks/__pycache__/rlambda_networks.cpython-37.pyc ADDED
Binary file (2.92 kB). View file
 
CAR/code/lambda_networks/lambda_networks.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn, einsum
3
+ import torch.nn.functional as F
4
+ from einops import rearrange
5
+
6
+ # my layer normalization
7
+
8
+ class LayerNorm(nn.Module):
9
+ def __init__(self, eps= 1e-5):
10
+ super(LayerNorm, self).__init__()
11
+ self.eps = eps
12
+ def forward(self, input):
13
+ shape=tuple(input.size()[1:])
14
+ return F.layer_norm(input, shape, eps=self.eps)
15
+ def extra_repr(self):
16
+ return f'eps={self.eps}'
17
+
18
+ # helpers functions
19
+
20
+ def exists(val):
21
+ return val is not None
22
+
23
+ def default(val, d):
24
+ return val if exists(val) else d
25
+
26
+ # lambda layer
27
+
28
+ class LambdaLayer(nn.Module):
29
+ def __init__(
30
+ self,
31
+ dim,
32
+ *,
33
+ dim_k,
34
+ n = None,
35
+ r = None,
36
+ heads = 4,
37
+ dim_out = None,
38
+ dim_u = 1,
39
+ normalization="batch"):
40
+ super().__init__()
41
+ dim_out = default(dim_out, dim)
42
+ self.u = dim_u # intra-depth dimension
43
+ self.heads = heads
44
+
45
+ assert (dim_out % heads) == 0, 'values dimension must be divisible by number of heads for multi-head query'
46
+ dim_v = dim_out // heads
47
+
48
+ self.to_q = nn.Conv2d(dim, dim_k * heads, 1, bias = False)
49
+ self.to_k = nn.Conv2d(dim, dim_k * dim_u, 1, bias = False)
50
+ self.to_v = nn.Conv2d(dim, dim_v * dim_u, 1, bias = False)
51
+ print(f"using {normalization} in lambda layer")
52
+ if normalization=="none":
53
+ self.norm_q = nn.Identity(dim_k * heads)
54
+ self.norm_v = nn.Identity(dim_v * dim_u)
55
+ elif normalization=="instance":
56
+ self.norm_q = nn.InstanceNorm2d(dim_k * heads)
57
+ self.norm_v = nn.InstanceNorm2d(dim_v * dim_u)
58
+ elif normalization=="layer":
59
+ self.norm_q = LayerNorm()
60
+ self.norm_v = LayerNorm()
61
+ else:
62
+ self.norm_q = nn.BatchNorm2d(dim_k * heads)
63
+ self.norm_v = nn.BatchNorm2d(dim_v * dim_u)
64
+ print(f"using BN in lambda layer?")
65
+
66
+ self.local_contexts = exists(r)
67
+ if exists(r):
68
+ assert (r % 2) == 1, 'Receptive kernel size should be odd'
69
+ self.pos_conv = nn.Conv3d(dim_u, dim_k, (1, r, r), padding = (0, r // 2, r // 2))
70
+ else:
71
+ assert exists(n), 'You must specify the total sequence length (h x w)'
72
+ self.pos_emb = nn.Parameter(torch.randn(n, n, dim_k, dim_u))
73
+
74
+
75
+ def forward(self, x):
76
+ b, c, hh, ww, u, h = *x.shape, self.u, self.heads
77
+
78
+ q = self.to_q(x)
79
+ k = self.to_k(x)
80
+ v = self.to_v(x)
81
+
82
+ q = self.norm_q(q)
83
+ v = self.norm_v(v)
84
+
85
+ q = rearrange(q, 'b (h k) hh ww -> b h k (hh ww)', h = h)
86
+ k = rearrange(k, 'b (u k) hh ww -> b u k (hh ww)', u = u)
87
+ v = rearrange(v, 'b (u v) hh ww -> b u v (hh ww)', u = u)
88
+
89
+ k = k.softmax(dim=-1)
90
+
91
+ λc = einsum('b u k m, b u v m -> b k v', k, v)
92
+ Yc = einsum('b h k n, b k v -> b h v n', q, λc)
93
+
94
+ if self.local_contexts:
95
+ v = rearrange(v, 'b u v (hh ww) -> b u v hh ww', hh = hh, ww = ww)
96
+ λp = self.pos_conv(v)
97
+ Yp = einsum('b h k n, b k v n -> b h v n', q, λp.flatten(3))
98
+ else:
99
+ λp = einsum('n m k u, b u v m -> b n k v', self.pos_emb, v)
100
+ Yp = einsum('b h k n, b n k v -> b h v n', q, λp)
101
+
102
+ Y = Yc + Yp
103
+ out = rearrange(Y, 'b h v (hh ww) -> b (h v) hh ww', hh = hh, ww = ww)
104
+ return out
105
+
106
+
107
+ # i'm not sure whether this will work or not
108
+ class Recursion(nn.Module):
109
+ def __init__(self, N: int, hidden_dim:int=64):
110
+ super(Recursion,self).__init__()
111
+ self.N = N
112
+ self.lambdaNxN_identity = LambdaLayer(dim=hidden_dim, dim_out=hidden_dim, n=N * N, dim_k=16, heads=2, dim_u=1)
113
+ # merge upstream information here
114
+ self.lambdaNxN_merge = LambdaLayer(dim=2*hidden_dim, dim_out=hidden_dim, n=N * N, dim_k=16, heads=2, dim_u=1)
115
+ self.downscale_conv = nn.Conv2d(hidden_dim, hidden_dim, kernel_size=N, stride=N)
116
+ self.upscale_conv = nn.Conv2d(hidden_dim, hidden_dim * N * N, kernel_size=3,padding=1)
117
+ self.pixel_shuffle = nn.PixelShuffle(N)
118
+
119
+ def forward(self, x: torch.Tensor):
120
+ N = self.N
121
+
122
+ def to_patch(blocks:torch.Tensor)->torch.Tensor:
123
+ shape = blocks.shape
124
+ blocks_patch = F.unfold(blocks, kernel_size=N, stride=N)
125
+ blocks_patch = blocks_patch.view(shape[0], shape[1], N, N, -1)
126
+ num_patch = blocks_patch.shape[-1]
127
+ blocks_patch = blocks_patch.permute(0, 4, 1, 2, 3).reshape(-1, shape[1], N, N).contiguous()
128
+ return blocks_patch, num_patch
129
+
130
+ def combine_patch(processed_patch,shape,num_patch):
131
+ processed_patch = processed_patch.reshape(shape[0], num_patch, shape[1], N, N)
132
+ processed_patch=processed_patch.permute(0, 2, 3, 4, 1).reshape(shape[0],shape[1] * N * N,num_patch).contiguous()
133
+ processed=F.fold(processed_patch,output_size=(shape[-2],shape[-1]),kernel_size=N,stride=N)
134
+ return processed
135
+
136
+ def process(blocks:torch.Tensor)->torch.Tensor:
137
+ shape = blocks.shape
138
+ if blocks.shape[-1] == N:
139
+ processed = self.lambdaNxN_identity(blocks)
140
+ return processed
141
+ # to NxN patchs
142
+ blocks_patch,num_patch=to_patch(blocks)
143
+ # pass through identity
144
+ processed_patch = self.lambdaNxN_identity(blocks_patch)
145
+ # back to HxW
146
+ processed=combine_patch(processed_patch,shape,num_patch)
147
+ # get feedback
148
+ feedback = process(self.downscale_conv(processed))
149
+ # upscale feedback
150
+ upscale_feedback = self.upscale_conv(feedback)
151
+ upscale_feedback=self.pixel_shuffle(upscale_feedback)
152
+ # combine results
153
+ combined = torch.cat([processed, upscale_feedback], dim=1)
154
+ combined_shape=combined.shape
155
+ combined_patch,num_patch=to_patch(combined)
156
+ combined_patch_reduced = self.lambdaNxN_merge(combined_patch)
157
+ ret_shape=(combined_shape[0],combined_shape[1]//2,combined_shape[2],combined_shape[3])
158
+ ret=combine_patch(combined_patch_reduced,ret_shape,num_patch)
159
+ return ret
160
+
161
+ return process(x)
CAR/code/lambda_networks/rlambda_networks.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn, einsum
3
+ import torch.nn.functional as F
4
+ from einops import rearrange
5
+
6
+
7
+ # helpers functions
8
+
9
+ def exists(val):
10
+ return val is not None
11
+
12
+
13
+ def default(val, d):
14
+ return val if exists(val) else d
15
+
16
+
17
+ # lambda layer
18
+
19
+ class RLambdaLayer(nn.Module):
20
+ def __init__(
21
+ self,
22
+ dim,
23
+ *,
24
+ dim_k,
25
+ n=None,
26
+ r=None,
27
+ heads=4,
28
+ dim_out=None,
29
+ dim_u=1,
30
+ recurrence=None
31
+ ):
32
+ super().__init__()
33
+ dim_out = default(dim_out, dim)
34
+ self.u = dim_u # intra-depth dimension
35
+ self.heads = heads
36
+
37
+ assert (dim_out % heads) == 0, 'values dimension must be divisible by number of heads for multi-head query'
38
+ dim_v = dim_out // heads
39
+
40
+ self.to_q = nn.Conv2d(dim, dim_k * heads, 1, bias=False)
41
+ self.to_k = nn.Conv2d(dim, dim_k * dim_u, 1, bias=False)
42
+ self.to_v = nn.Conv2d(dim, dim_v * dim_u, 1, bias=False)
43
+
44
+ self.norm_q = nn.BatchNorm2d(dim_k * heads)
45
+ self.norm_v = nn.BatchNorm2d(dim_v * dim_u)
46
+
47
+ self.local_contexts = exists(r)
48
+ self.recurrence = recurrence
49
+ if exists(r):
50
+ assert (r % 2) == 1, 'Receptive kernel size should be odd'
51
+ self.pos_conv = nn.Conv3d(dim_u, dim_k, (1, r, r), padding=(0, r // 2, r // 2))
52
+ else:
53
+ assert exists(n), 'You must specify the total sequence length (h x w)'
54
+ self.pos_emb = nn.Parameter(torch.randn(n, n, dim_k, dim_u))
55
+
56
+ def apply_lambda(self, lambda_c, lambda_p, x):
57
+ b, c, hh, ww, u, h = *x.shape, self.u, self.heads
58
+ q = self.to_q(x)
59
+ q = self.norm_q(q)
60
+ q = rearrange(q, 'b (h k) hh ww -> b h k (hh ww)', h=h)
61
+ Yc = einsum('b h k n, b k v -> b h v n', q, lambda_c)
62
+ if self.local_contexts:
63
+ Yp = einsum('b h k n, b k v n -> b h v n', q, lambda_p.flatten(3))
64
+ else:
65
+ Yp = einsum('b h k n, b n k v -> b h v n', q, lambda_p)
66
+ Y = Yc + Yp
67
+ out = rearrange(Y, 'b h v (hh ww) -> b (h v) hh ww', hh=hh, ww=ww)
68
+ return out
69
+
70
+ def forward(self, x):
71
+ b, c, hh, ww, u, h = *x.shape, self.u, self.heads
72
+
73
+ k = self.to_k(x)
74
+ v = self.to_v(x)
75
+
76
+ v = self.norm_v(v)
77
+
78
+ k = rearrange(k, 'b (u k) hh ww -> b u k (hh ww)', u=u)
79
+ v = rearrange(v, 'b (u v) hh ww -> b u v (hh ww)', u=u)
80
+
81
+ k = k.softmax(dim=-1)
82
+
83
+ λc = einsum('b u k m, b u v m -> b k v', k, v)
84
+
85
+ if self.local_contexts:
86
+ v = rearrange(v, 'b u v (hh ww) -> b u v hh ww', hh=hh, ww=ww)
87
+ λp = self.pos_conv(v)
88
+ else:
89
+ λp = einsum('n m k u, b u v m -> b n k v', self.pos_emb, v)
90
+ out = x
91
+ for i in range(self.recurrence):
92
+ out = self.apply_lambda(λc, λp, out)
93
+ return out
CAR/code/loss/__init__.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from importlib import import_module
3
+
4
+ import matplotlib
5
+ matplotlib.use('Agg')
6
+ import matplotlib.pyplot as plt
7
+
8
+ import numpy as np
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+
14
+ def sequence_loss(sr, hr, loss_func, gamma=0.8, max_val=None):
15
+ """ Loss function defined over sequence of flow predictions """
16
+
17
+ n_recurrence = len(sr)
18
+ total_loss = 0.0
19
+ buffer=[0.0]*n_recurrence
20
+ # exlude invalid pixels and extremely large diplacements
21
+ for i in range(n_recurrence):
22
+ i_weight = gamma**(n_recurrence - i - 1)
23
+ i_loss = loss_func(sr[i],hr)
24
+ buffer[i]=i_loss.item()
25
+ # total_loss += i_weight * (valid[:, None] * i_loss).mean()
26
+ total_loss += i_weight * (i_loss)
27
+ return total_loss,buffer
28
+
29
+ class Loss(nn.modules.loss._Loss):
30
+ def __init__(self, args, ckp):
31
+ super(Loss, self).__init__()
32
+ print('Preparing loss function:')
33
+ self.buffer=[0.0]*args.recurrence
34
+ self.n_GPUs = args.n_GPUs
35
+ self.loss = []
36
+ self.loss_module = nn.ModuleList()
37
+ for loss in args.loss.split('+'):
38
+ weight, loss_type = loss.split('*')
39
+ if loss_type == 'MSE':
40
+ loss_function = nn.MSELoss()
41
+ elif loss_type == 'L1':
42
+ loss_function = nn.L1Loss()
43
+ elif loss_type.find('VGG') >= 0:
44
+ module = import_module('loss.vgg')
45
+ loss_function = getattr(module, 'VGG')(
46
+ loss_type[3:],
47
+ rgb_range=args.rgb_range
48
+ )
49
+ elif loss_type.find('GAN') >= 0:
50
+ module = import_module('loss.adversarial')
51
+ loss_function = getattr(module, 'Adversarial')(
52
+ args,
53
+ loss_type
54
+ )
55
+
56
+ self.loss.append({
57
+ 'type': loss_type,
58
+ 'weight': float(weight),
59
+ 'function': loss_function}
60
+ )
61
+ if loss_type.find('GAN') >= 0:
62
+ self.loss.append({'type': 'DIS', 'weight': 1, 'function': None})
63
+
64
+ if len(self.loss) > 1:
65
+ self.loss.append({'type': 'Total', 'weight': 0, 'function': None})
66
+
67
+ for l in self.loss:
68
+ if l['function'] is not None:
69
+ print('{:.3f} * {}'.format(l['weight'], l['type']))
70
+ # self.loss_module.append(l['function'])
71
+
72
+ self.log = torch.Tensor()
73
+
74
+ device = torch.device('cpu' if args.cpu else 'cuda')
75
+ self.loss_module.to(device)
76
+ if args.precision == 'half': self.loss_module.half()
77
+ if not args.cpu and args.n_GPUs > 1:
78
+ self.loss_module = nn.DataParallel(
79
+ self.loss_module, range(args.n_GPUs)
80
+ )
81
+
82
+ if args.load != '': self.load(ckp.dir, cpu=args.cpu)
83
+
84
+ def forward(self, sr, hr):
85
+ losses = []
86
+ for i, l in enumerate(self.loss):
87
+ if l['function'] is not None:
88
+ if isinstance(sr,list):
89
+ # weights=[0.32,0.08,0.02,0.01,0.005]
90
+ # weights=weights[::-1]
91
+ # weights=[0.01,0.02,0.08,0.32]
92
+ # self.buffer=[]
93
+ effective_loss,buffer_lst=sequence_loss(sr,hr,l['function'])
94
+ # for k in range(len(sr)):
95
+ # loss=l['function'](sr[k], hr)
96
+ # self.buffer.append(loss.item())
97
+ # effective_loss=loss*weights[k]*l['weight']
98
+ losses.append(effective_loss)
99
+ self.buffer=buffer_lst
100
+ self.log[-1, i] += effective_loss.item()
101
+ else:
102
+ loss = l['function'](sr, hr)
103
+ effective_loss = l['weight'] * loss
104
+ losses.append(effective_loss)
105
+ self.buffer[0]=effective_loss.item()
106
+ self.log[-1, i] += effective_loss.item()
107
+ elif l['type'] == 'DIS':
108
+ self.log[-1, i] += self.loss[i - 1]['function'].loss
109
+
110
+ loss_sum = sum(losses)
111
+ if len(self.loss) > 1:
112
+ self.log[-1, -1] += loss_sum.item()
113
+
114
+ return loss_sum
115
+
116
+ def step(self):
117
+ for l in self.get_loss_module():
118
+ if hasattr(l, 'scheduler'):
119
+ l.scheduler.step()
120
+
121
+ def start_log(self):
122
+ self.log = torch.cat((self.log, torch.zeros(1, len(self.loss))))
123
+
124
+ def end_log(self, n_batches):
125
+ self.log[-1].div_(n_batches)
126
+
127
+ def display_loss(self, batch):
128
+ n_samples = batch + 1
129
+ log = []
130
+ for l, c in zip(self.loss, self.log[-1]):
131
+ log.append('[{}: {:.4f}]'.format(l['type'], c / n_samples))
132
+
133
+ return ''.join(log)
134
+
135
+ def plot_loss(self, apath, epoch):
136
+ axis = np.linspace(1, epoch, epoch)
137
+ for i, l in enumerate(self.loss):
138
+ label = '{} Loss'.format(l['type'])
139
+ fig = plt.figure()
140
+ plt.title(label)
141
+ plt.plot(axis, self.log[:, i].numpy(), label=label)
142
+ plt.legend()
143
+ plt.xlabel('Epochs')
144
+ plt.ylabel('Loss')
145
+ plt.grid(True)
146
+ plt.savefig(os.path.join(apath, 'loss_{}.pdf'.format(l['type'])))
147
+ plt.close(fig)
148
+
149
+ def get_loss_module(self):
150
+ if self.n_GPUs == 1:
151
+ return self.loss_module
152
+ else:
153
+ return self.loss_module.module
154
+
155
+ def save(self, apath):
156
+ torch.save(self.state_dict(), os.path.join(apath, 'loss.pt'))
157
+ torch.save(self.log, os.path.join(apath, 'loss_log.pt'))
158
+
159
+ def load(self, apath, cpu=False):
160
+ if cpu:
161
+ kwargs = {'map_location': lambda storage, loc: storage}
162
+ else:
163
+ kwargs = {}
164
+
165
+ self.load_state_dict(torch.load(
166
+ os.path.join(apath, 'loss.pt'),
167
+ **kwargs
168
+ ))
169
+ self.log = torch.load(os.path.join(apath, 'loss_log.pt'))
170
+ for l in self.get_loss_module():
171
+ if hasattr(l, 'scheduler'):
172
+ for _ in range(len(self.log)): l.scheduler.step()
173
+
CAR/code/loss/__pycache__/__init__.cpython-36.pyc ADDED
Binary file (4.57 kB). View file
 
CAR/code/loss/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (5.13 kB). View file
 
CAR/code/loss/adversarial.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import utility
2
+ from types import SimpleNamespace
3
+
4
+ from model import common
5
+ from loss import discriminator
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ import torch.optim as optim
11
+
12
+ class Adversarial(nn.Module):
13
+ def __init__(self, args, gan_type):
14
+ super(Adversarial, self).__init__()
15
+ self.gan_type = gan_type
16
+ self.gan_k = args.gan_k
17
+ self.dis = discriminator.Discriminator(args)
18
+ if gan_type == 'WGAN_GP':
19
+ # see https://arxiv.org/pdf/1704.00028.pdf pp.4
20
+ optim_dict = {
21
+ 'optimizer': 'ADAM',
22
+ 'betas': (0, 0.9),
23
+ 'epsilon': 1e-8,
24
+ 'lr': 1e-5,
25
+ 'weight_decay': args.weight_decay,
26
+ 'decay': args.decay,
27
+ 'gamma': args.gamma
28
+ }
29
+ optim_args = SimpleNamespace(**optim_dict)
30
+ else:
31
+ optim_args = args
32
+
33
+ self.optimizer = utility.make_optimizer(optim_args, self.dis)
34
+
35
+ def forward(self, fake, real):
36
+ # updating discriminator...
37
+ self.loss = 0
38
+ fake_detach = fake.detach() # do not backpropagate through G
39
+ for _ in range(self.gan_k):
40
+ self.optimizer.zero_grad()
41
+ # d: B x 1 tensor
42
+ d_fake = self.dis(fake_detach)
43
+ d_real = self.dis(real)
44
+ retain_graph = False
45
+ if self.gan_type == 'GAN':
46
+ loss_d = self.bce(d_real, d_fake)
47
+ elif self.gan_type.find('WGAN') >= 0:
48
+ loss_d = (d_fake - d_real).mean()
49
+ if self.gan_type.find('GP') >= 0:
50
+ epsilon = torch.rand_like(fake).view(-1, 1, 1, 1)
51
+ hat = fake_detach.mul(1 - epsilon) + real.mul(epsilon)
52
+ hat.requires_grad = True
53
+ d_hat = self.dis(hat)
54
+ gradients = torch.autograd.grad(
55
+ outputs=d_hat.sum(), inputs=hat,
56
+ retain_graph=True, create_graph=True, only_inputs=True
57
+ )[0]
58
+ gradients = gradients.view(gradients.size(0), -1)
59
+ gradient_norm = gradients.norm(2, dim=1)
60
+ gradient_penalty = 10 * gradient_norm.sub(1).pow(2).mean()
61
+ loss_d += gradient_penalty
62
+ # from ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks
63
+ elif self.gan_type == 'RGAN':
64
+ better_real = d_real - d_fake.mean(dim=0, keepdim=True)
65
+ better_fake = d_fake - d_real.mean(dim=0, keepdim=True)
66
+ loss_d = self.bce(better_real, better_fake)
67
+ retain_graph = True
68
+
69
+ # Discriminator update
70
+ self.loss += loss_d.item()
71
+ loss_d.backward(retain_graph=retain_graph)
72
+ self.optimizer.step()
73
+
74
+ if self.gan_type == 'WGAN':
75
+ for p in self.dis.parameters():
76
+ p.data.clamp_(-1, 1)
77
+
78
+ self.loss /= self.gan_k
79
+
80
+ # updating generator...
81
+ d_fake_bp = self.dis(fake) # for backpropagation, use fake as it is
82
+ if self.gan_type == 'GAN':
83
+ label_real = torch.ones_like(d_fake_bp)
84
+ loss_g = F.binary_cross_entropy_with_logits(d_fake_bp, label_real)
85
+ elif self.gan_type.find('WGAN') >= 0:
86
+ loss_g = -d_fake_bp.mean()
87
+ elif self.gan_type == 'RGAN':
88
+ better_real = d_real - d_fake_bp.mean(dim=0, keepdim=True)
89
+ better_fake = d_fake_bp - d_real.mean(dim=0, keepdim=True)
90
+ loss_g = self.bce(better_fake, better_real)
91
+
92
+ # Generator loss
93
+ return loss_g
94
+
95
+ def state_dict(self, *args, **kwargs):
96
+ state_discriminator = self.dis.state_dict(*args, **kwargs)
97
+ state_optimizer = self.optimizer.state_dict()
98
+
99
+ return dict(**state_discriminator, **state_optimizer)
100
+
101
+ def bce(self, real, fake):
102
+ label_real = torch.ones_like(real)
103
+ label_fake = torch.zeros_like(fake)
104
+ bce_real = F.binary_cross_entropy_with_logits(real, label_real)
105
+ bce_fake = F.binary_cross_entropy_with_logits(fake, label_fake)
106
+ bce_loss = bce_real + bce_fake
107
+ return bce_loss
108
+
109
+ # Some references
110
+ # https://github.com/kuc2477/pytorch-wgan-gp/blob/master/model.py
111
+ # OR
112
+ # https://github.com/caogang/wgan-gp/blob/master/gan_cifar10.py
CAR/code/loss/discriminator.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from model import common
2
+
3
+ import torch.nn as nn
4
+
5
+ class Discriminator(nn.Module):
6
+ '''
7
+ output is not normalized
8
+ '''
9
+ def __init__(self, args):
10
+ super(Discriminator, self).__init__()
11
+
12
+ in_channels = args.n_colors
13
+ out_channels = 64
14
+ depth = 7
15
+
16
+ def _block(_in_channels, _out_channels, stride=1):
17
+ return nn.Sequential(
18
+ nn.Conv2d(
19
+ _in_channels,
20
+ _out_channels,
21
+ 3,
22
+ padding=1,
23
+ stride=stride,
24
+ bias=False
25
+ ),
26
+ nn.BatchNorm2d(_out_channels),
27
+ nn.LeakyReLU(negative_slope=0.2, inplace=True)
28
+ )
29
+
30
+ m_features = [_block(in_channels, out_channels)]
31
+ for i in range(depth):
32
+ in_channels = out_channels
33
+ if i % 2 == 1:
34
+ stride = 1
35
+ out_channels *= 2
36
+ else:
37
+ stride = 2
38
+ m_features.append(_block(in_channels, out_channels, stride=stride))
39
+
40
+ patch_size = args.patch_size // (2**((depth + 1) // 2))
41
+ m_classifier = [
42
+ nn.Linear(out_channels * patch_size**2, 1024),
43
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
44
+ nn.Linear(1024, 1)
45
+ ]
46
+
47
+ self.features = nn.Sequential(*m_features)
48
+ self.classifier = nn.Sequential(*m_classifier)
49
+
50
+ def forward(self, x):
51
+ features = self.features(x)
52
+ output = self.classifier(features.view(features.size(0), -1))
53
+
54
+ return output
55
+
CAR/code/loss/vgg.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from model import common
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import torchvision.models as models
7
+
8
+ class VGG(nn.Module):
9
+ def __init__(self, conv_index, rgb_range=1):
10
+ super(VGG, self).__init__()
11
+ vgg_features = models.vgg19(pretrained=True).features
12
+ modules = [m for m in vgg_features]
13
+ if conv_index.find('22') >= 0:
14
+ self.vgg = nn.Sequential(*modules[:8])
15
+ elif conv_index.find('54') >= 0:
16
+ self.vgg = nn.Sequential(*modules[:35])
17
+
18
+ vgg_mean = (0.485, 0.456, 0.406)
19
+ vgg_std = (0.229 * rgb_range, 0.224 * rgb_range, 0.225 * rgb_range)
20
+ self.sub_mean = common.MeanShift(rgb_range, vgg_mean, vgg_std)
21
+ for p in self.parameters():
22
+ p.requires_grad = False
23
+
24
+ def forward(self, sr, hr):
25
+ def _forward(x):
26
+ x = self.sub_mean(x)
27
+ x = self.vgg(x)
28
+ return x
29
+
30
+ vgg_sr = _forward(sr)
31
+ with torch.no_grad():
32
+ vgg_hr = _forward(hr.detach())
33
+
34
+ loss = F.mse_loss(vgg_sr, vgg_hr)
35
+
36
+ return loss
CAR/code/main.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ import utility
4
+ import data
5
+ import model
6
+ import loss
7
+ from option import args
8
+ from trainer import Trainer
9
+
10
+ torch.manual_seed(args.seed)
11
+ checkpoint = utility.checkpoint(args)
12
+
13
+ def main():
14
+ global model
15
+ if args.data_test == ['video']:
16
+ from videotester import VideoTester
17
+ model = model.Model(args,checkpoint)
18
+ print('total params: %.2fM' % (sum(p.numel() for p in model.parameters())/1000000.0))
19
+ t = VideoTester(args, model, checkpoint)
20
+ t.test()
21
+ else:
22
+ if checkpoint.ok:
23
+ loader = data.Data(args)
24
+ _model = model.Model(args, checkpoint)
25
+ print('total params:%.5fM' % (sum(p.numel() for p in _model.parameters())/1000000.0))
26
+ _loss = loss.Loss(args, checkpoint) if not args.test_only else None
27
+ t = Trainer(args, loader, _model, _loss, checkpoint)
28
+ while not t.terminate():
29
+ t.train()
30
+ t.test()
31
+
32
+ checkpoint.done()
33
+
34
+ if __name__ == '__main__':
35
+ main()
CAR/code/model/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2018 Sanghyun Son
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
CAR/code/model/__init__.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from importlib import import_module
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch.autograd import Variable
7
+
8
+ class Model(nn.Module):
9
+ def __init__(self, args, ckp):
10
+ super(Model, self).__init__()
11
+ print('Making model...')
12
+
13
+ self.scale = args.scale
14
+ self.idx_scale = 0
15
+ self.self_ensemble = args.self_ensemble
16
+ self.chop = args.chop
17
+ self.precision = args.precision
18
+ self.cpu = args.cpu
19
+ self.device = torch.device('cpu' if args.cpu else 'cuda')
20
+ self.n_GPUs = args.n_GPUs
21
+ self.save_models = args.save_models
22
+
23
+ module = import_module('model.' + args.model.lower())
24
+ self.model = module.make_model(args).to(self.device)
25
+ if args.precision == 'half': self.model.half()
26
+
27
+ if not args.cpu and args.n_GPUs > 1:
28
+ self.model = nn.DataParallel(self.model, range(args.n_GPUs))
29
+
30
+ self.load(
31
+ ckp.dir,
32
+ pre_train=args.pre_train,
33
+ resume=args.resume,
34
+ cpu=args.cpu
35
+ )
36
+ print(self.model, file=ckp.log_file)
37
+
38
+ def forward(self, x, idx_scale):
39
+ self.idx_scale = idx_scale
40
+ target = self.get_model()
41
+ if hasattr(target, 'set_scale'):
42
+ target.set_scale(idx_scale)
43
+
44
+ if self.self_ensemble and not self.training:
45
+ if self.chop:
46
+ forward_function = self.forward_chop
47
+ else:
48
+ forward_function = self.model.forward
49
+
50
+ return self.forward_x8(x, forward_function)
51
+ elif self.chop and not self.training:
52
+ return self.forward_chop(x)
53
+ else:
54
+ return self.model(x)
55
+
56
+ def get_model(self):
57
+ if self.n_GPUs == 1:
58
+ return self.model
59
+ else:
60
+ return self.model.module
61
+
62
+ def state_dict(self, **kwargs):
63
+ target = self.get_model()
64
+ return target.state_dict(**kwargs)
65
+
66
+ def save(self, apath, epoch, is_best=False):
67
+ target = self.get_model()
68
+ torch.save(
69
+ target.state_dict(),
70
+ os.path.join(apath, 'model_latest.pt')
71
+ )
72
+ if is_best:
73
+ torch.save(
74
+ target.state_dict(),
75
+ os.path.join(apath, 'model_best.pt')
76
+ )
77
+
78
+ if self.save_models:
79
+ torch.save(
80
+ target.state_dict(),
81
+ os.path.join(apath, 'model_{}.pt'.format(epoch))
82
+ )
83
+
84
+ def load(self, apath, pre_train='.', resume=-1, cpu=False):
85
+ if cpu:
86
+ kwargs = {'map_location': lambda storage, loc: storage}
87
+ else:
88
+ kwargs = {}
89
+
90
+ if resume == -1:
91
+ self.get_model().load_state_dict(
92
+ torch.load(
93
+ os.path.join(apath,'model', 'model_latest.pt'),
94
+ **kwargs
95
+ ),
96
+ strict=False
97
+ )
98
+ elif resume == 0:
99
+ if pre_train != '.':
100
+ print('Loading model from {}'.format(pre_train))
101
+ self.get_model().load_state_dict(
102
+ torch.load(pre_train, **kwargs),
103
+ strict=False
104
+ )
105
+ else:
106
+ self.get_model().load_state_dict(
107
+ torch.load(
108
+ os.path.join(apath, 'model', 'model_{}.pt'.format(resume)),
109
+ **kwargs
110
+ ),
111
+ strict=False
112
+ )
113
+
114
+ def forward_chop(self, x, shave=10, min_size=6400):
115
+ scale = self.scale[self.idx_scale]
116
+ scale = 1
117
+ n_GPUs = min(self.n_GPUs, 4)
118
+ b, c, h, w = x.size()
119
+ h_half, w_half = h // 2, w // 2
120
+ h_size, w_size = h_half + shave, w_half + shave
121
+ lr_list = [
122
+ x[:, :, 0:h_size, 0:w_size],
123
+ x[:, :, 0:h_size, (w - w_size):w],
124
+ x[:, :, (h - h_size):h, 0:w_size],
125
+ x[:, :, (h - h_size):h, (w - w_size):w]]
126
+
127
+ if w_size * h_size < min_size:
128
+ sr_list = []
129
+ for i in range(0, 4, n_GPUs):
130
+ lr_batch = torch.cat(lr_list[i:(i + n_GPUs)], dim=0)
131
+ sr_batch = self.model(lr_batch)
132
+ sr_list.extend(sr_batch.chunk(n_GPUs, dim=0))
133
+ else:
134
+ sr_list = [
135
+ self.forward_chop(patch, shave=shave, min_size=min_size) \
136
+ for patch in lr_list
137
+ ]
138
+
139
+ h, w = scale * h, scale * w
140
+ h_half, w_half = scale * h_half, scale * w_half
141
+ h_size, w_size = scale * h_size, scale * w_size
142
+ shave *= scale
143
+
144
+ output = x.new(b, c, h, w)
145
+ output[:, :, 0:h_half, 0:w_half] \
146
+ = sr_list[0][:, :, 0:h_half, 0:w_half]
147
+ output[:, :, 0:h_half, w_half:w] \
148
+ = sr_list[1][:, :, 0:h_half, (w_size - w + w_half):w_size]
149
+ output[:, :, h_half:h, 0:w_half] \
150
+ = sr_list[2][:, :, (h_size - h + h_half):h_size, 0:w_half]
151
+ output[:, :, h_half:h, w_half:w] \
152
+ = sr_list[3][:, :, (h_size - h + h_half):h_size, (w_size - w + w_half):w_size]
153
+
154
+ return output
155
+
156
+ def forward_x8(self, x, forward_function):
157
+ def _transform(v, op):
158
+ if self.precision != 'single': v = v.float()
159
+
160
+ v2np = v.data.cpu().numpy()
161
+ if op == 'v':
162
+ tfnp = v2np[:, :, :, ::-1].copy()
163
+ elif op == 'h':
164
+ tfnp = v2np[:, :, ::-1, :].copy()
165
+ elif op == 't':
166
+ tfnp = v2np.transpose((0, 1, 3, 2)).copy()
167
+
168
+ ret = torch.Tensor(tfnp).to(self.device)
169
+ if self.precision == 'half': ret = ret.half()
170
+
171
+ return ret
172
+
173
+ lr_list = [x]
174
+ for tf in 'v', 'h', 't':
175
+ lr_list.extend([_transform(t, tf) for t in lr_list])
176
+
177
+ sr_list = [forward_function(aug) for aug in lr_list]
178
+ for i in range(len(sr_list)):
179
+ if i > 3:
180
+ sr_list[i] = _transform(sr_list[i], 't')
181
+ if i % 4 > 1:
182
+ sr_list[i] = _transform(sr_list[i], 'h')
183
+ if (i % 4) % 2 == 1:
184
+ sr_list[i] = _transform(sr_list[i], 'v')
185
+
186
+ output_cat = torch.cat(sr_list, dim=0)
187
+ output = output_cat.mean(dim=0, keepdim=True)
188
+
189
+ return output
190
+
CAR/code/model/__pycache__/__init__.cpython-36.pyc ADDED
Binary file (5.47 kB). View file
 
CAR/code/model/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (5.42 kB). View file
 
CAR/code/model/__pycache__/attention.cpython-36.pyc ADDED
Binary file (12.8 kB). View file