Delete archs
Browse files- archs/NAFBlock.py +0 -176
- archs/arch_util.py +0 -111
- archs/model.py +0 -181
archs/NAFBlock.py
DELETED
@@ -1,176 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import torch.nn as nn
|
3 |
-
import torch.nn.functional as F
|
4 |
-
|
5 |
-
# Modules from model
|
6 |
-
try:
|
7 |
-
from archs.arch_util import LayerNorm2d
|
8 |
-
import archs.arch_util as arch_util
|
9 |
-
except:
|
10 |
-
from arch_util import LayerNorm2d
|
11 |
-
import arch_util as arch_util
|
12 |
-
|
13 |
-
# Process Block 4 en SFNet y 5 bloques en AmpNet, con el spatial block aplicado en AmpNet (frequency stage)
|
14 |
-
# tal y como lo tienen ellos en su github (aunque en el paper es al revés) y no lo aplican el space stage
|
15 |
-
|
16 |
-
|
17 |
-
class SimpleGate(nn.Module):
|
18 |
-
def forward(self, x):
|
19 |
-
x1, x2 = x.chunk(2, dim=1)
|
20 |
-
return x1 * x2
|
21 |
-
|
22 |
-
class SpaBlock(nn.Module):
|
23 |
-
def __init__(self, nc, DW_Expand = 2, FFN_Expand=2, drop_out_rate=0.):
|
24 |
-
super(SpaBlock, self).__init__()
|
25 |
-
dw_channel = nc * DW_Expand
|
26 |
-
self.conv1 = nn.Conv2d(in_channels=nc, out_channels=dw_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
|
27 |
-
self.conv2 = nn.Conv2d(in_channels=dw_channel, out_channels=dw_channel, kernel_size=3, padding=1, stride=1, groups=dw_channel,
|
28 |
-
bias=True) # the dconv
|
29 |
-
self.conv3 = nn.Conv2d(in_channels=dw_channel // 2, out_channels=nc, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
|
30 |
-
|
31 |
-
# Simplified Channel Attention
|
32 |
-
self.sca = nn.Sequential(
|
33 |
-
nn.AdaptiveAvgPool2d(1),
|
34 |
-
nn.Conv2d(in_channels=dw_channel // 2, out_channels=dw_channel // 2, kernel_size=1, padding=0, stride=1,
|
35 |
-
groups=1, bias=True),
|
36 |
-
)
|
37 |
-
|
38 |
-
# SimpleGate
|
39 |
-
self.sg = SimpleGate()
|
40 |
-
|
41 |
-
ffn_channel = FFN_Expand * nc
|
42 |
-
self.conv4 = nn.Conv2d(in_channels=nc, out_channels=ffn_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
|
43 |
-
self.conv5 = nn.Conv2d(in_channels=ffn_channel // 2, out_channels=nc, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
|
44 |
-
|
45 |
-
self.norm1 = LayerNorm2d(nc)
|
46 |
-
self.norm2 = LayerNorm2d(nc)
|
47 |
-
|
48 |
-
self.dropout1 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity()
|
49 |
-
self.dropout2 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity()
|
50 |
-
|
51 |
-
self.beta = nn.Parameter(torch.zeros((1, nc, 1, 1)), requires_grad=True)
|
52 |
-
self.gamma = nn.Parameter(torch.zeros((1, nc, 1, 1)), requires_grad=True)
|
53 |
-
|
54 |
-
def forward(self, x):
|
55 |
-
|
56 |
-
x = self.norm1(x) # size [B, C, H, W]
|
57 |
-
|
58 |
-
x = self.conv1(x) # size [B, 2*C, H, W]
|
59 |
-
x = self.conv2(x) # size [B, 2*C, H, W]
|
60 |
-
x = self.sg(x) # size [B, C, H, W]
|
61 |
-
x = x * self.sca(x) # size [B, C, H, W]
|
62 |
-
x = self.conv3(x) # size [B, C, H, W]
|
63 |
-
|
64 |
-
x = self.dropout1(x)
|
65 |
-
|
66 |
-
y = x + x * self.beta # size [B, C, H, W]
|
67 |
-
|
68 |
-
x = self.conv4(self.norm2(y)) # size [B, 2*C, H, W]
|
69 |
-
x = self.sg(x) # size [B, C, H, W]
|
70 |
-
x = self.conv5(x) # size [B, C, H, W]
|
71 |
-
|
72 |
-
x = self.dropout2(x)
|
73 |
-
|
74 |
-
return y + x * self.gamma
|
75 |
-
|
76 |
-
class FreBlock(nn.Module):
|
77 |
-
def __init__(self, nc):
|
78 |
-
super(FreBlock, self).__init__()
|
79 |
-
self.fpre = nn.Conv2d(nc, nc, 1, 1, 0)
|
80 |
-
self.process1 = nn.Sequential(
|
81 |
-
nn.Conv2d(nc, nc, 1, 1, 0),
|
82 |
-
nn.LeakyReLU(0.1, inplace=True),
|
83 |
-
nn.Conv2d(nc, nc, 1, 1, 0))
|
84 |
-
self.process2 = nn.Sequential(
|
85 |
-
nn.Conv2d(nc, nc, 1, 1, 0),
|
86 |
-
nn.LeakyReLU(0.1, inplace=True),
|
87 |
-
nn.Conv2d(nc, nc, 1, 1, 0))
|
88 |
-
|
89 |
-
def forward(self, x):
|
90 |
-
_, _, H, W = x.shape
|
91 |
-
x_freq = torch.fft.rfft2(self.fpre(x), norm='backward')
|
92 |
-
mag = torch.abs(x_freq)
|
93 |
-
pha = torch.angle(x_freq)
|
94 |
-
mag = self.process1(mag)
|
95 |
-
pha = self.process2(pha)
|
96 |
-
real = mag * torch.cos(pha)
|
97 |
-
imag = mag * torch.sin(pha)
|
98 |
-
x_out = torch.complex(real, imag)
|
99 |
-
x_out = torch.fft.irfft2(x_out, s=(H, W), norm='backward')
|
100 |
-
|
101 |
-
return x_out+x
|
102 |
-
|
103 |
-
class ProcessBlock(nn.Module):
|
104 |
-
def __init__(self, in_nc, spatial = True):
|
105 |
-
super(ProcessBlock,self).__init__()
|
106 |
-
self.spatial = spatial
|
107 |
-
self.spatial_process = SpaBlock(in_nc) if spatial else nn.Identity()
|
108 |
-
self.frequency_process = FreBlock(in_nc)
|
109 |
-
self.cat = nn.Conv2d(2*in_nc,in_nc,1,1,0) if spatial else nn.Conv2d(in_nc,in_nc,1,1,0)
|
110 |
-
|
111 |
-
def forward(self, x):
|
112 |
-
xori = x
|
113 |
-
x_freq = self.frequency_process(x)
|
114 |
-
x_spatial = self.spatial_process(x)
|
115 |
-
xcat = torch.cat([x_spatial,x_freq],1)
|
116 |
-
x_out = self.cat(xcat) if self.spatial else self.cat(x_freq)
|
117 |
-
|
118 |
-
return x_out+xori
|
119 |
-
|
120 |
-
class SFNet(nn.Module):
|
121 |
-
|
122 |
-
def __init__(self, nc,n=5):
|
123 |
-
super(SFNet,self).__init__()
|
124 |
-
|
125 |
-
self.list_block = list()
|
126 |
-
for index in range(n):
|
127 |
-
|
128 |
-
self.list_block.append(ProcessBlock(nc,spatial=False))
|
129 |
-
|
130 |
-
self.block = nn.Sequential(*self.list_block)
|
131 |
-
|
132 |
-
def forward(self, x):
|
133 |
-
|
134 |
-
x_ori = x
|
135 |
-
x_out = self.block(x_ori)
|
136 |
-
xout = x_ori + x_out
|
137 |
-
|
138 |
-
return xout
|
139 |
-
|
140 |
-
class AmplitudeNet_skip(nn.Module):
|
141 |
-
def __init__(self, nc,n=1):
|
142 |
-
super(AmplitudeNet_skip,self).__init__()
|
143 |
-
|
144 |
-
self.conv1 = nn.Sequential(
|
145 |
-
nn.Conv2d(3, nc, 1, 1, 0),
|
146 |
-
ProcessBlock(nc),
|
147 |
-
)
|
148 |
-
self.conv2 = ProcessBlock(nc)
|
149 |
-
self.conv3 = ProcessBlock(nc)
|
150 |
-
self.conv4 = nn.Sequential(
|
151 |
-
ProcessBlock(nc * 2),
|
152 |
-
nn.Conv2d(nc * 2, nc, 1, 1, 0),
|
153 |
-
)
|
154 |
-
|
155 |
-
self.conv5 = nn.Sequential(
|
156 |
-
ProcessBlock(nc * 2),
|
157 |
-
nn.Conv2d(nc * 2, nc, 1, 1, 0),
|
158 |
-
)
|
159 |
-
|
160 |
-
self.convout = nn.Sequential(
|
161 |
-
ProcessBlock(nc * 2),
|
162 |
-
nn.Conv2d(nc * 2, 3, 1, 1, 0),
|
163 |
-
)
|
164 |
-
|
165 |
-
def forward(self, x):
|
166 |
-
|
167 |
-
x1 = self.conv1(x)
|
168 |
-
x2 = self.conv2(x1)
|
169 |
-
x3 = self.conv3(x2)
|
170 |
-
x4 = self.conv5(torch.cat((x2, x3), dim=1))
|
171 |
-
xout = self.convout(torch.cat((x1, x4), dim=1))
|
172 |
-
|
173 |
-
return xout
|
174 |
-
|
175 |
-
|
176 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
archs/arch_util.py
DELETED
@@ -1,111 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import torch.nn as nn
|
3 |
-
import torch.nn.init as init
|
4 |
-
import torch.nn.functional as F
|
5 |
-
|
6 |
-
|
7 |
-
def initialize_weights(net_l, scale=1):
|
8 |
-
if not isinstance(net_l, list):
|
9 |
-
net_l = [net_l]
|
10 |
-
for net in net_l:
|
11 |
-
for m in net.modules():
|
12 |
-
if isinstance(m, nn.Conv2d):
|
13 |
-
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
|
14 |
-
m.weight.data *= scale # for residual block
|
15 |
-
if m.bias is not None:
|
16 |
-
m.bias.data.zero_()
|
17 |
-
elif isinstance(m, nn.Linear):
|
18 |
-
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
|
19 |
-
m.weight.data *= scale
|
20 |
-
if m.bias is not None:
|
21 |
-
m.bias.data.zero_()
|
22 |
-
elif isinstance(m, nn.BatchNorm2d):
|
23 |
-
init.constant_(m.weight, 1)
|
24 |
-
init.constant_(m.bias.data, 0.0)
|
25 |
-
|
26 |
-
|
27 |
-
def make_layer(block, n_layers):
|
28 |
-
layers = []
|
29 |
-
for _ in range(n_layers):
|
30 |
-
layers.append(block())
|
31 |
-
return nn.Sequential(*layers)
|
32 |
-
|
33 |
-
|
34 |
-
class ResidualBlock_noBN(nn.Module):
|
35 |
-
'''Residual block w/o BN
|
36 |
-
---Conv-ReLU-Conv-+-
|
37 |
-
|________________|
|
38 |
-
'''
|
39 |
-
|
40 |
-
def __init__(self, nf=64):
|
41 |
-
super(ResidualBlock_noBN, self).__init__()
|
42 |
-
self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
43 |
-
self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
44 |
-
|
45 |
-
# initialization
|
46 |
-
initialize_weights([self.conv1, self.conv2], 0.1)
|
47 |
-
|
48 |
-
def forward(self, x):
|
49 |
-
identity = x
|
50 |
-
out = F.relu(self.conv1(x), inplace=True)
|
51 |
-
out = self.conv2(out)
|
52 |
-
return identity + out
|
53 |
-
|
54 |
-
class ResidualBlock(nn.Module):
|
55 |
-
'''Residual block w/o BN
|
56 |
-
---Conv-ReLU-Conv-+-
|
57 |
-
|________________|
|
58 |
-
'''
|
59 |
-
|
60 |
-
def __init__(self, nf=64):
|
61 |
-
super(ResidualBlock, self).__init__()
|
62 |
-
self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
63 |
-
self.bn = nn.BatchNorm2d(nf)
|
64 |
-
self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
65 |
-
|
66 |
-
# initialization
|
67 |
-
initialize_weights([self.conv1, self.conv2], 0.1)
|
68 |
-
|
69 |
-
def forward(self, x):
|
70 |
-
identity = x
|
71 |
-
out = F.relu(self.bn(self.conv1(x)), inplace=True)
|
72 |
-
out = self.conv2(out)
|
73 |
-
return identity + out
|
74 |
-
|
75 |
-
class LayerNormFunction(torch.autograd.Function):
|
76 |
-
|
77 |
-
@staticmethod
|
78 |
-
def forward(ctx, x, weight, bias, eps):
|
79 |
-
ctx.eps = eps
|
80 |
-
N, C, H, W = x.size()
|
81 |
-
mu = x.mean(1, keepdim=True)
|
82 |
-
var = (x - mu).pow(2).mean(1, keepdim=True)
|
83 |
-
y = (x - mu) / (var + eps).sqrt()
|
84 |
-
ctx.save_for_backward(y, var, weight)
|
85 |
-
y = weight.view(1, C, 1, 1) * y + bias.view(1, C, 1, 1)
|
86 |
-
return y
|
87 |
-
|
88 |
-
@staticmethod
|
89 |
-
def backward(ctx, grad_output):
|
90 |
-
eps = ctx.eps
|
91 |
-
|
92 |
-
N, C, H, W = grad_output.size()
|
93 |
-
y, var, weight = ctx.saved_variables
|
94 |
-
g = grad_output * weight.view(1, C, 1, 1)
|
95 |
-
mean_g = g.mean(dim=1, keepdim=True)
|
96 |
-
|
97 |
-
mean_gy = (g * y).mean(dim=1, keepdim=True)
|
98 |
-
gx = 1. / torch.sqrt(var + eps) * (g - y * mean_gy - mean_g)
|
99 |
-
return gx, (grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0), grad_output.sum(dim=3).sum(dim=2).sum(
|
100 |
-
dim=0), None
|
101 |
-
|
102 |
-
class LayerNorm2d(nn.Module):
|
103 |
-
|
104 |
-
def __init__(self, channels, eps=1e-6):
|
105 |
-
super(LayerNorm2d, self).__init__()
|
106 |
-
self.register_parameter('weight', nn.Parameter(torch.ones(channels)))
|
107 |
-
self.register_parameter('bias', nn.Parameter(torch.zeros(channels)))
|
108 |
-
self.eps = eps
|
109 |
-
|
110 |
-
def forward(self, x):
|
111 |
-
return LayerNormFunction.apply(x, self.weight, self.bias, self.eps)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
archs/model.py
DELETED
@@ -1,181 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import torch.nn as nn
|
3 |
-
import torch.nn.functional as F
|
4 |
-
import functools
|
5 |
-
# import arch_util as arch_util
|
6 |
-
# from NAFBlock import *
|
7 |
-
import kornia
|
8 |
-
import torch.nn.functional as F
|
9 |
-
import torchvision.models
|
10 |
-
|
11 |
-
try:
|
12 |
-
import archs.arch_util as arch_util
|
13 |
-
from archs.NAFBlock import *
|
14 |
-
|
15 |
-
except:
|
16 |
-
import arch_util as arch_util
|
17 |
-
from NAFBlock import *
|
18 |
-
class VGG19(torch.nn.Module):
|
19 |
-
|
20 |
-
def __init__(self, requires_grad=False):
|
21 |
-
super().__init__()
|
22 |
-
vgg_pretrained_features = torchvision.models.vgg19(pretrained=True).features
|
23 |
-
self.slice1 = torch.nn.Sequential()
|
24 |
-
self.slice2 = torch.nn.Sequential()
|
25 |
-
self.slice3 = torch.nn.Sequential()
|
26 |
-
self.slice4 = torch.nn.Sequential()
|
27 |
-
self.slice5 = torch.nn.Sequential()
|
28 |
-
for x in range(2):
|
29 |
-
self.slice1.add_module(str(x), vgg_pretrained_features[x])
|
30 |
-
for x in range(2, 7):
|
31 |
-
self.slice2.add_module(str(x), vgg_pretrained_features[x])
|
32 |
-
for x in range(7, 12):
|
33 |
-
self.slice3.add_module(str(x), vgg_pretrained_features[x])
|
34 |
-
for x in range(12, 21):
|
35 |
-
self.slice4.add_module(str(x), vgg_pretrained_features[x])
|
36 |
-
for x in range(21, 30):
|
37 |
-
self.slice5.add_module(str(x), vgg_pretrained_features[x])
|
38 |
-
if not requires_grad:
|
39 |
-
for param in self.parameters():
|
40 |
-
param.requires_grad = False
|
41 |
-
|
42 |
-
def forward(self, X):
|
43 |
-
h_relu1 = self.slice1(X)
|
44 |
-
h_relu2 = self.slice2(h_relu1)
|
45 |
-
h_relu3 = self.slice3(h_relu2)
|
46 |
-
h_relu4 = self.slice4(h_relu3)
|
47 |
-
h_relu5 = self.slice5(h_relu4)
|
48 |
-
out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
|
49 |
-
return out
|
50 |
-
|
51 |
-
class VGGLoss(nn.Module):
|
52 |
-
|
53 |
-
def __init__(self):
|
54 |
-
|
55 |
-
super(VGGLoss, self).__init__()
|
56 |
-
self.vgg = VGG19().cuda()
|
57 |
-
# self.criterion = nn.L1Loss()
|
58 |
-
self.criterion = nn.L1Loss(reduction='sum')
|
59 |
-
self.criterion2 = nn.L1Loss()
|
60 |
-
self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0]
|
61 |
-
|
62 |
-
def forward(self, x, y):
|
63 |
-
|
64 |
-
x_vgg, y_vgg = self.vgg(x), self.vgg(y)
|
65 |
-
# print(x_vgg.shape, x_vgg.dtype, torch.max(x_vgg), torch.min(x_vgg), y_vgg.shape, y_vgg.dtype, torch.max(y_vgg), torch.min(y_vgg))
|
66 |
-
loss = 0
|
67 |
-
for i in range(len(x_vgg)):
|
68 |
-
# print(x_vgg[i].shape, y_vgg[i].shape, 'hey')
|
69 |
-
loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach())
|
70 |
-
# print(loss, i, 'hey')
|
71 |
-
|
72 |
-
return loss
|
73 |
-
|
74 |
-
|
75 |
-
class FourNet(nn.Module):
|
76 |
-
def __init__(self, nf=64):
|
77 |
-
super(FourNet, self).__init__()
|
78 |
-
|
79 |
-
# AMPLITUDE ENHANCEMENT
|
80 |
-
self.AmpNet = nn.Sequential(
|
81 |
-
AmplitudeNet_skip(8),
|
82 |
-
nn.Sigmoid()
|
83 |
-
)
|
84 |
-
|
85 |
-
self.nf = nf
|
86 |
-
ResidualBlock_noBN_f = functools.partial(arch_util.ResidualBlock_noBN, nf=nf)
|
87 |
-
|
88 |
-
self.conv_first_1 = nn.Conv2d(3 * 2, nf, 3, 1, 1, bias=True)
|
89 |
-
self.conv_first_2 = nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
|
90 |
-
self.conv_first_3 = nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
|
91 |
-
|
92 |
-
self.feature_extraction = arch_util.make_layer(ResidualBlock_noBN_f, 1)
|
93 |
-
self.recon_trunk = arch_util.make_layer(ResidualBlock_noBN_f, 1)
|
94 |
-
|
95 |
-
self.upconv1 = nn.Conv2d(nf*2, nf * 4, 3, 1, 1, bias=True)
|
96 |
-
self.upconv2 = nn.Conv2d(nf*2, nf * 4, 3, 1, 1, bias=True)
|
97 |
-
self.pixel_shuffle = nn.PixelShuffle(2)
|
98 |
-
self.HRconv = nn.Conv2d(nf*2, nf, 3, 1, 1, bias=True)
|
99 |
-
self.conv_last = nn.Conv2d(nf, 3, 3, 1, 1, bias=True)
|
100 |
-
|
101 |
-
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
102 |
-
self.transformer = SFNet(nf, n = 4)
|
103 |
-
self.recon_trunk_light = arch_util.make_layer(ResidualBlock_noBN_f, 6)
|
104 |
-
|
105 |
-
def get_mask(self,dark): # SNR map
|
106 |
-
|
107 |
-
light = kornia.filters.gaussian_blur2d(dark, (5, 5), (1.5, 1.5))
|
108 |
-
dark = dark[:, 0:1, :, :] * 0.299 + dark[:, 1:2, :, :] * 0.587 + dark[:, 2:3, :, :] * 0.114
|
109 |
-
light = light[:, 0:1, :, :] * 0.299 + light[:, 1:2, :, :] * 0.587 + light[:, 2:3, :, :] * 0.114
|
110 |
-
noise = torch.abs(dark - light)
|
111 |
-
|
112 |
-
mask = torch.div(light, noise + 0.0001)
|
113 |
-
|
114 |
-
batch_size = mask.shape[0]
|
115 |
-
height = mask.shape[2]
|
116 |
-
width = mask.shape[3]
|
117 |
-
mask_max = torch.max(mask.view(batch_size, -1), dim=1)[0]
|
118 |
-
mask_max = mask_max.view(batch_size, 1, 1, 1)
|
119 |
-
mask_max = mask_max.repeat(1, 1, height, width)
|
120 |
-
mask = mask * 1.0 / (mask_max + 0.0001)
|
121 |
-
|
122 |
-
mask = torch.clamp(mask, min=0, max=1.0)
|
123 |
-
return mask.float()
|
124 |
-
|
125 |
-
def forward(self, x):
|
126 |
-
|
127 |
-
# AMPLITUDE ENHANCEMENT
|
128 |
-
#--------------------------------------------------------Frequency Stage---------------------------------------------------
|
129 |
-
_, _, H, W = x.shape
|
130 |
-
image_fft = torch.fft.fft2(x, norm='backward')
|
131 |
-
mag_image = torch.abs(image_fft)
|
132 |
-
pha_image = torch.angle(image_fft)
|
133 |
-
curve_amps = self.AmpNet(x)
|
134 |
-
mag_image = mag_image / (curve_amps + 0.00000001) # * d4
|
135 |
-
real_image_enhanced = mag_image * torch.cos(pha_image)
|
136 |
-
imag_image_enhanced = mag_image * torch.sin(pha_image)
|
137 |
-
img_amp_enhanced = torch.fft.ifft2(torch.complex(real_image_enhanced, imag_image_enhanced), s=(H, W),
|
138 |
-
norm='backward').real
|
139 |
-
|
140 |
-
x_center = img_amp_enhanced
|
141 |
-
|
142 |
-
rate = 2 ** 3
|
143 |
-
pad_h = (rate - H % rate) % rate
|
144 |
-
pad_w = (rate - W % rate) % rate
|
145 |
-
if pad_h != 0 or pad_w != 0:
|
146 |
-
x_center = F.pad(x_center, (0, pad_w, 0, pad_h), "reflect")
|
147 |
-
x = F.pad(x, (0, pad_w, 0, pad_h), "reflect")
|
148 |
-
|
149 |
-
#------------------------------------------Spatial Stage---------------------------------------------------------------------
|
150 |
-
|
151 |
-
L1_fea_1 = self.lrelu(self.conv_first_1(torch.cat((x_center,x),dim=1)))
|
152 |
-
L1_fea_2 = self.lrelu(self.conv_first_2(L1_fea_1)) # Encoder
|
153 |
-
L1_fea_3 = self.lrelu(self.conv_first_3(L1_fea_2))
|
154 |
-
|
155 |
-
fea = self.feature_extraction(L1_fea_3)
|
156 |
-
fea_light = self.recon_trunk_light(fea)
|
157 |
-
|
158 |
-
h_feature = fea.shape[2]
|
159 |
-
w_feature = fea.shape[3]
|
160 |
-
mask_image = self.get_mask(x_center) # SNR Map
|
161 |
-
mask = F.interpolate(mask_image, size=[h_feature, w_feature], mode='nearest') # Resize and Normalize SNR map
|
162 |
-
|
163 |
-
fea_unfold = self.transformer(fea)
|
164 |
-
|
165 |
-
channel = fea.shape[1]
|
166 |
-
mask = mask.repeat(1, channel, 1, 1)
|
167 |
-
fea = fea_unfold * (1 - mask) + fea_light * mask # SNR-based Interaction
|
168 |
-
|
169 |
-
out_noise = self.recon_trunk(fea)
|
170 |
-
out_noise = torch.cat([out_noise, L1_fea_3], dim=1)
|
171 |
-
out_noise = self.lrelu(self.pixel_shuffle(self.upconv1(out_noise)))
|
172 |
-
out_noise = torch.cat([out_noise, L1_fea_2], dim=1) # Decoder
|
173 |
-
out_noise = self.lrelu(self.pixel_shuffle(self.upconv2(out_noise)))
|
174 |
-
out_noise = torch.cat([out_noise, L1_fea_1], dim=1)
|
175 |
-
out_noise = self.lrelu(self.HRconv(out_noise))
|
176 |
-
out_noise = self.conv_last(out_noise)
|
177 |
-
out_noise = out_noise + x
|
178 |
-
out_noise = out_noise[:, :, :H, :W]
|
179 |
-
|
180 |
-
|
181 |
-
return out_noise, mag_image, x_center, mask_image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|