Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- CAR/README.md +103 -0
- CAR/code/LICENSE +21 -0
- CAR/code/__init__.py +0 -0
- CAR/code/__pycache__/option.cpython-36.pyc +0 -0
- CAR/code/__pycache__/option.cpython-37.pyc +0 -0
- CAR/code/__pycache__/template.cpython-36.pyc +0 -0
- CAR/code/__pycache__/template.cpython-37.pyc +0 -0
- CAR/code/__pycache__/trainer.cpython-36.pyc +0 -0
- CAR/code/__pycache__/trainer.cpython-37.pyc +0 -0
- CAR/code/__pycache__/utility.cpython-36.pyc +0 -0
- CAR/code/__pycache__/utility.cpython-37.pyc +0 -0
- CAR/code/data/__init__.py +52 -0
- CAR/code/data/__pycache__/__init__.cpython-36.pyc +0 -0
- CAR/code/data/__pycache__/__init__.cpython-37.pyc +0 -0
- CAR/code/data/__pycache__/benchmark.cpython-36.pyc +0 -0
- CAR/code/data/__pycache__/common.cpython-36.pyc +0 -0
- CAR/code/data/__pycache__/common.cpython-37.pyc +0 -0
- CAR/code/data/__pycache__/div2k.cpython-36.pyc +0 -0
- CAR/code/data/__pycache__/div2k.cpython-37.pyc +0 -0
- CAR/code/data/__pycache__/srdata.cpython-36.pyc +0 -0
- CAR/code/data/__pycache__/srdata.cpython-37.pyc +0 -0
- CAR/code/data/benchmark.py +25 -0
- CAR/code/data/common.py +79 -0
- CAR/code/data/demo.py +39 -0
- CAR/code/data/div2k.py +32 -0
- CAR/code/data/div2kjpeg.py +20 -0
- CAR/code/data/sr291.py +6 -0
- CAR/code/data/srdata.py +166 -0
- CAR/code/data/video.py +44 -0
- CAR/code/dataloader.py +158 -0
- CAR/code/demo.sb +6 -0
- CAR/code/demo.sh +56 -0
- CAR/code/lambda_networks/__init__.py +4 -0
- CAR/code/lambda_networks/__pycache__/__init__.cpython-37.pyc +0 -0
- CAR/code/lambda_networks/__pycache__/lambda_networks.cpython-37.pyc +0 -0
- CAR/code/lambda_networks/__pycache__/rlambda_networks.cpython-37.pyc +0 -0
- CAR/code/lambda_networks/lambda_networks.py +161 -0
- CAR/code/lambda_networks/rlambda_networks.py +93 -0
- CAR/code/loss/__init__.py +173 -0
- CAR/code/loss/__pycache__/__init__.cpython-36.pyc +0 -0
- CAR/code/loss/__pycache__/__init__.cpython-37.pyc +0 -0
- CAR/code/loss/adversarial.py +112 -0
- CAR/code/loss/discriminator.py +55 -0
- CAR/code/loss/vgg.py +36 -0
- CAR/code/main.py +35 -0
- CAR/code/model/LICENSE +21 -0
- CAR/code/model/__init__.py +190 -0
- CAR/code/model/__pycache__/__init__.cpython-36.pyc +0 -0
- CAR/code/model/__pycache__/__init__.cpython-37.pyc +0 -0
- CAR/code/model/__pycache__/attention.cpython-36.pyc +0 -0
CAR/README.md
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Pyramid Attention for Image Restoration
|
2 |
+
This repository is for PANet and PA-EDSR introduced in the following paper
|
3 |
+
|
4 |
+
[Yiqun Mei](http://yiqunm2.web.illinois.edu/), [Yuchen Fan](https://scholar.google.com/citations?user=BlfdYL0AAAAJ&hl=en), [Yulun Zhang](http://yulunzhang.com/), [Jiahui Yu](https://jiahuiyu.com/), [Yuqian Zhou](https://yzhouas.github.io/), [Ding Liu](https://scholar.google.com/citations?user=PGtHUI0AAAAJ&hl=en), [Yun Fu](http://www1.ece.neu.edu/~yunfu/), [Thomas S. Huang](http://ifp-uiuc.github.io/) and [Honghui Shi](https://www.humphreyshi.com/) "Pyramid Attention for Image Restoration", [[Arxiv]](https://arxiv.org/abs/2004.13824)
|
5 |
+
|
6 |
+
The code is built on [EDSR (PyTorch)](https://github.com/thstkdgus35/EDSR-PyTorch) & [RNAN](https://github.com/yulunzhang/RNAN) and tested on Ubuntu 18.04 environment (Python3.6, PyTorch_1.1) with Titan X/1080Ti/V100 GPUs.
|
7 |
+
|
8 |
+
## Contents
|
9 |
+
1. [Train](#train)
|
10 |
+
2. [Test](#test)
|
11 |
+
3. [Results](#results)
|
12 |
+
4. [Citation](#citation)
|
13 |
+
5. [Acknowledgements](#acknowledgements)
|
14 |
+
|
15 |
+
## Train
|
16 |
+
### Prepare training data
|
17 |
+
|
18 |
+
1. Download DIV2K training data (800 training + 100 validtion images) from [DIV2K dataset](https://data.vision.ee.ethz.ch/cvl/DIV2K/) or [SNU_CVLab](https://cv.snu.ac.kr/research/EDSR/DIV2K.tar).
|
19 |
+
|
20 |
+
2. Specify '--dir_data' based on the HR and LR images path.
|
21 |
+
|
22 |
+
3. Organize training data like:
|
23 |
+
```bash
|
24 |
+
DIV2K/
|
25 |
+
├── DIV2K_train_HR
|
26 |
+
├── DIV2K_train_LR_bicubic
|
27 |
+
│ └── X10
|
28 |
+
│ └── X20
|
29 |
+
│ └── X30
|
30 |
+
│ └── X40
|
31 |
+
├── DIV2K_valid_HR
|
32 |
+
└── DIV2K_valid_LR_bicubic
|
33 |
+
└── X10
|
34 |
+
└── X20
|
35 |
+
└── X30
|
36 |
+
└── X40
|
37 |
+
```
|
38 |
+
For more informaiton, please refer to [EDSR(PyTorch)](https://github.com/thstkdgus35/EDSR-PyTorch).
|
39 |
+
|
40 |
+
### Begin to train
|
41 |
+
|
42 |
+
1. (optional) All the pretrained models and visual results can be downloaded from [Google Drive](https://drive.google.com/open?id=1q9iUzqYX0fVRzDu4J6fvSPRosgOZoJJE).
|
43 |
+
|
44 |
+
2. Cd to 'PANet-PyTorch/[Task]/code', run the following scripts to train models.
|
45 |
+
|
46 |
+
**You can use scripts in file 'demo.sb' to train and test models for our paper.**
|
47 |
+
|
48 |
+
```bash
|
49 |
+
# Example Usage: Q=10
|
50 |
+
python main.py --n_GPUs 2 --batch_size 16 --lr 1e-4 --decay 200-400-600-800 --save_models --n_resblocks 80 --model PANET --scale 10 --patch_size 48 --save PANET_Q10 --n_feats 64 --data_train DIV2K --chop
|
51 |
+
```
|
52 |
+
## Test
|
53 |
+
### Quick start
|
54 |
+
|
55 |
+
1. Cd to 'PANet-PyTorch/[Task]/code', run the following scripts.
|
56 |
+
|
57 |
+
**You can use scripts in file 'demo.sb' to produce results for our paper.**
|
58 |
+
|
59 |
+
```bash
|
60 |
+
# No self-ensemble, use different testsets (Classic5, LIVE1) to reproduce the results in the paper.
|
61 |
+
# Example Usage: Q=40
|
62 |
+
python main.py --model PANET --save_results --n_GPUs 1 --chop --data_test classic5+LIVE1 --scale 40 --n_resblocks 80 --n_feats 64 --pre_train ../Q40.pt --test_only
|
63 |
+
|
64 |
+
```
|
65 |
+
|
66 |
+
### The whole test pipeline
|
67 |
+
1. Prepare test data. Organize training data like:
|
68 |
+
```bash
|
69 |
+
benchmark/
|
70 |
+
├── testset1
|
71 |
+
│ └── HR
|
72 |
+
│ └── LR_bicubic
|
73 |
+
│ └── X10
|
74 |
+
│ └── ..
|
75 |
+
├── testset2
|
76 |
+
```
|
77 |
+
|
78 |
+
2. Conduct image CAR.
|
79 |
+
|
80 |
+
See **Quick start**
|
81 |
+
3. Evaluate the results.
|
82 |
+
|
83 |
+
Run 'Evaluate_PSNR_SSIM.m' to obtain PSNR/SSIM values for paper.
|
84 |
+
|
85 |
+
## Citation
|
86 |
+
If you find the code helpful in your resarch or work, please cite the following papers.
|
87 |
+
```
|
88 |
+
@article{mei2020pyramid,
|
89 |
+
title={Pyramid Attention Networks for Image Restoration},
|
90 |
+
author={Mei, Yiqun and Fan, Yuchen and Zhang, Yulun and Yu, Jiahui and Zhou, Yuqian and Liu, Ding and Fu, Yun and Huang, Thomas S and Shi, Honghui},
|
91 |
+
journal={arXiv preprint arXiv:2004.13824},
|
92 |
+
year={2020}
|
93 |
+
}
|
94 |
+
@InProceedings{Lim_2017_CVPR_Workshops,
|
95 |
+
author = {Lim, Bee and Son, Sanghyun and Kim, Heewon and Nah, Seungjun and Lee, Kyoung Mu},
|
96 |
+
title = {Enhanced Deep Residual Networks for Single Image Super-Resolution},
|
97 |
+
booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) Workshops},
|
98 |
+
month = {July},
|
99 |
+
year = {2017}
|
100 |
+
}
|
101 |
+
```
|
102 |
+
## Acknowledgements
|
103 |
+
This code is built on [EDSR (PyTorch)](https://github.com/thstkdgus35/EDSR-PyTorch), [RNAN](https://github.com/yulunzhang/RNAN) and [generative-inpainting-pytorch](https://github.com/daa233/generative-inpainting-pytorch). We thank the authors for sharing their codes.
|
CAR/code/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2018 Sanghyun Son
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
CAR/code/__init__.py
ADDED
File without changes
|
CAR/code/__pycache__/option.cpython-36.pyc
ADDED
Binary file (4.63 kB). View file
|
|
CAR/code/__pycache__/option.cpython-37.pyc
ADDED
Binary file (4.85 kB). View file
|
|
CAR/code/__pycache__/template.cpython-36.pyc
ADDED
Binary file (1.01 kB). View file
|
|
CAR/code/__pycache__/template.cpython-37.pyc
ADDED
Binary file (994 Bytes). View file
|
|
CAR/code/__pycache__/trainer.cpython-36.pyc
ADDED
Binary file (3.99 kB). View file
|
|
CAR/code/__pycache__/trainer.cpython-37.pyc
ADDED
Binary file (5.07 kB). View file
|
|
CAR/code/__pycache__/utility.cpython-36.pyc
ADDED
Binary file (9.04 kB). View file
|
|
CAR/code/__pycache__/utility.cpython-37.pyc
ADDED
Binary file (9.1 kB). View file
|
|
CAR/code/data/__init__.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from importlib import import_module
|
2 |
+
#from dataloader import MSDataLoader
|
3 |
+
from torch.utils.data import dataloader
|
4 |
+
from torch.utils.data import ConcatDataset
|
5 |
+
|
6 |
+
# This is a simple wrapper function for ConcatDataset
|
7 |
+
class MyConcatDataset(ConcatDataset):
|
8 |
+
def __init__(self, datasets):
|
9 |
+
super(MyConcatDataset, self).__init__(datasets)
|
10 |
+
self.train = datasets[0].train
|
11 |
+
|
12 |
+
def set_scale(self, idx_scale):
|
13 |
+
for d in self.datasets:
|
14 |
+
if hasattr(d, 'set_scale'): d.set_scale(idx_scale)
|
15 |
+
|
16 |
+
class Data:
|
17 |
+
def __init__(self, args):
|
18 |
+
self.loader_train = None
|
19 |
+
if not args.test_only:
|
20 |
+
datasets = []
|
21 |
+
for d in args.data_train:
|
22 |
+
module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG'
|
23 |
+
m = import_module('data.' + module_name.lower())
|
24 |
+
datasets.append(getattr(m, module_name)(args, name=d))
|
25 |
+
|
26 |
+
self.loader_train = dataloader.DataLoader(
|
27 |
+
MyConcatDataset(datasets),
|
28 |
+
batch_size=args.batch_size,
|
29 |
+
shuffle=True,
|
30 |
+
pin_memory=not args.cpu,
|
31 |
+
num_workers=args.n_threads,
|
32 |
+
)
|
33 |
+
|
34 |
+
self.loader_test = []
|
35 |
+
for d in args.data_test:
|
36 |
+
if d in ['CBSD68','classic5','LIVE1','Kodak24','Set5', 'Set14', 'B100', 'Urban100']:
|
37 |
+
m = import_module('data.benchmark')
|
38 |
+
testset = getattr(m, 'Benchmark')(args, train=False, name=d)
|
39 |
+
else:
|
40 |
+
module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG'
|
41 |
+
m = import_module('data.' + module_name.lower())
|
42 |
+
testset = getattr(m, module_name)(args, train=False, name=d)
|
43 |
+
|
44 |
+
self.loader_test.append(
|
45 |
+
dataloader.DataLoader(
|
46 |
+
testset,
|
47 |
+
batch_size=1,
|
48 |
+
shuffle=False,
|
49 |
+
pin_memory=not args.cpu,
|
50 |
+
num_workers=args.n_threads,
|
51 |
+
)
|
52 |
+
)
|
CAR/code/data/__pycache__/__init__.cpython-36.pyc
ADDED
Binary file (1.84 kB). View file
|
|
CAR/code/data/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (1.79 kB). View file
|
|
CAR/code/data/__pycache__/benchmark.cpython-36.pyc
ADDED
Binary file (1.08 kB). View file
|
|
CAR/code/data/__pycache__/common.cpython-36.pyc
ADDED
Binary file (2.77 kB). View file
|
|
CAR/code/data/__pycache__/common.cpython-37.pyc
ADDED
Binary file (2.71 kB). View file
|
|
CAR/code/data/__pycache__/div2k.cpython-36.pyc
ADDED
Binary file (1.75 kB). View file
|
|
CAR/code/data/__pycache__/div2k.cpython-37.pyc
ADDED
Binary file (1.75 kB). View file
|
|
CAR/code/data/__pycache__/srdata.cpython-36.pyc
ADDED
Binary file (4.83 kB). View file
|
|
CAR/code/data/__pycache__/srdata.cpython-37.pyc
ADDED
Binary file (4.91 kB). View file
|
|
CAR/code/data/benchmark.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
from data import common
|
4 |
+
from data import srdata
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.utils.data as data
|
10 |
+
|
11 |
+
class Benchmark(srdata.SRData):
|
12 |
+
def __init__(self, args, name='', train=True, benchmark=True):
|
13 |
+
super(Benchmark, self).__init__(
|
14 |
+
args, name=name, train=train, benchmark=True
|
15 |
+
)
|
16 |
+
|
17 |
+
def _set_filesystem(self, dir_data):
|
18 |
+
self.apath = os.path.join(dir_data, 'benchmark', self.name)
|
19 |
+
self.dir_hr = os.path.join(self.apath, 'HR')
|
20 |
+
if self.input_large:
|
21 |
+
self.dir_lr = os.path.join(self.apath, 'LR_bicubicL')
|
22 |
+
else:
|
23 |
+
self.dir_lr = os.path.join(self.apath, 'LR_bicubic')
|
24 |
+
self.ext = ('','.jpg')
|
25 |
+
|
CAR/code/data/common.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import skimage.color as sc
|
5 |
+
|
6 |
+
import torch
|
7 |
+
|
8 |
+
def get_patch(*args, patch_size=96, scale=1, multi=False, input_large=False):
|
9 |
+
ih, iw = args[0].shape[:2]
|
10 |
+
# print('heelo')
|
11 |
+
# print(args[0].shape)
|
12 |
+
|
13 |
+
if not input_large:
|
14 |
+
p = 1 if multi else 1
|
15 |
+
tp = p * patch_size
|
16 |
+
ip = tp // 1
|
17 |
+
else:
|
18 |
+
tp = patch_size
|
19 |
+
ip = patch_size
|
20 |
+
|
21 |
+
ix = random.randrange(0, iw - ip + 1)
|
22 |
+
iy = random.randrange(0, ih - ip + 1)
|
23 |
+
|
24 |
+
if not input_large:
|
25 |
+
tx, ty = 1 * ix, 1 * iy
|
26 |
+
else:
|
27 |
+
tx, ty = ix, iy
|
28 |
+
|
29 |
+
ret = [
|
30 |
+
args[0][iy:iy + ip, ix:ix + ip],
|
31 |
+
*[a[ty:ty + tp, tx:tx + tp] for a in args[1:]]
|
32 |
+
]
|
33 |
+
|
34 |
+
return ret
|
35 |
+
|
36 |
+
def set_channel(*args, n_channels=3):
|
37 |
+
def _set_channel(img):
|
38 |
+
if img.ndim == 2:
|
39 |
+
img = np.expand_dims(img, axis=2)
|
40 |
+
|
41 |
+
c = img.shape[2]
|
42 |
+
if n_channels == 1 and c == 3:
|
43 |
+
img = np.expand_dims(sc.rgb2ycbcr(img)[:, :, 0], 2)
|
44 |
+
elif n_channels == 3 and c == 1:
|
45 |
+
img = np.concatenate([img] * n_channels, 2)
|
46 |
+
|
47 |
+
return img
|
48 |
+
|
49 |
+
return [_set_channel(a) for a in args]
|
50 |
+
|
51 |
+
def np2Tensor(*args, rgb_range=255):
|
52 |
+
def _np2Tensor(img):
|
53 |
+
np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1)))
|
54 |
+
tensor = torch.from_numpy(np_transpose).float()
|
55 |
+
tensor.mul_(rgb_range / 255)
|
56 |
+
|
57 |
+
return tensor
|
58 |
+
|
59 |
+
return [_np2Tensor(a) for a in args]
|
60 |
+
|
61 |
+
def augment(*args, hflip=True, rot=True):
|
62 |
+
hflip = hflip and random.random() < 0.5
|
63 |
+
vflip = rot and random.random() < 0.5
|
64 |
+
rot90 = rot and random.random() < 0.5
|
65 |
+
# hflip=True
|
66 |
+
# vflip=True
|
67 |
+
# rot90=True
|
68 |
+
|
69 |
+
def _augment(img):
|
70 |
+
if hflip: img = img[:, ::-1]
|
71 |
+
if vflip: img = img[::-1, :]
|
72 |
+
if rot90:
|
73 |
+
# print(img.shape)
|
74 |
+
img = img.transpose(1, 0)
|
75 |
+
|
76 |
+
return img
|
77 |
+
|
78 |
+
return [_augment(a) for a in args]
|
79 |
+
|
CAR/code/data/demo.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
from data import common
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import imageio
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.utils.data as data
|
10 |
+
|
11 |
+
class Demo(data.Dataset):
|
12 |
+
def __init__(self, args, name='Demo', train=False, benchmark=False):
|
13 |
+
self.args = args
|
14 |
+
self.name = name
|
15 |
+
self.scale = args.scale
|
16 |
+
self.idx_scale = 0
|
17 |
+
self.train = False
|
18 |
+
self.benchmark = benchmark
|
19 |
+
|
20 |
+
self.filelist = []
|
21 |
+
for f in os.listdir(args.dir_demo):
|
22 |
+
if f.find('.png') >= 0 or f.find('.jp') >= 0:
|
23 |
+
self.filelist.append(os.path.join(args.dir_demo, f))
|
24 |
+
self.filelist.sort()
|
25 |
+
|
26 |
+
def __getitem__(self, idx):
|
27 |
+
filename = os.path.splitext(os.path.basename(self.filelist[idx]))[0]
|
28 |
+
lr = imageio.imread(self.filelist[idx])
|
29 |
+
lr, = common.set_channel(lr, n_channels=self.args.n_colors)
|
30 |
+
lr_t, = common.np2Tensor(lr, rgb_range=self.args.rgb_range)
|
31 |
+
|
32 |
+
return lr_t, -1, filename
|
33 |
+
|
34 |
+
def __len__(self):
|
35 |
+
return len(self.filelist)
|
36 |
+
|
37 |
+
def set_scale(self, idx_scale):
|
38 |
+
self.idx_scale = idx_scale
|
39 |
+
|
CAR/code/data/div2k.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from data import srdata
|
3 |
+
|
4 |
+
class DIV2K(srdata.SRData):
|
5 |
+
def __init__(self, args, name='DIV2K', train=True, benchmark=False):
|
6 |
+
data_range = [r.split('-') for r in args.data_range.split('/')]
|
7 |
+
if train:
|
8 |
+
data_range = data_range[0]
|
9 |
+
else:
|
10 |
+
if args.test_only and len(data_range) == 1:
|
11 |
+
data_range = data_range[0]
|
12 |
+
else:
|
13 |
+
data_range = data_range[1]
|
14 |
+
|
15 |
+
self.begin, self.end = list(map(lambda x: int(x), data_range))
|
16 |
+
super(DIV2K, self).__init__(
|
17 |
+
args, name=name, train=train, benchmark=benchmark
|
18 |
+
)
|
19 |
+
|
20 |
+
def _scan(self):
|
21 |
+
names_hr, names_lr = super(DIV2K, self)._scan()
|
22 |
+
names_hr = names_hr[self.begin - 1:self.end]
|
23 |
+
names_lr = [n[self.begin - 1:self.end] for n in names_lr]
|
24 |
+
|
25 |
+
return names_hr, names_lr
|
26 |
+
|
27 |
+
def _set_filesystem(self, dir_data):
|
28 |
+
super(DIV2K, self)._set_filesystem(dir_data)
|
29 |
+
self.dir_hr = os.path.join(self.apath, 'DIV2K_train_HR')
|
30 |
+
self.dir_lr = os.path.join(self.apath, 'DIV2K_train_LR_bicubic')
|
31 |
+
if self.input_large: self.dir_lr += 'L'
|
32 |
+
|
CAR/code/data/div2kjpeg.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from data import srdata
|
3 |
+
from data import div2k
|
4 |
+
|
5 |
+
class DIV2KJPEG(div2k.DIV2K):
|
6 |
+
def __init__(self, args, name='', train=True, benchmark=False):
|
7 |
+
self.q_factor = int(name.replace('DIV2K-Q', ''))
|
8 |
+
super(DIV2KJPEG, self).__init__(
|
9 |
+
args, name=name, train=train, benchmark=benchmark
|
10 |
+
)
|
11 |
+
|
12 |
+
def _set_filesystem(self, dir_data):
|
13 |
+
self.apath = os.path.join(dir_data, 'DIV2K')
|
14 |
+
self.dir_hr = os.path.join(self.apath, 'DIV2K_train_HR')
|
15 |
+
self.dir_lr = os.path.join(
|
16 |
+
self.apath, 'DIV2K_Q{}'.format(self.q_factor)
|
17 |
+
)
|
18 |
+
if self.input_large: self.dir_lr += 'L'
|
19 |
+
self.ext = ('.png', '.jpg')
|
20 |
+
|
CAR/code/data/sr291.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from data import srdata
|
2 |
+
|
3 |
+
class SR291(srdata.SRData):
|
4 |
+
def __init__(self, args, name='SR291', train=True, benchmark=False):
|
5 |
+
super(SR291, self).__init__(args, name=name)
|
6 |
+
|
CAR/code/data/srdata.py
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import glob
|
3 |
+
import random
|
4 |
+
import pickle
|
5 |
+
|
6 |
+
from data import common
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import imageio
|
10 |
+
import torch
|
11 |
+
import torch.utils.data as data
|
12 |
+
|
13 |
+
class SRData(data.Dataset):
|
14 |
+
def __init__(self, args, name='', train=True, benchmark=False):
|
15 |
+
self.args = args
|
16 |
+
self.name = name
|
17 |
+
self.train = train
|
18 |
+
self.split = 'train' if train else 'test'
|
19 |
+
self.do_eval = True
|
20 |
+
self.benchmark = benchmark
|
21 |
+
self.input_large = (args.model == 'VDSR')
|
22 |
+
self.scale = args.scale
|
23 |
+
self.idx_scale = 0
|
24 |
+
|
25 |
+
self._set_filesystem(args.dir_data)
|
26 |
+
if args.ext.find('img') < 0:
|
27 |
+
path_bin = os.path.join(self.apath, 'bin')
|
28 |
+
os.makedirs(path_bin, exist_ok=True)
|
29 |
+
|
30 |
+
list_hr, list_lr = self._scan()
|
31 |
+
if args.ext.find('img') >= 0 or benchmark:
|
32 |
+
self.images_hr, self.images_lr = list_hr, list_lr
|
33 |
+
elif args.ext.find('sep') >= 0:
|
34 |
+
os.makedirs(
|
35 |
+
self.dir_hr.replace(self.apath, path_bin),
|
36 |
+
exist_ok=True
|
37 |
+
)
|
38 |
+
for s in self.scale:
|
39 |
+
os.makedirs(
|
40 |
+
os.path.join(
|
41 |
+
self.dir_lr.replace(self.apath, path_bin),
|
42 |
+
'X{}'.format(s)
|
43 |
+
),
|
44 |
+
exist_ok=True
|
45 |
+
)
|
46 |
+
|
47 |
+
self.images_hr, self.images_lr = [], [[] for _ in self.scale]
|
48 |
+
for h in list_hr:
|
49 |
+
b = h.replace(self.apath, path_bin)
|
50 |
+
b = b.replace(self.ext[0], '.pt')
|
51 |
+
self.images_hr.append(b)
|
52 |
+
self._check_and_load(args.ext, h, b, verbose=True)
|
53 |
+
for i, ll in enumerate(list_lr):
|
54 |
+
for l in ll:
|
55 |
+
b = l.replace(self.apath, path_bin)
|
56 |
+
b = b.replace(self.ext[1], '.pt')
|
57 |
+
self.images_lr[i].append(b)
|
58 |
+
self._check_and_load(args.ext, l, b, verbose=True)
|
59 |
+
if train:
|
60 |
+
n_patches = args.batch_size * args.test_every
|
61 |
+
n_images = len(args.data_train) * len(self.images_hr)
|
62 |
+
if n_images == 0:
|
63 |
+
self.repeat = 0
|
64 |
+
else:
|
65 |
+
self.repeat = max(n_patches // n_images, 1)
|
66 |
+
|
67 |
+
# Below functions as used to prepare images
|
68 |
+
def _scan(self):
|
69 |
+
names_hr = sorted(
|
70 |
+
glob.glob(os.path.join(self.dir_hr, '*' + self.ext[0]))
|
71 |
+
)
|
72 |
+
names_lr = [[] for _ in self.scale]
|
73 |
+
for f in names_hr:
|
74 |
+
filename, _ = os.path.splitext(os.path.basename(f))
|
75 |
+
for si, s in enumerate(self.scale):
|
76 |
+
names_lr[si].append(os.path.join(
|
77 |
+
self.dir_lr, 'X{}/{}{}'.format(
|
78 |
+
s, filename, self.ext[1]
|
79 |
+
)
|
80 |
+
))
|
81 |
+
|
82 |
+
return names_hr, names_lr
|
83 |
+
|
84 |
+
def _set_filesystem(self, dir_data):
|
85 |
+
self.apath = os.path.join(dir_data, self.name)
|
86 |
+
self.dir_hr = os.path.join(self.apath, 'HR')
|
87 |
+
self.dir_lr = os.path.join(self.apath, 'LR_bicubic')
|
88 |
+
if self.input_large: self.dir_lr += 'L'
|
89 |
+
self.ext = ('.png', '.jpg')
|
90 |
+
|
91 |
+
def _check_and_load(self, ext, img, f, verbose=True):
|
92 |
+
if not os.path.isfile(f) or ext.find('reset') >= 0:
|
93 |
+
if verbose:
|
94 |
+
print('Making a binary: {}'.format(f))
|
95 |
+
with open(f, 'wb') as _f:
|
96 |
+
pickle.dump(imageio.imread(img), _f)
|
97 |
+
|
98 |
+
def __getitem__(self, idx):
|
99 |
+
lr, hr, filename = self._load_file(idx)
|
100 |
+
pair = self.get_patch(lr, hr)
|
101 |
+
pair = common.set_channel(*pair, n_channels=self.args.n_colors)
|
102 |
+
pair_t = common.np2Tensor(*pair, rgb_range=self.args.rgb_range)
|
103 |
+
|
104 |
+
return pair_t[0], pair_t[1], filename
|
105 |
+
|
106 |
+
def __len__(self):
|
107 |
+
if self.train:
|
108 |
+
return len(self.images_hr) * self.repeat
|
109 |
+
else:
|
110 |
+
return len(self.images_hr)
|
111 |
+
|
112 |
+
def _get_index(self, idx):
|
113 |
+
if self.train:
|
114 |
+
return idx % len(self.images_hr)
|
115 |
+
else:
|
116 |
+
return idx
|
117 |
+
|
118 |
+
def _load_file(self, idx):
|
119 |
+
idx = self._get_index(idx)
|
120 |
+
f_hr = self.images_hr[idx]
|
121 |
+
f_lr = self.images_lr[self.idx_scale][idx]
|
122 |
+
|
123 |
+
filename, _ = os.path.splitext(os.path.basename(f_hr))
|
124 |
+
if self.args.ext == 'img' or self.benchmark:
|
125 |
+
hr = imageio.imread(f_hr)
|
126 |
+
lr = imageio.imread(f_lr)
|
127 |
+
elif self.args.ext.find('sep') >= 0:
|
128 |
+
with open(f_hr, 'rb') as _f:
|
129 |
+
hr = pickle.load(_f)
|
130 |
+
with open(f_lr, 'rb') as _f:
|
131 |
+
lr = pickle.load(_f)
|
132 |
+
|
133 |
+
return lr, hr, filename
|
134 |
+
|
135 |
+
def get_patch(self, lr, hr):
|
136 |
+
scale = self.scale[self.idx_scale]
|
137 |
+
if self.train:
|
138 |
+
lr, hr = common.get_patch(
|
139 |
+
lr, hr,
|
140 |
+
patch_size=self.args.patch_size,
|
141 |
+
scale=scale,
|
142 |
+
multi=(len(self.scale) > 1),
|
143 |
+
input_large=self.input_large
|
144 |
+
)
|
145 |
+
if not self.args.no_augment: lr, hr = common.augment(lr, hr)
|
146 |
+
else:
|
147 |
+
if self.args.model=="RECURSIONNET":
|
148 |
+
lr, hr = common.get_patch(
|
149 |
+
lr, hr,
|
150 |
+
patch_size=self.args.patch_size,
|
151 |
+
scale=scale,
|
152 |
+
multi=(len(self.scale) > 1),
|
153 |
+
input_large=self.input_large
|
154 |
+
)
|
155 |
+
else:
|
156 |
+
ih, iw = lr.shape[:2]
|
157 |
+
hr = hr[0:ih * scale, 0:iw * scale]
|
158 |
+
|
159 |
+
return lr, hr
|
160 |
+
|
161 |
+
def set_scale(self, idx_scale):
|
162 |
+
if not self.input_large:
|
163 |
+
self.idx_scale = idx_scale
|
164 |
+
else:
|
165 |
+
self.idx_scale = random.randint(0, len(self.scale) - 1)
|
166 |
+
|
CAR/code/data/video.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
from data import common
|
4 |
+
|
5 |
+
import cv2
|
6 |
+
import numpy as np
|
7 |
+
import imageio
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.utils.data as data
|
11 |
+
|
12 |
+
class Video(data.Dataset):
|
13 |
+
def __init__(self, args, name='Video', train=False, benchmark=False):
|
14 |
+
self.args = args
|
15 |
+
self.name = name
|
16 |
+
self.scale = args.scale
|
17 |
+
self.idx_scale = 0
|
18 |
+
self.train = False
|
19 |
+
self.do_eval = False
|
20 |
+
self.benchmark = benchmark
|
21 |
+
|
22 |
+
self.filename, _ = os.path.splitext(os.path.basename(args.dir_demo))
|
23 |
+
self.vidcap = cv2.VideoCapture(args.dir_demo)
|
24 |
+
self.n_frames = 0
|
25 |
+
self.total_frames = int(self.vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
|
26 |
+
|
27 |
+
def __getitem__(self, idx):
|
28 |
+
success, lr = self.vidcap.read()
|
29 |
+
if success:
|
30 |
+
self.n_frames += 1
|
31 |
+
lr, = common.set_channel(lr, n_channels=self.args.n_colors)
|
32 |
+
lr_t, = common.np2Tensor(lr, rgb_range=self.args.rgb_range)
|
33 |
+
|
34 |
+
return lr_t, -1, '{}_{:0>5}'.format(self.filename, self.n_frames)
|
35 |
+
else:
|
36 |
+
vidcap.release()
|
37 |
+
return None
|
38 |
+
|
39 |
+
def __len__(self):
|
40 |
+
return self.total_frames
|
41 |
+
|
42 |
+
def set_scale(self, idx_scale):
|
43 |
+
self.idx_scale = idx_scale
|
44 |
+
|
CAR/code/dataloader.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import threading
|
2 |
+
import random
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.multiprocessing as multiprocessing
|
6 |
+
from torch.utils.data import DataLoader
|
7 |
+
from torch.utils.data import SequentialSampler
|
8 |
+
from torch.utils.data import RandomSampler
|
9 |
+
from torch.utils.data import BatchSampler
|
10 |
+
from torch.utils.data import _utils
|
11 |
+
from torch.utils.data.dataloader import _DataLoaderIter
|
12 |
+
|
13 |
+
from torch.utils.data._utils import collate
|
14 |
+
from torch.utils.data._utils import signal_handling
|
15 |
+
from torch.utils.data._utils import MP_STATUS_CHECK_INTERVAL
|
16 |
+
from torch.utils.data._utils import ExceptionWrapper
|
17 |
+
from torch.utils.data._utils import IS_WINDOWS
|
18 |
+
from torch.utils.data._utils.worker import ManagerWatchdog
|
19 |
+
|
20 |
+
from torch._six import queue
|
21 |
+
|
22 |
+
def _ms_loop(dataset, index_queue, data_queue, done_event, collate_fn, scale, seed, init_fn, worker_id):
|
23 |
+
try:
|
24 |
+
collate._use_shared_memory = True
|
25 |
+
signal_handling._set_worker_signal_handlers()
|
26 |
+
|
27 |
+
torch.set_num_threads(1)
|
28 |
+
random.seed(seed)
|
29 |
+
torch.manual_seed(seed)
|
30 |
+
|
31 |
+
data_queue.cancel_join_thread()
|
32 |
+
|
33 |
+
if init_fn is not None:
|
34 |
+
init_fn(worker_id)
|
35 |
+
|
36 |
+
watchdog = ManagerWatchdog()
|
37 |
+
|
38 |
+
while watchdog.is_alive():
|
39 |
+
try:
|
40 |
+
r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
|
41 |
+
except queue.Empty:
|
42 |
+
continue
|
43 |
+
|
44 |
+
if r is None:
|
45 |
+
assert done_event.is_set()
|
46 |
+
return
|
47 |
+
elif done_event.is_set():
|
48 |
+
continue
|
49 |
+
|
50 |
+
idx, batch_indices = r
|
51 |
+
try:
|
52 |
+
idx_scale = 0
|
53 |
+
if len(scale) > 1 and dataset.train:
|
54 |
+
idx_scale = random.randrange(0, len(scale))
|
55 |
+
dataset.set_scale(idx_scale)
|
56 |
+
|
57 |
+
samples = collate_fn([dataset[i] for i in batch_indices])
|
58 |
+
samples.append(idx_scale)
|
59 |
+
except Exception:
|
60 |
+
data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
|
61 |
+
else:
|
62 |
+
data_queue.put((idx, samples))
|
63 |
+
del samples
|
64 |
+
|
65 |
+
except KeyboardInterrupt:
|
66 |
+
pass
|
67 |
+
|
68 |
+
class _MSDataLoaderIter(_DataLoaderIter):
|
69 |
+
|
70 |
+
def __init__(self, loader):
|
71 |
+
self.dataset = loader.dataset
|
72 |
+
self.scale = loader.scale
|
73 |
+
self.collate_fn = loader.collate_fn
|
74 |
+
self.batch_sampler = loader.batch_sampler
|
75 |
+
self.num_workers = loader.num_workers
|
76 |
+
self.pin_memory = loader.pin_memory and torch.cuda.is_available()
|
77 |
+
self.timeout = loader.timeout
|
78 |
+
|
79 |
+
self.sample_iter = iter(self.batch_sampler)
|
80 |
+
|
81 |
+
base_seed = torch.LongTensor(1).random_().item()
|
82 |
+
|
83 |
+
if self.num_workers > 0:
|
84 |
+
self.worker_init_fn = loader.worker_init_fn
|
85 |
+
self.worker_queue_idx = 0
|
86 |
+
self.worker_result_queue = multiprocessing.Queue()
|
87 |
+
self.batches_outstanding = 0
|
88 |
+
self.worker_pids_set = False
|
89 |
+
self.shutdown = False
|
90 |
+
self.send_idx = 0
|
91 |
+
self.rcvd_idx = 0
|
92 |
+
self.reorder_dict = {}
|
93 |
+
self.done_event = multiprocessing.Event()
|
94 |
+
|
95 |
+
base_seed = torch.LongTensor(1).random_()[0]
|
96 |
+
|
97 |
+
self.index_queues = []
|
98 |
+
self.workers = []
|
99 |
+
for i in range(self.num_workers):
|
100 |
+
index_queue = multiprocessing.Queue()
|
101 |
+
index_queue.cancel_join_thread()
|
102 |
+
w = multiprocessing.Process(
|
103 |
+
target=_ms_loop,
|
104 |
+
args=(
|
105 |
+
self.dataset,
|
106 |
+
index_queue,
|
107 |
+
self.worker_result_queue,
|
108 |
+
self.done_event,
|
109 |
+
self.collate_fn,
|
110 |
+
self.scale,
|
111 |
+
base_seed + i,
|
112 |
+
self.worker_init_fn,
|
113 |
+
i
|
114 |
+
)
|
115 |
+
)
|
116 |
+
w.daemon = True
|
117 |
+
w.start()
|
118 |
+
self.index_queues.append(index_queue)
|
119 |
+
self.workers.append(w)
|
120 |
+
|
121 |
+
if self.pin_memory:
|
122 |
+
self.data_queue = queue.Queue()
|
123 |
+
pin_memory_thread = threading.Thread(
|
124 |
+
target=_utils.pin_memory._pin_memory_loop,
|
125 |
+
args=(
|
126 |
+
self.worker_result_queue,
|
127 |
+
self.data_queue,
|
128 |
+
torch.cuda.current_device(),
|
129 |
+
self.done_event
|
130 |
+
)
|
131 |
+
)
|
132 |
+
pin_memory_thread.daemon = True
|
133 |
+
pin_memory_thread.start()
|
134 |
+
self.pin_memory_thread = pin_memory_thread
|
135 |
+
else:
|
136 |
+
self.data_queue = self.worker_result_queue
|
137 |
+
|
138 |
+
_utils.signal_handling._set_worker_pids(
|
139 |
+
id(self), tuple(w.pid for w in self.workers)
|
140 |
+
)
|
141 |
+
_utils.signal_handling._set_SIGCHLD_handler()
|
142 |
+
self.worker_pids_set = True
|
143 |
+
|
144 |
+
for _ in range(2 * self.num_workers):
|
145 |
+
self._put_indices()
|
146 |
+
|
147 |
+
|
148 |
+
class MSDataLoader(DataLoader):
|
149 |
+
|
150 |
+
def __init__(self, cfg, *args, **kwargs):
|
151 |
+
super(MSDataLoader, self).__init__(
|
152 |
+
*args, **kwargs, num_workers=cfg.n_threads
|
153 |
+
)
|
154 |
+
self.scale = cfg.scale
|
155 |
+
|
156 |
+
def __iter__(self):
|
157 |
+
return _MSDataLoaderIter(self)
|
158 |
+
|
CAR/code/demo.sb
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
# PANet Train
|
3 |
+
#python main.py --n_GPUs 2 --batch_size 16 --lr 1e-4 --decay 200-400-600-800 ---save_models --model PANET --scale 10 --patch_size 48 --save PANET_Q10 --n_feats 64 --data_train DIV2K --chop
|
4 |
+
|
5 |
+
# Test
|
6 |
+
python main.py --model PANET --save_results --n_GPUs 1 --chop --data_test classic5+LIVE1 --scale 40 --n_resblocks 80 --n_feats 64 --pre_train ../Q40.pt --test_only
|
CAR/code/demo.sh
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# EDSR baseline model (x2) + JPEG augmentation
|
2 |
+
python main.py --n_GPUs 2 --batch_size 16 --reset --save_models --model EDSRL1T --scale 1 --patch_size 48 --save EDSR_R16F64P48_CSAL1TK20 --n_feats 64 --data_train DIV2K --chop
|
3 |
+
#python main.py --model EDSR --scale 2 --patch_size 96 --save edsr_baseline_x2 --reset --data_train DIV2K+DIV2K-Q75 --data_test DIV2K+DIV2K-Q75
|
4 |
+
|
5 |
+
# EDSR baseline model (x3) - from EDSR baseline model (x2)
|
6 |
+
#python main.py --model EDSR --scale 3 --patch_size 144 --save edsr_baseline_x3 --reset --pre_train [pre-trained EDSR_baseline_x2 model dir]
|
7 |
+
|
8 |
+
# EDSR baseline model (x4) - from EDSR baseline model (x2)
|
9 |
+
#python main.py --model EDSR --scale 4 --save edsr_baseline_x4 --reset --pre_train pre-trained EDSR_baseline_x2 model dir
|
10 |
+
|
11 |
+
# EDSR in the paper (x2)
|
12 |
+
#python main.py --model EDSR --scale 2 --save edsr_x2 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --reset
|
13 |
+
|
14 |
+
# EDSR in the paper (x3) - from EDSR (x2)
|
15 |
+
#python main.py --model EDSR --scale 3 --save edsr_x3 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --reset --pre_train [pre-trained EDSR model dir]
|
16 |
+
|
17 |
+
# EDSR in the paper (x4) - from EDSR (x2)
|
18 |
+
#python main.py --model EDSR --scale 4 --save edsr_x4 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --reset --pre_train [pre-trained EDSR_x2 model dir]
|
19 |
+
|
20 |
+
# MDSR baseline model
|
21 |
+
#python main.py --template MDSR --model MDSR --scale 2+3+4 --save MDSR_baseline --reset --save_models
|
22 |
+
|
23 |
+
# MDSR in the paper
|
24 |
+
#python main.py --template MDSR --model MDSR --scale 2+3+4 --n_resblocks 80 --save MDSR --reset --save_models
|
25 |
+
|
26 |
+
# Standard benchmarks (Ex. EDSR_baseline_x4)
|
27 |
+
#python main.py --model EDSRL1 --save_results --n_GPUs 2 --chop --data_test Kodak24+CBSD68+Urban100 --scale 1 --pre_train ../experiment/EDSR_R16F64P48_CSALK20/model/model_best.pt --test_only
|
28 |
+
|
29 |
+
#python main.py --data_test Set5+Set14+B100+Urban100+DIV2K --data_range 801-900 --scale 4 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --pre_train download --test_only --self_ensemble
|
30 |
+
|
31 |
+
# Test your own images
|
32 |
+
#python main.py --data_test Demo --scale 4 --pre_train download --test_only --save_results
|
33 |
+
|
34 |
+
# Advanced - Test with JPEG images
|
35 |
+
#python main.py --model MDSR --data_test Demo --scale 2+3+4 --pre_train download --test_only --save_results
|
36 |
+
|
37 |
+
# Advanced - Training with adversarial loss
|
38 |
+
#python main.py --template GAN --scale 4 --save edsr_gan --reset --patch_size 96 --loss 5*VGG54+0.15*GAN --pre_train download
|
39 |
+
|
40 |
+
# RDN BI model (x2)
|
41 |
+
#python3.6 main.py --scale 2 --save RDN_D16C8G64_BIx2 --model RDN --epochs 200 --batch_size 16 --data_range 801-805 --patch_size 64 --reset
|
42 |
+
# RDN BI model (x3)
|
43 |
+
#python3.6 main.py --scale 3 --save RDN_D16C8G64_BIx3 --model RDN --epochs 200 --batch_size 16 --data_range 801-805 --patch_size 96 --reset
|
44 |
+
# RDN BI model (x4)
|
45 |
+
#python3.6 main.py --scale 4 --save RDN_D16C8G64_BIx4 --model RDN --epochs 200 --batch_size 16 --data_range 801-805 --patch_size 128 --reset
|
46 |
+
|
47 |
+
# RCAN_BIX2_G10R20P48, input=48x48, output=96x96
|
48 |
+
# pretrained model can be downloaded from https://www.dropbox.com/s/mjbcqkd4nwhr6nu/models_ECCV2018RCAN.zip?dl=0
|
49 |
+
#python main.py --template RCAN --save RCAN_BIX2_G10R20P48 --scale 2 --reset --save_results --patch_size 96
|
50 |
+
# RCAN_BIX3_G10R20P48, input=48x48, output=144x144
|
51 |
+
#python main.py --template RCAN --save RCAN_BIX3_G10R20P48 --scale 3 --reset --save_results --patch_size 144 --pre_train ../experiment/model/RCAN_BIX2.pt
|
52 |
+
# RCAN_BIX4_G10R20P48, input=48x48, output=192x192
|
53 |
+
#python main.py --template RCAN --save RCAN_BIX4_G10R20P48 --scale 4 --reset --save_results --patch_size 192 --pre_train ../experiment/model/RCAN_BIX2.pt
|
54 |
+
# RCAN_BIX8_G10R20P48, input=48x48, output=384x384
|
55 |
+
#python main.py --template RCAN --save RCAN_BIX8_G10R20P48 --scale 8 --reset --save_results --patch_size 384 --pre_train ../experiment/model/RCAN_BIX2.pt
|
56 |
+
|
CAR/code/lambda_networks/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from lambda_networks.lambda_networks import LambdaLayer
|
2 |
+
from lambda_networks.lambda_networks import Recursion
|
3 |
+
from lambda_networks.rlambda_networks import RLambdaLayer
|
4 |
+
λLayer = LambdaLayer
|
CAR/code/lambda_networks/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (338 Bytes). View file
|
|
CAR/code/lambda_networks/__pycache__/lambda_networks.cpython-37.pyc
ADDED
Binary file (5.75 kB). View file
|
|
CAR/code/lambda_networks/__pycache__/rlambda_networks.cpython-37.pyc
ADDED
Binary file (2.92 kB). View file
|
|
CAR/code/lambda_networks/lambda_networks.py
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn, einsum
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from einops import rearrange
|
5 |
+
|
6 |
+
# my layer normalization
|
7 |
+
|
8 |
+
class LayerNorm(nn.Module):
|
9 |
+
def __init__(self, eps= 1e-5):
|
10 |
+
super(LayerNorm, self).__init__()
|
11 |
+
self.eps = eps
|
12 |
+
def forward(self, input):
|
13 |
+
shape=tuple(input.size()[1:])
|
14 |
+
return F.layer_norm(input, shape, eps=self.eps)
|
15 |
+
def extra_repr(self):
|
16 |
+
return f'eps={self.eps}'
|
17 |
+
|
18 |
+
# helpers functions
|
19 |
+
|
20 |
+
def exists(val):
|
21 |
+
return val is not None
|
22 |
+
|
23 |
+
def default(val, d):
|
24 |
+
return val if exists(val) else d
|
25 |
+
|
26 |
+
# lambda layer
|
27 |
+
|
28 |
+
class LambdaLayer(nn.Module):
|
29 |
+
def __init__(
|
30 |
+
self,
|
31 |
+
dim,
|
32 |
+
*,
|
33 |
+
dim_k,
|
34 |
+
n = None,
|
35 |
+
r = None,
|
36 |
+
heads = 4,
|
37 |
+
dim_out = None,
|
38 |
+
dim_u = 1,
|
39 |
+
normalization="batch"):
|
40 |
+
super().__init__()
|
41 |
+
dim_out = default(dim_out, dim)
|
42 |
+
self.u = dim_u # intra-depth dimension
|
43 |
+
self.heads = heads
|
44 |
+
|
45 |
+
assert (dim_out % heads) == 0, 'values dimension must be divisible by number of heads for multi-head query'
|
46 |
+
dim_v = dim_out // heads
|
47 |
+
|
48 |
+
self.to_q = nn.Conv2d(dim, dim_k * heads, 1, bias = False)
|
49 |
+
self.to_k = nn.Conv2d(dim, dim_k * dim_u, 1, bias = False)
|
50 |
+
self.to_v = nn.Conv2d(dim, dim_v * dim_u, 1, bias = False)
|
51 |
+
print(f"using {normalization} in lambda layer")
|
52 |
+
if normalization=="none":
|
53 |
+
self.norm_q = nn.Identity(dim_k * heads)
|
54 |
+
self.norm_v = nn.Identity(dim_v * dim_u)
|
55 |
+
elif normalization=="instance":
|
56 |
+
self.norm_q = nn.InstanceNorm2d(dim_k * heads)
|
57 |
+
self.norm_v = nn.InstanceNorm2d(dim_v * dim_u)
|
58 |
+
elif normalization=="layer":
|
59 |
+
self.norm_q = LayerNorm()
|
60 |
+
self.norm_v = LayerNorm()
|
61 |
+
else:
|
62 |
+
self.norm_q = nn.BatchNorm2d(dim_k * heads)
|
63 |
+
self.norm_v = nn.BatchNorm2d(dim_v * dim_u)
|
64 |
+
print(f"using BN in lambda layer?")
|
65 |
+
|
66 |
+
self.local_contexts = exists(r)
|
67 |
+
if exists(r):
|
68 |
+
assert (r % 2) == 1, 'Receptive kernel size should be odd'
|
69 |
+
self.pos_conv = nn.Conv3d(dim_u, dim_k, (1, r, r), padding = (0, r // 2, r // 2))
|
70 |
+
else:
|
71 |
+
assert exists(n), 'You must specify the total sequence length (h x w)'
|
72 |
+
self.pos_emb = nn.Parameter(torch.randn(n, n, dim_k, dim_u))
|
73 |
+
|
74 |
+
|
75 |
+
def forward(self, x):
|
76 |
+
b, c, hh, ww, u, h = *x.shape, self.u, self.heads
|
77 |
+
|
78 |
+
q = self.to_q(x)
|
79 |
+
k = self.to_k(x)
|
80 |
+
v = self.to_v(x)
|
81 |
+
|
82 |
+
q = self.norm_q(q)
|
83 |
+
v = self.norm_v(v)
|
84 |
+
|
85 |
+
q = rearrange(q, 'b (h k) hh ww -> b h k (hh ww)', h = h)
|
86 |
+
k = rearrange(k, 'b (u k) hh ww -> b u k (hh ww)', u = u)
|
87 |
+
v = rearrange(v, 'b (u v) hh ww -> b u v (hh ww)', u = u)
|
88 |
+
|
89 |
+
k = k.softmax(dim=-1)
|
90 |
+
|
91 |
+
λc = einsum('b u k m, b u v m -> b k v', k, v)
|
92 |
+
Yc = einsum('b h k n, b k v -> b h v n', q, λc)
|
93 |
+
|
94 |
+
if self.local_contexts:
|
95 |
+
v = rearrange(v, 'b u v (hh ww) -> b u v hh ww', hh = hh, ww = ww)
|
96 |
+
λp = self.pos_conv(v)
|
97 |
+
Yp = einsum('b h k n, b k v n -> b h v n', q, λp.flatten(3))
|
98 |
+
else:
|
99 |
+
λp = einsum('n m k u, b u v m -> b n k v', self.pos_emb, v)
|
100 |
+
Yp = einsum('b h k n, b n k v -> b h v n', q, λp)
|
101 |
+
|
102 |
+
Y = Yc + Yp
|
103 |
+
out = rearrange(Y, 'b h v (hh ww) -> b (h v) hh ww', hh = hh, ww = ww)
|
104 |
+
return out
|
105 |
+
|
106 |
+
|
107 |
+
# i'm not sure whether this will work or not
|
108 |
+
class Recursion(nn.Module):
|
109 |
+
def __init__(self, N: int, hidden_dim:int=64):
|
110 |
+
super(Recursion,self).__init__()
|
111 |
+
self.N = N
|
112 |
+
self.lambdaNxN_identity = LambdaLayer(dim=hidden_dim, dim_out=hidden_dim, n=N * N, dim_k=16, heads=2, dim_u=1)
|
113 |
+
# merge upstream information here
|
114 |
+
self.lambdaNxN_merge = LambdaLayer(dim=2*hidden_dim, dim_out=hidden_dim, n=N * N, dim_k=16, heads=2, dim_u=1)
|
115 |
+
self.downscale_conv = nn.Conv2d(hidden_dim, hidden_dim, kernel_size=N, stride=N)
|
116 |
+
self.upscale_conv = nn.Conv2d(hidden_dim, hidden_dim * N * N, kernel_size=3,padding=1)
|
117 |
+
self.pixel_shuffle = nn.PixelShuffle(N)
|
118 |
+
|
119 |
+
def forward(self, x: torch.Tensor):
|
120 |
+
N = self.N
|
121 |
+
|
122 |
+
def to_patch(blocks:torch.Tensor)->torch.Tensor:
|
123 |
+
shape = blocks.shape
|
124 |
+
blocks_patch = F.unfold(blocks, kernel_size=N, stride=N)
|
125 |
+
blocks_patch = blocks_patch.view(shape[0], shape[1], N, N, -1)
|
126 |
+
num_patch = blocks_patch.shape[-1]
|
127 |
+
blocks_patch = blocks_patch.permute(0, 4, 1, 2, 3).reshape(-1, shape[1], N, N).contiguous()
|
128 |
+
return blocks_patch, num_patch
|
129 |
+
|
130 |
+
def combine_patch(processed_patch,shape,num_patch):
|
131 |
+
processed_patch = processed_patch.reshape(shape[0], num_patch, shape[1], N, N)
|
132 |
+
processed_patch=processed_patch.permute(0, 2, 3, 4, 1).reshape(shape[0],shape[1] * N * N,num_patch).contiguous()
|
133 |
+
processed=F.fold(processed_patch,output_size=(shape[-2],shape[-1]),kernel_size=N,stride=N)
|
134 |
+
return processed
|
135 |
+
|
136 |
+
def process(blocks:torch.Tensor)->torch.Tensor:
|
137 |
+
shape = blocks.shape
|
138 |
+
if blocks.shape[-1] == N:
|
139 |
+
processed = self.lambdaNxN_identity(blocks)
|
140 |
+
return processed
|
141 |
+
# to NxN patchs
|
142 |
+
blocks_patch,num_patch=to_patch(blocks)
|
143 |
+
# pass through identity
|
144 |
+
processed_patch = self.lambdaNxN_identity(blocks_patch)
|
145 |
+
# back to HxW
|
146 |
+
processed=combine_patch(processed_patch,shape,num_patch)
|
147 |
+
# get feedback
|
148 |
+
feedback = process(self.downscale_conv(processed))
|
149 |
+
# upscale feedback
|
150 |
+
upscale_feedback = self.upscale_conv(feedback)
|
151 |
+
upscale_feedback=self.pixel_shuffle(upscale_feedback)
|
152 |
+
# combine results
|
153 |
+
combined = torch.cat([processed, upscale_feedback], dim=1)
|
154 |
+
combined_shape=combined.shape
|
155 |
+
combined_patch,num_patch=to_patch(combined)
|
156 |
+
combined_patch_reduced = self.lambdaNxN_merge(combined_patch)
|
157 |
+
ret_shape=(combined_shape[0],combined_shape[1]//2,combined_shape[2],combined_shape[3])
|
158 |
+
ret=combine_patch(combined_patch_reduced,ret_shape,num_patch)
|
159 |
+
return ret
|
160 |
+
|
161 |
+
return process(x)
|
CAR/code/lambda_networks/rlambda_networks.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn, einsum
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from einops import rearrange
|
5 |
+
|
6 |
+
|
7 |
+
# helpers functions
|
8 |
+
|
9 |
+
def exists(val):
|
10 |
+
return val is not None
|
11 |
+
|
12 |
+
|
13 |
+
def default(val, d):
|
14 |
+
return val if exists(val) else d
|
15 |
+
|
16 |
+
|
17 |
+
# lambda layer
|
18 |
+
|
19 |
+
class RLambdaLayer(nn.Module):
|
20 |
+
def __init__(
|
21 |
+
self,
|
22 |
+
dim,
|
23 |
+
*,
|
24 |
+
dim_k,
|
25 |
+
n=None,
|
26 |
+
r=None,
|
27 |
+
heads=4,
|
28 |
+
dim_out=None,
|
29 |
+
dim_u=1,
|
30 |
+
recurrence=None
|
31 |
+
):
|
32 |
+
super().__init__()
|
33 |
+
dim_out = default(dim_out, dim)
|
34 |
+
self.u = dim_u # intra-depth dimension
|
35 |
+
self.heads = heads
|
36 |
+
|
37 |
+
assert (dim_out % heads) == 0, 'values dimension must be divisible by number of heads for multi-head query'
|
38 |
+
dim_v = dim_out // heads
|
39 |
+
|
40 |
+
self.to_q = nn.Conv2d(dim, dim_k * heads, 1, bias=False)
|
41 |
+
self.to_k = nn.Conv2d(dim, dim_k * dim_u, 1, bias=False)
|
42 |
+
self.to_v = nn.Conv2d(dim, dim_v * dim_u, 1, bias=False)
|
43 |
+
|
44 |
+
self.norm_q = nn.BatchNorm2d(dim_k * heads)
|
45 |
+
self.norm_v = nn.BatchNorm2d(dim_v * dim_u)
|
46 |
+
|
47 |
+
self.local_contexts = exists(r)
|
48 |
+
self.recurrence = recurrence
|
49 |
+
if exists(r):
|
50 |
+
assert (r % 2) == 1, 'Receptive kernel size should be odd'
|
51 |
+
self.pos_conv = nn.Conv3d(dim_u, dim_k, (1, r, r), padding=(0, r // 2, r // 2))
|
52 |
+
else:
|
53 |
+
assert exists(n), 'You must specify the total sequence length (h x w)'
|
54 |
+
self.pos_emb = nn.Parameter(torch.randn(n, n, dim_k, dim_u))
|
55 |
+
|
56 |
+
def apply_lambda(self, lambda_c, lambda_p, x):
|
57 |
+
b, c, hh, ww, u, h = *x.shape, self.u, self.heads
|
58 |
+
q = self.to_q(x)
|
59 |
+
q = self.norm_q(q)
|
60 |
+
q = rearrange(q, 'b (h k) hh ww -> b h k (hh ww)', h=h)
|
61 |
+
Yc = einsum('b h k n, b k v -> b h v n', q, lambda_c)
|
62 |
+
if self.local_contexts:
|
63 |
+
Yp = einsum('b h k n, b k v n -> b h v n', q, lambda_p.flatten(3))
|
64 |
+
else:
|
65 |
+
Yp = einsum('b h k n, b n k v -> b h v n', q, lambda_p)
|
66 |
+
Y = Yc + Yp
|
67 |
+
out = rearrange(Y, 'b h v (hh ww) -> b (h v) hh ww', hh=hh, ww=ww)
|
68 |
+
return out
|
69 |
+
|
70 |
+
def forward(self, x):
|
71 |
+
b, c, hh, ww, u, h = *x.shape, self.u, self.heads
|
72 |
+
|
73 |
+
k = self.to_k(x)
|
74 |
+
v = self.to_v(x)
|
75 |
+
|
76 |
+
v = self.norm_v(v)
|
77 |
+
|
78 |
+
k = rearrange(k, 'b (u k) hh ww -> b u k (hh ww)', u=u)
|
79 |
+
v = rearrange(v, 'b (u v) hh ww -> b u v (hh ww)', u=u)
|
80 |
+
|
81 |
+
k = k.softmax(dim=-1)
|
82 |
+
|
83 |
+
λc = einsum('b u k m, b u v m -> b k v', k, v)
|
84 |
+
|
85 |
+
if self.local_contexts:
|
86 |
+
v = rearrange(v, 'b u v (hh ww) -> b u v hh ww', hh=hh, ww=ww)
|
87 |
+
λp = self.pos_conv(v)
|
88 |
+
else:
|
89 |
+
λp = einsum('n m k u, b u v m -> b n k v', self.pos_emb, v)
|
90 |
+
out = x
|
91 |
+
for i in range(self.recurrence):
|
92 |
+
out = self.apply_lambda(λc, λp, out)
|
93 |
+
return out
|
CAR/code/loss/__init__.py
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from importlib import import_module
|
3 |
+
|
4 |
+
import matplotlib
|
5 |
+
matplotlib.use('Agg')
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
import torch.nn.functional as F
|
13 |
+
|
14 |
+
def sequence_loss(sr, hr, loss_func, gamma=0.8, max_val=None):
|
15 |
+
""" Loss function defined over sequence of flow predictions """
|
16 |
+
|
17 |
+
n_recurrence = len(sr)
|
18 |
+
total_loss = 0.0
|
19 |
+
buffer=[0.0]*n_recurrence
|
20 |
+
# exlude invalid pixels and extremely large diplacements
|
21 |
+
for i in range(n_recurrence):
|
22 |
+
i_weight = gamma**(n_recurrence - i - 1)
|
23 |
+
i_loss = loss_func(sr[i],hr)
|
24 |
+
buffer[i]=i_loss.item()
|
25 |
+
# total_loss += i_weight * (valid[:, None] * i_loss).mean()
|
26 |
+
total_loss += i_weight * (i_loss)
|
27 |
+
return total_loss,buffer
|
28 |
+
|
29 |
+
class Loss(nn.modules.loss._Loss):
|
30 |
+
def __init__(self, args, ckp):
|
31 |
+
super(Loss, self).__init__()
|
32 |
+
print('Preparing loss function:')
|
33 |
+
self.buffer=[0.0]*args.recurrence
|
34 |
+
self.n_GPUs = args.n_GPUs
|
35 |
+
self.loss = []
|
36 |
+
self.loss_module = nn.ModuleList()
|
37 |
+
for loss in args.loss.split('+'):
|
38 |
+
weight, loss_type = loss.split('*')
|
39 |
+
if loss_type == 'MSE':
|
40 |
+
loss_function = nn.MSELoss()
|
41 |
+
elif loss_type == 'L1':
|
42 |
+
loss_function = nn.L1Loss()
|
43 |
+
elif loss_type.find('VGG') >= 0:
|
44 |
+
module = import_module('loss.vgg')
|
45 |
+
loss_function = getattr(module, 'VGG')(
|
46 |
+
loss_type[3:],
|
47 |
+
rgb_range=args.rgb_range
|
48 |
+
)
|
49 |
+
elif loss_type.find('GAN') >= 0:
|
50 |
+
module = import_module('loss.adversarial')
|
51 |
+
loss_function = getattr(module, 'Adversarial')(
|
52 |
+
args,
|
53 |
+
loss_type
|
54 |
+
)
|
55 |
+
|
56 |
+
self.loss.append({
|
57 |
+
'type': loss_type,
|
58 |
+
'weight': float(weight),
|
59 |
+
'function': loss_function}
|
60 |
+
)
|
61 |
+
if loss_type.find('GAN') >= 0:
|
62 |
+
self.loss.append({'type': 'DIS', 'weight': 1, 'function': None})
|
63 |
+
|
64 |
+
if len(self.loss) > 1:
|
65 |
+
self.loss.append({'type': 'Total', 'weight': 0, 'function': None})
|
66 |
+
|
67 |
+
for l in self.loss:
|
68 |
+
if l['function'] is not None:
|
69 |
+
print('{:.3f} * {}'.format(l['weight'], l['type']))
|
70 |
+
# self.loss_module.append(l['function'])
|
71 |
+
|
72 |
+
self.log = torch.Tensor()
|
73 |
+
|
74 |
+
device = torch.device('cpu' if args.cpu else 'cuda')
|
75 |
+
self.loss_module.to(device)
|
76 |
+
if args.precision == 'half': self.loss_module.half()
|
77 |
+
if not args.cpu and args.n_GPUs > 1:
|
78 |
+
self.loss_module = nn.DataParallel(
|
79 |
+
self.loss_module, range(args.n_GPUs)
|
80 |
+
)
|
81 |
+
|
82 |
+
if args.load != '': self.load(ckp.dir, cpu=args.cpu)
|
83 |
+
|
84 |
+
def forward(self, sr, hr):
|
85 |
+
losses = []
|
86 |
+
for i, l in enumerate(self.loss):
|
87 |
+
if l['function'] is not None:
|
88 |
+
if isinstance(sr,list):
|
89 |
+
# weights=[0.32,0.08,0.02,0.01,0.005]
|
90 |
+
# weights=weights[::-1]
|
91 |
+
# weights=[0.01,0.02,0.08,0.32]
|
92 |
+
# self.buffer=[]
|
93 |
+
effective_loss,buffer_lst=sequence_loss(sr,hr,l['function'])
|
94 |
+
# for k in range(len(sr)):
|
95 |
+
# loss=l['function'](sr[k], hr)
|
96 |
+
# self.buffer.append(loss.item())
|
97 |
+
# effective_loss=loss*weights[k]*l['weight']
|
98 |
+
losses.append(effective_loss)
|
99 |
+
self.buffer=buffer_lst
|
100 |
+
self.log[-1, i] += effective_loss.item()
|
101 |
+
else:
|
102 |
+
loss = l['function'](sr, hr)
|
103 |
+
effective_loss = l['weight'] * loss
|
104 |
+
losses.append(effective_loss)
|
105 |
+
self.buffer[0]=effective_loss.item()
|
106 |
+
self.log[-1, i] += effective_loss.item()
|
107 |
+
elif l['type'] == 'DIS':
|
108 |
+
self.log[-1, i] += self.loss[i - 1]['function'].loss
|
109 |
+
|
110 |
+
loss_sum = sum(losses)
|
111 |
+
if len(self.loss) > 1:
|
112 |
+
self.log[-1, -1] += loss_sum.item()
|
113 |
+
|
114 |
+
return loss_sum
|
115 |
+
|
116 |
+
def step(self):
|
117 |
+
for l in self.get_loss_module():
|
118 |
+
if hasattr(l, 'scheduler'):
|
119 |
+
l.scheduler.step()
|
120 |
+
|
121 |
+
def start_log(self):
|
122 |
+
self.log = torch.cat((self.log, torch.zeros(1, len(self.loss))))
|
123 |
+
|
124 |
+
def end_log(self, n_batches):
|
125 |
+
self.log[-1].div_(n_batches)
|
126 |
+
|
127 |
+
def display_loss(self, batch):
|
128 |
+
n_samples = batch + 1
|
129 |
+
log = []
|
130 |
+
for l, c in zip(self.loss, self.log[-1]):
|
131 |
+
log.append('[{}: {:.4f}]'.format(l['type'], c / n_samples))
|
132 |
+
|
133 |
+
return ''.join(log)
|
134 |
+
|
135 |
+
def plot_loss(self, apath, epoch):
|
136 |
+
axis = np.linspace(1, epoch, epoch)
|
137 |
+
for i, l in enumerate(self.loss):
|
138 |
+
label = '{} Loss'.format(l['type'])
|
139 |
+
fig = plt.figure()
|
140 |
+
plt.title(label)
|
141 |
+
plt.plot(axis, self.log[:, i].numpy(), label=label)
|
142 |
+
plt.legend()
|
143 |
+
plt.xlabel('Epochs')
|
144 |
+
plt.ylabel('Loss')
|
145 |
+
plt.grid(True)
|
146 |
+
plt.savefig(os.path.join(apath, 'loss_{}.pdf'.format(l['type'])))
|
147 |
+
plt.close(fig)
|
148 |
+
|
149 |
+
def get_loss_module(self):
|
150 |
+
if self.n_GPUs == 1:
|
151 |
+
return self.loss_module
|
152 |
+
else:
|
153 |
+
return self.loss_module.module
|
154 |
+
|
155 |
+
def save(self, apath):
|
156 |
+
torch.save(self.state_dict(), os.path.join(apath, 'loss.pt'))
|
157 |
+
torch.save(self.log, os.path.join(apath, 'loss_log.pt'))
|
158 |
+
|
159 |
+
def load(self, apath, cpu=False):
|
160 |
+
if cpu:
|
161 |
+
kwargs = {'map_location': lambda storage, loc: storage}
|
162 |
+
else:
|
163 |
+
kwargs = {}
|
164 |
+
|
165 |
+
self.load_state_dict(torch.load(
|
166 |
+
os.path.join(apath, 'loss.pt'),
|
167 |
+
**kwargs
|
168 |
+
))
|
169 |
+
self.log = torch.load(os.path.join(apath, 'loss_log.pt'))
|
170 |
+
for l in self.get_loss_module():
|
171 |
+
if hasattr(l, 'scheduler'):
|
172 |
+
for _ in range(len(self.log)): l.scheduler.step()
|
173 |
+
|
CAR/code/loss/__pycache__/__init__.cpython-36.pyc
ADDED
Binary file (4.57 kB). View file
|
|
CAR/code/loss/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (5.13 kB). View file
|
|
CAR/code/loss/adversarial.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import utility
|
2 |
+
from types import SimpleNamespace
|
3 |
+
|
4 |
+
from model import common
|
5 |
+
from loss import discriminator
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
import torch.optim as optim
|
11 |
+
|
12 |
+
class Adversarial(nn.Module):
|
13 |
+
def __init__(self, args, gan_type):
|
14 |
+
super(Adversarial, self).__init__()
|
15 |
+
self.gan_type = gan_type
|
16 |
+
self.gan_k = args.gan_k
|
17 |
+
self.dis = discriminator.Discriminator(args)
|
18 |
+
if gan_type == 'WGAN_GP':
|
19 |
+
# see https://arxiv.org/pdf/1704.00028.pdf pp.4
|
20 |
+
optim_dict = {
|
21 |
+
'optimizer': 'ADAM',
|
22 |
+
'betas': (0, 0.9),
|
23 |
+
'epsilon': 1e-8,
|
24 |
+
'lr': 1e-5,
|
25 |
+
'weight_decay': args.weight_decay,
|
26 |
+
'decay': args.decay,
|
27 |
+
'gamma': args.gamma
|
28 |
+
}
|
29 |
+
optim_args = SimpleNamespace(**optim_dict)
|
30 |
+
else:
|
31 |
+
optim_args = args
|
32 |
+
|
33 |
+
self.optimizer = utility.make_optimizer(optim_args, self.dis)
|
34 |
+
|
35 |
+
def forward(self, fake, real):
|
36 |
+
# updating discriminator...
|
37 |
+
self.loss = 0
|
38 |
+
fake_detach = fake.detach() # do not backpropagate through G
|
39 |
+
for _ in range(self.gan_k):
|
40 |
+
self.optimizer.zero_grad()
|
41 |
+
# d: B x 1 tensor
|
42 |
+
d_fake = self.dis(fake_detach)
|
43 |
+
d_real = self.dis(real)
|
44 |
+
retain_graph = False
|
45 |
+
if self.gan_type == 'GAN':
|
46 |
+
loss_d = self.bce(d_real, d_fake)
|
47 |
+
elif self.gan_type.find('WGAN') >= 0:
|
48 |
+
loss_d = (d_fake - d_real).mean()
|
49 |
+
if self.gan_type.find('GP') >= 0:
|
50 |
+
epsilon = torch.rand_like(fake).view(-1, 1, 1, 1)
|
51 |
+
hat = fake_detach.mul(1 - epsilon) + real.mul(epsilon)
|
52 |
+
hat.requires_grad = True
|
53 |
+
d_hat = self.dis(hat)
|
54 |
+
gradients = torch.autograd.grad(
|
55 |
+
outputs=d_hat.sum(), inputs=hat,
|
56 |
+
retain_graph=True, create_graph=True, only_inputs=True
|
57 |
+
)[0]
|
58 |
+
gradients = gradients.view(gradients.size(0), -1)
|
59 |
+
gradient_norm = gradients.norm(2, dim=1)
|
60 |
+
gradient_penalty = 10 * gradient_norm.sub(1).pow(2).mean()
|
61 |
+
loss_d += gradient_penalty
|
62 |
+
# from ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks
|
63 |
+
elif self.gan_type == 'RGAN':
|
64 |
+
better_real = d_real - d_fake.mean(dim=0, keepdim=True)
|
65 |
+
better_fake = d_fake - d_real.mean(dim=0, keepdim=True)
|
66 |
+
loss_d = self.bce(better_real, better_fake)
|
67 |
+
retain_graph = True
|
68 |
+
|
69 |
+
# Discriminator update
|
70 |
+
self.loss += loss_d.item()
|
71 |
+
loss_d.backward(retain_graph=retain_graph)
|
72 |
+
self.optimizer.step()
|
73 |
+
|
74 |
+
if self.gan_type == 'WGAN':
|
75 |
+
for p in self.dis.parameters():
|
76 |
+
p.data.clamp_(-1, 1)
|
77 |
+
|
78 |
+
self.loss /= self.gan_k
|
79 |
+
|
80 |
+
# updating generator...
|
81 |
+
d_fake_bp = self.dis(fake) # for backpropagation, use fake as it is
|
82 |
+
if self.gan_type == 'GAN':
|
83 |
+
label_real = torch.ones_like(d_fake_bp)
|
84 |
+
loss_g = F.binary_cross_entropy_with_logits(d_fake_bp, label_real)
|
85 |
+
elif self.gan_type.find('WGAN') >= 0:
|
86 |
+
loss_g = -d_fake_bp.mean()
|
87 |
+
elif self.gan_type == 'RGAN':
|
88 |
+
better_real = d_real - d_fake_bp.mean(dim=0, keepdim=True)
|
89 |
+
better_fake = d_fake_bp - d_real.mean(dim=0, keepdim=True)
|
90 |
+
loss_g = self.bce(better_fake, better_real)
|
91 |
+
|
92 |
+
# Generator loss
|
93 |
+
return loss_g
|
94 |
+
|
95 |
+
def state_dict(self, *args, **kwargs):
|
96 |
+
state_discriminator = self.dis.state_dict(*args, **kwargs)
|
97 |
+
state_optimizer = self.optimizer.state_dict()
|
98 |
+
|
99 |
+
return dict(**state_discriminator, **state_optimizer)
|
100 |
+
|
101 |
+
def bce(self, real, fake):
|
102 |
+
label_real = torch.ones_like(real)
|
103 |
+
label_fake = torch.zeros_like(fake)
|
104 |
+
bce_real = F.binary_cross_entropy_with_logits(real, label_real)
|
105 |
+
bce_fake = F.binary_cross_entropy_with_logits(fake, label_fake)
|
106 |
+
bce_loss = bce_real + bce_fake
|
107 |
+
return bce_loss
|
108 |
+
|
109 |
+
# Some references
|
110 |
+
# https://github.com/kuc2477/pytorch-wgan-gp/blob/master/model.py
|
111 |
+
# OR
|
112 |
+
# https://github.com/caogang/wgan-gp/blob/master/gan_cifar10.py
|
CAR/code/loss/discriminator.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from model import common
|
2 |
+
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
class Discriminator(nn.Module):
|
6 |
+
'''
|
7 |
+
output is not normalized
|
8 |
+
'''
|
9 |
+
def __init__(self, args):
|
10 |
+
super(Discriminator, self).__init__()
|
11 |
+
|
12 |
+
in_channels = args.n_colors
|
13 |
+
out_channels = 64
|
14 |
+
depth = 7
|
15 |
+
|
16 |
+
def _block(_in_channels, _out_channels, stride=1):
|
17 |
+
return nn.Sequential(
|
18 |
+
nn.Conv2d(
|
19 |
+
_in_channels,
|
20 |
+
_out_channels,
|
21 |
+
3,
|
22 |
+
padding=1,
|
23 |
+
stride=stride,
|
24 |
+
bias=False
|
25 |
+
),
|
26 |
+
nn.BatchNorm2d(_out_channels),
|
27 |
+
nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
28 |
+
)
|
29 |
+
|
30 |
+
m_features = [_block(in_channels, out_channels)]
|
31 |
+
for i in range(depth):
|
32 |
+
in_channels = out_channels
|
33 |
+
if i % 2 == 1:
|
34 |
+
stride = 1
|
35 |
+
out_channels *= 2
|
36 |
+
else:
|
37 |
+
stride = 2
|
38 |
+
m_features.append(_block(in_channels, out_channels, stride=stride))
|
39 |
+
|
40 |
+
patch_size = args.patch_size // (2**((depth + 1) // 2))
|
41 |
+
m_classifier = [
|
42 |
+
nn.Linear(out_channels * patch_size**2, 1024),
|
43 |
+
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
44 |
+
nn.Linear(1024, 1)
|
45 |
+
]
|
46 |
+
|
47 |
+
self.features = nn.Sequential(*m_features)
|
48 |
+
self.classifier = nn.Sequential(*m_classifier)
|
49 |
+
|
50 |
+
def forward(self, x):
|
51 |
+
features = self.features(x)
|
52 |
+
output = self.classifier(features.view(features.size(0), -1))
|
53 |
+
|
54 |
+
return output
|
55 |
+
|
CAR/code/loss/vgg.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from model import common
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
import torchvision.models as models
|
7 |
+
|
8 |
+
class VGG(nn.Module):
|
9 |
+
def __init__(self, conv_index, rgb_range=1):
|
10 |
+
super(VGG, self).__init__()
|
11 |
+
vgg_features = models.vgg19(pretrained=True).features
|
12 |
+
modules = [m for m in vgg_features]
|
13 |
+
if conv_index.find('22') >= 0:
|
14 |
+
self.vgg = nn.Sequential(*modules[:8])
|
15 |
+
elif conv_index.find('54') >= 0:
|
16 |
+
self.vgg = nn.Sequential(*modules[:35])
|
17 |
+
|
18 |
+
vgg_mean = (0.485, 0.456, 0.406)
|
19 |
+
vgg_std = (0.229 * rgb_range, 0.224 * rgb_range, 0.225 * rgb_range)
|
20 |
+
self.sub_mean = common.MeanShift(rgb_range, vgg_mean, vgg_std)
|
21 |
+
for p in self.parameters():
|
22 |
+
p.requires_grad = False
|
23 |
+
|
24 |
+
def forward(self, sr, hr):
|
25 |
+
def _forward(x):
|
26 |
+
x = self.sub_mean(x)
|
27 |
+
x = self.vgg(x)
|
28 |
+
return x
|
29 |
+
|
30 |
+
vgg_sr = _forward(sr)
|
31 |
+
with torch.no_grad():
|
32 |
+
vgg_hr = _forward(hr.detach())
|
33 |
+
|
34 |
+
loss = F.mse_loss(vgg_sr, vgg_hr)
|
35 |
+
|
36 |
+
return loss
|
CAR/code/main.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
import utility
|
4 |
+
import data
|
5 |
+
import model
|
6 |
+
import loss
|
7 |
+
from option import args
|
8 |
+
from trainer import Trainer
|
9 |
+
|
10 |
+
torch.manual_seed(args.seed)
|
11 |
+
checkpoint = utility.checkpoint(args)
|
12 |
+
|
13 |
+
def main():
|
14 |
+
global model
|
15 |
+
if args.data_test == ['video']:
|
16 |
+
from videotester import VideoTester
|
17 |
+
model = model.Model(args,checkpoint)
|
18 |
+
print('total params: %.2fM' % (sum(p.numel() for p in model.parameters())/1000000.0))
|
19 |
+
t = VideoTester(args, model, checkpoint)
|
20 |
+
t.test()
|
21 |
+
else:
|
22 |
+
if checkpoint.ok:
|
23 |
+
loader = data.Data(args)
|
24 |
+
_model = model.Model(args, checkpoint)
|
25 |
+
print('total params:%.5fM' % (sum(p.numel() for p in _model.parameters())/1000000.0))
|
26 |
+
_loss = loss.Loss(args, checkpoint) if not args.test_only else None
|
27 |
+
t = Trainer(args, loader, _model, _loss, checkpoint)
|
28 |
+
while not t.terminate():
|
29 |
+
t.train()
|
30 |
+
t.test()
|
31 |
+
|
32 |
+
checkpoint.done()
|
33 |
+
|
34 |
+
if __name__ == '__main__':
|
35 |
+
main()
|
CAR/code/model/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2018 Sanghyun Son
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
CAR/code/model/__init__.py
ADDED
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from importlib import import_module
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from torch.autograd import Variable
|
7 |
+
|
8 |
+
class Model(nn.Module):
|
9 |
+
def __init__(self, args, ckp):
|
10 |
+
super(Model, self).__init__()
|
11 |
+
print('Making model...')
|
12 |
+
|
13 |
+
self.scale = args.scale
|
14 |
+
self.idx_scale = 0
|
15 |
+
self.self_ensemble = args.self_ensemble
|
16 |
+
self.chop = args.chop
|
17 |
+
self.precision = args.precision
|
18 |
+
self.cpu = args.cpu
|
19 |
+
self.device = torch.device('cpu' if args.cpu else 'cuda')
|
20 |
+
self.n_GPUs = args.n_GPUs
|
21 |
+
self.save_models = args.save_models
|
22 |
+
|
23 |
+
module = import_module('model.' + args.model.lower())
|
24 |
+
self.model = module.make_model(args).to(self.device)
|
25 |
+
if args.precision == 'half': self.model.half()
|
26 |
+
|
27 |
+
if not args.cpu and args.n_GPUs > 1:
|
28 |
+
self.model = nn.DataParallel(self.model, range(args.n_GPUs))
|
29 |
+
|
30 |
+
self.load(
|
31 |
+
ckp.dir,
|
32 |
+
pre_train=args.pre_train,
|
33 |
+
resume=args.resume,
|
34 |
+
cpu=args.cpu
|
35 |
+
)
|
36 |
+
print(self.model, file=ckp.log_file)
|
37 |
+
|
38 |
+
def forward(self, x, idx_scale):
|
39 |
+
self.idx_scale = idx_scale
|
40 |
+
target = self.get_model()
|
41 |
+
if hasattr(target, 'set_scale'):
|
42 |
+
target.set_scale(idx_scale)
|
43 |
+
|
44 |
+
if self.self_ensemble and not self.training:
|
45 |
+
if self.chop:
|
46 |
+
forward_function = self.forward_chop
|
47 |
+
else:
|
48 |
+
forward_function = self.model.forward
|
49 |
+
|
50 |
+
return self.forward_x8(x, forward_function)
|
51 |
+
elif self.chop and not self.training:
|
52 |
+
return self.forward_chop(x)
|
53 |
+
else:
|
54 |
+
return self.model(x)
|
55 |
+
|
56 |
+
def get_model(self):
|
57 |
+
if self.n_GPUs == 1:
|
58 |
+
return self.model
|
59 |
+
else:
|
60 |
+
return self.model.module
|
61 |
+
|
62 |
+
def state_dict(self, **kwargs):
|
63 |
+
target = self.get_model()
|
64 |
+
return target.state_dict(**kwargs)
|
65 |
+
|
66 |
+
def save(self, apath, epoch, is_best=False):
|
67 |
+
target = self.get_model()
|
68 |
+
torch.save(
|
69 |
+
target.state_dict(),
|
70 |
+
os.path.join(apath, 'model_latest.pt')
|
71 |
+
)
|
72 |
+
if is_best:
|
73 |
+
torch.save(
|
74 |
+
target.state_dict(),
|
75 |
+
os.path.join(apath, 'model_best.pt')
|
76 |
+
)
|
77 |
+
|
78 |
+
if self.save_models:
|
79 |
+
torch.save(
|
80 |
+
target.state_dict(),
|
81 |
+
os.path.join(apath, 'model_{}.pt'.format(epoch))
|
82 |
+
)
|
83 |
+
|
84 |
+
def load(self, apath, pre_train='.', resume=-1, cpu=False):
|
85 |
+
if cpu:
|
86 |
+
kwargs = {'map_location': lambda storage, loc: storage}
|
87 |
+
else:
|
88 |
+
kwargs = {}
|
89 |
+
|
90 |
+
if resume == -1:
|
91 |
+
self.get_model().load_state_dict(
|
92 |
+
torch.load(
|
93 |
+
os.path.join(apath,'model', 'model_latest.pt'),
|
94 |
+
**kwargs
|
95 |
+
),
|
96 |
+
strict=False
|
97 |
+
)
|
98 |
+
elif resume == 0:
|
99 |
+
if pre_train != '.':
|
100 |
+
print('Loading model from {}'.format(pre_train))
|
101 |
+
self.get_model().load_state_dict(
|
102 |
+
torch.load(pre_train, **kwargs),
|
103 |
+
strict=False
|
104 |
+
)
|
105 |
+
else:
|
106 |
+
self.get_model().load_state_dict(
|
107 |
+
torch.load(
|
108 |
+
os.path.join(apath, 'model', 'model_{}.pt'.format(resume)),
|
109 |
+
**kwargs
|
110 |
+
),
|
111 |
+
strict=False
|
112 |
+
)
|
113 |
+
|
114 |
+
def forward_chop(self, x, shave=10, min_size=6400):
|
115 |
+
scale = self.scale[self.idx_scale]
|
116 |
+
scale = 1
|
117 |
+
n_GPUs = min(self.n_GPUs, 4)
|
118 |
+
b, c, h, w = x.size()
|
119 |
+
h_half, w_half = h // 2, w // 2
|
120 |
+
h_size, w_size = h_half + shave, w_half + shave
|
121 |
+
lr_list = [
|
122 |
+
x[:, :, 0:h_size, 0:w_size],
|
123 |
+
x[:, :, 0:h_size, (w - w_size):w],
|
124 |
+
x[:, :, (h - h_size):h, 0:w_size],
|
125 |
+
x[:, :, (h - h_size):h, (w - w_size):w]]
|
126 |
+
|
127 |
+
if w_size * h_size < min_size:
|
128 |
+
sr_list = []
|
129 |
+
for i in range(0, 4, n_GPUs):
|
130 |
+
lr_batch = torch.cat(lr_list[i:(i + n_GPUs)], dim=0)
|
131 |
+
sr_batch = self.model(lr_batch)
|
132 |
+
sr_list.extend(sr_batch.chunk(n_GPUs, dim=0))
|
133 |
+
else:
|
134 |
+
sr_list = [
|
135 |
+
self.forward_chop(patch, shave=shave, min_size=min_size) \
|
136 |
+
for patch in lr_list
|
137 |
+
]
|
138 |
+
|
139 |
+
h, w = scale * h, scale * w
|
140 |
+
h_half, w_half = scale * h_half, scale * w_half
|
141 |
+
h_size, w_size = scale * h_size, scale * w_size
|
142 |
+
shave *= scale
|
143 |
+
|
144 |
+
output = x.new(b, c, h, w)
|
145 |
+
output[:, :, 0:h_half, 0:w_half] \
|
146 |
+
= sr_list[0][:, :, 0:h_half, 0:w_half]
|
147 |
+
output[:, :, 0:h_half, w_half:w] \
|
148 |
+
= sr_list[1][:, :, 0:h_half, (w_size - w + w_half):w_size]
|
149 |
+
output[:, :, h_half:h, 0:w_half] \
|
150 |
+
= sr_list[2][:, :, (h_size - h + h_half):h_size, 0:w_half]
|
151 |
+
output[:, :, h_half:h, w_half:w] \
|
152 |
+
= sr_list[3][:, :, (h_size - h + h_half):h_size, (w_size - w + w_half):w_size]
|
153 |
+
|
154 |
+
return output
|
155 |
+
|
156 |
+
def forward_x8(self, x, forward_function):
|
157 |
+
def _transform(v, op):
|
158 |
+
if self.precision != 'single': v = v.float()
|
159 |
+
|
160 |
+
v2np = v.data.cpu().numpy()
|
161 |
+
if op == 'v':
|
162 |
+
tfnp = v2np[:, :, :, ::-1].copy()
|
163 |
+
elif op == 'h':
|
164 |
+
tfnp = v2np[:, :, ::-1, :].copy()
|
165 |
+
elif op == 't':
|
166 |
+
tfnp = v2np.transpose((0, 1, 3, 2)).copy()
|
167 |
+
|
168 |
+
ret = torch.Tensor(tfnp).to(self.device)
|
169 |
+
if self.precision == 'half': ret = ret.half()
|
170 |
+
|
171 |
+
return ret
|
172 |
+
|
173 |
+
lr_list = [x]
|
174 |
+
for tf in 'v', 'h', 't':
|
175 |
+
lr_list.extend([_transform(t, tf) for t in lr_list])
|
176 |
+
|
177 |
+
sr_list = [forward_function(aug) for aug in lr_list]
|
178 |
+
for i in range(len(sr_list)):
|
179 |
+
if i > 3:
|
180 |
+
sr_list[i] = _transform(sr_list[i], 't')
|
181 |
+
if i % 4 > 1:
|
182 |
+
sr_list[i] = _transform(sr_list[i], 'h')
|
183 |
+
if (i % 4) % 2 == 1:
|
184 |
+
sr_list[i] = _transform(sr_list[i], 'v')
|
185 |
+
|
186 |
+
output_cat = torch.cat(sr_list, dim=0)
|
187 |
+
output = output_cat.mean(dim=0, keepdim=True)
|
188 |
+
|
189 |
+
return output
|
190 |
+
|
CAR/code/model/__pycache__/__init__.cpython-36.pyc
ADDED
Binary file (5.47 kB). View file
|
|
CAR/code/model/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (5.42 kB). View file
|
|
CAR/code/model/__pycache__/attention.cpython-36.pyc
ADDED
Binary file (12.8 kB). View file
|
|