danifei commited on
Commit
d960e2d
·
verified ·
1 Parent(s): d4baede

add archs folder

Browse files
archs/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .nafnet_utils.arch_model import NAFNet
2
+ from .network import Network
3
+
4
+ __all__ = ['NAFNet','Network']
archs/arch_util.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.init as init
4
+ import torch.nn.functional as F
5
+
6
+ try:
7
+ from .nafnet_utils.arch_util import LayerNorm2d
8
+ from .nafnet_utils.arch_model import SimpleGate
9
+ except:
10
+ from nafnet_utils.arch_util import LayerNorm2d
11
+ from nafnet_utils.arch_model import SimpleGate
12
+
13
+ '''
14
+ https://github.com/wangchx67/FourLLIE.git
15
+ '''
16
+
17
+ def initialize_weights(net_l, scale=1):
18
+ if not isinstance(net_l, list):
19
+ net_l = [net_l]
20
+ for net in net_l:
21
+ for m in net.modules():
22
+ if isinstance(m, nn.Conv2d):
23
+ init.kaiming_normal_(m.weight, a=0, mode='fan_in')
24
+ m.weight.data *= scale # for residual block
25
+ if m.bias is not None:
26
+ m.bias.data.zero_()
27
+ elif isinstance(m, nn.Linear):
28
+ init.kaiming_normal_(m.weight, a=0, mode='fan_in')
29
+ m.weight.data *= scale
30
+ if m.bias is not None:
31
+ m.bias.data.zero_()
32
+ elif isinstance(m, nn.BatchNorm2d):
33
+ init.constant_(m.weight, 1)
34
+ init.constant_(m.bias.data, 0.0)
35
+
36
+
37
+ def make_layer(block, n_layers):
38
+ layers = []
39
+ for _ in range(n_layers):
40
+ layers.append(block())
41
+ return nn.Sequential(*layers)
42
+
43
+
44
+ class ResidualBlock_noBN(nn.Module):
45
+ '''Residual block w/o BN
46
+ ---Conv-ReLU-Conv-+-
47
+ |________________|
48
+ '''
49
+
50
+ def __init__(self, nf=64):
51
+ super(ResidualBlock_noBN, self).__init__()
52
+ self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
53
+ self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
54
+
55
+ # initialization
56
+ initialize_weights([self.conv1, self.conv2], 0.1)
57
+
58
+ def forward(self, x):
59
+ identity = x
60
+ out = F.relu(self.conv1(x), inplace=True)
61
+ out = self.conv2(out)
62
+ return identity + out
63
+
64
+ class SpaBlock(nn.Module):
65
+ def __init__(self, nc):
66
+ super(SpaBlock, self).__init__()
67
+ self.block = nn.Sequential(
68
+ nn.Conv2d(nc,nc,3,1,1),
69
+ nn.LeakyReLU(0.1,inplace=True),
70
+ nn.Conv2d(nc, nc, 3, 1, 1),
71
+ nn.LeakyReLU(0.1, inplace=True))
72
+
73
+ def forward(self, x):
74
+ return x+self.block(x)
75
+
76
+ class FreBlock(nn.Module):
77
+ def __init__(self, nc):
78
+ super(FreBlock, self).__init__()
79
+ self.fpre = nn.Conv2d(nc, nc, 1, 1, 0)
80
+ self.process1 = nn.Sequential(
81
+ nn.Conv2d(nc, nc, 1, 1, 0),
82
+ nn.LeakyReLU(0.1, inplace=True),
83
+ nn.Conv2d(nc, nc, 1, 1, 0))
84
+ self.process2 = nn.Sequential(
85
+ nn.Conv2d(nc, nc, 1, 1, 0),
86
+ nn.LeakyReLU(0.1, inplace=True),
87
+ nn.Conv2d(nc, nc, 1, 1, 0))
88
+
89
+ def forward(self, x):
90
+ _, _, H, W = x.shape
91
+ x_freq = torch.fft.rfft2(self.fpre(x), norm='backward')
92
+ mag = torch.abs(x_freq)
93
+ pha = torch.angle(x_freq)
94
+ mag = self.process1(mag)
95
+ pha = self.process2(pha)
96
+ real = mag * torch.cos(pha)
97
+ imag = mag * torch.sin(pha)
98
+ x_out = torch.complex(real, imag)
99
+ x_out = torch.fft.irfft2(x_out, s=(H, W), norm='backward')
100
+
101
+ return x_out+x
102
+
103
+ class ProcessBlock(nn.Module):
104
+ def __init__(self, in_nc, spatial = True):
105
+ super(ProcessBlock,self).__init__()
106
+ self.spatial = spatial
107
+ self.spatial_process = SpaBlock(in_nc) if spatial else nn.Identity()
108
+ self.frequency_process = FreBlock(in_nc)
109
+ self.cat = nn.Conv2d(2*in_nc,in_nc,1,1,0) if spatial else nn.Conv2d(in_nc,in_nc,1,1,0)
110
+
111
+ def forward(self, x):
112
+ xori = x
113
+ x_freq = self.frequency_process(x)
114
+ x_spatial = self.spatial_process(x)
115
+ xcat = torch.cat([x_spatial,x_freq],1)
116
+ x_out = self.cat(xcat) if self.spatial else self.cat(x_freq)
117
+
118
+ return x_out+xori
119
+
120
+ class Attention_Light(nn.Module):
121
+
122
+ def __init__(self, img_channels = 3, width = 16, spatial = False):
123
+ super(Attention_Light, self).__init__()
124
+ self.block = nn.Sequential(
125
+ nn.Conv2d(in_channels = img_channels, out_channels = width//2, kernel_size = 1, padding = 0, stride = 1, groups = 1, bias = True),
126
+ ProcessBlock(in_nc = width //2, spatial = spatial),
127
+ nn.Conv2d(in_channels = width//2, out_channels = width, kernel_size = 1, padding = 0, stride = 1, groups = 1, bias = True),
128
+ ProcessBlock(in_nc = width, spatial = spatial),
129
+ nn.Conv2d(in_channels = width, out_channels = width, kernel_size = 1, padding = 0, stride = 1, groups = 1, bias = True),
130
+ ProcessBlock(in_nc=width, spatial = spatial),
131
+ nn.Sigmoid()
132
+ )
133
+ def forward(self, input):
134
+ return self.block(input)
135
+
136
+ class Branch(nn.Module):
137
+ '''
138
+ Branch that lasts lonly the dilated convolutions
139
+ '''
140
+ def __init__(self, c, DW_Expand, dilation = 1, extra_depth_wise = False):
141
+ super().__init__()
142
+ self.dw_channel = DW_Expand * c
143
+ self.branch = nn.Sequential(
144
+ nn.Conv2d(c, c, kernel_size=3, padding=1, stride=1, groups=c, bias=True, dilation=1) if extra_depth_wise else nn.Identity(), #optional extra dw
145
+ nn.Conv2d(in_channels=c, out_channels=self.dw_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True, dilation = 1),
146
+ nn.Conv2d(in_channels=self.dw_channel, out_channels=self.dw_channel, kernel_size=3, padding=dilation, stride=1, groups=self.dw_channel,
147
+ bias=True, dilation = dilation) # the dconv
148
+ )
149
+ def forward(self, input):
150
+ return self.branch(input)
151
+
152
+ class EBlock(nn.Module):
153
+ '''
154
+ Change this block using Branch
155
+ '''
156
+
157
+ def __init__(self, c, DW_Expand=2, FFN_Expand=2, dilations = [1], extra_depth_wise = False):
158
+ super().__init__()
159
+ #we define the 2 branches
160
+
161
+ self.branches = nn.ModuleList()
162
+ for dilation in dilations:
163
+ self.branches.append(Branch(c, DW_Expand, dilation = dilation, extra_depth_wise=extra_depth_wise))
164
+
165
+ assert len(dilations) == len(self.branches)
166
+ self.dw_channel = DW_Expand * c
167
+ self.sca = nn.Sequential(
168
+ nn.AdaptiveAvgPool2d(1),
169
+ nn.Conv2d(in_channels=self.dw_channel // 2, out_channels=self.dw_channel // 2, kernel_size=1, padding=0, stride=1,
170
+ groups=1, bias=True, dilation = 1),
171
+ )
172
+ self.sg1 = SimpleGate()
173
+ self.sg2 = SimpleGate()
174
+ self.conv3 = nn.Conv2d(in_channels=self.dw_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True, dilation = 1)
175
+ ffn_channel = FFN_Expand * c
176
+ self.conv4 = nn.Conv2d(in_channels=c, out_channels=ffn_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
177
+ self.conv5 = nn.Conv2d(in_channels=ffn_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
178
+
179
+ self.norm1 = LayerNorm2d(c)
180
+ self.norm2 = LayerNorm2d(c)
181
+
182
+ self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
183
+ self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
184
+
185
+ def forward(self, inp):
186
+
187
+ y = inp
188
+ x = self.norm1(inp)
189
+ z = 0
190
+ for branch in self.branches:
191
+ z += branch(x)
192
+
193
+ z = self.sg1(z)
194
+ x = self.sca(z) * z
195
+ x = self.conv3(x)
196
+ y = inp + self.beta * x
197
+ #second step
198
+ x = self.conv4(self.norm2(y)) # size [B, 2*C, H, W]
199
+ x = self.sg2(x) # size [B, C, H, W]
200
+ x = self.conv5(x) # size [B, C, H, W]
201
+
202
+ return y + x * self.gamma
203
+
204
+ #----------------------------------------------------------------------------------------------
205
+ if __name__ == '__main__':
206
+
207
+ img_channel = 3
208
+ width = 32
209
+
210
+ enc_blks = [1, 2, 3]
211
+ middle_blk_num = 3
212
+ dec_blks = [3, 1, 1]
213
+ dilations = [1, 4, 9]
214
+ extra_depth_wise = False
215
+
216
+ # net = NAFNet(img_channel=img_channel, width=width, middle_blk_num=middle_blk_num,
217
+ # enc_blk_nums=enc_blks, dec_blk_nums=dec_blks)
218
+ net = EBlock(c = img_channel,
219
+ dilations = dilations,
220
+ extra_depth_wise=extra_depth_wise)
221
+
222
+ inp_shape = (3, 256, 256)
223
+
224
+ from ptflops import get_model_complexity_info
225
+
226
+ macs, params = get_model_complexity_info(net, inp_shape, verbose=False, print_per_layer_stat=True)
227
+
228
+
229
+ print(macs, params)
archs/arch_util_freq.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.init as init
4
+ import torch.nn.functional as F
5
+
6
+ try:
7
+ from .nafnet_utils.arch_util import LayerNorm2d
8
+ from .nafnet_utils.arch_model import SimpleGate
9
+ except:
10
+ from nafnet_utils.arch_util import LayerNorm2d
11
+ from nafnet_utils.arch_model import SimpleGate
12
+
13
+ '''
14
+ https://github.com/wangchx67/FourLLIE.git
15
+ '''
16
+
17
+ def initialize_weights(net_l, scale=1):
18
+ if not isinstance(net_l, list):
19
+ net_l = [net_l]
20
+ for net in net_l:
21
+ for m in net.modules():
22
+ if isinstance(m, nn.Conv2d):
23
+ init.kaiming_normal_(m.weight, a=0, mode='fan_in')
24
+ m.weight.data *= scale # for residual block
25
+ if m.bias is not None:
26
+ m.bias.data.zero_()
27
+ elif isinstance(m, nn.Linear):
28
+ init.kaiming_normal_(m.weight, a=0, mode='fan_in')
29
+ m.weight.data *= scale
30
+ if m.bias is not None:
31
+ m.bias.data.zero_()
32
+ elif isinstance(m, nn.BatchNorm2d):
33
+ init.constant_(m.weight, 1)
34
+ init.constant_(m.bias.data, 0.0)
35
+
36
+ class FreNAFBlock(nn.Module):
37
+
38
+ def __init__(self, nc, expand = 2):
39
+ super(FreNAFBlock, self).__init__()
40
+ self.process1 = nn.Sequential(
41
+ nn.Conv2d(nc, expand * nc, 1, 1, 0),
42
+ nn.LeakyReLU(0.1, inplace=True),
43
+ nn.Conv2d(expand * nc, nc, 1, 1, 0))
44
+
45
+ def forward(self, x):
46
+ _, _, H, W = x.shape
47
+ x_freq = torch.fft.rfft2(x, norm='backward')
48
+ mag = torch.abs(x_freq)
49
+ pha = torch.angle(x_freq)
50
+ mag = self.process1(mag)
51
+ real = mag * torch.cos(pha)
52
+ imag = mag * torch.sin(pha)
53
+ x_out = torch.complex(real, imag)
54
+ x_out = torch.fft.irfft2(x_out, s=(H, W), norm='backward')
55
+ return x_out
56
+
57
+ # ------------------------------------------------------------------------------------------------
58
+
59
+ class Branch(nn.Module):
60
+ '''
61
+ Branch that lasts lonly the dilated convolutions
62
+ '''
63
+ def __init__(self, c, DW_Expand, dilation = 1, extra_depth_wise = False):
64
+ super().__init__()
65
+ self.dw_channel = DW_Expand * c
66
+ self.branch = nn.Sequential(
67
+ nn.Conv2d(c, c, kernel_size=3, padding=1, stride=1, groups=c, bias=True, dilation=1) if extra_depth_wise else nn.Identity(), #optional extra dw
68
+ nn.Conv2d(in_channels=c, out_channels=self.dw_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True, dilation = 1),
69
+ nn.Conv2d(in_channels=self.dw_channel, out_channels=self.dw_channel, kernel_size=3, padding=dilation, stride=1, groups=self.dw_channel,
70
+ bias=True, dilation = dilation) # the dconv
71
+ )
72
+ def forward(self, input):
73
+ return self.branch(input)
74
+
75
+ class EBlock_freq(nn.Module):
76
+ '''
77
+ Change this block using Branch
78
+ '''
79
+
80
+ def __init__(self, c, DW_Expand=2, dilations = [1], extra_depth_wise = False):
81
+ super().__init__()
82
+ #we define the 2 branches
83
+
84
+ self.branches = nn.ModuleList()
85
+ for dilation in dilations:
86
+ self.branches.append(Branch(c, DW_Expand, dilation = dilation, extra_depth_wise=extra_depth_wise))
87
+
88
+ assert len(dilations) == len(self.branches)
89
+ self.dw_channel = DW_Expand * c
90
+ self.sca = nn.Sequential(
91
+ nn.AdaptiveAvgPool2d(1),
92
+ nn.Conv2d(in_channels=self.dw_channel // 2, out_channels=self.dw_channel // 2, kernel_size=1, padding=0, stride=1,
93
+ groups=1, bias=True, dilation = 1),
94
+ )
95
+ self.sg1 = SimpleGate()
96
+ self.conv3 = nn.Conv2d(in_channels=self.dw_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True, dilation = 1)
97
+ # second step
98
+
99
+ self.norm1 = LayerNorm2d(c)
100
+ self.norm2 = LayerNorm2d(c)
101
+ self.freq = FreNAFBlock(nc = c, expand=2)
102
+ self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
103
+ self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
104
+
105
+ def forward(self, inp):
106
+
107
+ y = inp
108
+ x = self.norm1(inp)
109
+ z = 0
110
+ for branch in self.branches:
111
+ z += branch(x)
112
+
113
+ z = self.sg1(z)
114
+ x = self.sca(z) * z
115
+ x = self.conv3(x)
116
+ y = inp + self.beta * x
117
+ #second step
118
+ x_step2 = self.norm2(y) # size [B, 2*C, H, W]
119
+ x_freq = self.freq(x_step2) # size [B, C, H, W]
120
+ x = y * x_freq
121
+
122
+ return y + x * self.gamma
123
+
124
+ #----------------------------------------------------------------------------------------------
125
+ if __name__ == '__main__':
126
+
127
+ img_channel = 128
128
+ width = 32
129
+
130
+ enc_blks = [1, 2, 3]
131
+ middle_blk_num = 3
132
+ dec_blks = [3, 1, 1]
133
+ dilations = [1, 4, 9]
134
+ extra_depth_wise = True
135
+
136
+ # net = NAFNet(img_channel=img_channel, width=width, middle_blk_num=middle_blk_num,
137
+ # enc_blk_nums=enc_blks, dec_blk_nums=dec_blks)
138
+ net = EBlock(c = img_channel,
139
+ dilations = dilations,
140
+ extra_depth_wise=extra_depth_wise)
141
+
142
+ inp_shape = (128, 32, 32)
143
+
144
+ from ptflops import get_model_complexity_info
145
+
146
+ macs, params = get_model_complexity_info(net, inp_shape, verbose=False, print_per_layer_stat=False)
147
+
148
+
149
+ print(macs, params)
archs/nafnet_utils/arch_model.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ try:
6
+ from .arch_util import LayerNorm2d
7
+ from .local_arch import Local_Base
8
+ except:
9
+ from arch_util import LayerNorm2d
10
+ from local_arch import Local_Base
11
+
12
+
13
+ class SimpleGate(nn.Module):
14
+ def forward(self, x):
15
+ x1, x2 = x.chunk(2, dim=1)
16
+ return x1 * x2
17
+
18
+ class NAFBlock(nn.Module):
19
+ def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.):
20
+ super().__init__()
21
+ dw_channel = c * DW_Expand
22
+ self.conv1 = nn.Conv2d(in_channels=c, out_channels=dw_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
23
+ self.conv2 = nn.Conv2d(in_channels=dw_channel, out_channels=dw_channel, kernel_size=3, padding=1, stride=1, groups=dw_channel,
24
+ bias=True) # the dconv
25
+ self.conv3 = nn.Conv2d(in_channels=dw_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
26
+
27
+ # Simplified Channel Attention
28
+ self.sca = nn.Sequential(
29
+ nn.AdaptiveAvgPool2d(1),
30
+ nn.Conv2d(in_channels=dw_channel // 2, out_channels=dw_channel // 2, kernel_size=1, padding=0, stride=1,
31
+ groups=1, bias=True),
32
+ )
33
+
34
+ # SimpleGate
35
+ self.sg = SimpleGate()
36
+
37
+ ffn_channel = FFN_Expand * c
38
+ self.conv4 = nn.Conv2d(in_channels=c, out_channels=ffn_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
39
+ self.conv5 = nn.Conv2d(in_channels=ffn_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
40
+
41
+ self.norm1 = LayerNorm2d(c)
42
+ self.norm2 = LayerNorm2d(c)
43
+
44
+ self.dropout1 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity()
45
+ self.dropout2 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity()
46
+
47
+ self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
48
+ self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
49
+
50
+ def forward(self, inp):
51
+ x = inp # size [B, C, H, W]
52
+
53
+ x = self.norm1(x) # size [B, C, H, W]
54
+
55
+ x = self.conv1(x) # size [B, 2*C, H, W]
56
+ x = self.conv2(x) # size [B, 2*C, H, W]
57
+ x = self.sg(x) # size [B, C, H, W]
58
+ x = x * self.sca(x) # size [B, C, H, W]
59
+ x = self.conv3(x) # size [B, C, H, W]
60
+
61
+ x = self.dropout1(x)
62
+
63
+ y = inp + x * self.beta # size [B, C, H, W]
64
+
65
+ x = self.conv4(self.norm2(y)) # size [B, 2*C, H, W]
66
+ x = self.sg(x) # size [B, C, H, W]
67
+ x = self.conv5(x) # size [B, C, H, W]
68
+
69
+ x = self.dropout2(x)
70
+
71
+ return y + x * self.gamma
72
+
73
+
74
+ class NAFNet(nn.Module):
75
+
76
+ def __init__(self, img_channel=3, width=16, middle_blk_num=1, enc_blk_nums=[], dec_blk_nums=[]):
77
+ super().__init__()
78
+
79
+ self.intro = nn.Conv2d(in_channels=img_channel, out_channels=width, kernel_size=3, padding=1, stride=1, groups=1,
80
+ bias=True)
81
+ self.ending = nn.Conv2d(in_channels=width, out_channels=img_channel, kernel_size=3, padding=1, stride=1, groups=1,
82
+ bias=True)
83
+
84
+ self.encoders = nn.ModuleList()
85
+ self.decoders = nn.ModuleList()
86
+ self.middle_blks = nn.ModuleList()
87
+ self.ups = nn.ModuleList()
88
+ self.downs = nn.ModuleList()
89
+
90
+ chan = width
91
+ for num in enc_blk_nums:
92
+ self.encoders.append(
93
+ nn.Sequential(
94
+ *[NAFBlock(chan) for _ in range(num)]
95
+ )
96
+ )
97
+ self.downs.append(
98
+ nn.Conv2d(chan, 2*chan, 2, 2)
99
+ )
100
+ chan = chan * 2
101
+
102
+ self.middle_blks = \
103
+ nn.Sequential(
104
+ *[NAFBlock(chan) for _ in range(middle_blk_num)]
105
+ )
106
+
107
+ for num in dec_blk_nums:
108
+ self.ups.append(
109
+ nn.Sequential(
110
+ nn.Conv2d(chan, chan * 2, 1, bias=False),
111
+ nn.PixelShuffle(2)
112
+ )
113
+ )
114
+ chan = chan // 2
115
+ self.decoders.append(
116
+ nn.Sequential(
117
+ *[NAFBlock(chan) for _ in range(num)]
118
+ )
119
+ )
120
+
121
+ self.padder_size = 2 ** len(self.encoders)
122
+
123
+ def forward(self, inp):
124
+ B, C, H, W = inp.shape
125
+ inp = self.check_image_size(inp)
126
+
127
+ x = self.intro(inp)
128
+
129
+ encs = []
130
+
131
+ for encoder, down in zip(self.encoders, self.downs):
132
+ x = encoder(x)
133
+ encs.append(x)
134
+ x = down(x)
135
+
136
+ x = self.middle_blks(x)
137
+
138
+ for decoder, up, enc_skip in zip(self.decoders, self.ups, encs[::-1]):
139
+ x = up(x)
140
+ x = x + enc_skip
141
+ x = decoder(x)
142
+
143
+ x = self.ending(x)
144
+ x = x + inp
145
+
146
+ return x[:, :, :H, :W]
147
+
148
+ def check_image_size(self, x):
149
+ _, _, h, w = x.size()
150
+ mod_pad_h = (self.padder_size - h % self.padder_size) % self.padder_size
151
+ mod_pad_w = (self.padder_size - w % self.padder_size) % self.padder_size
152
+ x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), value = 0)
153
+ return x
154
+
155
+
156
+ class NAFNetLocal(Local_Base, NAFNet):
157
+ def __init__(self, *args, train_size=(1, 3, 256, 256), fast_imp=False, **kwargs):
158
+ Local_Base.__init__(self)
159
+ NAFNet.__init__(self, *args, **kwargs)
160
+
161
+ N, C, H, W = train_size
162
+ base_size = (int(H * 1.5), int(W * 1.5))
163
+
164
+ self.eval()
165
+ with torch.no_grad():
166
+ self.convert(base_size=base_size, train_size=train_size, fast_imp=fast_imp)
167
+
168
+ class FreBlock(nn.Module):
169
+ def __init__(self, nc):
170
+ super(FreBlock, self).__init__()
171
+ self.fpre = nn.Conv2d(nc, nc, 1, 1, 0)
172
+ self.process1 = nn.Sequential(
173
+ nn.Conv2d(nc, nc, 1, 1, 0),
174
+ nn.LeakyReLU(0.1, inplace=True),
175
+ nn.Conv2d(nc, nc, 1, 1, 0))
176
+ self.process2 = nn.Sequential(
177
+ nn.Conv2d(nc, nc, 1, 1, 0),
178
+ nn.LeakyReLU(0.1, inplace=True),
179
+ nn.Conv2d(nc, nc, 1, 1, 0))
180
+
181
+ def forward(self, x):
182
+ _, _, H, W = x.shape
183
+ x_freq = torch.fft.rfft2(self.fpre(x), norm='backward')
184
+ mag = torch.abs(x_freq)
185
+ pha = torch.angle(x_freq)
186
+ mag = self.process1(mag)
187
+ pha = self.process2(pha)
188
+ real = mag * torch.cos(pha)
189
+ imag = mag * torch.sin(pha)
190
+ x_out = torch.complex(real, imag)
191
+ x_out = torch.fft.irfft2(x_out, s=(H, W), norm='backward')
192
+
193
+ return x_out+x
194
+
195
+ # class FPA(nn.Module):
196
+
197
+ # def __init__(self,nc):
198
+ # super(FPA, self).__init__()
199
+ # self.process_mag = nn.Sequential(
200
+ # nn.Conv2d(nc, nc, 1, 1, 0),
201
+ # nn.LeakyReLU(0.1, inplace=True),
202
+ # nn.Conv2d(nc, nc, 1, 1, 0),
203
+ # nn.LeakyReLU(0.1, inplace=True),
204
+ # nn.Conv2d(nc, nc, 1, 1, 0))
205
+ # self.process_pha = nn.Sequential(
206
+ # nn.Conv2d(nc, nc, 1, 1, 0),
207
+ # nn.LeakyReLU(0.1, inplace=True),
208
+ # nn.Conv2d(nc, nc, 1, 1, 0),
209
+ # nn.LeakyReLU(0.1, inplace=True),
210
+ # nn.Conv2d(nc, nc, 1, 1, 0))
211
+
212
+ # def forward(self, input):
213
+ # _, _, H, W = input.shape
214
+ # x_freq = torch.fft.rfft2(input, norm='backward')
215
+ # mag = torch.abs(x_freq)
216
+ # pha = torch.angle(x_freq)
217
+ # mag = mag + self.process_mag(mag)
218
+ # pha = pha + self.process_pha(pha)
219
+ # real = mag * torch.cos(pha)
220
+ # imag = mag * torch.sin(pha)
221
+ # x_out = torch.complex(real, imag)
222
+ # x_out = torch.fft.irfft2(x_out, s=(H, W), norm='backward')
223
+ # return x_out
224
+
225
+
226
+ # class FBlock(nn.Module):
227
+
228
+ # def __init__(self, c, DW_Expand=2, FFN_Expand=2, dilations = [1], extra_depth_wise = False):
229
+ # super(FBlock, self).__init__()
230
+
231
+ # self.branches = nn.ModuleList()
232
+ # for dilation in dilations:
233
+ # self.branches.append(Branch_v2(c, DW_Expand, dilation = dilation, extra_depth_wise=extra_depth_wise))
234
+
235
+ # assert len(dilations) == len(self.branches)
236
+ # self.dw_channel = DW_Expand * c
237
+ # self.sca = nn.Sequential(
238
+ # nn.AdaptiveAvgPool2d(1),
239
+ # nn.Conv2d(in_channels=self.dw_channel // 2, out_channels=self.dw_channel // 2, kernel_size=1, padding=0, stride=1,
240
+ # groups=1, bias=True, dilation = 1),
241
+ # )
242
+ # self.sg1 = SimpleGate()
243
+ # self.conv3 = nn.Conv2d(in_channels=self.dw_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True, dilation = 1)
244
+
245
+
246
+
247
+ # self.norm1 = LayerNorm2d(c)
248
+ # self.norm2 = LayerNorm2d(c)
249
+
250
+ # ffn_channel = FFN_Expand * c
251
+ # self.conv_fpr_intro = nn.Conv2d(in_channels=c, out_channels=ffn_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True, dilation = 1)
252
+ # self.fpa = FPA(nc = ffn_channel)
253
+ # self.conv_fpr_out = nn.Conv2d(in_channels=ffn_channel, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True, dilation = 1)
254
+
255
+ # self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
256
+ # self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
257
+
258
+ def forward(self, inp):
259
+
260
+ y = inp
261
+ x = self.norm1(inp)
262
+ z=0
263
+ for branch in self.branches:
264
+ z += branch(x)
265
+
266
+ z = self.sg1(z)
267
+ x = self.sca(z) * z
268
+ x = self.conv3(x)
269
+ y = inp + self.beta * x
270
+ #Frequency pixel residue
271
+ x = self.conv_fpr_intro(self.norm2(y)) # size [B, C, H, W]
272
+ x = self.fpa(x) # size [B, C, H, W]
273
+ x = self.conv_fpr_out(x)
274
+
275
+ return y + x * self.gamma
276
+
277
+ if __name__ == '__main__':
278
+
279
+ img_channel = 3
280
+ width = 32
281
+
282
+ enc_blks = [1, 2, 3]
283
+ middle_blk_num = 3
284
+ dec_blks = [3, 1, 1]
285
+ dilations = [1, 4, 9]
286
+ extra_depth_wise = False
287
+
288
+ # net = NAFNet(img_channel=img_channel, width=width, middle_blk_num=middle_blk_num,
289
+ # enc_blk_nums=enc_blks, dec_blk_nums=dec_blks)
290
+ net = EBlock_v2(c = img_channel,
291
+ dilations = dilations,
292
+ extra_depth_wise=extra_depth_wise)
293
+
294
+ inp_shape = (3, 256, 256)
295
+
296
+ from ptflops import get_model_complexity_info
297
+
298
+ macs, params = get_model_complexity_info(net, inp_shape, verbose=False, print_per_layer_stat=True)
299
+
300
+
301
+ print(macs, params)
archs/nafnet_utils/arch_util.py ADDED
@@ -0,0 +1,343 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn as nn
4
+ from torch.nn import functional as F
5
+ from torch.nn import init as init
6
+ from torch.nn.modules.batchnorm import _BatchNorm
7
+
8
+
9
+ # try:
10
+ # from basicsr.models.ops.dcn import (ModulatedDeformConvPack,
11
+ # modulated_deform_conv)
12
+ # except ImportError:
13
+ # # print('Cannot import dcn. Ignore this warning if dcn is not used. '
14
+ # # 'Otherwise install BasicSR with compiling dcn.')
15
+ #
16
+
17
+ @torch.no_grad()
18
+ def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs):
19
+ """Initialize network weights.
20
+
21
+ Args:
22
+ module_list (list[nn.Module] | nn.Module): Modules to be initialized.
23
+ scale (float): Scale initialized weights, especially for residual
24
+ blocks. Default: 1.
25
+ bias_fill (float): The value to fill bias. Default: 0
26
+ kwargs (dict): Other arguments for initialization function.
27
+ """
28
+ if not isinstance(module_list, list):
29
+ module_list = [module_list]
30
+ for module in module_list:
31
+ for m in module.modules():
32
+ if isinstance(m, nn.Conv2d):
33
+ init.kaiming_normal_(m.weight, **kwargs)
34
+ m.weight.data *= scale
35
+ if m.bias is not None:
36
+ m.bias.data.fill_(bias_fill)
37
+ elif isinstance(m, nn.Linear):
38
+ init.kaiming_normal_(m.weight, **kwargs)
39
+ m.weight.data *= scale
40
+ if m.bias is not None:
41
+ m.bias.data.fill_(bias_fill)
42
+ elif isinstance(m, _BatchNorm):
43
+ init.constant_(m.weight, 1)
44
+ if m.bias is not None:
45
+ m.bias.data.fill_(bias_fill)
46
+
47
+
48
+ def make_layer(basic_block, num_basic_block, **kwarg):
49
+ """Make layers by stacking the same blocks.
50
+
51
+ Args:
52
+ basic_block (nn.module): nn.module class for basic block.
53
+ num_basic_block (int): number of blocks.
54
+
55
+ Returns:
56
+ nn.Sequential: Stacked blocks in nn.Sequential.
57
+ """
58
+ layers = []
59
+ for _ in range(num_basic_block):
60
+ layers.append(basic_block(**kwarg))
61
+ return nn.Sequential(*layers)
62
+
63
+
64
+ class ResidualBlockNoBN(nn.Module):
65
+ """Residual block without BN.
66
+
67
+ It has a style of:
68
+ ---Conv-ReLU-Conv-+-
69
+ |________________|
70
+
71
+ Args:
72
+ num_feat (int): Channel number of intermediate features.
73
+ Default: 64.
74
+ res_scale (float): Residual scale. Default: 1.
75
+ pytorch_init (bool): If set to True, use pytorch default init,
76
+ otherwise, use default_init_weights. Default: False.
77
+ """
78
+
79
+ def __init__(self, num_feat=64, res_scale=1, pytorch_init=False):
80
+ super(ResidualBlockNoBN, self).__init__()
81
+ self.res_scale = res_scale
82
+ self.conv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
83
+ self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
84
+ self.relu = nn.ReLU(inplace=True)
85
+
86
+ if not pytorch_init:
87
+ default_init_weights([self.conv1, self.conv2], 0.1)
88
+
89
+ def forward(self, x):
90
+ identity = x
91
+ out = self.conv2(self.relu(self.conv1(x)))
92
+ return identity + out * self.res_scale
93
+
94
+
95
+ class Upsample(nn.Sequential):
96
+ """Upsample module.
97
+
98
+ Args:
99
+ scale (int): Scale factor. Supported scales: 2^n and 3.
100
+ num_feat (int): Channel number of intermediate features.
101
+ """
102
+
103
+ def __init__(self, scale, num_feat):
104
+ m = []
105
+ if (scale & (scale - 1)) == 0: # scale = 2^n
106
+ for _ in range(int(math.log(scale, 2))):
107
+ m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
108
+ m.append(nn.PixelShuffle(2))
109
+ elif scale == 3:
110
+ m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
111
+ m.append(nn.PixelShuffle(3))
112
+ else:
113
+ raise ValueError(f'scale {scale} is not supported. '
114
+ 'Supported scales: 2^n and 3.')
115
+ super(Upsample, self).__init__(*m)
116
+
117
+
118
+ def flow_warp(x,
119
+ flow,
120
+ interp_mode='bilinear',
121
+ padding_mode='zeros',
122
+ align_corners=True):
123
+ """Warp an image or feature map with optical flow.
124
+
125
+ Args:
126
+ x (Tensor): Tensor with size (n, c, h, w).
127
+ flow (Tensor): Tensor with size (n, h, w, 2), normal value.
128
+ interp_mode (str): 'nearest' or 'bilinear'. Default: 'bilinear'.
129
+ padding_mode (str): 'zeros' or 'border' or 'reflection'.
130
+ Default: 'zeros'.
131
+ align_corners (bool): Before pytorch 1.3, the default value is
132
+ align_corners=True. After pytorch 1.3, the default value is
133
+ align_corners=False. Here, we use the True as default.
134
+
135
+ Returns:
136
+ Tensor: Warped image or feature map.
137
+ """
138
+ assert x.size()[-2:] == flow.size()[1:3]
139
+ _, _, h, w = x.size()
140
+ # create mesh grid
141
+ grid_y, grid_x = torch.meshgrid(
142
+ torch.arange(0, h).type_as(x),
143
+ torch.arange(0, w).type_as(x))
144
+ grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2
145
+ grid.requires_grad = False
146
+
147
+ vgrid = grid + flow
148
+ # scale grid to [-1,1]
149
+ vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0
150
+ vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0
151
+ vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3)
152
+ output = F.grid_sample(
153
+ x,
154
+ vgrid_scaled,
155
+ mode=interp_mode,
156
+ padding_mode=padding_mode,
157
+ align_corners=align_corners)
158
+
159
+ # TODO, what if align_corners=False
160
+ return output
161
+
162
+
163
+ def resize_flow(flow,
164
+ size_type,
165
+ sizes,
166
+ interp_mode='bilinear',
167
+ align_corners=False):
168
+ """Resize a flow according to ratio or shape.
169
+
170
+ Args:
171
+ flow (Tensor): Precomputed flow. shape [N, 2, H, W].
172
+ size_type (str): 'ratio' or 'shape'.
173
+ sizes (list[int | float]): the ratio for resizing or the final output
174
+ shape.
175
+ 1) The order of ratio should be [ratio_h, ratio_w]. For
176
+ downsampling, the ratio should be smaller than 1.0 (i.e., ratio
177
+ < 1.0). For upsampling, the ratio should be larger than 1.0 (i.e.,
178
+ ratio > 1.0).
179
+ 2) The order of output_size should be [out_h, out_w].
180
+ interp_mode (str): The mode of interpolation for resizing.
181
+ Default: 'bilinear'.
182
+ align_corners (bool): Whether align corners. Default: False.
183
+
184
+ Returns:
185
+ Tensor: Resized flow.
186
+ """
187
+ _, _, flow_h, flow_w = flow.size()
188
+ if size_type == 'ratio':
189
+ output_h, output_w = int(flow_h * sizes[0]), int(flow_w * sizes[1])
190
+ elif size_type == 'shape':
191
+ output_h, output_w = sizes[0], sizes[1]
192
+ else:
193
+ raise ValueError(
194
+ f'Size type should be ratio or shape, but got type {size_type}.')
195
+
196
+ input_flow = flow.clone()
197
+ ratio_h = output_h / flow_h
198
+ ratio_w = output_w / flow_w
199
+ input_flow[:, 0, :, :] *= ratio_w
200
+ input_flow[:, 1, :, :] *= ratio_h
201
+ resized_flow = F.interpolate(
202
+ input=input_flow,
203
+ size=(output_h, output_w),
204
+ mode=interp_mode,
205
+ align_corners=align_corners)
206
+ return resized_flow
207
+
208
+
209
+ # TODO: may write a cpp file
210
+ def pixel_unshuffle(x, scale):
211
+ """ Pixel unshuffle.
212
+
213
+ Args:
214
+ x (Tensor): Input feature with shape (b, c, hh, hw).
215
+ scale (int): Downsample ratio.
216
+
217
+ Returns:
218
+ Tensor: the pixel unshuffled feature.
219
+ """
220
+ b, c, hh, hw = x.size()
221
+ out_channel = c * (scale**2)
222
+ assert hh % scale == 0 and hw % scale == 0
223
+ h = hh // scale
224
+ w = hw // scale
225
+ x_view = x.view(b, c, h, scale, w, scale)
226
+ return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
227
+
228
+
229
+ # class DCNv2Pack(ModulatedDeformConvPack):
230
+ # """Modulated deformable conv for deformable alignment.
231
+ #
232
+ # Different from the official DCNv2Pack, which generates offsets and masks
233
+ # from the preceding features, this DCNv2Pack takes another different
234
+ # features to generate offsets and masks.
235
+ #
236
+ # Ref:
237
+ # Delving Deep into Deformable Alignment in Video Super-Resolution.
238
+ # """
239
+ #
240
+ # def forward(self, x, feat):
241
+ # out = self.conv_offset(feat)
242
+ # o1, o2, mask = torch.chunk(out, 3, dim=1)
243
+ # offset = torch.cat((o1, o2), dim=1)
244
+ # mask = torch.sigmoid(mask)
245
+ #
246
+ # offset_absmean = torch.mean(torch.abs(offset))
247
+ # if offset_absmean > 50:
248
+ # logger = get_root_logger()
249
+ # logger.warning(
250
+ # f'Offset abs mean is {offset_absmean}, larger than 50.')
251
+ #
252
+ # return modulated_deform_conv(x, offset, mask, self.weight, self.bias,
253
+ # self.stride, self.padding, self.dilation,
254
+ # self.groups, self.deformable_groups)
255
+
256
+
257
+ class LayerNormFunction(torch.autograd.Function):
258
+
259
+ @staticmethod
260
+ def forward(ctx, x, weight, bias, eps):
261
+ ctx.eps = eps
262
+ N, C, H, W = x.size()
263
+ mu = x.mean(1, keepdim=True)
264
+ var = (x - mu).pow(2).mean(1, keepdim=True)
265
+ y = (x - mu) / (var + eps).sqrt()
266
+ ctx.save_for_backward(y, var, weight)
267
+ y = weight.view(1, C, 1, 1) * y + bias.view(1, C, 1, 1)
268
+ return y
269
+
270
+ @staticmethod
271
+ def backward(ctx, grad_output):
272
+ eps = ctx.eps
273
+
274
+ N, C, H, W = grad_output.size()
275
+ y, var, weight = ctx.saved_variables
276
+ g = grad_output * weight.view(1, C, 1, 1)
277
+ mean_g = g.mean(dim=1, keepdim=True)
278
+
279
+ mean_gy = (g * y).mean(dim=1, keepdim=True)
280
+ gx = 1. / torch.sqrt(var + eps) * (g - y * mean_gy - mean_g)
281
+ return gx, (grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0), grad_output.sum(dim=3).sum(dim=2).sum(
282
+ dim=0), None
283
+
284
+ class LayerNorm2d(nn.Module):
285
+
286
+ def __init__(self, channels, eps=1e-6):
287
+ super(LayerNorm2d, self).__init__()
288
+ self.register_parameter('weight', nn.Parameter(torch.ones(channels)))
289
+ self.register_parameter('bias', nn.Parameter(torch.zeros(channels)))
290
+ self.eps = eps
291
+
292
+ def forward(self, x):
293
+ return LayerNormFunction.apply(x, self.weight, self.bias, self.eps)
294
+
295
+ # handle multiple input
296
+ class MySequential(nn.Sequential):
297
+ def forward(self, *inputs):
298
+ for module in self._modules.values():
299
+ if type(inputs) == tuple:
300
+ inputs = module(*inputs)
301
+ else:
302
+ inputs = module(inputs)
303
+ return inputs
304
+
305
+ import time
306
+ def measure_inference_speed(model, data, max_iter=200, log_interval=50):
307
+ model.eval()
308
+
309
+ # the first several iterations may be very slow so skip them
310
+ num_warmup = 5
311
+ pure_inf_time = 0
312
+ fps = 0
313
+
314
+ # benchmark with 2000 image and take the average
315
+ for i in range(max_iter):
316
+
317
+ torch.cuda.synchronize()
318
+ start_time = time.perf_counter()
319
+
320
+ with torch.no_grad():
321
+ model(*data)
322
+
323
+ torch.cuda.synchronize()
324
+ elapsed = time.perf_counter() - start_time
325
+
326
+ if i >= num_warmup:
327
+ pure_inf_time += elapsed
328
+ if (i + 1) % log_interval == 0:
329
+ fps = (i + 1 - num_warmup) / pure_inf_time
330
+ print(
331
+ f'Done image [{i + 1:<3}/ {max_iter}], '
332
+ f'fps: {fps:.1f} img / s, '
333
+ f'times per image: {1000 / fps:.1f} ms / img',
334
+ flush=True)
335
+
336
+ if (i + 1) == max_iter:
337
+ fps = (i + 1 - num_warmup) / pure_inf_time
338
+ print(
339
+ f'Overall fps: {fps:.1f} img / s, '
340
+ f'times per image: {1000 / fps:.1f} ms / img',
341
+ flush=True)
342
+ break
343
+ return fps
archs/nafnet_utils/local_arch.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ class AvgPool2d(nn.Module):
7
+ def __init__(self, kernel_size=None, base_size=None, auto_pad=True, fast_imp=False, train_size=None):
8
+ super().__init__()
9
+ self.kernel_size = kernel_size
10
+ self.base_size = base_size
11
+ self.auto_pad = auto_pad
12
+
13
+ # only used for fast implementation
14
+ self.fast_imp = fast_imp
15
+ self.rs = [5, 4, 3, 2, 1]
16
+ self.max_r1 = self.rs[0]
17
+ self.max_r2 = self.rs[0]
18
+ self.train_size = train_size
19
+
20
+ def extra_repr(self) -> str:
21
+ return 'kernel_size={}, base_size={}, stride={}, fast_imp={}'.format(
22
+ self.kernel_size, self.base_size, self.kernel_size, self.fast_imp
23
+ )
24
+
25
+ def forward(self, x):
26
+ if self.kernel_size is None and self.base_size:
27
+ train_size = self.train_size
28
+ if isinstance(self.base_size, int):
29
+ self.base_size = (self.base_size, self.base_size)
30
+ self.kernel_size = list(self.base_size)
31
+ self.kernel_size[0] = x.shape[2] * self.base_size[0] // train_size[-2]
32
+ self.kernel_size[1] = x.shape[3] * self.base_size[1] // train_size[-1]
33
+
34
+ # only used for fast implementation
35
+ self.max_r1 = max(1, self.rs[0] * x.shape[2] // train_size[-2])
36
+ self.max_r2 = max(1, self.rs[0] * x.shape[3] // train_size[-1])
37
+
38
+ if self.kernel_size[0] >= x.size(-2) and self.kernel_size[1] >= x.size(-1):
39
+ return F.adaptive_avg_pool2d(x, 1)
40
+
41
+ if self.fast_imp: # Non-equivalent implementation but faster
42
+ h, w = x.shape[2:]
43
+ if self.kernel_size[0] >= h and self.kernel_size[1] >= w:
44
+ out = F.adaptive_avg_pool2d(x, 1)
45
+ else:
46
+ r1 = [r for r in self.rs if h % r == 0][0]
47
+ r2 = [r for r in self.rs if w % r == 0][0]
48
+ # reduction_constraint
49
+ r1 = min(self.max_r1, r1)
50
+ r2 = min(self.max_r2, r2)
51
+ s = x[:, :, ::r1, ::r2].cumsum(dim=-1).cumsum(dim=-2)
52
+ n, c, h, w = s.shape
53
+ k1, k2 = min(h - 1, self.kernel_size[0] // r1), min(w - 1, self.kernel_size[1] // r2)
54
+ out = (s[:, :, :-k1, :-k2] - s[:, :, :-k1, k2:] - s[:, :, k1:, :-k2] + s[:, :, k1:, k2:]) / (k1 * k2)
55
+ out = torch.nn.functional.interpolate(out, scale_factor=(r1, r2))
56
+ else:
57
+ n, c, h, w = x.shape
58
+ s = x.cumsum(dim=-1).cumsum_(dim=-2)
59
+ s = torch.nn.functional.pad(s, (1, 0, 1, 0)) # pad 0 for convenience
60
+ k1, k2 = min(h, self.kernel_size[0]), min(w, self.kernel_size[1])
61
+ s1, s2, s3, s4 = s[:, :, :-k1, :-k2], s[:, :, :-k1, k2:], s[:, :, k1:, :-k2], s[:, :, k1:, k2:]
62
+ out = s4 + s1 - s2 - s3
63
+ out = out / (k1 * k2)
64
+
65
+ if self.auto_pad:
66
+ n, c, h, w = x.shape
67
+ _h, _w = out.shape[2:]
68
+ # print(x.shape, self.kernel_size)
69
+ pad2d = ((w - _w) // 2, (w - _w + 1) // 2, (h - _h) // 2, (h - _h + 1) // 2)
70
+ out = torch.nn.functional.pad(out, pad2d, mode='replicate')
71
+
72
+ return out
73
+
74
+ def replace_layers(model, base_size, train_size, fast_imp, **kwargs):
75
+ for n, m in model.named_children():
76
+ if len(list(m.children())) > 0:
77
+ ## compound module, go inside it
78
+ replace_layers(m, base_size, train_size, fast_imp, **kwargs)
79
+
80
+ if isinstance(m, nn.AdaptiveAvgPool2d):
81
+ pool = AvgPool2d(base_size=base_size, fast_imp=fast_imp, train_size=train_size)
82
+ assert m.output_size == 1
83
+ setattr(model, n, pool)
84
+
85
+
86
+
87
+ class Local_Base():
88
+ def convert(self, *args, train_size, **kwargs):
89
+ replace_layers(self, *args, train_size=train_size, **kwargs)
90
+ imgs = torch.rand(train_size)
91
+ with torch.no_grad():
92
+ self.forward(imgs)
archs/network.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import functools
5
+ try:
6
+ from .arch_util import EBlock, Attention_Light
7
+ from .arch_util_freq import EBlock_freq
8
+ except:
9
+ from arch_util import EBlock, Attention_Light
10
+ from arch_util_freq import EBlock_freq
11
+
12
+
13
+ class Network(nn.Module):
14
+
15
+ def __init__(self, img_channel=3,
16
+ width=16,
17
+ middle_blk_num=1,
18
+ enc_blk_nums=[],
19
+ dec_blk_nums=[],
20
+ dilations = [1],
21
+ extra_depth_wise = False):
22
+ super(Network, self).__init__()
23
+
24
+ self.intro = nn.Conv2d(in_channels=img_channel, out_channels=width, kernel_size=3, padding=1, stride=1, groups=1,
25
+ bias=True)
26
+ self.ending = nn.Conv2d(in_channels=width, out_channels=img_channel, kernel_size=3, padding=1, stride=1, groups=1,
27
+ bias=True)
28
+
29
+ self.encoders = nn.ModuleList()
30
+ self.decoders = nn.ModuleList()
31
+ self.middle_blks = nn.ModuleList()
32
+ self.ups = nn.ModuleList()
33
+ self.downs = nn.ModuleList()
34
+
35
+ chan = width
36
+ for num in enc_blk_nums:
37
+ self.encoders.append(
38
+ nn.Sequential(
39
+ *[EBlock(chan, dilations = dilations, extra_depth_wise=extra_depth_wise) for _ in range(num)]
40
+ )
41
+ )
42
+ self.downs.append(
43
+ nn.Conv2d(chan, 2*chan, 2, 2)
44
+ )
45
+ chan = chan * 2
46
+
47
+ self.middle_blks = \
48
+ nn.Sequential(
49
+ *[EBlock(chan, dilations = dilations, extra_depth_wise=extra_depth_wise) for _ in range(middle_blk_num)]
50
+ )
51
+
52
+ for num in dec_blk_nums:
53
+ self.ups.append(
54
+ nn.Sequential(
55
+ nn.Conv2d(chan, chan * 2, 1, bias=False),
56
+ nn.PixelShuffle(2)
57
+ )
58
+ )
59
+ chan = chan // 2
60
+ self.decoders.append(
61
+ nn.Sequential(
62
+ *[EBlock(chan, extra_depth_wise=extra_depth_wise) for _ in range(num)]
63
+ )
64
+ )
65
+
66
+ self.padder_size = 2 ** len(self.encoders)
67
+
68
+ #define the attention layers
69
+
70
+ # self.recon_trunk_light = nn.Sequential(*[FBlock(c = chan * self.padder_size,
71
+ # DW_Expand=2, FFN_Expand=2, dilations = dilations,
72
+ # extra_depth_wise = False) for i in range(residual_layers)])
73
+
74
+ # ResidualBlock_noBN_f = functools.partial(ResidualBlock_noBN, nf = width * self.padder_size)
75
+ # self.recon_trunk_light = make_layer(ResidualBlock_noBN_f, residual_layers)
76
+
77
+
78
+
79
+ def forward(self, input):
80
+
81
+ _, _, H, W = input.shape
82
+
83
+ x = self.intro(input)
84
+
85
+ encs = []
86
+ # i = 0
87
+ for encoder, down in zip(self.encoders, self.downs):
88
+ x = encoder(x)
89
+ # print(i, x.shape)
90
+ encs.append(x)
91
+ x = down(x)
92
+ # i += 1
93
+
94
+ x = self.middle_blks(x)
95
+ # print('3', x.shape)
96
+ # apply the mask
97
+ # x = x * mask
98
+
99
+ # x = self.recon_trunk_light(x)
100
+
101
+ for decoder, up, enc_skip in zip(self.decoders, self.ups, encs[::-1]):
102
+ x = up(x)
103
+ x = x + enc_skip
104
+ x = decoder(x)
105
+
106
+ x = self.ending(x)
107
+ x = x + input
108
+
109
+ return x[:, :, :H, :W]
110
+
111
+
112
+ if __name__ == '__main__':
113
+
114
+ img_channel = 3
115
+ width = 32
116
+
117
+ enc_blks = [1, 2, 3]
118
+ middle_blk_num = 3
119
+ dec_blks = [3, 1, 1]
120
+ residual_layers = 2
121
+ dilations = [1, 4]
122
+
123
+ net = Network(img_channel=img_channel,
124
+ width=width,
125
+ middle_blk_num=middle_blk_num,
126
+ enc_blk_nums=enc_blks,
127
+ dec_blk_nums=dec_blks,
128
+ dilations = dilations)
129
+
130
+ # NAF = NAFNet(img_channel=img_channel, width=width, middle_blk_num=middle_blk_num,
131
+ # enc_blk_nums=enc_blks, dec_blk_nums=dec_blks)
132
+
133
+ inp_shape = (3, 256, 256)
134
+
135
+ from ptflops import get_model_complexity_info
136
+
137
+ macs, params = get_model_complexity_info(net, inp_shape, verbose=False, print_per_layer_stat=False)
138
+
139
+ print(macs, params)
140
+ inp = torch.randn(1, 3, 256, 256)
141
+ out = net(inp)
142
+
143
+