jessicaNono
commited on
Commit
·
c69e4df
1
Parent(s):
4f57a75
library to use the model
Browse files- colorize.py +134 -0
colorize.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torchvision import transforms
|
4 |
+
from PIL import Image
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
import numpy as np
|
7 |
+
import torchvision.models as models
|
8 |
+
from skimage.color import lab2rgb
|
9 |
+
import os
|
10 |
+
|
11 |
+
class ColorizationNet(nn.Module):
|
12 |
+
def __init__(self, input_size=128):
|
13 |
+
super(ColorizationNet, self).__init__()
|
14 |
+
MIDLEVEL_FEATURE_SIZE = 128
|
15 |
+
|
16 |
+
## First half: ResNet
|
17 |
+
resnet = models.resnet18(num_classes=365)
|
18 |
+
# Change first conv layer to accept single-channel (grayscale) input
|
19 |
+
resnet.conv1.weight = nn.Parameter(resnet.conv1.weight.sum(dim=1).unsqueeze(1))
|
20 |
+
# Extract midlevel features from ResNet-gray
|
21 |
+
self.midlevel_resnet = nn.Sequential(*list(resnet.children())[0:6])
|
22 |
+
|
23 |
+
## Second half: Upsampling
|
24 |
+
self.upsample = nn.Sequential(
|
25 |
+
nn.Conv2d(MIDLEVEL_FEATURE_SIZE, 128, kernel_size=3, stride=1, padding=1),
|
26 |
+
nn.BatchNorm2d(128),
|
27 |
+
nn.ReLU(),
|
28 |
+
nn.Upsample(scale_factor=2),
|
29 |
+
nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
|
30 |
+
nn.BatchNorm2d(64),
|
31 |
+
nn.ReLU(),
|
32 |
+
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
|
33 |
+
nn.BatchNorm2d(64),
|
34 |
+
nn.ReLU(),
|
35 |
+
nn.Upsample(scale_factor=2),
|
36 |
+
nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1),
|
37 |
+
nn.BatchNorm2d(32),
|
38 |
+
nn.ReLU(),
|
39 |
+
nn.Conv2d(32, 2, kernel_size=3, stride=1, padding=1),
|
40 |
+
nn.Upsample(scale_factor=2)
|
41 |
+
)
|
42 |
+
|
43 |
+
def forward(self, input):
|
44 |
+
|
45 |
+
# Pass input through ResNet-gray to extract features
|
46 |
+
midlevel_features = self.midlevel_resnet(input)
|
47 |
+
|
48 |
+
# Upsample to get colors
|
49 |
+
output = self.upsample(midlevel_features)
|
50 |
+
return output
|
51 |
+
|
52 |
+
|
53 |
+
|
54 |
+
def to_rgb(grayscale_input, ab_input, save_path, save_name):
|
55 |
+
# Adjust the shape unpacking
|
56 |
+
C, H, W = grayscale_input.shape # Now expecting 3 values: channels, height, width
|
57 |
+
|
58 |
+
# Ensure ab_input has the same spatial dimensions as grayscale_input
|
59 |
+
ab_input_resized = torch.nn.functional.interpolate(ab_input.unsqueeze(0), size=(H, W), mode='bilinear',
|
60 |
+
align_corners=False).squeeze(0)
|
61 |
+
|
62 |
+
# Combine grayscale and ab channels
|
63 |
+
# Combine grayscale and ab channels
|
64 |
+
color_image = torch.cat((grayscale_input, ab_input_resized), 0).numpy() # combine channels
|
65 |
+
|
66 |
+
color_image = color_image.transpose((1, 2, 0)) # rescale for matplotlib
|
67 |
+
color_image[:, :, 0:1] = color_image[:, :, 0:1] * 100
|
68 |
+
color_image[:, :, 1:3] = color_image[:, :, 1:3] * 255 - 128
|
69 |
+
color_image = lab2rgb(color_image.astype(np.float64))
|
70 |
+
grayscale_input = grayscale_input.squeeze().numpy()
|
71 |
+
if save_path is not None and save_name is not None:
|
72 |
+
plt.imsave(arr=grayscale_input, fname='{}{}'.format(save_path['grayscale'], save_name), cmap='gray')
|
73 |
+
plt.imsave(arr=color_image, fname='{}{}'.format(save_path['colorized'], save_name))
|
74 |
+
|
75 |
+
|
76 |
+
def colorize_single_image(image_path, model, criterion, save_dir, epoch, use_gpu=True):
|
77 |
+
model.eval()
|
78 |
+
|
79 |
+
# Load and preprocess the image
|
80 |
+
transform = transforms.Compose([
|
81 |
+
|
82 |
+
transforms.ToTensor()
|
83 |
+
])
|
84 |
+
image = Image.open(image_path).convert("L") # Convert to grayscale
|
85 |
+
input_gray = transform(image).unsqueeze(0) # Add batch dimension
|
86 |
+
|
87 |
+
# Use GPU if available
|
88 |
+
if use_gpu and torch.cuda.is_available():
|
89 |
+
input_gray = input_gray.cuda()
|
90 |
+
model = model.cuda()
|
91 |
+
|
92 |
+
# Run model
|
93 |
+
with torch.no_grad():
|
94 |
+
output_ab = model(input_gray)
|
95 |
+
|
96 |
+
# Create save directory if it doesn't exist
|
97 |
+
|
98 |
+
os.makedirs(save_dir, exist_ok=True)
|
99 |
+
|
100 |
+
# Create save paths for grayscale and colorized images
|
101 |
+
save_paths = {
|
102 |
+
'grayscale': os.path.join(save_dir, 'gray/'),
|
103 |
+
'colorized': os.path.join(save_dir, 'color/')
|
104 |
+
}
|
105 |
+
os.makedirs(save_paths['grayscale'], exist_ok=True)
|
106 |
+
os.makedirs(save_paths['colorized'], exist_ok=True)
|
107 |
+
|
108 |
+
# Save the colorized image
|
109 |
+
save_name = f'colorized-epoch-{epoch}.jpg'
|
110 |
+
to_rgb(input_gray[0].cpu(), ab_input=output_ab[0].detach().cpu(), save_path=save_paths, save_name=save_name)
|
111 |
+
|
112 |
+
print(f'Colorized image saved in {save_paths["colorized"]}')
|
113 |
+
|
114 |
+
# Load model and run colorization (Example usage)
|
115 |
+
|
116 |
+
def run_example(image_path, save_dir):
|
117 |
+
use_gpu = torch.cuda.is_available()
|
118 |
+
|
119 |
+
model = ColorizationNet()
|
120 |
+
model_path = 'colorization_md1.pth' # Update with the path to your model
|
121 |
+
pretrained = torch.load(model_path, map_location=lambda storage, loc: storage)
|
122 |
+
model.load_state_dict(pretrained)
|
123 |
+
model.eval()
|
124 |
+
|
125 |
+
criterion = nn.MSELoss()
|
126 |
+
|
127 |
+
with torch.no_grad():
|
128 |
+
colorize_single_image(image_path, model, criterion, save_dir, epoch=0, use_gpu=use_gpu)
|
129 |
+
|
130 |
+
if __name__ == "__main__":
|
131 |
+
# Example of how to use this script as a library
|
132 |
+
image_path = 'example_image.jpg' # Replace with your image path
|
133 |
+
save_dir = 'results' # Replace with your desired save path
|
134 |
+
run_example(image_path, save_dir)
|