dblasko commited on
Commit
3a273df
1 Parent(s): 16530c8

Add application files

Browse files
app.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from PIL import Image
3
+ import io
4
+ import sys, os
5
+ import torch
6
+ import torchvision.transforms as T
7
+ import torchvision.utils as vutils
8
+ import base64
9
+ import torch
10
+ import torchvision.transforms as T
11
+ from PIL import Image
12
+ from huggingface_hub import hf_hub_download
13
+ from model.MIRNet.model import MIRNet
14
+
15
+ from model.MIRNet.model import MIRNet
16
+
17
+
18
+ def run_model(input_image):
19
+ device = (
20
+ torch.device("cuda")
21
+ if torch.cuda.is_available()
22
+ else torch.device("mps")
23
+ if torch.backends.mps.is_available()
24
+ else torch.device("cpu")
25
+ )
26
+
27
+ model = MIRNet(num_features=64).to(device)
28
+ model_path = hf_hub_download(
29
+ repo_id="dblasko/mirnet-low-light-img-enhancement",
30
+ filename="mirnet_finetuned.pth",
31
+ )
32
+ model.load_state_dict(
33
+ torch.load(model_path, map_location=device)["model_state_dict"]
34
+ )
35
+
36
+ model.eval()
37
+ with torch.no_grad():
38
+ img = input_image
39
+ img_tensor = T.Compose(
40
+ [
41
+ T.Resize(400),
42
+ T.ToTensor(),
43
+ T.Normalize([0.0, 0.0, 0.0], [1.0, 1.0, 1.0]),
44
+ ]
45
+ )(img).unsqueeze(0)
46
+ img_tensor = img_tensor.to(device)
47
+
48
+ if img_tensor.shape[2] % 8 != 0:
49
+ img_tensor = img_tensor[:, :, : -(img_tensor.shape[2] % 8), :]
50
+ if img_tensor.shape[3] % 8 != 0:
51
+ img_tensor = img_tensor[:, :, :, : -(img_tensor.shape[3] % 8)]
52
+
53
+ output = model(img_tensor)
54
+
55
+ vutils.save_image(output, open(f"temp.png", "wb"))
56
+ output_image = Image.open("temp.png")
57
+ os.remove("temp.png")
58
+ return output_image
59
+
60
+
61
+ def get_base64_font(font_path):
62
+ with open(font_path, "rb") as font_file:
63
+ return base64.b64encode(font_file.read()).decode()
64
+
65
+
66
+ st.set_page_config(layout="wide")
67
+
68
+ font_name = "Gloock"
69
+ gloock_b64 = get_base64_font("utils/assets/Gloock-Regular.ttf")
70
+ font_name_text = "Merriweather sans"
71
+ merri_b64 = get_base64_font("utils/assets/MerriweatherSans-Regular.ttf")
72
+ hide_streamlit_style = f"""
73
+ <style>
74
+ #MainMenu {'{visibility: hidden;}'}
75
+ footer {'{visibility: hidden;}'}
76
+
77
+ @font-face {{
78
+ font-family: '{font_name}';
79
+ src: url(data:font/ttf;base64,{gloock_b64}) format('truetype');
80
+ }}
81
+ @font-face {{
82
+ font-family: '{font_name_text}';
83
+ src: url(data:font/ttf;base64,{merri_b64}) format('truetype');
84
+ }}
85
+ span {{
86
+ font-family: '{font_name_text}';
87
+ }}
88
+ .e1nzilvr1, .st-emotion-cache-10trblm {{
89
+ font-family: '{font_name}';
90
+ font-size: 65px;
91
+ }}
92
+
93
+ </style>
94
+ """
95
+ st.markdown(hide_streamlit_style, unsafe_allow_html=True)
96
+
97
+ st.title("Low-light event-image enhancement with MIRNet.")
98
+
99
+ # File uploader widget
100
+ uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
101
+ if uploaded_file is not None:
102
+ # To read file as bytes:
103
+ bytes_data = uploaded_file.getvalue()
104
+ image = Image.open(io.BytesIO(bytes_data)).convert("RGB")
105
+
106
+ # Create two columns for images
107
+ col1, col2 = st.columns(2)
108
+
109
+ with col1:
110
+ st.image(image, caption="Original Image", use_column_width="always")
111
+
112
+ # Button to enhance image
113
+ if st.button("Enhance Image"):
114
+ with col2:
115
+ # Assume your model has a function 'enhance' to enhance the image
116
+ enhanced_image = run_model(image)
117
+ st.image(
118
+ enhanced_image, caption="Enhanced Image", use_column_width="always"
119
+ )
120
+
121
+ # Download button
122
+ buf = io.BytesIO()
123
+ enhanced_image.save(buf, format="JPEG")
124
+ byte_im = buf.getvalue()
125
+ st.download_button(
126
+ label="Download image",
127
+ data=byte_im,
128
+ file_name="enhanced_image.jpg",
129
+ mime="image/jpeg",
130
+ )
model/MIRNet/ChannelAttention.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class ChannelAttention(nn.Module):
6
+ """
7
+ Squeezes down the input to 1x1xC, applies the excitation operation and restores the C channels through a 1x1 convolution.
8
+
9
+ In: HxWxC
10
+ Out: HxWxC (original channels are restored by multiplying the output with the original input)
11
+ """
12
+
13
+ def __init__(self, in_channels, reduction_ratio=8, bias=True):
14
+ super().__init__()
15
+ self.squeezing = nn.AdaptiveAvgPool2d(1)
16
+ self.excitation = nn.Sequential(
17
+ nn.Conv2d(
18
+ in_channels,
19
+ in_channels // reduction_ratio,
20
+ kernel_size=1,
21
+ padding=0,
22
+ bias=bias,
23
+ ),
24
+ nn.PReLU(),
25
+ nn.Conv2d(
26
+ in_channels // reduction_ratio,
27
+ in_channels,
28
+ kernel_size=1,
29
+ padding=0,
30
+ bias=bias,
31
+ ),
32
+ nn.Sigmoid(),
33
+ )
34
+
35
+ def forward(self, x):
36
+ squeezed_x = self.squeezing(x) # 1x1xC
37
+ excitation = self.excitation(squeezed_x) # 1x1x(C/r)
38
+ return (
39
+ excitation * x
40
+ ) # HxWxC restored through the mult. with the original input
model/MIRNet/ChannelCompression.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class ChannelCompression(nn.Module):
6
+ """
7
+ Reduces the input to 2 channels by concatenating the global average pooling and global max pooling outputs.
8
+
9
+ In: HxWxC
10
+ Out: HxWx2
11
+ """
12
+
13
+ def forward(self, x):
14
+ return torch.cat(
15
+ (torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1
16
+ )
model/MIRNet/Downsampling.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as fun
4
+ import numpy as np
5
+
6
+
7
+ class DownsamplingBlock(nn.Module):
8
+ """
9
+ Downsamples the input to halve the dimensions while doubling the channels through two parallel conv + antialiased downsampling branches.
10
+
11
+ In: HxWxC
12
+ Out: H/2xW/2x2C
13
+ """
14
+
15
+ def __init__(self, in_channels, bias=False):
16
+ super().__init__()
17
+ self.branch1 = (
18
+ nn.Sequential( # 1x1 conv + PReLU -> 3x3 conv + PReLU -> AD -> 1x1 conv
19
+ nn.Conv2d(
20
+ in_channels, in_channels, kernel_size=1, padding=0, bias=bias
21
+ ),
22
+ nn.PReLU(),
23
+ nn.Conv2d(
24
+ in_channels, in_channels, kernel_size=3, padding=1, bias=bias
25
+ ),
26
+ nn.PReLU(),
27
+ DownSample(channels=in_channels, filter_size=3, stride=2),
28
+ nn.Conv2d(
29
+ in_channels, in_channels * 2, kernel_size=1, padding=0, bias=bias
30
+ ),
31
+ )
32
+ )
33
+ self.branch2 = nn.Sequential(
34
+ DownSample(channels=in_channels, filter_size=3, stride=2),
35
+ nn.Conv2d(
36
+ in_channels, in_channels * 2, kernel_size=1, padding=0, bias=bias
37
+ ),
38
+ )
39
+
40
+ def forward(self, x):
41
+ return self.branch1(x) + self.branch2(x) # H/2xW/2x2C
42
+
43
+
44
+ class DownsamplingModule(nn.Module):
45
+ """
46
+ Downsampling module of the network composed of (scaling factor) DownsamplingBlocks.
47
+
48
+ In: HxWxC
49
+ Out: H/2^(scaling factor) x W/2^(scaling factor) x C^2(scaling factor)
50
+ """
51
+
52
+ def __init__(self, in_channels, scaling_factor, stride=2):
53
+ super().__init__()
54
+ self.scaling_factor = int(np.log2(scaling_factor))
55
+
56
+ blocks = []
57
+ for i in range(self.scaling_factor):
58
+ blocks.append(DownsamplingBlock(in_channels))
59
+ in_channels = int(in_channels * stride)
60
+ self.blocks = nn.Sequential(*blocks)
61
+
62
+ def forward(self, x):
63
+ x = self.blocks(x)
64
+ return x # H/2^(scaling factor) x W/2^(scaling factor) x C^2(scaling factor)
65
+
66
+
67
+ class DownSample(nn.Module):
68
+ """
69
+ Antialiased downsampling module using the blur-pooling method.
70
+
71
+ From Adobe's implementation available here: https://github.com/yilundu/improved_contrastive_divergence/blob/master/downsample.py
72
+ """
73
+
74
+ def __init__(
75
+ self, pad_type="reflect", filter_size=3, stride=2, channels=None, pad_off=0
76
+ ):
77
+ super().__init__()
78
+ self.filter_size = filter_size
79
+ self.stride = stride
80
+ self.pad_off = pad_off
81
+ self.channels = channels
82
+ self.pad_sizes = [
83
+ int(1.0 * (filter_size - 1) / 2),
84
+ int(np.ceil(1.0 * (filter_size - 1) / 2)),
85
+ int(1.0 * (filter_size - 1) / 2),
86
+ int(np.ceil(1.0 * (filter_size - 1) / 2)),
87
+ ]
88
+
89
+ self.pad_sizes = [pad_size + pad_off for pad_size in self.pad_sizes]
90
+ self.off = int((self.stride - 1) / 2.0)
91
+
92
+ if self.filter_size == 1:
93
+ a = np.array([1.0])
94
+ elif self.filter_size == 2:
95
+ a = np.array([1.0, 1.0])
96
+ elif self.filter_size == 3:
97
+ a = np.array([1.0, 2.0, 1.0])
98
+ elif self.filter_size == 4:
99
+ a = np.array([1.0, 3.0, 3.0, 1.0])
100
+ elif self.filter_size == 5:
101
+ a = np.array([1.0, 4.0, 6.0, 4.0, 1.0])
102
+ elif self.filter_size == 6:
103
+ a = np.array([1.0, 5.0, 10.0, 10.0, 5.0, 1.0])
104
+ elif self.filter_size == 7:
105
+ a = np.array([1.0, 6.0, 15.0, 20.0, 15.0, 6.0, 1.0])
106
+
107
+ filt = torch.Tensor(a[:, None] * a[None, :])
108
+ filt = filt / torch.sum(filt)
109
+ self.register_buffer(
110
+ "filt", filt[None, None, :, :].repeat((self.channels, 1, 1, 1))
111
+ )
112
+ self.pad = get_pad_layer(pad_type)(self.pad_sizes)
113
+
114
+ def forward(self, x):
115
+ if self.filter_size == 1:
116
+ if self.pad_off == 0:
117
+ return x[:, :, :: self.stride, :: self.stride]
118
+ else:
119
+ return self.pad(x)[:, :, :: self.stride, :: self.stride]
120
+
121
+ else:
122
+ return fun.conv2d(
123
+ self.pad(x), self.filt, stride=self.stride, groups=x.shape[1]
124
+ )
125
+
126
+
127
+ def get_pad_layer(pad_type):
128
+ if pad_type == "reflect":
129
+ pad_layer = nn.ReflectionPad2d
130
+ elif pad_type == "replication":
131
+ pad_layer = nn.ReplicationPad2d
132
+ else:
133
+ print("Pad Type [%s] not recognized" % pad_type)
134
+
135
+ return pad_layer
model/MIRNet/DualAttentionUnit.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from model.MIRNet.ChannelAttention import ChannelAttention
4
+
5
+ from model.MIRNet.SpatialAttention import SpatialAttention
6
+
7
+
8
+ class DualAttentionUnit(nn.Module):
9
+ """
10
+ Combines the ChannelAttention and SpatialAttention modules.
11
+ (conv, PReLU, conv -> concat. SA & CA output -> conv -> skip connection from input)
12
+
13
+ In: HxWxC
14
+ Out: HxWxC (original channels are restored by multiplying the output with the original input)
15
+ """
16
+
17
+ def __init__(self, in_channels, kernel_size=3, reduction_ratio=8, bias=False):
18
+ super().__init__()
19
+ self.initial_convs = nn.Sequential(
20
+ nn.Conv2d(in_channels, in_channels, kernel_size, padding=1, bias=bias),
21
+ nn.PReLU(),
22
+ nn.Conv2d(in_channels, in_channels, kernel_size, padding=1, bias=bias),
23
+ )
24
+ self.channel_attention = ChannelAttention(in_channels, reduction_ratio, bias)
25
+ self.spatial_attention = SpatialAttention()
26
+ self.final_conv = nn.Conv2d(
27
+ in_channels * 2, in_channels, kernel_size=1, bias=bias
28
+ )
29
+ self.in_channels = in_channels
30
+
31
+ def forward(self, x):
32
+ initial_convs = self.initial_convs(x) # HxWxC
33
+ channel_attention = self.channel_attention(initial_convs) # HxWxC
34
+ spatial_attention = self.spatial_attention(initial_convs) # HxWxC
35
+ attention = torch.cat((spatial_attention, channel_attention), dim=1) # HxWx2C
36
+ block_output = self.final_conv(
37
+ attention
38
+ ) # HxWxC - the 1x1 conv. restores the C channels for the skip connection
39
+ return x + block_output # the addition is the skip connection from input
model/MIRNet/MultiScaleResidualBlock.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import torch.nn as nn
4
+ from model.MIRNet.Downsampling import DownsamplingModule
5
+
6
+ from model.MIRNet.DualAttentionUnit import DualAttentionUnit
7
+ from model.MIRNet.SelectiveKernelFeatureFusion import SelectiveKernelFeatureFusion
8
+ from model.MIRNet.Upsampling import UpsamplingModule
9
+
10
+
11
+ class MultiScaleResidualBlock(nn.Module):
12
+ """
13
+ Three parallel convolutional streams at different resolutions. Information is exchanged through residual connexions.
14
+ """
15
+
16
+ def __init__(self, num_features, height, width, stride, bias):
17
+ super().__init__()
18
+ self.num_features = num_features
19
+ self.height = height
20
+ self.width = width
21
+ features = [int((stride**i) * num_features) for i in range(height)]
22
+ scale = [2**i for i in range(1, height)]
23
+
24
+ self.dual_attention_units = nn.ModuleList(
25
+ [
26
+ nn.ModuleList(
27
+ [DualAttentionUnit(int(num_features * stride**i))] * width
28
+ )
29
+ for i in range(height)
30
+ ]
31
+ )
32
+ self.last_up = nn.ModuleDict()
33
+ for i in range(1, height):
34
+ self.last_up.update(
35
+ {
36
+ f"{i}": UpsamplingModule(
37
+ in_channels=int(num_features * stride**i),
38
+ scaling_factor=2**i,
39
+ stride=stride,
40
+ )
41
+ }
42
+ )
43
+
44
+ self.down = nn.ModuleDict()
45
+ i = 0
46
+ scale.reverse()
47
+ for f in features:
48
+ for s in scale[i:]:
49
+ self.down.update({f"{f}_{s}": DownsamplingModule(f, s, stride)})
50
+ i += 1
51
+
52
+ self.up = nn.ModuleDict()
53
+ i = 0
54
+ features.reverse()
55
+ for f in features:
56
+ for s in scale[i:]:
57
+ self.up.update({f"{f}_{s}": UpsamplingModule(f, s, stride)})
58
+ i += 1
59
+
60
+ self.out_conv = nn.Conv2d(
61
+ num_features, num_features, kernel_size=3, padding=1, bias=bias
62
+ )
63
+ self.skff_blocks = nn.ModuleList(
64
+ [
65
+ SelectiveKernelFeatureFusion(num_features * stride**i, height)
66
+ for i in range(height)
67
+ ]
68
+ )
69
+
70
+ def forward(self, x):
71
+ inp = x.clone()
72
+ out = []
73
+
74
+ for j in range(self.height):
75
+ if j == 0:
76
+ inp = self.dual_attention_units[j][0](inp)
77
+ else:
78
+ inp = self.dual_attention_units[j][0](
79
+ self.down[f"{inp.size(1)}_{2}"](inp)
80
+ )
81
+ out.append(inp)
82
+
83
+ for i in range(1, self.width):
84
+ if True:
85
+ temp = []
86
+ for j in range(self.height):
87
+ TENSOR = []
88
+ nfeats = (2**j) * self.num_features
89
+ for k in range(self.height):
90
+ TENSOR.append(self.select_up_down(out[k], j, k))
91
+
92
+ skff = self.skff_blocks[j](TENSOR)
93
+ temp.append(skff)
94
+
95
+ else:
96
+ temp = out
97
+
98
+ for j in range(self.height):
99
+ out[j] = self.dual_attention_units[j][i](temp[j])
100
+
101
+ output = []
102
+ for k in range(self.height):
103
+ output.append(self.select_last_up(out[k], k))
104
+
105
+ output = self.skff_blocks[0](output)
106
+ output = self.out_conv(output)
107
+ output = output + x
108
+ return output
109
+
110
+ def select_up_down(self, tensor, j, k):
111
+ if j == k:
112
+ return tensor
113
+ else:
114
+ diff = 2 ** np.abs(j - k)
115
+ if j < k:
116
+ return self.up[f"{tensor.size(1)}_{diff}"](tensor)
117
+ else:
118
+ return self.down[f"{tensor.size(1)}_{diff}"](tensor)
119
+
120
+ def select_last_up(self, tensor, k):
121
+ if k == 0:
122
+ return tensor
123
+ else:
124
+ return self.last_up[f"{k}"](tensor)
model/MIRNet/ResidualRecurrentGroup.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from model.MIRNet.MultiScaleResidualBlock import MultiScaleResidualBlock
5
+
6
+
7
+ class ResidualRecurrentGroup(nn.Module):
8
+ """
9
+ Group of multi-scale residual blocks followed by a convolutional layer. The output is what is added to the input image for restoration.
10
+ """
11
+
12
+ def __init__(
13
+ self, num_features, number_msrb_blocks, height, width, stride, bias=False
14
+ ):
15
+ super().__init__()
16
+ blocks = [
17
+ MultiScaleResidualBlock(num_features, height, width, stride, bias)
18
+ for _ in range(number_msrb_blocks)
19
+ ]
20
+ blocks.append(
21
+ nn.Conv2d(
22
+ num_features,
23
+ num_features,
24
+ kernel_size=3,
25
+ padding=1,
26
+ stride=1,
27
+ bias=bias,
28
+ )
29
+ )
30
+ self.blocks = nn.Sequential(*blocks)
31
+
32
+ def forward(self, x):
33
+ output = self.blocks(x)
34
+ return x + output # restored image, HxWxC
model/MIRNet/SelectiveKernelFeatureFusion.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class SelectiveKernelFeatureFusion(nn.Module):
6
+ """
7
+ Merges outputs of the three different resolutions through self-attention.
8
+
9
+ All three inputs are summed -> global average pooling -> downscaling -> the signal is passed through 3 different convs to have three descriptors,
10
+ softmax is applied to each descriptor to get 3 attention activations used to recalibrate the three input feature maps.
11
+ """
12
+
13
+ def __init__(self, in_channels, reduction_ratio, bias=False):
14
+ super().__init__()
15
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
16
+ conv_out_channels = max(int(in_channels / reduction_ratio), 4)
17
+ self.convolution = nn.Sequential(
18
+ nn.Conv2d(
19
+ in_channels, conv_out_channels, kernel_size=1, padding=0, bias=bias
20
+ ),
21
+ nn.PReLU(),
22
+ )
23
+
24
+ self.attention_convs = nn.ModuleList([])
25
+ for i in range(3):
26
+ self.attention_convs.append(
27
+ nn.Conv2d(
28
+ conv_out_channels, in_channels, kernel_size=1, stride=1, bias=bias
29
+ )
30
+ )
31
+
32
+ self.softmax = nn.Softmax(dim=1)
33
+
34
+ def forward(self, x):
35
+ batch_size = x[0].shape[0]
36
+ n_features = x[0].shape[1]
37
+
38
+ x = torch.cat(
39
+ x, dim=1
40
+ ) # the three outputs of diff. res. are concatenated along the channel dimension
41
+ x = x.view(
42
+ batch_size, 3, n_features, x.shape[2], x.shape[3]
43
+ ) # batch_size x 3 x n_features x H x W
44
+
45
+ z = torch.sum(x, dim=1) # batch_size x n_features x H x W
46
+ z = self.avg_pool(z) # batch_size x n_features x 1 x 1
47
+ z = self.convolution(z) # batch_size x n_features/8 x 1 x 1
48
+
49
+ attention_activations = [
50
+ atn(z) for atn in self.attention_convs
51
+ ] # 3 x (batch_size x n_features x 1 x 1)
52
+ attention_activations = torch.cat(
53
+ attention_activations, dim=1
54
+ ) # batch_size x 3*n_features x 1 x 1
55
+ attention_activations = attention_activations.view(
56
+ batch_size, 3, n_features, 1, 1
57
+ ) # batch_size x 3 x n_features x 1 x 1
58
+
59
+ attention_activations = self.softmax(
60
+ attention_activations
61
+ ) # batch_size x 3 x n_features x 1 x 1
62
+
63
+ return torch.sum(
64
+ x * attention_activations, dim=1
65
+ ) # batch_size x n_features x H x W (the three feature maps are recalibrated and summed
model/MIRNet/SpatialAttention.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from model.MIRNet.ChannelCompression import ChannelCompression
5
+
6
+
7
+ class SpatialAttention(nn.Module):
8
+ """
9
+ Reduces the input to 2 channel with the ChannelCompression module and applies a 2D convolution with 1 output channel.
10
+
11
+ In: HxWxC
12
+ Out: HxWxC (original channels are restored by multiplying the output with the original input)
13
+ """
14
+
15
+ def __init__(self):
16
+ super().__init__()
17
+ self.channel_compression = ChannelCompression()
18
+ self.conv = nn.Conv2d(2, 1, kernel_size=5, stride=1, padding=2)
19
+
20
+ def forward(self, x):
21
+ x_compressed = self.channel_compression(x) # HxWx2
22
+ x_conv = self.conv(x_compressed) # HxWx1
23
+ scaling_factor = torch.sigmoid(x_conv)
24
+ return x * scaling_factor # HxWxC
model/MIRNet/Upsampling.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+
5
+
6
+ class UpsamplingBlock(nn.Module):
7
+ """
8
+ Upsamples the input to double the dimensions while halving the channels through two parallel conv + bilinear upsampling branches.
9
+
10
+ In: HxWxC
11
+ Out: 2Hx2WxC/2
12
+ """
13
+
14
+ def __init__(self, in_channels, bias=False):
15
+ super().__init__()
16
+ self.branch1 = nn.Sequential( # 1x1 conv + PReLU -> 3x3 conv + PReLU -> BU -> 1x1 conv
17
+ nn.Conv2d(in_channels, in_channels, kernel_size=1, padding=0, bias=bias),
18
+ nn.PReLU(),
19
+ nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, bias=bias),
20
+ nn.PReLU(),
21
+ nn.Upsample(scale_factor=2, mode='bilinear', align_corners=bias),
22
+ nn.Conv2d(in_channels, in_channels // 2, kernel_size=1, padding=0, bias=bias)
23
+ )
24
+ self.branch2 = nn.Sequential(
25
+ nn.Upsample(scale_factor=2, mode='bilinear', align_corners=bias),
26
+ nn.Conv2d(in_channels, in_channels // 2, kernel_size=1, padding=0, bias=bias)
27
+ )
28
+
29
+ def forward(self, x):
30
+ return self.branch1(x) + self.branch2(x) # 2Hx2WxC/2
31
+
32
+
33
+
34
+ class UpsamplingModule(nn.Module):
35
+ """
36
+ Upsampling module of the network composed of (scaling factor) UpsamplingBlocks.
37
+
38
+ In: HxWxC
39
+ Out: 2^(scaling factor)H x 2^(scaling factor)W x C/2^(scaling factor)
40
+ """
41
+
42
+ def __init__(self, in_channels, scaling_factor, stride=2):
43
+ super().__init__()
44
+ self.scaling_factor = int(np.log2(scaling_factor))
45
+
46
+ blocks = []
47
+ for i in range(self.scaling_factor):
48
+ blocks.append(UpsamplingBlock(in_channels))
49
+ in_channels = int(in_channels // 2)
50
+ self.blocks = nn.Sequential(*blocks)
51
+
52
+
53
+ def forward(self, x):
54
+ return self.blocks(x) # 2^(scaling factor)H x 2^(scaling factor)W x C/2^(scaling factor)
55
+
56
+
model/MIRNet/__init__.py ADDED
File without changes
model/MIRNet/__pycache__/ChannelAttention.cpython-310.pyc ADDED
Binary file (1.28 kB). View file
 
model/MIRNet/__pycache__/ChannelCompression.cpython-310.pyc ADDED
Binary file (764 Bytes). View file
 
model/MIRNet/__pycache__/Downsampling.cpython-310.pyc ADDED
Binary file (4.24 kB). View file
 
model/MIRNet/__pycache__/DualAttentionUnit.cpython-310.pyc ADDED
Binary file (1.61 kB). View file
 
model/MIRNet/__pycache__/MultiScaleResidualBlock.cpython-310.pyc ADDED
Binary file (3.6 kB). View file
 
model/MIRNet/__pycache__/ResidualRecurrentGroup.cpython-310.pyc ADDED
Binary file (1.43 kB). View file
 
model/MIRNet/__pycache__/SelectiveKernelFeatureFusion.cpython-310.pyc ADDED
Binary file (2.07 kB). View file
 
model/MIRNet/__pycache__/SpatialAttention.cpython-310.pyc ADDED
Binary file (1.23 kB). View file
 
model/MIRNet/__pycache__/Upsampling.cpython-310.pyc ADDED
Binary file (2 kB). View file
 
model/MIRNet/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (141 Bytes). View file
 
model/MIRNet/__pycache__/model.cpython-310.pyc ADDED
Binary file (1.72 kB). View file
 
model/MIRNet/model.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from model.MIRNet.ResidualRecurrentGroup import ResidualRecurrentGroup
5
+
6
+
7
+ class MIRNet(nn.Module):
8
+ """
9
+ Low-level features are extracted through convolution and passed to n residual recurrent groups that operate at different resolutions.
10
+ Their output is added to the input image for restoration.
11
+
12
+ Please refer to the documentation of the different blocks of the model in this folder for detailed explanations.
13
+ """
14
+
15
+ def __init__(
16
+ self,
17
+ in_channels=3,
18
+ out_channels=3,
19
+ num_features=64,
20
+ kernel_size=3,
21
+ stride=2,
22
+ number_msrb=2,
23
+ number_rrg=3,
24
+ height=3,
25
+ width=2,
26
+ bias=False,
27
+ ):
28
+ super().__init__()
29
+ self.conv_start = nn.Conv2d(
30
+ in_channels, num_features, kernel_size, padding=1, bias=bias
31
+ )
32
+ msrb_blocks = [
33
+ ResidualRecurrentGroup(
34
+ num_features, number_msrb, height, width, stride, bias
35
+ )
36
+ for _ in range(number_rrg)
37
+ ]
38
+ self.msrb_blocks = nn.Sequential(*msrb_blocks)
39
+ self.conv_end = nn.Conv2d(
40
+ num_features, out_channels, kernel_size, padding=1, bias=bias
41
+ )
42
+
43
+ def forward(self, x):
44
+ output = self.conv_start(x)
45
+ output = self.msrb_blocks(output)
46
+ output = self.conv_end(output)
47
+ return x + output # restored image, HxWxC
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ huggingface
2
+ datasets
3
+ torch
4
+ torchvision
5
+ streamlit