juaben commited on
Commit
95da5be
·
verified ·
1 Parent(s): e32151a

Delete archs

Browse files
Files changed (3) hide show
  1. archs/NAFBlock.py +0 -176
  2. archs/arch_util.py +0 -111
  3. archs/model.py +0 -181
archs/NAFBlock.py DELETED
@@ -1,176 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
-
5
- # Modules from model
6
- try:
7
- from archs.arch_util import LayerNorm2d
8
- import archs.arch_util as arch_util
9
- except:
10
- from arch_util import LayerNorm2d
11
- import arch_util as arch_util
12
-
13
- # Process Block 4 en SFNet y 5 bloques en AmpNet, con el spatial block aplicado en AmpNet (frequency stage)
14
- # tal y como lo tienen ellos en su github (aunque en el paper es al revés) y no lo aplican el space stage
15
-
16
-
17
- class SimpleGate(nn.Module):
18
- def forward(self, x):
19
- x1, x2 = x.chunk(2, dim=1)
20
- return x1 * x2
21
-
22
- class SpaBlock(nn.Module):
23
- def __init__(self, nc, DW_Expand = 2, FFN_Expand=2, drop_out_rate=0.):
24
- super(SpaBlock, self).__init__()
25
- dw_channel = nc * DW_Expand
26
- self.conv1 = nn.Conv2d(in_channels=nc, out_channels=dw_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
27
- self.conv2 = nn.Conv2d(in_channels=dw_channel, out_channels=dw_channel, kernel_size=3, padding=1, stride=1, groups=dw_channel,
28
- bias=True) # the dconv
29
- self.conv3 = nn.Conv2d(in_channels=dw_channel // 2, out_channels=nc, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
30
-
31
- # Simplified Channel Attention
32
- self.sca = nn.Sequential(
33
- nn.AdaptiveAvgPool2d(1),
34
- nn.Conv2d(in_channels=dw_channel // 2, out_channels=dw_channel // 2, kernel_size=1, padding=0, stride=1,
35
- groups=1, bias=True),
36
- )
37
-
38
- # SimpleGate
39
- self.sg = SimpleGate()
40
-
41
- ffn_channel = FFN_Expand * nc
42
- self.conv4 = nn.Conv2d(in_channels=nc, out_channels=ffn_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
43
- self.conv5 = nn.Conv2d(in_channels=ffn_channel // 2, out_channels=nc, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
44
-
45
- self.norm1 = LayerNorm2d(nc)
46
- self.norm2 = LayerNorm2d(nc)
47
-
48
- self.dropout1 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity()
49
- self.dropout2 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity()
50
-
51
- self.beta = nn.Parameter(torch.zeros((1, nc, 1, 1)), requires_grad=True)
52
- self.gamma = nn.Parameter(torch.zeros((1, nc, 1, 1)), requires_grad=True)
53
-
54
- def forward(self, x):
55
-
56
- x = self.norm1(x) # size [B, C, H, W]
57
-
58
- x = self.conv1(x) # size [B, 2*C, H, W]
59
- x = self.conv2(x) # size [B, 2*C, H, W]
60
- x = self.sg(x) # size [B, C, H, W]
61
- x = x * self.sca(x) # size [B, C, H, W]
62
- x = self.conv3(x) # size [B, C, H, W]
63
-
64
- x = self.dropout1(x)
65
-
66
- y = x + x * self.beta # size [B, C, H, W]
67
-
68
- x = self.conv4(self.norm2(y)) # size [B, 2*C, H, W]
69
- x = self.sg(x) # size [B, C, H, W]
70
- x = self.conv5(x) # size [B, C, H, W]
71
-
72
- x = self.dropout2(x)
73
-
74
- return y + x * self.gamma
75
-
76
- class FreBlock(nn.Module):
77
- def __init__(self, nc):
78
- super(FreBlock, self).__init__()
79
- self.fpre = nn.Conv2d(nc, nc, 1, 1, 0)
80
- self.process1 = nn.Sequential(
81
- nn.Conv2d(nc, nc, 1, 1, 0),
82
- nn.LeakyReLU(0.1, inplace=True),
83
- nn.Conv2d(nc, nc, 1, 1, 0))
84
- self.process2 = nn.Sequential(
85
- nn.Conv2d(nc, nc, 1, 1, 0),
86
- nn.LeakyReLU(0.1, inplace=True),
87
- nn.Conv2d(nc, nc, 1, 1, 0))
88
-
89
- def forward(self, x):
90
- _, _, H, W = x.shape
91
- x_freq = torch.fft.rfft2(self.fpre(x), norm='backward')
92
- mag = torch.abs(x_freq)
93
- pha = torch.angle(x_freq)
94
- mag = self.process1(mag)
95
- pha = self.process2(pha)
96
- real = mag * torch.cos(pha)
97
- imag = mag * torch.sin(pha)
98
- x_out = torch.complex(real, imag)
99
- x_out = torch.fft.irfft2(x_out, s=(H, W), norm='backward')
100
-
101
- return x_out+x
102
-
103
- class ProcessBlock(nn.Module):
104
- def __init__(self, in_nc, spatial = True):
105
- super(ProcessBlock,self).__init__()
106
- self.spatial = spatial
107
- self.spatial_process = SpaBlock(in_nc) if spatial else nn.Identity()
108
- self.frequency_process = FreBlock(in_nc)
109
- self.cat = nn.Conv2d(2*in_nc,in_nc,1,1,0) if spatial else nn.Conv2d(in_nc,in_nc,1,1,0)
110
-
111
- def forward(self, x):
112
- xori = x
113
- x_freq = self.frequency_process(x)
114
- x_spatial = self.spatial_process(x)
115
- xcat = torch.cat([x_spatial,x_freq],1)
116
- x_out = self.cat(xcat) if self.spatial else self.cat(x_freq)
117
-
118
- return x_out+xori
119
-
120
- class SFNet(nn.Module):
121
-
122
- def __init__(self, nc,n=5):
123
- super(SFNet,self).__init__()
124
-
125
- self.list_block = list()
126
- for index in range(n):
127
-
128
- self.list_block.append(ProcessBlock(nc,spatial=False))
129
-
130
- self.block = nn.Sequential(*self.list_block)
131
-
132
- def forward(self, x):
133
-
134
- x_ori = x
135
- x_out = self.block(x_ori)
136
- xout = x_ori + x_out
137
-
138
- return xout
139
-
140
- class AmplitudeNet_skip(nn.Module):
141
- def __init__(self, nc,n=1):
142
- super(AmplitudeNet_skip,self).__init__()
143
-
144
- self.conv1 = nn.Sequential(
145
- nn.Conv2d(3, nc, 1, 1, 0),
146
- ProcessBlock(nc),
147
- )
148
- self.conv2 = ProcessBlock(nc)
149
- self.conv3 = ProcessBlock(nc)
150
- self.conv4 = nn.Sequential(
151
- ProcessBlock(nc * 2),
152
- nn.Conv2d(nc * 2, nc, 1, 1, 0),
153
- )
154
-
155
- self.conv5 = nn.Sequential(
156
- ProcessBlock(nc * 2),
157
- nn.Conv2d(nc * 2, nc, 1, 1, 0),
158
- )
159
-
160
- self.convout = nn.Sequential(
161
- ProcessBlock(nc * 2),
162
- nn.Conv2d(nc * 2, 3, 1, 1, 0),
163
- )
164
-
165
- def forward(self, x):
166
-
167
- x1 = self.conv1(x)
168
- x2 = self.conv2(x1)
169
- x3 = self.conv3(x2)
170
- x4 = self.conv5(torch.cat((x2, x3), dim=1))
171
- xout = self.convout(torch.cat((x1, x4), dim=1))
172
-
173
- return xout
174
-
175
-
176
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
archs/arch_util.py DELETED
@@ -1,111 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.init as init
4
- import torch.nn.functional as F
5
-
6
-
7
- def initialize_weights(net_l, scale=1):
8
- if not isinstance(net_l, list):
9
- net_l = [net_l]
10
- for net in net_l:
11
- for m in net.modules():
12
- if isinstance(m, nn.Conv2d):
13
- init.kaiming_normal_(m.weight, a=0, mode='fan_in')
14
- m.weight.data *= scale # for residual block
15
- if m.bias is not None:
16
- m.bias.data.zero_()
17
- elif isinstance(m, nn.Linear):
18
- init.kaiming_normal_(m.weight, a=0, mode='fan_in')
19
- m.weight.data *= scale
20
- if m.bias is not None:
21
- m.bias.data.zero_()
22
- elif isinstance(m, nn.BatchNorm2d):
23
- init.constant_(m.weight, 1)
24
- init.constant_(m.bias.data, 0.0)
25
-
26
-
27
- def make_layer(block, n_layers):
28
- layers = []
29
- for _ in range(n_layers):
30
- layers.append(block())
31
- return nn.Sequential(*layers)
32
-
33
-
34
- class ResidualBlock_noBN(nn.Module):
35
- '''Residual block w/o BN
36
- ---Conv-ReLU-Conv-+-
37
- |________________|
38
- '''
39
-
40
- def __init__(self, nf=64):
41
- super(ResidualBlock_noBN, self).__init__()
42
- self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
43
- self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
44
-
45
- # initialization
46
- initialize_weights([self.conv1, self.conv2], 0.1)
47
-
48
- def forward(self, x):
49
- identity = x
50
- out = F.relu(self.conv1(x), inplace=True)
51
- out = self.conv2(out)
52
- return identity + out
53
-
54
- class ResidualBlock(nn.Module):
55
- '''Residual block w/o BN
56
- ---Conv-ReLU-Conv-+-
57
- |________________|
58
- '''
59
-
60
- def __init__(self, nf=64):
61
- super(ResidualBlock, self).__init__()
62
- self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
63
- self.bn = nn.BatchNorm2d(nf)
64
- self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
65
-
66
- # initialization
67
- initialize_weights([self.conv1, self.conv2], 0.1)
68
-
69
- def forward(self, x):
70
- identity = x
71
- out = F.relu(self.bn(self.conv1(x)), inplace=True)
72
- out = self.conv2(out)
73
- return identity + out
74
-
75
- class LayerNormFunction(torch.autograd.Function):
76
-
77
- @staticmethod
78
- def forward(ctx, x, weight, bias, eps):
79
- ctx.eps = eps
80
- N, C, H, W = x.size()
81
- mu = x.mean(1, keepdim=True)
82
- var = (x - mu).pow(2).mean(1, keepdim=True)
83
- y = (x - mu) / (var + eps).sqrt()
84
- ctx.save_for_backward(y, var, weight)
85
- y = weight.view(1, C, 1, 1) * y + bias.view(1, C, 1, 1)
86
- return y
87
-
88
- @staticmethod
89
- def backward(ctx, grad_output):
90
- eps = ctx.eps
91
-
92
- N, C, H, W = grad_output.size()
93
- y, var, weight = ctx.saved_variables
94
- g = grad_output * weight.view(1, C, 1, 1)
95
- mean_g = g.mean(dim=1, keepdim=True)
96
-
97
- mean_gy = (g * y).mean(dim=1, keepdim=True)
98
- gx = 1. / torch.sqrt(var + eps) * (g - y * mean_gy - mean_g)
99
- return gx, (grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0), grad_output.sum(dim=3).sum(dim=2).sum(
100
- dim=0), None
101
-
102
- class LayerNorm2d(nn.Module):
103
-
104
- def __init__(self, channels, eps=1e-6):
105
- super(LayerNorm2d, self).__init__()
106
- self.register_parameter('weight', nn.Parameter(torch.ones(channels)))
107
- self.register_parameter('bias', nn.Parameter(torch.zeros(channels)))
108
- self.eps = eps
109
-
110
- def forward(self, x):
111
- return LayerNormFunction.apply(x, self.weight, self.bias, self.eps)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
archs/model.py DELETED
@@ -1,181 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- import functools
5
- # import arch_util as arch_util
6
- # from NAFBlock import *
7
- import kornia
8
- import torch.nn.functional as F
9
- import torchvision.models
10
-
11
- try:
12
- import archs.arch_util as arch_util
13
- from archs.NAFBlock import *
14
-
15
- except:
16
- import arch_util as arch_util
17
- from NAFBlock import *
18
- class VGG19(torch.nn.Module):
19
-
20
- def __init__(self, requires_grad=False):
21
- super().__init__()
22
- vgg_pretrained_features = torchvision.models.vgg19(pretrained=True).features
23
- self.slice1 = torch.nn.Sequential()
24
- self.slice2 = torch.nn.Sequential()
25
- self.slice3 = torch.nn.Sequential()
26
- self.slice4 = torch.nn.Sequential()
27
- self.slice5 = torch.nn.Sequential()
28
- for x in range(2):
29
- self.slice1.add_module(str(x), vgg_pretrained_features[x])
30
- for x in range(2, 7):
31
- self.slice2.add_module(str(x), vgg_pretrained_features[x])
32
- for x in range(7, 12):
33
- self.slice3.add_module(str(x), vgg_pretrained_features[x])
34
- for x in range(12, 21):
35
- self.slice4.add_module(str(x), vgg_pretrained_features[x])
36
- for x in range(21, 30):
37
- self.slice5.add_module(str(x), vgg_pretrained_features[x])
38
- if not requires_grad:
39
- for param in self.parameters():
40
- param.requires_grad = False
41
-
42
- def forward(self, X):
43
- h_relu1 = self.slice1(X)
44
- h_relu2 = self.slice2(h_relu1)
45
- h_relu3 = self.slice3(h_relu2)
46
- h_relu4 = self.slice4(h_relu3)
47
- h_relu5 = self.slice5(h_relu4)
48
- out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
49
- return out
50
-
51
- class VGGLoss(nn.Module):
52
-
53
- def __init__(self):
54
-
55
- super(VGGLoss, self).__init__()
56
- self.vgg = VGG19().cuda()
57
- # self.criterion = nn.L1Loss()
58
- self.criterion = nn.L1Loss(reduction='sum')
59
- self.criterion2 = nn.L1Loss()
60
- self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0]
61
-
62
- def forward(self, x, y):
63
-
64
- x_vgg, y_vgg = self.vgg(x), self.vgg(y)
65
- # print(x_vgg.shape, x_vgg.dtype, torch.max(x_vgg), torch.min(x_vgg), y_vgg.shape, y_vgg.dtype, torch.max(y_vgg), torch.min(y_vgg))
66
- loss = 0
67
- for i in range(len(x_vgg)):
68
- # print(x_vgg[i].shape, y_vgg[i].shape, 'hey')
69
- loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach())
70
- # print(loss, i, 'hey')
71
-
72
- return loss
73
-
74
-
75
- class FourNet(nn.Module):
76
- def __init__(self, nf=64):
77
- super(FourNet, self).__init__()
78
-
79
- # AMPLITUDE ENHANCEMENT
80
- self.AmpNet = nn.Sequential(
81
- AmplitudeNet_skip(8),
82
- nn.Sigmoid()
83
- )
84
-
85
- self.nf = nf
86
- ResidualBlock_noBN_f = functools.partial(arch_util.ResidualBlock_noBN, nf=nf)
87
-
88
- self.conv_first_1 = nn.Conv2d(3 * 2, nf, 3, 1, 1, bias=True)
89
- self.conv_first_2 = nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
90
- self.conv_first_3 = nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
91
-
92
- self.feature_extraction = arch_util.make_layer(ResidualBlock_noBN_f, 1)
93
- self.recon_trunk = arch_util.make_layer(ResidualBlock_noBN_f, 1)
94
-
95
- self.upconv1 = nn.Conv2d(nf*2, nf * 4, 3, 1, 1, bias=True)
96
- self.upconv2 = nn.Conv2d(nf*2, nf * 4, 3, 1, 1, bias=True)
97
- self.pixel_shuffle = nn.PixelShuffle(2)
98
- self.HRconv = nn.Conv2d(nf*2, nf, 3, 1, 1, bias=True)
99
- self.conv_last = nn.Conv2d(nf, 3, 3, 1, 1, bias=True)
100
-
101
- self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
102
- self.transformer = SFNet(nf, n = 4)
103
- self.recon_trunk_light = arch_util.make_layer(ResidualBlock_noBN_f, 6)
104
-
105
- def get_mask(self,dark): # SNR map
106
-
107
- light = kornia.filters.gaussian_blur2d(dark, (5, 5), (1.5, 1.5))
108
- dark = dark[:, 0:1, :, :] * 0.299 + dark[:, 1:2, :, :] * 0.587 + dark[:, 2:3, :, :] * 0.114
109
- light = light[:, 0:1, :, :] * 0.299 + light[:, 1:2, :, :] * 0.587 + light[:, 2:3, :, :] * 0.114
110
- noise = torch.abs(dark - light)
111
-
112
- mask = torch.div(light, noise + 0.0001)
113
-
114
- batch_size = mask.shape[0]
115
- height = mask.shape[2]
116
- width = mask.shape[3]
117
- mask_max = torch.max(mask.view(batch_size, -1), dim=1)[0]
118
- mask_max = mask_max.view(batch_size, 1, 1, 1)
119
- mask_max = mask_max.repeat(1, 1, height, width)
120
- mask = mask * 1.0 / (mask_max + 0.0001)
121
-
122
- mask = torch.clamp(mask, min=0, max=1.0)
123
- return mask.float()
124
-
125
- def forward(self, x):
126
-
127
- # AMPLITUDE ENHANCEMENT
128
- #--------------------------------------------------------Frequency Stage---------------------------------------------------
129
- _, _, H, W = x.shape
130
- image_fft = torch.fft.fft2(x, norm='backward')
131
- mag_image = torch.abs(image_fft)
132
- pha_image = torch.angle(image_fft)
133
- curve_amps = self.AmpNet(x)
134
- mag_image = mag_image / (curve_amps + 0.00000001) # * d4
135
- real_image_enhanced = mag_image * torch.cos(pha_image)
136
- imag_image_enhanced = mag_image * torch.sin(pha_image)
137
- img_amp_enhanced = torch.fft.ifft2(torch.complex(real_image_enhanced, imag_image_enhanced), s=(H, W),
138
- norm='backward').real
139
-
140
- x_center = img_amp_enhanced
141
-
142
- rate = 2 ** 3
143
- pad_h = (rate - H % rate) % rate
144
- pad_w = (rate - W % rate) % rate
145
- if pad_h != 0 or pad_w != 0:
146
- x_center = F.pad(x_center, (0, pad_w, 0, pad_h), "reflect")
147
- x = F.pad(x, (0, pad_w, 0, pad_h), "reflect")
148
-
149
- #------------------------------------------Spatial Stage---------------------------------------------------------------------
150
-
151
- L1_fea_1 = self.lrelu(self.conv_first_1(torch.cat((x_center,x),dim=1)))
152
- L1_fea_2 = self.lrelu(self.conv_first_2(L1_fea_1)) # Encoder
153
- L1_fea_3 = self.lrelu(self.conv_first_3(L1_fea_2))
154
-
155
- fea = self.feature_extraction(L1_fea_3)
156
- fea_light = self.recon_trunk_light(fea)
157
-
158
- h_feature = fea.shape[2]
159
- w_feature = fea.shape[3]
160
- mask_image = self.get_mask(x_center) # SNR Map
161
- mask = F.interpolate(mask_image, size=[h_feature, w_feature], mode='nearest') # Resize and Normalize SNR map
162
-
163
- fea_unfold = self.transformer(fea)
164
-
165
- channel = fea.shape[1]
166
- mask = mask.repeat(1, channel, 1, 1)
167
- fea = fea_unfold * (1 - mask) + fea_light * mask # SNR-based Interaction
168
-
169
- out_noise = self.recon_trunk(fea)
170
- out_noise = torch.cat([out_noise, L1_fea_3], dim=1)
171
- out_noise = self.lrelu(self.pixel_shuffle(self.upconv1(out_noise)))
172
- out_noise = torch.cat([out_noise, L1_fea_2], dim=1) # Decoder
173
- out_noise = self.lrelu(self.pixel_shuffle(self.upconv2(out_noise)))
174
- out_noise = torch.cat([out_noise, L1_fea_1], dim=1)
175
- out_noise = self.lrelu(self.HRconv(out_noise))
176
- out_noise = self.conv_last(out_noise)
177
- out_noise = out_noise + x
178
- out_noise = out_noise[:, :, :H, :W]
179
-
180
-
181
- return out_noise, mag_image, x_center, mask_image