GeorgiosIoannouCoder commited on
Commit
9ab5f71
·
verified ·
1 Parent(s): 041580b

Create inference directory

Browse files
Files changed (1) hide show
  1. inference/real_esrgan.py +306 -0
inference/real_esrgan.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ###########################################################################################
2
+ # Filename: realsrgan.py
3
+ # Description: Upscale images using the trained REALESRGAN model.
4
+ ###########################################################################################
5
+ #
6
+ # Import libraries.
7
+ #
8
+ # Import OpenCV library for image processing.
9
+ import cv2
10
+
11
+ # Import the math module for mathematical operations.
12
+ import math
13
+
14
+ # Import NumPy for numerical operations on arrays.
15
+ import numpy as np
16
+
17
+ # Import the os module for operating system functionalities.
18
+ import os
19
+
20
+ # Import the queue module for implementing queues.
21
+ import queue
22
+
23
+ # Import the threading module for multi-threading support.
24
+ import threading
25
+
26
+ # Import PyTorch for deep learning.
27
+ import torch
28
+
29
+ # Import a utility function for downloading files.
30
+ from basicsr.utils.download_util import load_file_from_url
31
+
32
+ # Import functional module from PyTorch's neural network library.
33
+ from torch.nn import functional as F
34
+
35
+ ###########################################################################################
36
+ # Define the root directory.
37
+ ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
38
+
39
+
40
+ ###########################################################################################
41
+ class RealEsrGan:
42
+ def __init__(
43
+ self,
44
+ scale, # Upsampling scale factor used in the networks.
45
+ model_path, # The path to the pretrained model.
46
+ dni_weight=None, # Performing the interpolation between two networks.
47
+ model=None, # The pretained model weights.
48
+ pre_pad=10, # Pad the input images to avoid border artifacts.
49
+ half=False, # Whether to use half precision during inference or not.
50
+ device=None, # What device to run inference on. cpu or cuda.
51
+ gpu_id=None, # ID of GPU to be used if there are more than one GPUs.
52
+ ):
53
+ self.scale = scale
54
+ self.model_path = model_path
55
+ self.dni_weight = dni_weight
56
+ self.model = model
57
+ self.pre_pad = pre_pad
58
+ self.half = half
59
+ self.device = device
60
+ self.gpu_id = gpu_id
61
+
62
+ self.mod_scale = None
63
+
64
+ # Initialize device based on GPU availability and user preference.
65
+ if self.gpu_id:
66
+ self.device = (
67
+ torch.device(
68
+ f"cuda:{self.gpu_id}" if torch.cuda.is_available() else "cpu"
69
+ )
70
+ if self.device is None
71
+ else self.device
72
+ )
73
+ else:
74
+ self.device = (
75
+ torch.device("cuda" if torch.cuda.is_available() else "cpu")
76
+ if self.device is None
77
+ else self.device
78
+ )
79
+
80
+ # Load the RealESRGAN model from the specified path or URL.
81
+ if isinstance(self.model_path, list):
82
+ assert len(self.model_path) == len(self.dni_weight)
83
+ loadnet = self.dni(self.model_path[0], self.model_path[1], self.dni_weight)
84
+ else:
85
+ # Download model if model path is a URL.
86
+ if self.model_path.startswith("https://"):
87
+ self.model_path = load_file_from_url(
88
+ url=model_path,
89
+ model_dir=os.path.join(ROOT_DIR, "weights"),
90
+ progress=True,
91
+ file_name=None,
92
+ )
93
+ loadnet = torch.load(model_path, map_location=torch.device("cpu"))
94
+
95
+ # Use params_ema if available, otherwise use params.
96
+ if "params_ema" in loadnet:
97
+ keyname = "params_ema"
98
+ else:
99
+ keyname = "params"
100
+
101
+ # Load model weights.
102
+ model.load_state_dict(loadnet[keyname], strict=True)
103
+
104
+ # Put the model in evaluation mode.
105
+ model.eval()
106
+
107
+ # Move the model to the specified device.
108
+ self.model = model.to(self.device)
109
+
110
+ if self.half:
111
+ self.model = self.model.half()
112
+
113
+ def dni(self, net_a, net_b, dni_weight, key="params", loc="cpu"):
114
+ # Define a method for Domain-Adversarial Neural Interface (DNI).
115
+
116
+ # Load the parameters of neural network A from a file, considering the specified device location.
117
+ net_a = torch.load(net_a, map_location=torch.device(loc))
118
+
119
+ # Load the parameters of neural network B from a file, considering the specified device location.
120
+ net_b = torch.load(net_b, map_location=torch.device(loc))
121
+
122
+ # Iterate over each key-value pair in the parameters of neural network A.
123
+ for k, v_a in net_a[key].items():
124
+ # Update the parameters of neural network A using a weighted combination
125
+ # of its own parameters and those of neural network B.
126
+ net_a[key][k] = dni_weight[0] * v_a + dni_weight[1] * net_b[key][k]
127
+
128
+ # Return the updated model.
129
+ return net_a
130
+
131
+ def pre_process(self, img):
132
+ # Convert image to PyTorch tensor and adjust dimensions.
133
+ img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float()
134
+
135
+ # Add a batch dimension and move the tensor to the specified device.
136
+ self.img = img.unsqueeze(0).to(self.device)
137
+
138
+ # If half precision is enabled, convert the tensor to half precision.
139
+ if self.half:
140
+ self.img = self.img.half()
141
+
142
+ # Apply reflective padding to the image if pre_pad is not zero.
143
+ if self.pre_pad != 0:
144
+ self.img = F.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), "reflect")
145
+
146
+ # Set mod_scale based on the scale factor.
147
+ if self.scale == 2:
148
+ self.mod_scale = 2
149
+ elif self.scale == 1:
150
+ self.mod_scale = 4
151
+
152
+ # Check if mod_scale is specified and perform padding accordingly.
153
+ if self.mod_scale is not None:
154
+ self.mod_pad_h, self.mod_pad_w = 0, 0
155
+ _, _, h, w = self.img.size()
156
+
157
+ # Calculate padding required to make dimensions divisible by mod_scale.
158
+ if h % self.mod_scale != 0:
159
+ self.mod_pad_h = self.mod_scale - h % self.mod_scale
160
+
161
+ if w % self.mod_scale != 0:
162
+ self.mod_pad_w = self.mod_scale - w % self.mod_scale
163
+
164
+ # Apply reflective padding to the image based on mod_pad_h and mod_pad_w.
165
+ self.img = F.pad(
166
+ self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), "reflect"
167
+ )
168
+
169
+ def process(self):
170
+ # Process/inference on the image.
171
+ self.output = self.model(self.img)
172
+
173
+ def post_process(self):
174
+ # Check if a modification scale is specified.
175
+ if self.mod_scale is not None:
176
+ # Get the height and width of the output tensor.
177
+ _, _, h, w = self.output.size()
178
+
179
+ # Crop the output tensor based on the specified modification scale and padding
180
+ self.output = self.output[
181
+ :,
182
+ :,
183
+ 0 : h - self.mod_pad_h * self.scale,
184
+ 0 : w - self.mod_pad_w * self.scale,
185
+ ]
186
+
187
+ # Check if there is pre-padding applied.
188
+ if self.pre_pad != 0:
189
+ # Get the height and width of the output tensor.
190
+ _, _, h, w = self.output.size()
191
+
192
+ # Crop the output tensor based on the specified pre-padding.
193
+ self.output = self.output[
194
+ :,
195
+ :,
196
+ 0 : h - self.pre_pad * self.scale,
197
+ 0 : w - self.pre_pad * self.scale,
198
+ ]
199
+
200
+ # Return the processed output tensor after modification and cropping.
201
+ return self.output
202
+
203
+ def enhance(self, img, upscale=None, alpha_upsampler="realesrgan"):
204
+ # Get the height and width of the input image.
205
+ h_input, w_input = img.shape[0:2]
206
+ img = img.astype(np.float32)
207
+
208
+ # Determine if the input image is 16-bit.
209
+ if np.max(img) > 256:
210
+ max_range = 65535
211
+ print("\tInput is a 16-bit image")
212
+ else:
213
+ max_range = 255
214
+
215
+ # Normalize the image to the range [0, 1].
216
+ img = img / max_range
217
+
218
+ # Identify the image mode based on its number of channels.
219
+ if len(img.shape) == 2:
220
+ img_mode = "L" # Gray image.
221
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
222
+ elif img.shape[2] == 4: # RGBA image with alpha channel
223
+ img_mode = "RGBA" # RGBA image with alpha channel.
224
+ alpha = img[:, :, 3]
225
+ img = img[:, :, 0:3]
226
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
227
+
228
+ # Convert alpha channel to RGB if using realesrgan alpha upsampling.
229
+ if alpha_upsampler == "realesrgan":
230
+ alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2RGB)
231
+ else:
232
+ img_mode = "RGB" # RGB image.
233
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
234
+
235
+ # Pre-process the image using a method not provided in the code.
236
+ self.pre_process(img)
237
+
238
+ # Process the image.
239
+ self.process()
240
+
241
+ # Post-process the image and retrieve the enhanced output.
242
+ output_img = self.post_process()
243
+ output_img = output_img.data.squeeze().float().cpu().clamp_(0, 1).numpy()
244
+ output_img = np.transpose(output_img[[2, 1, 0], :, :], (1, 2, 0))
245
+
246
+ # Convert output image back to grayscale if the original image was grayscale.
247
+ if img_mode == "L":
248
+ output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY)
249
+
250
+ # Process alpha channel if the original image had RGBA mode.
251
+ if img_mode == "RGBA":
252
+ # Check if RealESRGAN should be used for alpha channel upsampling.
253
+ if alpha_upsampler == "realesrgan":
254
+ # Pre-process the alpha channel using a method not provided in this code.
255
+ self.pre_process(alpha)
256
+
257
+ # Process the image.
258
+ self.process()
259
+
260
+ # Post-process the alpha channel and retrieve the enhanced output.
261
+ output_alpha = self.post_process()
262
+
263
+ # Convert the alpha channel output to a NumPy array in the range [0, 1].
264
+ output_alpha = (
265
+ output_alpha.data.squeeze().float().cpu().clamp_(0, 1).numpy()
266
+ )
267
+
268
+ # Transpose the alpha channel array for proper channel ordering.
269
+ output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0))
270
+
271
+ # Convert the alpha channel to grayscale.
272
+ output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY)
273
+ else:
274
+ # Resize the alpha channel using linear interpolation if not using realesrgan.
275
+ h, w = alpha.shape[0:2]
276
+ output_alpha = cv2.resize(
277
+ alpha,
278
+ (w * self.scale, h * self.scale),
279
+ interpolation=cv2.INTER_LINEAR,
280
+ )
281
+
282
+ # Convert output image to BGRA format and assign the processed alpha channel.
283
+ output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2BGRA)
284
+ output_img[:, :, 3] = output_alpha
285
+
286
+ # Scale the output image back to the original size if specified.
287
+ if max_range == 65535:
288
+ output = (output_img * 65535.0).round().astype(np.uint16)
289
+ else:
290
+ output = (output_img * 255.0).round().astype(np.uint8)
291
+
292
+ # Resize the output image if a different scale is specified.
293
+ if upscale is not None and upscale != float(self.scale):
294
+ output = cv2.resize(
295
+ output,
296
+ (
297
+ int(w_input * upscale),
298
+ int(h_input * upscale),
299
+ ),
300
+ interpolation=cv2.INTER_LANCZOS4,
301
+ )
302
+
303
+ return output, img_mode
304
+
305
+
306
+ ###########################################################################################