Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- SR/README.md +78 -0
- SR/code/.vscode/settings.json +3 -0
- SR/code/__init__.py +0 -0
- SR/code/__pycache__/option.cpython-35.pyc +0 -0
- SR/code/__pycache__/option.cpython-36.pyc +0 -0
- SR/code/__pycache__/option.cpython-37.pyc +0 -0
- SR/code/__pycache__/template.cpython-35.pyc +0 -0
- SR/code/__pycache__/template.cpython-36.pyc +0 -0
- SR/code/__pycache__/template.cpython-37.pyc +0 -0
- SR/code/__pycache__/trainer.cpython-35.pyc +0 -0
- SR/code/__pycache__/trainer.cpython-36.pyc +0 -0
- SR/code/__pycache__/trainer.cpython-37.pyc +0 -0
- SR/code/__pycache__/utility.cpython-35.pyc +0 -0
- SR/code/__pycache__/utility.cpython-36.pyc +0 -0
- SR/code/__pycache__/utility.cpython-37.pyc +0 -0
- SR/code/bootstrap.sh +5 -0
- SR/code/data/__init__.py +52 -0
- SR/code/data/__pycache__/__init__.cpython-35.pyc +0 -0
- SR/code/data/__pycache__/__init__.cpython-36.pyc +0 -0
- SR/code/data/__pycache__/__init__.cpython-37.pyc +0 -0
- SR/code/data/__pycache__/benchmark.cpython-35.pyc +0 -0
- SR/code/data/__pycache__/benchmark.cpython-36.pyc +0 -0
- SR/code/data/__pycache__/common.cpython-35.pyc +0 -0
- SR/code/data/__pycache__/common.cpython-36.pyc +0 -0
- SR/code/data/__pycache__/common.cpython-37.pyc +0 -0
- SR/code/data/__pycache__/demo.cpython-36.pyc +0 -0
- SR/code/data/__pycache__/div2k.cpython-35.pyc +0 -0
- SR/code/data/__pycache__/div2k.cpython-36.pyc +0 -0
- SR/code/data/__pycache__/div2k.cpython-37.pyc +0 -0
- SR/code/data/__pycache__/srdata.cpython-35.pyc +0 -0
- SR/code/data/__pycache__/srdata.cpython-36.pyc +0 -0
- SR/code/data/__pycache__/srdata.cpython-37.pyc +0 -0
- SR/code/data/benchmark.py +25 -0
- SR/code/data/common.py +72 -0
- SR/code/data/demo.py +39 -0
- SR/code/data/div2k.py +32 -0
- SR/code/data/div2kjpeg.py +20 -0
- SR/code/data/sr291.py +6 -0
- SR/code/data/srdata.py +157 -0
- SR/code/data/video.py +44 -0
- SR/code/dataloader.py +158 -0
- SR/code/demo.sb +5 -0
- SR/code/lambda_networks/__init__.py +4 -0
- SR/code/lambda_networks/__pycache__/__init__.cpython-37.pyc +0 -0
- SR/code/lambda_networks/__pycache__/lambda_networks.cpython-37.pyc +0 -0
- SR/code/lambda_networks/__pycache__/rlambda_networks.cpython-37.pyc +0 -0
- SR/code/lambda_networks/lambda_networks.py +137 -0
- SR/code/lambda_networks/rlambda_networks.py +93 -0
- SR/code/loss/__init__.py +173 -0
- SR/code/loss/__loss__.py +0 -0
SR/README.md
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Pyramid Attention for Image Restoration
|
2 |
+
This repository is for PANet and PA-EDSR introduced in the following paper
|
3 |
+
|
4 |
+
[Yiqun Mei](http://yiqunm2.web.illinois.edu/), [Yuchen Fan](https://scholar.google.com/citations?user=BlfdYL0AAAAJ&hl=en), [Yulun Zhang](http://yulunzhang.com/), [Jiahui Yu](https://jiahuiyu.com/), [Yuqian Zhou](https://yzhouas.github.io/), [Ding Liu](https://scholar.google.com/citations?user=PGtHUI0AAAAJ&hl=en), [Yun Fu](http://www1.ece.neu.edu/~yunfu/), [Thomas S. Huang](http://ifp-uiuc.github.io/) and [Honghui Shi](https://www.humphreyshi.com/) "Pyramid Attention for Image Restoration", [[Arxiv]](https://arxiv.org/abs/2004.13824)
|
5 |
+
|
6 |
+
The code is built on [EDSR (PyTorch)](https://github.com/thstkdgus35/EDSR-PyTorch) & [RNAN](https://github.com/yulunzhang/RNAN) and tested on Ubuntu 18.04 environment (Python3.6, PyTorch_1.1) with Titan X/1080Ti/V100 GPUs.
|
7 |
+
|
8 |
+
## Contents
|
9 |
+
1. [Train](#train)
|
10 |
+
2. [Test](#test)
|
11 |
+
3. [Results](#results)
|
12 |
+
4. [Citation](#citation)
|
13 |
+
5. [Acknowledgements](#acknowledgements)
|
14 |
+
|
15 |
+
## Train
|
16 |
+
### Prepare training data
|
17 |
+
|
18 |
+
1. Download DIV2K training data (800 training + 100 validtion images) from [DIV2K dataset](https://data.vision.ee.ethz.ch/cvl/DIV2K/) or [SNU_CVLab](https://cv.snu.ac.kr/research/EDSR/DIV2K.tar).
|
19 |
+
|
20 |
+
2. Specify '--dir_data' based on the HR and LR images path.
|
21 |
+
|
22 |
+
For more informaiton, please refer to [EDSR(PyTorch)](https://github.com/thstkdgus35/EDSR-PyTorch).
|
23 |
+
|
24 |
+
### Begin to train
|
25 |
+
|
26 |
+
1. (optional) All the pretrained models and visual results can be downloaded from [Google Drive](https://drive.google.com/open?id=1q9iUzqYX0fVRzDu4J6fvSPRosgOZoJJE).
|
27 |
+
|
28 |
+
2. Cd to 'PANet-PyTorch/[Task]/code', run the following scripts to train models.
|
29 |
+
|
30 |
+
**You can use scripts in file 'demo.sb' to train and test models for our paper.**
|
31 |
+
|
32 |
+
```bash
|
33 |
+
# Example Usage:
|
34 |
+
python main.py --n_GPUs 4 --rgb_range 1 --reset --save_models --lr 1e-4 --decay 200-400-600-800 --chop --save_results --n_resblocks 32 --n_feats 256 --res_scale 0.1 --batch_size 16 --model PAEDSR --scale 2 --patch_size 96 --save EDSR_PA_x2 --data_train DIV2K
|
35 |
+
```
|
36 |
+
## Test
|
37 |
+
### Quick start
|
38 |
+
|
39 |
+
1. Cd to 'PANet-PyTorch/[Task]/code', run the following scripts.
|
40 |
+
|
41 |
+
**You can use scripts in file 'demo.sb' to produce results for our paper.**
|
42 |
+
|
43 |
+
```bash
|
44 |
+
# No self-ensemble, use different testsets to reproduce the results in the paper.
|
45 |
+
# Example Usage:
|
46 |
+
python main.py --model PAEDSR --data_test Set5+Set14+B100+Urban100 --save_results --rgb_range 1 --data_range 801-900 --scale 2 --n_feats 256 --n_resblocks 32 --res_scale 0.1 --pre_train ../model_x2.pt --test_only --chop
|
47 |
+
|
48 |
+
```
|
49 |
+
|
50 |
+
### The whole test pipeline
|
51 |
+
1. Prepare benchmark datasets [from SNU_CVLab](https://cv.snu.ac.kr/research/EDSR/benchmark.tar)
|
52 |
+
|
53 |
+
2. Conduct image SR.
|
54 |
+
|
55 |
+
See **Quick start**
|
56 |
+
3. Evaluate the results.
|
57 |
+
|
58 |
+
Run 'Evaluate_PSNR_SSIM.m' to obtain PSNR/SSIM values for paper.
|
59 |
+
|
60 |
+
## Citation
|
61 |
+
If you find the code helpful in your resarch or work, please cite the following papers.
|
62 |
+
```
|
63 |
+
@article{mei2020pyramid,
|
64 |
+
title={Pyramid Attention Networks for Image Restoration},
|
65 |
+
author={Mei, Yiqun and Fan, Yuchen and Zhang, Yulun and Yu, Jiahui and Zhou, Yuqian and Liu, Ding and Fu, Yun and Huang, Thomas S and Shi, Honghui},
|
66 |
+
journal={arXiv preprint arXiv:2004.13824},
|
67 |
+
year={2020}
|
68 |
+
}
|
69 |
+
@InProceedings{Lim_2017_CVPR_Workshops,
|
70 |
+
author = {Lim, Bee and Son, Sanghyun and Kim, Heewon and Nah, Seungjun and Lee, Kyoung Mu},
|
71 |
+
title = {Enhanced Deep Residual Networks for Single Image Super-Resolution},
|
72 |
+
booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) Workshops},
|
73 |
+
month = {July},
|
74 |
+
year = {2017}
|
75 |
+
}
|
76 |
+
```
|
77 |
+
## Acknowledgements
|
78 |
+
This code is built on [EDSR (PyTorch)](https://github.com/thstkdgus35/EDSR-PyTorch), [RNAN](https://github.com/yulunzhang/RNAN) and [generative-inpainting-pytorch](https://github.com/daa233/generative-inpainting-pytorch). We thank the authors for sharing their codes.
|
SR/code/.vscode/settings.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"python.pythonPath": "D:\\Environments\\Anaconda3\\envs\\torch1.6\\python.exe"
|
3 |
+
}
|
SR/code/__init__.py
ADDED
File without changes
|
SR/code/__pycache__/option.cpython-35.pyc
ADDED
Binary file (5.42 kB). View file
|
|
SR/code/__pycache__/option.cpython-36.pyc
ADDED
Binary file (4.76 kB). View file
|
|
SR/code/__pycache__/option.cpython-37.pyc
ADDED
Binary file (4.9 kB). View file
|
|
SR/code/__pycache__/template.cpython-35.pyc
ADDED
Binary file (1.13 kB). View file
|
|
SR/code/__pycache__/template.cpython-36.pyc
ADDED
Binary file (1 kB). View file
|
|
SR/code/__pycache__/template.cpython-37.pyc
ADDED
Binary file (993 Bytes). View file
|
|
SR/code/__pycache__/trainer.cpython-35.pyc
ADDED
Binary file (4.46 kB). View file
|
|
SR/code/__pycache__/trainer.cpython-36.pyc
ADDED
Binary file (3.98 kB). View file
|
|
SR/code/__pycache__/trainer.cpython-37.pyc
ADDED
Binary file (4.84 kB). View file
|
|
SR/code/__pycache__/utility.cpython-35.pyc
ADDED
Binary file (9.88 kB). View file
|
|
SR/code/__pycache__/utility.cpython-36.pyc
ADDED
Binary file (9.04 kB). View file
|
|
SR/code/__pycache__/utility.cpython-37.pyc
ADDED
Binary file (9.1 kB). View file
|
|
SR/code/bootstrap.sh
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
|
4 |
+
|
5 |
+
bash ~/miniconda.sh -b -p $HOME/miniconda
|
SR/code/data/__init__.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from importlib import import_module
|
2 |
+
#from dataloader import MSDataLoader
|
3 |
+
from torch.utils.data import dataloader
|
4 |
+
from torch.utils.data import ConcatDataset
|
5 |
+
|
6 |
+
# This is a simple wrapper function for ConcatDataset
|
7 |
+
class MyConcatDataset(ConcatDataset):
|
8 |
+
def __init__(self, datasets):
|
9 |
+
super(MyConcatDataset, self).__init__(datasets)
|
10 |
+
self.train = datasets[0].train
|
11 |
+
|
12 |
+
def set_scale(self, idx_scale):
|
13 |
+
for d in self.datasets:
|
14 |
+
if hasattr(d, 'set_scale'): d.set_scale(idx_scale)
|
15 |
+
|
16 |
+
class Data:
|
17 |
+
def __init__(self, args):
|
18 |
+
self.loader_train = None
|
19 |
+
if not args.test_only:
|
20 |
+
datasets = []
|
21 |
+
for d in args.data_train:
|
22 |
+
module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG'
|
23 |
+
m = import_module('data.' + module_name.lower())
|
24 |
+
datasets.append(getattr(m, module_name)(args, name=d))
|
25 |
+
|
26 |
+
self.loader_train = dataloader.DataLoader(
|
27 |
+
MyConcatDataset(datasets),
|
28 |
+
batch_size=args.batch_size,
|
29 |
+
shuffle=True,
|
30 |
+
pin_memory=not args.cpu,
|
31 |
+
num_workers=args.n_threads,
|
32 |
+
)
|
33 |
+
|
34 |
+
self.loader_test = []
|
35 |
+
for d in args.data_test:
|
36 |
+
if d in ['Set5', 'Set14', 'B100', 'Urban100']:
|
37 |
+
m = import_module('data.benchmark')
|
38 |
+
testset = getattr(m, 'Benchmark')(args, train=False, name=d)
|
39 |
+
else:
|
40 |
+
module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG'
|
41 |
+
m = import_module('data.' + module_name.lower())
|
42 |
+
testset = getattr(m, module_name)(args, train=False, name=d)
|
43 |
+
|
44 |
+
self.loader_test.append(
|
45 |
+
dataloader.DataLoader(
|
46 |
+
testset,
|
47 |
+
batch_size=1,
|
48 |
+
shuffle=False,
|
49 |
+
pin_memory=not args.cpu,
|
50 |
+
num_workers=args.n_threads,
|
51 |
+
)
|
52 |
+
)
|
SR/code/data/__pycache__/__init__.cpython-35.pyc
ADDED
Binary file (1.94 kB). View file
|
|
SR/code/data/__pycache__/__init__.cpython-36.pyc
ADDED
Binary file (1.79 kB). View file
|
|
SR/code/data/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (1.76 kB). View file
|
|
SR/code/data/__pycache__/benchmark.cpython-35.pyc
ADDED
Binary file (1.13 kB). View file
|
|
SR/code/data/__pycache__/benchmark.cpython-36.pyc
ADDED
Binary file (1.07 kB). View file
|
|
SR/code/data/__pycache__/common.cpython-35.pyc
ADDED
Binary file (3.05 kB). View file
|
|
SR/code/data/__pycache__/common.cpython-36.pyc
ADDED
Binary file (2.76 kB). View file
|
|
SR/code/data/__pycache__/common.cpython-37.pyc
ADDED
Binary file (2.73 kB). View file
|
|
SR/code/data/__pycache__/demo.cpython-36.pyc
ADDED
Binary file (1.5 kB). View file
|
|
SR/code/data/__pycache__/div2k.cpython-35.pyc
ADDED
Binary file (1.87 kB). View file
|
|
SR/code/data/__pycache__/div2k.cpython-36.pyc
ADDED
Binary file (1.74 kB). View file
|
|
SR/code/data/__pycache__/div2k.cpython-37.pyc
ADDED
Binary file (1.75 kB). View file
|
|
SR/code/data/__pycache__/srdata.cpython-35.pyc
ADDED
Binary file (5.43 kB). View file
|
|
SR/code/data/__pycache__/srdata.cpython-36.pyc
ADDED
Binary file (4.83 kB). View file
|
|
SR/code/data/__pycache__/srdata.cpython-37.pyc
ADDED
Binary file (4.82 kB). View file
|
|
SR/code/data/benchmark.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
from data import common
|
4 |
+
from data import srdata
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.utils.data as data
|
10 |
+
|
11 |
+
class Benchmark(srdata.SRData):
|
12 |
+
def __init__(self, args, name='', train=True, benchmark=True):
|
13 |
+
super(Benchmark, self).__init__(
|
14 |
+
args, name=name, train=train, benchmark=True
|
15 |
+
)
|
16 |
+
|
17 |
+
def _set_filesystem(self, dir_data):
|
18 |
+
self.apath = os.path.join(dir_data, 'benchmark', self.name)
|
19 |
+
self.dir_hr = os.path.join(self.apath, 'HR')
|
20 |
+
if self.input_large:
|
21 |
+
self.dir_lr = os.path.join(self.apath, 'LR_bicubicL')
|
22 |
+
else:
|
23 |
+
self.dir_lr = os.path.join(self.apath, 'LR_bicubic')
|
24 |
+
self.ext = ('', '.png')
|
25 |
+
|
SR/code/data/common.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import skimage.color as sc
|
5 |
+
|
6 |
+
import torch
|
7 |
+
|
8 |
+
def get_patch(*args, patch_size=96, scale=2, multi=False, input_large=False):
|
9 |
+
ih, iw = args[0].shape[:2]
|
10 |
+
|
11 |
+
if not input_large:
|
12 |
+
p = scale if multi else 1
|
13 |
+
tp = p * patch_size
|
14 |
+
ip = tp // scale
|
15 |
+
else:
|
16 |
+
tp = patch_size
|
17 |
+
ip = patch_size
|
18 |
+
|
19 |
+
ix = random.randrange(0, iw - ip + 1)
|
20 |
+
iy = random.randrange(0, ih - ip + 1)
|
21 |
+
|
22 |
+
if not input_large:
|
23 |
+
tx, ty = scale * ix, scale * iy
|
24 |
+
else:
|
25 |
+
tx, ty = ix, iy
|
26 |
+
|
27 |
+
ret = [
|
28 |
+
args[0][iy:iy + ip, ix:ix + ip, :],
|
29 |
+
*[a[ty:ty + tp, tx:tx + tp, :] for a in args[1:]]
|
30 |
+
]
|
31 |
+
|
32 |
+
return ret
|
33 |
+
|
34 |
+
def set_channel(*args, n_channels=3):
|
35 |
+
def _set_channel(img):
|
36 |
+
if img.ndim == 2:
|
37 |
+
img = np.expand_dims(img, axis=2)
|
38 |
+
|
39 |
+
c = img.shape[2]
|
40 |
+
if n_channels == 1 and c == 3:
|
41 |
+
img = np.expand_dims(sc.rgb2ycbcr(img)[:, :, 0], 2)
|
42 |
+
elif n_channels == 3 and c == 1:
|
43 |
+
img = np.concatenate([img] * n_channels, 2)
|
44 |
+
|
45 |
+
return img
|
46 |
+
|
47 |
+
return [_set_channel(a) for a in args]
|
48 |
+
|
49 |
+
def np2Tensor(*args, rgb_range=255):
|
50 |
+
def _np2Tensor(img):
|
51 |
+
np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1)))
|
52 |
+
tensor = torch.from_numpy(np_transpose).float()
|
53 |
+
tensor.mul_(rgb_range / 255)
|
54 |
+
|
55 |
+
return tensor
|
56 |
+
|
57 |
+
return [_np2Tensor(a) for a in args]
|
58 |
+
|
59 |
+
def augment(*args, hflip=True, rot=True):
|
60 |
+
hflip = hflip and random.random() < 0.5
|
61 |
+
vflip = rot and random.random() < 0.5
|
62 |
+
rot90 = rot and random.random() < 0.5
|
63 |
+
|
64 |
+
def _augment(img):
|
65 |
+
if hflip: img = img[:, ::-1, :]
|
66 |
+
if vflip: img = img[::-1, :, :]
|
67 |
+
if rot90: img = img.transpose(1, 0, 2)
|
68 |
+
|
69 |
+
return img
|
70 |
+
|
71 |
+
return [_augment(a) for a in args]
|
72 |
+
|
SR/code/data/demo.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
from data import common
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import imageio
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.utils.data as data
|
10 |
+
|
11 |
+
class Demo(data.Dataset):
|
12 |
+
def __init__(self, args, name='Demo', train=False, benchmark=False):
|
13 |
+
self.args = args
|
14 |
+
self.name = name
|
15 |
+
self.scale = args.scale
|
16 |
+
self.idx_scale = 0
|
17 |
+
self.train = False
|
18 |
+
self.benchmark = benchmark
|
19 |
+
|
20 |
+
self.filelist = []
|
21 |
+
for f in os.listdir(args.dir_demo):
|
22 |
+
if f.find('.png') >= 0 or f.find('.jp') >= 0:
|
23 |
+
self.filelist.append(os.path.join(args.dir_demo, f))
|
24 |
+
self.filelist.sort()
|
25 |
+
|
26 |
+
def __getitem__(self, idx):
|
27 |
+
filename = os.path.splitext(os.path.basename(self.filelist[idx]))[0]
|
28 |
+
lr = imageio.imread(self.filelist[idx])
|
29 |
+
lr, = common.set_channel(lr, n_channels=self.args.n_colors)
|
30 |
+
lr_t, = common.np2Tensor(lr, rgb_range=self.args.rgb_range)
|
31 |
+
|
32 |
+
return lr_t, -1, filename
|
33 |
+
|
34 |
+
def __len__(self):
|
35 |
+
return len(self.filelist)
|
36 |
+
|
37 |
+
def set_scale(self, idx_scale):
|
38 |
+
self.idx_scale = idx_scale
|
39 |
+
|
SR/code/data/div2k.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from data import srdata
|
3 |
+
|
4 |
+
class DIV2K(srdata.SRData):
|
5 |
+
def __init__(self, args, name='DIV2K', train=True, benchmark=False):
|
6 |
+
data_range = [r.split('-') for r in args.data_range.split('/')]
|
7 |
+
if train:
|
8 |
+
data_range = data_range[0]
|
9 |
+
else:
|
10 |
+
if args.test_only and len(data_range) == 1:
|
11 |
+
data_range = data_range[0]
|
12 |
+
else:
|
13 |
+
data_range = data_range[1]
|
14 |
+
|
15 |
+
self.begin, self.end = list(map(lambda x: int(x), data_range))
|
16 |
+
super(DIV2K, self).__init__(
|
17 |
+
args, name=name, train=train, benchmark=benchmark
|
18 |
+
)
|
19 |
+
|
20 |
+
def _scan(self):
|
21 |
+
names_hr, names_lr = super(DIV2K, self)._scan()
|
22 |
+
names_hr = names_hr[self.begin - 1:self.end]
|
23 |
+
names_lr = [n[self.begin - 1:self.end] for n in names_lr]
|
24 |
+
|
25 |
+
return names_hr, names_lr
|
26 |
+
|
27 |
+
def _set_filesystem(self, dir_data):
|
28 |
+
super(DIV2K, self)._set_filesystem(dir_data)
|
29 |
+
self.dir_hr = os.path.join(self.apath, 'DIV2K_train_HR')
|
30 |
+
self.dir_lr = os.path.join(self.apath, 'DIV2K_train_LR_bicubic')
|
31 |
+
if self.input_large: self.dir_lr += 'L'
|
32 |
+
|
SR/code/data/div2kjpeg.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from data import srdata
|
3 |
+
from data import div2k
|
4 |
+
|
5 |
+
class DIV2KJPEG(div2k.DIV2K):
|
6 |
+
def __init__(self, args, name='', train=True, benchmark=False):
|
7 |
+
self.q_factor = int(name.replace('DIV2K-Q', ''))
|
8 |
+
super(DIV2KJPEG, self).__init__(
|
9 |
+
args, name=name, train=train, benchmark=benchmark
|
10 |
+
)
|
11 |
+
|
12 |
+
def _set_filesystem(self, dir_data):
|
13 |
+
self.apath = os.path.join(dir_data, 'DIV2K')
|
14 |
+
self.dir_hr = os.path.join(self.apath, 'DIV2K_train_HR')
|
15 |
+
self.dir_lr = os.path.join(
|
16 |
+
self.apath, 'DIV2K_Q{}'.format(self.q_factor)
|
17 |
+
)
|
18 |
+
if self.input_large: self.dir_lr += 'L'
|
19 |
+
self.ext = ('.png', '.jpg')
|
20 |
+
|
SR/code/data/sr291.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from data import srdata
|
2 |
+
|
3 |
+
class SR291(srdata.SRData):
|
4 |
+
def __init__(self, args, name='SR291', train=True, benchmark=False):
|
5 |
+
super(SR291, self).__init__(args, name=name)
|
6 |
+
|
SR/code/data/srdata.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import glob
|
3 |
+
import random
|
4 |
+
import pickle
|
5 |
+
|
6 |
+
from data import common
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import imageio
|
10 |
+
import torch
|
11 |
+
import torch.utils.data as data
|
12 |
+
|
13 |
+
class SRData(data.Dataset):
|
14 |
+
def __init__(self, args, name='', train=True, benchmark=False):
|
15 |
+
self.args = args
|
16 |
+
self.name = name
|
17 |
+
self.train = train
|
18 |
+
self.split = 'train' if train else 'test'
|
19 |
+
self.do_eval = True
|
20 |
+
self.benchmark = benchmark
|
21 |
+
self.input_large = (args.model == 'VDSR')
|
22 |
+
self.scale = args.scale
|
23 |
+
self.idx_scale = 0
|
24 |
+
|
25 |
+
self._set_filesystem(args.dir_data)
|
26 |
+
if args.ext.find('img') < 0:
|
27 |
+
path_bin = os.path.join(self.apath, 'bin')
|
28 |
+
os.makedirs(path_bin, exist_ok=True)
|
29 |
+
|
30 |
+
list_hr, list_lr = self._scan()
|
31 |
+
if args.ext.find('img') >= 0 or benchmark:
|
32 |
+
self.images_hr, self.images_lr = list_hr, list_lr
|
33 |
+
elif args.ext.find('sep') >= 0:
|
34 |
+
os.makedirs(
|
35 |
+
self.dir_hr.replace(self.apath, path_bin),
|
36 |
+
exist_ok=True
|
37 |
+
)
|
38 |
+
for s in self.scale:
|
39 |
+
os.makedirs(
|
40 |
+
os.path.join(
|
41 |
+
self.dir_lr.replace(self.apath, path_bin),
|
42 |
+
'X{}'.format(s)
|
43 |
+
),
|
44 |
+
exist_ok=True
|
45 |
+
)
|
46 |
+
|
47 |
+
self.images_hr, self.images_lr = [], [[] for _ in self.scale]
|
48 |
+
for h in list_hr:
|
49 |
+
b = h.replace(self.apath, path_bin)
|
50 |
+
b = b.replace(self.ext[0], '.pt')
|
51 |
+
self.images_hr.append(b)
|
52 |
+
self._check_and_load(args.ext, h, b, verbose=True)
|
53 |
+
for i, ll in enumerate(list_lr):
|
54 |
+
for l in ll:
|
55 |
+
b = l.replace(self.apath, path_bin)
|
56 |
+
b = b.replace(self.ext[1], '.pt')
|
57 |
+
self.images_lr[i].append(b)
|
58 |
+
self._check_and_load(args.ext, l, b, verbose=True)
|
59 |
+
if train:
|
60 |
+
n_patches = args.batch_size * args.test_every
|
61 |
+
n_images = len(args.data_train) * len(self.images_hr)
|
62 |
+
if n_images == 0:
|
63 |
+
self.repeat = 0
|
64 |
+
else:
|
65 |
+
self.repeat = max(n_patches // n_images, 1)
|
66 |
+
|
67 |
+
# Below functions as used to prepare images
|
68 |
+
def _scan(self):
|
69 |
+
names_hr = sorted(
|
70 |
+
glob.glob(os.path.join(self.dir_hr, '*' + self.ext[0]))
|
71 |
+
)
|
72 |
+
names_lr = [[] for _ in self.scale]
|
73 |
+
for f in names_hr:
|
74 |
+
filename, _ = os.path.splitext(os.path.basename(f))
|
75 |
+
for si, s in enumerate(self.scale):
|
76 |
+
names_lr[si].append(os.path.join(
|
77 |
+
self.dir_lr, 'X{}/{}x{}{}'.format(
|
78 |
+
s, filename, s, self.ext[1]
|
79 |
+
)
|
80 |
+
))
|
81 |
+
|
82 |
+
return names_hr, names_lr
|
83 |
+
|
84 |
+
def _set_filesystem(self, dir_data):
|
85 |
+
self.apath = os.path.join(dir_data, self.name)
|
86 |
+
self.dir_hr = os.path.join(self.apath, 'HR')
|
87 |
+
self.dir_lr = os.path.join(self.apath, 'LR_bicubic')
|
88 |
+
if self.input_large: self.dir_lr += 'L'
|
89 |
+
self.ext = ('.png', '.png')
|
90 |
+
|
91 |
+
def _check_and_load(self, ext, img, f, verbose=True):
|
92 |
+
if not os.path.isfile(f) or ext.find('reset') >= 0:
|
93 |
+
if verbose:
|
94 |
+
print('Making a binary: {}'.format(f))
|
95 |
+
with open(f, 'wb') as _f:
|
96 |
+
pickle.dump(imageio.imread(img), _f)
|
97 |
+
|
98 |
+
def __getitem__(self, idx):
|
99 |
+
lr, hr, filename = self._load_file(idx)
|
100 |
+
pair = self.get_patch(lr, hr)
|
101 |
+
pair = common.set_channel(*pair, n_channels=self.args.n_colors)
|
102 |
+
pair_t = common.np2Tensor(*pair, rgb_range=self.args.rgb_range)
|
103 |
+
|
104 |
+
return pair_t[0], pair_t[1], filename
|
105 |
+
|
106 |
+
def __len__(self):
|
107 |
+
if self.train:
|
108 |
+
return len(self.images_hr) * self.repeat
|
109 |
+
else:
|
110 |
+
return len(self.images_hr)
|
111 |
+
|
112 |
+
def _get_index(self, idx):
|
113 |
+
if self.train:
|
114 |
+
return idx % len(self.images_hr)
|
115 |
+
else:
|
116 |
+
return idx
|
117 |
+
|
118 |
+
def _load_file(self, idx):
|
119 |
+
idx = self._get_index(idx)
|
120 |
+
f_hr = self.images_hr[idx]
|
121 |
+
f_lr = self.images_lr[self.idx_scale][idx]
|
122 |
+
|
123 |
+
filename, _ = os.path.splitext(os.path.basename(f_hr))
|
124 |
+
if self.args.ext == 'img' or self.benchmark:
|
125 |
+
hr = imageio.imread(f_hr)
|
126 |
+
lr = imageio.imread(f_lr)
|
127 |
+
elif self.args.ext.find('sep') >= 0:
|
128 |
+
with open(f_hr, 'rb') as _f:
|
129 |
+
hr = pickle.load(_f)
|
130 |
+
with open(f_lr, 'rb') as _f:
|
131 |
+
lr = pickle.load(_f)
|
132 |
+
|
133 |
+
return lr, hr, filename
|
134 |
+
|
135 |
+
def get_patch(self, lr, hr):
|
136 |
+
scale = self.scale[self.idx_scale]
|
137 |
+
if self.train:
|
138 |
+
lr, hr = common.get_patch(
|
139 |
+
lr, hr,
|
140 |
+
patch_size=self.args.patch_size,
|
141 |
+
scale=scale,
|
142 |
+
multi=(len(self.scale) > 1),
|
143 |
+
input_large=self.input_large
|
144 |
+
)
|
145 |
+
if not self.args.no_augment: lr, hr = common.augment(lr, hr)
|
146 |
+
else:
|
147 |
+
ih, iw = lr.shape[:2]
|
148 |
+
hr = hr[0:ih * scale, 0:iw * scale]
|
149 |
+
|
150 |
+
return lr, hr
|
151 |
+
|
152 |
+
def set_scale(self, idx_scale):
|
153 |
+
if not self.input_large:
|
154 |
+
self.idx_scale = idx_scale
|
155 |
+
else:
|
156 |
+
self.idx_scale = random.randint(0, len(self.scale) - 1)
|
157 |
+
|
SR/code/data/video.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
from data import common
|
4 |
+
|
5 |
+
import cv2
|
6 |
+
import numpy as np
|
7 |
+
import imageio
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.utils.data as data
|
11 |
+
|
12 |
+
class Video(data.Dataset):
|
13 |
+
def __init__(self, args, name='Video', train=False, benchmark=False):
|
14 |
+
self.args = args
|
15 |
+
self.name = name
|
16 |
+
self.scale = args.scale
|
17 |
+
self.idx_scale = 0
|
18 |
+
self.train = False
|
19 |
+
self.do_eval = False
|
20 |
+
self.benchmark = benchmark
|
21 |
+
|
22 |
+
self.filename, _ = os.path.splitext(os.path.basename(args.dir_demo))
|
23 |
+
self.vidcap = cv2.VideoCapture(args.dir_demo)
|
24 |
+
self.n_frames = 0
|
25 |
+
self.total_frames = int(self.vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
|
26 |
+
|
27 |
+
def __getitem__(self, idx):
|
28 |
+
success, lr = self.vidcap.read()
|
29 |
+
if success:
|
30 |
+
self.n_frames += 1
|
31 |
+
lr, = common.set_channel(lr, n_channels=self.args.n_colors)
|
32 |
+
lr_t, = common.np2Tensor(lr, rgb_range=self.args.rgb_range)
|
33 |
+
|
34 |
+
return lr_t, -1, '{}_{:0>5}'.format(self.filename, self.n_frames)
|
35 |
+
else:
|
36 |
+
vidcap.release()
|
37 |
+
return None
|
38 |
+
|
39 |
+
def __len__(self):
|
40 |
+
return self.total_frames
|
41 |
+
|
42 |
+
def set_scale(self, idx_scale):
|
43 |
+
self.idx_scale = idx_scale
|
44 |
+
|
SR/code/dataloader.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import threading
|
2 |
+
import random
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.multiprocessing as multiprocessing
|
6 |
+
from torch.utils.data import DataLoader
|
7 |
+
from torch.utils.data import SequentialSampler
|
8 |
+
from torch.utils.data import RandomSampler
|
9 |
+
from torch.utils.data import BatchSampler
|
10 |
+
from torch.utils.data import _utils
|
11 |
+
from torch.utils.data.dataloader import _DataLoaderIter
|
12 |
+
|
13 |
+
from torch.utils.data._utils import collate
|
14 |
+
from torch.utils.data._utils import signal_handling
|
15 |
+
from torch.utils.data._utils import MP_STATUS_CHECK_INTERVAL
|
16 |
+
from torch.utils.data._utils import ExceptionWrapper
|
17 |
+
from torch.utils.data._utils import IS_WINDOWS
|
18 |
+
from torch.utils.data._utils.worker import ManagerWatchdog
|
19 |
+
|
20 |
+
from torch._six import queue
|
21 |
+
|
22 |
+
def _ms_loop(dataset, index_queue, data_queue, done_event, collate_fn, scale, seed, init_fn, worker_id):
|
23 |
+
try:
|
24 |
+
collate._use_shared_memory = True
|
25 |
+
signal_handling._set_worker_signal_handlers()
|
26 |
+
|
27 |
+
torch.set_num_threads(1)
|
28 |
+
random.seed(seed)
|
29 |
+
torch.manual_seed(seed)
|
30 |
+
|
31 |
+
data_queue.cancel_join_thread()
|
32 |
+
|
33 |
+
if init_fn is not None:
|
34 |
+
init_fn(worker_id)
|
35 |
+
|
36 |
+
watchdog = ManagerWatchdog()
|
37 |
+
|
38 |
+
while watchdog.is_alive():
|
39 |
+
try:
|
40 |
+
r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
|
41 |
+
except queue.Empty:
|
42 |
+
continue
|
43 |
+
|
44 |
+
if r is None:
|
45 |
+
assert done_event.is_set()
|
46 |
+
return
|
47 |
+
elif done_event.is_set():
|
48 |
+
continue
|
49 |
+
|
50 |
+
idx, batch_indices = r
|
51 |
+
try:
|
52 |
+
idx_scale = 0
|
53 |
+
if len(scale) > 1 and dataset.train:
|
54 |
+
idx_scale = random.randrange(0, len(scale))
|
55 |
+
dataset.set_scale(idx_scale)
|
56 |
+
|
57 |
+
samples = collate_fn([dataset[i] for i in batch_indices])
|
58 |
+
samples.append(idx_scale)
|
59 |
+
except Exception:
|
60 |
+
data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
|
61 |
+
else:
|
62 |
+
data_queue.put((idx, samples))
|
63 |
+
del samples
|
64 |
+
|
65 |
+
except KeyboardInterrupt:
|
66 |
+
pass
|
67 |
+
|
68 |
+
class _MSDataLoaderIter(_DataLoaderIter):
|
69 |
+
|
70 |
+
def __init__(self, loader):
|
71 |
+
self.dataset = loader.dataset
|
72 |
+
self.scale = loader.scale
|
73 |
+
self.collate_fn = loader.collate_fn
|
74 |
+
self.batch_sampler = loader.batch_sampler
|
75 |
+
self.num_workers = loader.num_workers
|
76 |
+
self.pin_memory = loader.pin_memory and torch.cuda.is_available()
|
77 |
+
self.timeout = loader.timeout
|
78 |
+
|
79 |
+
self.sample_iter = iter(self.batch_sampler)
|
80 |
+
|
81 |
+
base_seed = torch.LongTensor(1).random_().item()
|
82 |
+
|
83 |
+
if self.num_workers > 0:
|
84 |
+
self.worker_init_fn = loader.worker_init_fn
|
85 |
+
self.worker_queue_idx = 0
|
86 |
+
self.worker_result_queue = multiprocessing.Queue()
|
87 |
+
self.batches_outstanding = 0
|
88 |
+
self.worker_pids_set = False
|
89 |
+
self.shutdown = False
|
90 |
+
self.send_idx = 0
|
91 |
+
self.rcvd_idx = 0
|
92 |
+
self.reorder_dict = {}
|
93 |
+
self.done_event = multiprocessing.Event()
|
94 |
+
|
95 |
+
base_seed = torch.LongTensor(1).random_()[0]
|
96 |
+
|
97 |
+
self.index_queues = []
|
98 |
+
self.workers = []
|
99 |
+
for i in range(self.num_workers):
|
100 |
+
index_queue = multiprocessing.Queue()
|
101 |
+
index_queue.cancel_join_thread()
|
102 |
+
w = multiprocessing.Process(
|
103 |
+
target=_ms_loop,
|
104 |
+
args=(
|
105 |
+
self.dataset,
|
106 |
+
index_queue,
|
107 |
+
self.worker_result_queue,
|
108 |
+
self.done_event,
|
109 |
+
self.collate_fn,
|
110 |
+
self.scale,
|
111 |
+
base_seed + i,
|
112 |
+
self.worker_init_fn,
|
113 |
+
i
|
114 |
+
)
|
115 |
+
)
|
116 |
+
w.daemon = True
|
117 |
+
w.start()
|
118 |
+
self.index_queues.append(index_queue)
|
119 |
+
self.workers.append(w)
|
120 |
+
|
121 |
+
if self.pin_memory:
|
122 |
+
self.data_queue = queue.Queue()
|
123 |
+
pin_memory_thread = threading.Thread(
|
124 |
+
target=_utils.pin_memory._pin_memory_loop,
|
125 |
+
args=(
|
126 |
+
self.worker_result_queue,
|
127 |
+
self.data_queue,
|
128 |
+
torch.cuda.current_device(),
|
129 |
+
self.done_event
|
130 |
+
)
|
131 |
+
)
|
132 |
+
pin_memory_thread.daemon = True
|
133 |
+
pin_memory_thread.start()
|
134 |
+
self.pin_memory_thread = pin_memory_thread
|
135 |
+
else:
|
136 |
+
self.data_queue = self.worker_result_queue
|
137 |
+
|
138 |
+
_utils.signal_handling._set_worker_pids(
|
139 |
+
id(self), tuple(w.pid for w in self.workers)
|
140 |
+
)
|
141 |
+
_utils.signal_handling._set_SIGCHLD_handler()
|
142 |
+
self.worker_pids_set = True
|
143 |
+
|
144 |
+
for _ in range(2 * self.num_workers):
|
145 |
+
self._put_indices()
|
146 |
+
|
147 |
+
|
148 |
+
class MSDataLoader(DataLoader):
|
149 |
+
|
150 |
+
def __init__(self, cfg, *args, **kwargs):
|
151 |
+
super(MSDataLoader, self).__init__(
|
152 |
+
*args, **kwargs, num_workers=cfg.n_threads
|
153 |
+
)
|
154 |
+
self.scale = cfg.scale
|
155 |
+
|
156 |
+
def __iter__(self):
|
157 |
+
return _MSDataLoaderIter(self)
|
158 |
+
|
SR/code/demo.sb
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
#Train x2
|
3 |
+
python main.py --n_GPUs 4 --rgb_range 1 --reset --save_models --lr 1e-4 --decay 200-400-600-800 --chop --save_results --n_resblocks 32 --n_feats 256 --res_scale 0.1 --batch_size 16 --model PAEDSR --scale 2 --patch_size 96 --save EDSR_PA_x2 --data_train DIV2K
|
4 |
+
#Test
|
5 |
+
python main.py --model PAEDSR --data_test Set5+Set14+B100+Urban100 --save_results --rgb_range 1 --data_range 801-900 --scale 2 --n_feats 256 --n_resblocks 32 --res_scale 0.1 --pre_train ../model_x2.pt --test_only --chop
|
SR/code/lambda_networks/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from lambda_networks.lambda_networks import LambdaLayer
|
2 |
+
from lambda_networks.lambda_networks import Recursion
|
3 |
+
from lambda_networks.rlambda_networks import RLambdaLayer
|
4 |
+
λLayer = LambdaLayer
|
SR/code/lambda_networks/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (337 Bytes). View file
|
|
SR/code/lambda_networks/__pycache__/lambda_networks.cpython-37.pyc
ADDED
Binary file (4.78 kB). View file
|
|
SR/code/lambda_networks/__pycache__/rlambda_networks.cpython-37.pyc
ADDED
Binary file (2.94 kB). View file
|
|
SR/code/lambda_networks/lambda_networks.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn, einsum
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from einops import rearrange
|
5 |
+
|
6 |
+
# helpers functions
|
7 |
+
|
8 |
+
def exists(val):
|
9 |
+
return val is not None
|
10 |
+
|
11 |
+
def default(val, d):
|
12 |
+
return val if exists(val) else d
|
13 |
+
|
14 |
+
# lambda layer
|
15 |
+
|
16 |
+
class LambdaLayer(nn.Module):
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
dim,
|
20 |
+
*,
|
21 |
+
dim_k,
|
22 |
+
n = None,
|
23 |
+
r = None,
|
24 |
+
heads = 4,
|
25 |
+
dim_out = None,
|
26 |
+
dim_u = 1):
|
27 |
+
super().__init__()
|
28 |
+
dim_out = default(dim_out, dim)
|
29 |
+
self.u = dim_u # intra-depth dimension
|
30 |
+
self.heads = heads
|
31 |
+
|
32 |
+
assert (dim_out % heads) == 0, 'values dimension must be divisible by number of heads for multi-head query'
|
33 |
+
dim_v = dim_out // heads
|
34 |
+
|
35 |
+
self.to_q = nn.Conv2d(dim, dim_k * heads, 1, bias = False)
|
36 |
+
self.to_k = nn.Conv2d(dim, dim_k * dim_u, 1, bias = False)
|
37 |
+
self.to_v = nn.Conv2d(dim, dim_v * dim_u, 1, bias = False)
|
38 |
+
|
39 |
+
self.norm_q = nn.BatchNorm2d(dim_k * heads)
|
40 |
+
self.norm_v = nn.BatchNorm2d(dim_v * dim_u)
|
41 |
+
|
42 |
+
self.local_contexts = exists(r)
|
43 |
+
if exists(r):
|
44 |
+
assert (r % 2) == 1, 'Receptive kernel size should be odd'
|
45 |
+
self.pos_conv = nn.Conv3d(dim_u, dim_k, (1, r, r), padding = (0, r // 2, r // 2))
|
46 |
+
else:
|
47 |
+
assert exists(n), 'You must specify the total sequence length (h x w)'
|
48 |
+
self.pos_emb = nn.Parameter(torch.randn(n, n, dim_k, dim_u))
|
49 |
+
|
50 |
+
|
51 |
+
def forward(self, x):
|
52 |
+
b, c, hh, ww, u, h = *x.shape, self.u, self.heads
|
53 |
+
|
54 |
+
q = self.to_q(x)
|
55 |
+
k = self.to_k(x)
|
56 |
+
v = self.to_v(x)
|
57 |
+
|
58 |
+
q = self.norm_q(q)
|
59 |
+
v = self.norm_v(v)
|
60 |
+
|
61 |
+
q = rearrange(q, 'b (h k) hh ww -> b h k (hh ww)', h = h)
|
62 |
+
k = rearrange(k, 'b (u k) hh ww -> b u k (hh ww)', u = u)
|
63 |
+
v = rearrange(v, 'b (u v) hh ww -> b u v (hh ww)', u = u)
|
64 |
+
|
65 |
+
k = k.softmax(dim=-1)
|
66 |
+
|
67 |
+
λc = einsum('b u k m, b u v m -> b k v', k, v)
|
68 |
+
Yc = einsum('b h k n, b k v -> b h v n', q, λc)
|
69 |
+
|
70 |
+
if self.local_contexts:
|
71 |
+
v = rearrange(v, 'b u v (hh ww) -> b u v hh ww', hh = hh, ww = ww)
|
72 |
+
λp = self.pos_conv(v)
|
73 |
+
Yp = einsum('b h k n, b k v n -> b h v n', q, λp.flatten(3))
|
74 |
+
else:
|
75 |
+
λp = einsum('n m k u, b u v m -> b n k v', self.pos_emb, v)
|
76 |
+
Yp = einsum('b h k n, b n k v -> b h v n', q, λp)
|
77 |
+
|
78 |
+
Y = Yc + Yp
|
79 |
+
out = rearrange(Y, 'b h v (hh ww) -> b (h v) hh ww', hh = hh, ww = ww)
|
80 |
+
return out
|
81 |
+
|
82 |
+
|
83 |
+
# i'm not sure whether this will work or not
|
84 |
+
class Recursion(nn.Module):
|
85 |
+
def __init__(self, N: int, hidden_dim:int=64):
|
86 |
+
super(Recursion,self).__init__()
|
87 |
+
self.N = N
|
88 |
+
self.lambdaNxN_identity = LambdaLayer(dim=hidden_dim, dim_out=hidden_dim, n=N * N, dim_k=16, heads=2, dim_u=1)
|
89 |
+
# merge upstream information here
|
90 |
+
self.lambdaNxN_merge = LambdaLayer(dim=2*hidden_dim, dim_out=hidden_dim, n=N * N, dim_k=16, heads=2, dim_u=1)
|
91 |
+
self.downscale_conv = nn.Conv2d(hidden_dim, hidden_dim, kernel_size=N, stride=N)
|
92 |
+
self.upscale_conv = nn.Conv2d(hidden_dim, hidden_dim * N * N, kernel_size=3,padding=1)
|
93 |
+
self.pixel_shuffle = nn.PixelShuffle(N)
|
94 |
+
|
95 |
+
def forward(self, x: torch.Tensor):
|
96 |
+
N = self.N
|
97 |
+
|
98 |
+
def to_patch(blocks:torch.Tensor)->torch.Tensor:
|
99 |
+
shape = blocks.shape
|
100 |
+
blocks_patch = F.unfold(blocks, kernel_size=N, stride=N)
|
101 |
+
blocks_patch = blocks_patch.view(shape[0], shape[1], N, N, -1)
|
102 |
+
num_patch = blocks_patch.shape[-1]
|
103 |
+
blocks_patch = blocks_patch.permute(0, 4, 1, 2, 3).reshape(-1, shape[1], N, N).contiguous()
|
104 |
+
return blocks_patch, num_patch
|
105 |
+
|
106 |
+
def combine_patch(processed_patch,shape,num_patch):
|
107 |
+
processed_patch = processed_patch.reshape(shape[0], num_patch, shape[1], N, N)
|
108 |
+
processed_patch=processed_patch.permute(0, 2, 3, 4, 1).reshape(shape[0],shape[1] * N * N,num_patch).contiguous()
|
109 |
+
processed=F.fold(processed_patch,output_size=(shape[-2],shape[-1]),kernel_size=N,stride=N)
|
110 |
+
return processed
|
111 |
+
|
112 |
+
def process(blocks:torch.Tensor)->torch.Tensor:
|
113 |
+
shape = blocks.shape
|
114 |
+
if blocks.shape[-1] == N:
|
115 |
+
processed = self.lambdaNxN_identity(blocks)
|
116 |
+
return processed
|
117 |
+
# to NxN patchs
|
118 |
+
blocks_patch,num_patch=to_patch(blocks)
|
119 |
+
# pass through identity
|
120 |
+
processed_patch = self.lambdaNxN_identity(blocks_patch)
|
121 |
+
# back to HxW
|
122 |
+
processed=combine_patch(processed_patch,shape,num_patch)
|
123 |
+
# get feedback
|
124 |
+
feedback = process(self.downscale_conv(processed))
|
125 |
+
# upscale feedback
|
126 |
+
upscale_feedback = self.upscale_conv(feedback)
|
127 |
+
upscale_feedback=self.pixel_shuffle(upscale_feedback)
|
128 |
+
# combine results
|
129 |
+
combined = torch.cat([processed, upscale_feedback], dim=1)
|
130 |
+
combined_shape=combined.shape
|
131 |
+
combined_patch,num_patch=to_patch(combined)
|
132 |
+
combined_patch_reduced = self.lambdaNxN_merge(combined_patch)
|
133 |
+
ret_shape=(combined_shape[0],combined_shape[1]//2,combined_shape[2],combined_shape[3])
|
134 |
+
ret=combine_patch(combined_patch_reduced,ret_shape,num_patch)
|
135 |
+
return ret
|
136 |
+
|
137 |
+
return process(x)
|
SR/code/lambda_networks/rlambda_networks.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn, einsum
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from einops import rearrange
|
5 |
+
|
6 |
+
|
7 |
+
# helpers functions
|
8 |
+
|
9 |
+
def exists(val):
|
10 |
+
return val is not None
|
11 |
+
|
12 |
+
|
13 |
+
def default(val, d):
|
14 |
+
return val if exists(val) else d
|
15 |
+
|
16 |
+
|
17 |
+
# lambda layer
|
18 |
+
|
19 |
+
class RLambdaLayer(nn.Module):
|
20 |
+
def __init__(
|
21 |
+
self,
|
22 |
+
dim,
|
23 |
+
*,
|
24 |
+
dim_k,
|
25 |
+
n=None,
|
26 |
+
r=None,
|
27 |
+
heads=4,
|
28 |
+
dim_out=None,
|
29 |
+
dim_u=1,
|
30 |
+
recurrence=None
|
31 |
+
):
|
32 |
+
super().__init__()
|
33 |
+
dim_out = default(dim_out, dim)
|
34 |
+
self.u = dim_u # intra-depth dimension
|
35 |
+
self.heads = heads
|
36 |
+
|
37 |
+
assert (dim_out % heads) == 0, 'values dimension must be divisible by number of heads for multi-head query'
|
38 |
+
dim_v = dim_out // heads
|
39 |
+
|
40 |
+
self.to_q = nn.Conv2d(dim, dim_k * heads, 1, bias=False)
|
41 |
+
self.to_k = nn.Conv2d(dim, dim_k * dim_u, 1, bias=False)
|
42 |
+
self.to_v = nn.Conv2d(dim, dim_v * dim_u, 1, bias=False)
|
43 |
+
|
44 |
+
self.norm_q = nn.BatchNorm2d(dim_k * heads)
|
45 |
+
self.norm_v = nn.BatchNorm2d(dim_v * dim_u)
|
46 |
+
|
47 |
+
self.local_contexts = exists(r)
|
48 |
+
self.recurrence = recurrence
|
49 |
+
if exists(r):
|
50 |
+
assert (r % 2) == 1, 'Receptive kernel size should be odd'
|
51 |
+
self.pos_conv = nn.Conv3d(dim_u, dim_k, (1, r, r), padding=(0, r // 2, r // 2))
|
52 |
+
else:
|
53 |
+
assert exists(n), 'You must specify the total sequence length (h x w)'
|
54 |
+
self.pos_emb = nn.Parameter(torch.randn(n, n, dim_k, dim_u))
|
55 |
+
|
56 |
+
def apply_lambda(self, lambda_c, lambda_p, x):
|
57 |
+
b, c, hh, ww, u, h = *x.shape, self.u, self.heads
|
58 |
+
q = self.to_q(x)
|
59 |
+
q = self.norm_q(q)
|
60 |
+
q = rearrange(q, 'b (h k) hh ww -> b h k (hh ww)', h=h)
|
61 |
+
Yc = einsum('b h k n, b k v -> b h v n', q, lambda_c)
|
62 |
+
if self.local_contexts:
|
63 |
+
Yp = einsum('b h k n, b k v n -> b h v n', q, lambda_p.flatten(3))
|
64 |
+
else:
|
65 |
+
Yp = einsum('b h k n, b n k v -> b h v n', q, lambda_p)
|
66 |
+
Y = Yc + Yp
|
67 |
+
out = rearrange(Y, 'b h v (hh ww) -> b (h v) hh ww', hh=hh, ww=ww)
|
68 |
+
return out
|
69 |
+
|
70 |
+
def forward(self, x):
|
71 |
+
b, c, hh, ww, u, h = *x.shape, self.u, self.heads
|
72 |
+
|
73 |
+
k = self.to_k(x)
|
74 |
+
v = self.to_v(x)
|
75 |
+
|
76 |
+
v = self.norm_v(v)
|
77 |
+
|
78 |
+
k = rearrange(k, 'b (u k) hh ww -> b u k (hh ww)', u=u)
|
79 |
+
v = rearrange(v, 'b (u v) hh ww -> b u v (hh ww)', u=u)
|
80 |
+
|
81 |
+
k = k.softmax(dim=-1)
|
82 |
+
|
83 |
+
λc = einsum('b u k m, b u v m -> b k v', k, v)
|
84 |
+
|
85 |
+
if self.local_contexts:
|
86 |
+
v = rearrange(v, 'b u v (hh ww) -> b u v hh ww', hh=hh, ww=ww)
|
87 |
+
λp = self.pos_conv(v)
|
88 |
+
else:
|
89 |
+
λp = einsum('n m k u, b u v m -> b n k v', self.pos_emb, v)
|
90 |
+
out = x
|
91 |
+
for i in range(self.recurrence):
|
92 |
+
out = self.apply_lambda(λc, λp, out)
|
93 |
+
return out
|
SR/code/loss/__init__.py
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from importlib import import_module
|
3 |
+
|
4 |
+
import matplotlib
|
5 |
+
matplotlib.use('Agg')
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
import torch.nn.functional as F
|
13 |
+
|
14 |
+
def sequence_loss(sr, hr, loss_func, gamma=0.8, max_val=None):
|
15 |
+
""" Loss function defined over sequence of flow predictions """
|
16 |
+
|
17 |
+
n_recurrence = len(sr)
|
18 |
+
total_loss = 0.0
|
19 |
+
buffer=[0.0]*n_recurrence
|
20 |
+
# exlude invalid pixels and extremely large diplacements
|
21 |
+
for i in range(n_recurrence):
|
22 |
+
i_weight = gamma**(n_recurrence - i - 1)
|
23 |
+
i_loss = loss_func(sr[i],hr)
|
24 |
+
buffer[i]=i_loss.item()
|
25 |
+
# total_loss += i_weight * (valid[:, None] * i_loss).mean()
|
26 |
+
total_loss += i_weight * (i_loss)
|
27 |
+
return total_loss,buffer
|
28 |
+
|
29 |
+
class Loss(nn.modules.loss._Loss):
|
30 |
+
def __init__(self, args, ckp):
|
31 |
+
super(Loss, self).__init__()
|
32 |
+
print('Preparing loss function:')
|
33 |
+
self.buffer=[0.0]*args.recurrence
|
34 |
+
self.n_GPUs = args.n_GPUs
|
35 |
+
self.loss = []
|
36 |
+
self.loss_module = nn.ModuleList()
|
37 |
+
for loss in args.loss.split('+'):
|
38 |
+
weight, loss_type = loss.split('*')
|
39 |
+
if loss_type == 'MSE':
|
40 |
+
loss_function = nn.MSELoss()
|
41 |
+
elif loss_type == 'L1':
|
42 |
+
loss_function = nn.L1Loss()
|
43 |
+
elif loss_type.find('VGG') >= 0:
|
44 |
+
module = import_module('loss.vgg')
|
45 |
+
loss_function = getattr(module, 'VGG')(
|
46 |
+
loss_type[3:],
|
47 |
+
rgb_range=args.rgb_range
|
48 |
+
)
|
49 |
+
elif loss_type.find('GAN') >= 0:
|
50 |
+
module = import_module('loss.adversarial')
|
51 |
+
loss_function = getattr(module, 'Adversarial')(
|
52 |
+
args,
|
53 |
+
loss_type
|
54 |
+
)
|
55 |
+
|
56 |
+
self.loss.append({
|
57 |
+
'type': loss_type,
|
58 |
+
'weight': float(weight),
|
59 |
+
'function': loss_function}
|
60 |
+
)
|
61 |
+
if loss_type.find('GAN') >= 0:
|
62 |
+
self.loss.append({'type': 'DIS', 'weight': 1, 'function': None})
|
63 |
+
|
64 |
+
if len(self.loss) > 1:
|
65 |
+
self.loss.append({'type': 'Total', 'weight': 0, 'function': None})
|
66 |
+
|
67 |
+
for l in self.loss:
|
68 |
+
if l['function'] is not None:
|
69 |
+
print('{:.3f} * {}'.format(l['weight'], l['type']))
|
70 |
+
# self.loss_module.append(l['function'])
|
71 |
+
|
72 |
+
self.log = torch.Tensor()
|
73 |
+
|
74 |
+
device = torch.device('cpu' if args.cpu else 'cuda')
|
75 |
+
self.loss_module.to(device)
|
76 |
+
if args.precision == 'half': self.loss_module.half()
|
77 |
+
if not args.cpu and args.n_GPUs > 1:
|
78 |
+
self.loss_module = nn.DataParallel(
|
79 |
+
self.loss_module, range(args.n_GPUs)
|
80 |
+
)
|
81 |
+
|
82 |
+
if args.load != '': self.load(ckp.dir, cpu=args.cpu)
|
83 |
+
|
84 |
+
def forward(self, sr, hr):
|
85 |
+
losses = []
|
86 |
+
for i, l in enumerate(self.loss):
|
87 |
+
if l['function'] is not None:
|
88 |
+
if isinstance(sr,list):
|
89 |
+
# weights=[0.32,0.08,0.02,0.01,0.005]
|
90 |
+
# weights=weights[::-1]
|
91 |
+
# weights=[0.01,0.02,0.08,0.32]
|
92 |
+
# self.buffer=[]
|
93 |
+
effective_loss,buffer_lst=sequence_loss(sr,hr,l['function'])
|
94 |
+
# for k in range(len(sr)):
|
95 |
+
# loss=l['function'](sr[k], hr)
|
96 |
+
# self.buffer.append(loss.item())
|
97 |
+
# effective_loss=loss*weights[k]*l['weight']
|
98 |
+
losses.append(effective_loss)
|
99 |
+
self.buffer=buffer_lst
|
100 |
+
self.log[-1, i] += effective_loss.item()
|
101 |
+
else:
|
102 |
+
loss = l['function'](sr, hr)
|
103 |
+
effective_loss = l['weight'] * loss
|
104 |
+
losses.append(effective_loss)
|
105 |
+
self.buffer[0]=effective_loss.item()
|
106 |
+
self.log[-1, i] += effective_loss.item()
|
107 |
+
elif l['type'] == 'DIS':
|
108 |
+
self.log[-1, i] += self.loss[i - 1]['function'].loss
|
109 |
+
|
110 |
+
loss_sum = sum(losses)
|
111 |
+
if len(self.loss) > 1:
|
112 |
+
self.log[-1, -1] += loss_sum.item()
|
113 |
+
|
114 |
+
return loss_sum
|
115 |
+
|
116 |
+
def step(self):
|
117 |
+
for l in self.get_loss_module():
|
118 |
+
if hasattr(l, 'scheduler'):
|
119 |
+
l.scheduler.step()
|
120 |
+
|
121 |
+
def start_log(self):
|
122 |
+
self.log = torch.cat((self.log, torch.zeros(1, len(self.loss))))
|
123 |
+
|
124 |
+
def end_log(self, n_batches):
|
125 |
+
self.log[-1].div_(n_batches)
|
126 |
+
|
127 |
+
def display_loss(self, batch):
|
128 |
+
n_samples = batch + 1
|
129 |
+
log = []
|
130 |
+
for l, c in zip(self.loss, self.log[-1]):
|
131 |
+
log.append('[{}: {:.4f}]'.format(l['type'], c / n_samples))
|
132 |
+
|
133 |
+
return ''.join(log)
|
134 |
+
|
135 |
+
def plot_loss(self, apath, epoch):
|
136 |
+
axis = np.linspace(1, epoch, epoch)
|
137 |
+
for i, l in enumerate(self.loss):
|
138 |
+
label = '{} Loss'.format(l['type'])
|
139 |
+
fig = plt.figure()
|
140 |
+
plt.title(label)
|
141 |
+
plt.plot(axis, self.log[:, i].numpy(), label=label)
|
142 |
+
plt.legend()
|
143 |
+
plt.xlabel('Epochs')
|
144 |
+
plt.ylabel('Loss')
|
145 |
+
plt.grid(True)
|
146 |
+
plt.savefig(os.path.join(apath, 'loss_{}.pdf'.format(l['type'])))
|
147 |
+
plt.close(fig)
|
148 |
+
|
149 |
+
def get_loss_module(self):
|
150 |
+
if self.n_GPUs == 1:
|
151 |
+
return self.loss_module
|
152 |
+
else:
|
153 |
+
return self.loss_module.module
|
154 |
+
|
155 |
+
def save(self, apath):
|
156 |
+
torch.save(self.state_dict(), os.path.join(apath, 'loss.pt'))
|
157 |
+
torch.save(self.log, os.path.join(apath, 'loss_log.pt'))
|
158 |
+
|
159 |
+
def load(self, apath, cpu=False):
|
160 |
+
if cpu:
|
161 |
+
kwargs = {'map_location': lambda storage, loc: storage}
|
162 |
+
else:
|
163 |
+
kwargs = {}
|
164 |
+
|
165 |
+
self.load_state_dict(torch.load(
|
166 |
+
os.path.join(apath, 'loss.pt'),
|
167 |
+
**kwargs
|
168 |
+
))
|
169 |
+
self.log = torch.load(os.path.join(apath, 'loss_log.pt'))
|
170 |
+
for l in self.get_loss_module():
|
171 |
+
if hasattr(l, 'scheduler'):
|
172 |
+
for _ in range(len(self.log)): l.scheduler.step()
|
173 |
+
|
SR/code/loss/__loss__.py
ADDED
File without changes
|