|
import torch |
|
import torch.nn as nn |
|
from torchvision import transforms |
|
from PIL import Image |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
import torchvision.models as models |
|
from skimage.color import lab2rgb |
|
import os |
|
|
|
class ColorizationNet(nn.Module): |
|
def __init__(self, input_size=128): |
|
super(ColorizationNet, self).__init__() |
|
MIDLEVEL_FEATURE_SIZE = 128 |
|
|
|
|
|
resnet = models.resnet18(num_classes=365) |
|
|
|
resnet.conv1.weight = nn.Parameter(resnet.conv1.weight.sum(dim=1).unsqueeze(1)) |
|
|
|
self.midlevel_resnet = nn.Sequential(*list(resnet.children())[0:6]) |
|
|
|
|
|
self.upsample = nn.Sequential( |
|
nn.Conv2d(MIDLEVEL_FEATURE_SIZE, 128, kernel_size=3, stride=1, padding=1), |
|
nn.BatchNorm2d(128), |
|
nn.ReLU(), |
|
nn.Upsample(scale_factor=2), |
|
nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1), |
|
nn.BatchNorm2d(64), |
|
nn.ReLU(), |
|
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), |
|
nn.BatchNorm2d(64), |
|
nn.ReLU(), |
|
nn.Upsample(scale_factor=2), |
|
nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1), |
|
nn.BatchNorm2d(32), |
|
nn.ReLU(), |
|
nn.Conv2d(32, 2, kernel_size=3, stride=1, padding=1), |
|
nn.Upsample(scale_factor=2) |
|
) |
|
|
|
def forward(self, input): |
|
|
|
|
|
midlevel_features = self.midlevel_resnet(input) |
|
|
|
|
|
output = self.upsample(midlevel_features) |
|
return output |
|
|
|
|
|
|
|
def to_rgb(grayscale_input, ab_input, save_path, save_name): |
|
|
|
C, H, W = grayscale_input.shape |
|
|
|
|
|
ab_input_resized = torch.nn.functional.interpolate(ab_input.unsqueeze(0), size=(H, W), mode='bilinear', |
|
align_corners=False).squeeze(0) |
|
|
|
|
|
|
|
color_image = torch.cat((grayscale_input, ab_input_resized), 0).numpy() |
|
|
|
color_image = color_image.transpose((1, 2, 0)) |
|
color_image[:, :, 0:1] = color_image[:, :, 0:1] * 100 |
|
color_image[:, :, 1:3] = color_image[:, :, 1:3] * 255 - 128 |
|
color_image = lab2rgb(color_image.astype(np.float64)) |
|
grayscale_input = grayscale_input.squeeze().numpy() |
|
if save_path is not None and save_name is not None: |
|
plt.imsave(arr=grayscale_input, fname='{}{}'.format(save_path['grayscale'], save_name), cmap='gray') |
|
plt.imsave(arr=color_image, fname='{}{}'.format(save_path['colorized'], save_name)) |
|
|
|
|
|
def colorize_single_image(image_path, model, criterion, save_dir, epoch, use_gpu=True): |
|
model.eval() |
|
|
|
|
|
transform = transforms.Compose([ |
|
|
|
transforms.ToTensor() |
|
]) |
|
image = Image.open(image_path).convert("L") |
|
input_gray = transform(image).unsqueeze(0) |
|
|
|
|
|
if use_gpu and torch.cuda.is_available(): |
|
input_gray = input_gray.cuda() |
|
model = model.cuda() |
|
|
|
|
|
with torch.no_grad(): |
|
output_ab = model(input_gray) |
|
|
|
|
|
|
|
os.makedirs(save_dir, exist_ok=True) |
|
|
|
|
|
save_paths = { |
|
'grayscale': os.path.join(save_dir, 'gray/'), |
|
'colorized': os.path.join(save_dir, 'color/') |
|
} |
|
os.makedirs(save_paths['grayscale'], exist_ok=True) |
|
os.makedirs(save_paths['colorized'], exist_ok=True) |
|
|
|
|
|
save_name = f'colorized-epoch-{epoch}.jpg' |
|
to_rgb(input_gray[0].cpu(), ab_input=output_ab[0].detach().cpu(), save_path=save_paths, save_name=save_name) |
|
|
|
print(f'Colorized image saved in {save_paths["colorized"]}') |
|
|
|
|
|
|
|
def run_example(image_path, save_dir): |
|
use_gpu = torch.cuda.is_available() |
|
|
|
model = ColorizationNet() |
|
model_path = 'colorization_md1.pth' |
|
pretrained = torch.load(model_path, map_location=lambda storage, loc: storage) |
|
model.load_state_dict(pretrained) |
|
model.eval() |
|
|
|
criterion = nn.MSELoss() |
|
|
|
with torch.no_grad(): |
|
colorize_single_image(image_path, model, criterion, save_dir, epoch=0, use_gpu=use_gpu) |
|
|
|
if __name__ == "__main__": |
|
|
|
image_path = 'example_image.jpg' |
|
save_dir = 'results' |
|
run_example(image_path, save_dir) |
|
|