juaben commited on
Commit
eb72f83
·
verified ·
1 Parent(s): 95da5be

Model uploaded

Browse files
Files changed (1) hide show
  1. model/flol.py +128 -0
model/flol.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import functools
5
+ import kornia
6
+
7
+ from utils.utils import *
8
+
9
+
10
+ class FLOL(nn.Module):
11
+ def __init__(self, nf=64):
12
+ super(FLOL, self).__init__()
13
+
14
+ # AMPLITUDE ENHANCEMENT
15
+ self.AmpNet = nn.Sequential(
16
+ AmplitudeNet_skip(8),
17
+ nn.Sigmoid()
18
+ )
19
+
20
+ self.nf = nf
21
+ ResidualBlock_noBN_f = functools.partial(ResidualBlock_noBN, nf=nf)
22
+
23
+ self.conv_first_1 = nn.Conv2d(3 * 2, nf, 3, 1, 1, bias=True)
24
+ self.conv_first_2 = nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
25
+ self.conv_first_3 = nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
26
+
27
+ self.feature_extraction = make_layer(ResidualBlock_noBN_f, 1)
28
+ self.recon_trunk = make_layer(ResidualBlock_noBN_f, 1)
29
+
30
+ self.upconv1 = nn.Conv2d(nf*2, nf * 4, 3, 1, 1, bias=True)
31
+ self.upconv2 = nn.Conv2d(nf*2, nf * 4, 3, 1, 1, bias=True)
32
+ self.pixel_shuffle = nn.PixelShuffle(2)
33
+ self.HRconv = nn.Conv2d(nf*2, nf, 3, 1, 1, bias=True)
34
+ self.conv_last = nn.Conv2d(nf, 3, 3, 1, 1, bias=True)
35
+
36
+ self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
37
+ self.transformer = SFNet(nf, n = 4)
38
+ self.recon_trunk_light = make_layer(ResidualBlock_noBN_f, 6)
39
+
40
+ def get_mask(self,dark): # SNR map
41
+
42
+ light = kornia.filters.gaussian_blur2d(dark, (5, 5), (1.5, 1.5))
43
+ dark = dark[:, 0:1, :, :] * 0.299 + dark[:, 1:2, :, :] * 0.587 + dark[:, 2:3, :, :] * 0.114
44
+ light = light[:, 0:1, :, :] * 0.299 + light[:, 1:2, :, :] * 0.587 + light[:, 2:3, :, :] * 0.114
45
+ noise = torch.abs(dark - light)
46
+
47
+ mask = torch.div(light, noise + 0.0001)
48
+
49
+ batch_size = mask.shape[0]
50
+ height = mask.shape[2]
51
+ width = mask.shape[3]
52
+ mask_max = torch.max(mask.view(batch_size, -1), dim=1)[0]
53
+ mask_max = mask_max.view(batch_size, 1, 1, 1)
54
+ mask_max = mask_max.repeat(1, 1, height, width)
55
+ mask = mask * 1.0 / (mask_max + 0.0001)
56
+
57
+ mask = torch.clamp(mask, min=0, max=1.0)
58
+ return mask.float()
59
+
60
+ def forward(self, x, side=False):
61
+
62
+ # AMPLITUDE ENHANCEMENT
63
+ #--------------------------------------------------------Frequency Stage---------------------------------------------------
64
+ _, _, H, W = x.shape
65
+ image_fft = torch.fft.fft2(x, norm='backward')
66
+ mag_image = torch.abs(image_fft)
67
+ pha_image = torch.angle(image_fft)
68
+
69
+ curve_amps = self.AmpNet(x)
70
+
71
+ mag_image = mag_image / (curve_amps + 0.00000001) # * d4
72
+ real_image_enhanced = mag_image * torch.cos(pha_image)
73
+ imag_image_enhanced = mag_image * torch.sin(pha_image)
74
+ img_amp_enhanced = torch.fft.ifft2(torch.complex(real_image_enhanced, imag_image_enhanced), s=(H, W),
75
+ norm='backward').real
76
+
77
+ x_center = img_amp_enhanced
78
+
79
+ rate = 2 ** 3
80
+ pad_h = (rate - H % rate) % rate
81
+ pad_w = (rate - W % rate) % rate
82
+ if pad_h != 0 or pad_w != 0:
83
+ x_center = F.pad(x_center, (0, pad_w, 0, pad_h), "reflect")
84
+ x = F.pad(x, (0, pad_w, 0, pad_h), "reflect")
85
+
86
+ #------------------------------------------Spatial Stage---------------------------------------------------------------------
87
+
88
+ L1_fea_1 = self.lrelu(self.conv_first_1(torch.cat((x_center,x),dim=1)))
89
+ L1_fea_2 = self.lrelu(self.conv_first_2(L1_fea_1)) # Encoder
90
+ L1_fea_3 = self.lrelu(self.conv_first_3(L1_fea_2))
91
+
92
+ fea = self.feature_extraction(L1_fea_3)
93
+ fea_light = self.recon_trunk_light(fea)
94
+
95
+ h_feature = fea.shape[2]
96
+ w_feature = fea.shape[3]
97
+ mask_image = self.get_mask(x_center) # SNR Map
98
+ mask = F.interpolate(mask_image, size=[h_feature, w_feature], mode='nearest') # Resize and Normalize SNR map
99
+
100
+ fea_unfold = self.transformer(fea)
101
+
102
+ channel = fea.shape[1]
103
+ mask = mask.repeat(1, channel, 1, 1)
104
+ fea = fea_unfold * (1 - mask) + fea_light * mask # SNR-based Interaction
105
+
106
+ out_noise = self.recon_trunk(fea)
107
+ out_noise = torch.cat([out_noise, L1_fea_3], dim=1)
108
+ out_noise = self.lrelu(self.pixel_shuffle(self.upconv1(out_noise)))
109
+ out_noise = torch.cat([out_noise, L1_fea_2], dim=1) # Decoder
110
+ out_noise = self.lrelu(self.pixel_shuffle(self.upconv2(out_noise)))
111
+ out_noise = torch.cat([out_noise, L1_fea_1], dim=1)
112
+ out_noise = self.lrelu(self.HRconv(out_noise))
113
+ out_noise = self.conv_last(out_noise)
114
+ out_noise = out_noise + x
115
+ out_noise = out_noise[:, :, :H, :W]
116
+
117
+ if side:
118
+ return out_noise, x_center #, mag_image, x_center, mask_image
119
+ else:
120
+ return out_noise
121
+
122
+
123
+ ##############################################################################
124
+
125
+ def create_model():
126
+
127
+ net = FLOL(nf=16)
128
+ return net