cycool29 commited on
Commit
1882b96
·
1 Parent(s): 813cce8
__pycache__/models.cpython-310.pyc ADDED
Binary file (5.83 kB). View file
 
__pycache__/models.cpython-311.pyc ADDED
Binary file (15.1 kB). View file
 
main.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.optim as optim
5
+ from torchvision.transforms import transforms
6
+ from torch.utils.data import DataLoader, random_split, Dataset
7
+ from torchvision.datasets import ImageFolder
8
+ import matplotlib.pyplot as plt
9
+ from models import *
10
+ from scipy.ndimage import gaussian_filter1d
11
+ import numpy as np
12
+
13
+ # Constants
14
+ RANDOM_SEED = 123
15
+ BATCH_SIZE = 32
16
+ NUM_EPOCHS = 100
17
+ LEARNING_RATE = 0.0001
18
+ STEP_SIZE = 10
19
+ GAMMA = 0.5
20
+ DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
21
+ NUM_PRINT = 100
22
+ NUM_CLASSES = 5
23
+
24
+ # Load and preprocess the data
25
+ data_dir = r"data/train/Task 1"
26
+
27
+ # Define transformation for preprocessing
28
+ preprocess = transforms.Compose(
29
+ [
30
+ transforms.Resize((64, 64)), # Resize images to 64x64
31
+ transforms.ToTensor(), # Convert to tensor
32
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), # Normalize
33
+ ]
34
+ )
35
+
36
+ augmentation = transforms.Compose(
37
+ [
38
+ transforms.Resize((64, 64)), # Resize images to 64x64
39
+ transforms.RandomHorizontalFlip(p=0.5), # Random horizontal flip
40
+ transforms.RandomRotation(degrees=45), # Random rotation
41
+ transforms.RandomVerticalFlip(p=0.5), # Random vertical flip
42
+ transforms.RandomGrayscale(p=0.1), # Random grayscale
43
+ transforms.ColorJitter(
44
+ brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5
45
+ ), # Random color jitter
46
+ transforms.ToTensor(), # Convert to tensor
47
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), # Normalize
48
+ ]
49
+ )
50
+
51
+ # Load the dataset using ImageFolder
52
+ original_dataset = ImageFolder(root=data_dir, transform=preprocess)
53
+ augmented_dataset = ImageFolder(root=data_dir, transform=augmentation)
54
+ dataset = original_dataset + augmented_dataset
55
+
56
+ print("Length of dataset: ", len(dataset))
57
+ print("Classes: ", original_dataset.classes)
58
+
59
+
60
+ # Custom dataset class
61
+ class CustomDataset(Dataset):
62
+ def __init__(self, dataset):
63
+ self.data = dataset
64
+
65
+ def __len__(self):
66
+ return len(self.data)
67
+
68
+ def __getitem__(self, idx):
69
+ img, label = self.data[idx]
70
+ return img, label
71
+
72
+
73
+ # Split the dataset into train and validation sets
74
+ train_size = int(0.8 * len(dataset))
75
+ val_size = len(dataset) - train_size
76
+ train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
77
+
78
+ # Create data loaders for the custom dataset
79
+ train_loader = DataLoader(
80
+ CustomDataset(train_dataset), batch_size=BATCH_SIZE, shuffle=True, num_workers=0
81
+ )
82
+ valid_loader = DataLoader(
83
+ CustomDataset(val_dataset), batch_size=BATCH_SIZE, num_workers=0
84
+ )
85
+
86
+ # Initialize model, criterion, optimizer, and scheduler
87
+ model = resnet18(pretrained=False, num_classes=NUM_CLASSES)
88
+ model = model.to(DEVICE)
89
+ criterion = nn.CrossEntropyLoss()
90
+ # Adam optimizer
91
+ optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
92
+ # ReduceLROnPlateau scheduler
93
+ scheduler = optim.lr_scheduler.ReduceLROnPlateau(
94
+ optimizer, mode="min", factor=0.1, patience=10, verbose=True
95
+ )
96
+
97
+ # Lists to store training and validation loss history
98
+ TRAIN_LOSS_HIST = []
99
+ VAL_LOSS_HIST = []
100
+ AVG_TRAIN_LOSS_HIST = []
101
+ AVG_VAL_LOSS_HIST = []
102
+ TRAIN_ACC_HIST = []
103
+ VAL_ACC_HIST = []
104
+
105
+ # Training loop
106
+ for epoch in range(NUM_EPOCHS):
107
+ model.train(True) # Set model to training mode
108
+ running_loss = 0.0
109
+ total_train = 0
110
+ correct_train = 0
111
+
112
+ for i, (inputs, labels) in enumerate(train_loader, 0):
113
+ inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
114
+ optimizer.zero_grad()
115
+ outputs = model(inputs)
116
+ loss = criterion(outputs, labels)
117
+ loss.backward()
118
+ optimizer.step()
119
+ running_loss += loss.item()
120
+
121
+ if (i + 1) % NUM_PRINT == 0:
122
+ print(
123
+ "[Epoch %d, Batch %d] Loss: %.6f"
124
+ % (epoch + 1, i + 1, running_loss / NUM_PRINT)
125
+ )
126
+ running_loss = 0.0
127
+
128
+ _, predicted = torch.max(outputs, 1)
129
+ total_train += labels.size(0)
130
+ correct_train += (predicted == labels).sum().item()
131
+
132
+ TRAIN_LOSS_HIST.append(loss.item())
133
+
134
+ # Calculate the average training loss for the epoch
135
+ avg_train_loss = running_loss / len(train_loader)
136
+ AVG_TRAIN_LOSS_HIST.append(avg_train_loss)
137
+
138
+ # Print average training loss for the epoch
139
+ print("[Epoch %d] Average Training Loss: %.6f" % (epoch + 1, avg_train_loss))
140
+
141
+ # Learning rate scheduling
142
+ lr_1 = optimizer.param_groups[0]["lr"]
143
+ print("Learning Rate: {:.15f}".format(lr_1))
144
+ scheduler.step(avg_val_loss)
145
+
146
+ # Validation loop
147
+ model.eval() # Set model to evaluation mode
148
+ val_loss = 0.0
149
+ correct_val = 0
150
+ total_val = 0
151
+
152
+ with torch.no_grad():
153
+ for inputs, labels in valid_loader:
154
+ inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
155
+ outputs = model(inputs)
156
+ loss = criterion(outputs, labels)
157
+ val_loss += loss.item()
158
+ # Calculate accuracy
159
+ _, predicted = torch.max(outputs, 1)
160
+ total_val += labels.size(0)
161
+ correct_val += (predicted == labels).sum().item()
162
+
163
+ VAL_LOSS_HIST.append(loss.item())
164
+
165
+ # Calculate the average validation loss for the epoch
166
+ avg_val_loss = val_loss / len(valid_loader)
167
+ AVG_VAL_LOSS_HIST.append(loss.item())
168
+ print("Average Validation Loss: %.6f" % (avg_val_loss))
169
+
170
+ # Calculate the accuracy of validation set
171
+ val_accuracy = correct_val / total_val
172
+ VAL_ACC_HIST.append(val_accuracy)
173
+ print("Validation Accuracy: %.6f" % (val_accuracy))
174
+
175
+ # End of training loop
176
+
177
+ # Save the model
178
+ model_save_path = "model.pth"
179
+ torch.save(model.state_dict(), model_save_path)
180
+ print("Model saved at", model_save_path)
181
+
182
+ print("Generating loss plot...")
183
+ # Make the plot smoother by interpolating the data
184
+ # https://stackoverflow.com/questions/5283649/plot-smooth-line-with-pyplot
185
+ # train_loss_line = gaussian_filter1d(TRAIN_LOSS_HIST, sigma=10)
186
+ # val_loss_line = gaussian_filter1d(VAL_LOSS_HIST, sigma=10)
187
+ # plt.plot(range(1, NUM_EPOCHS + 1), train_loss_line, label='Train Loss')
188
+ # plt.plot(range(1, NUM_EPOCHS + 1), val_loss_line, label='Validation Loss')
189
+ avg_train_loss_line = gaussian_filter1d(AVG_TRAIN_LOSS_HIST, sigma=2)
190
+ avg_val_loss_line = gaussian_filter1d(AVG_VAL_LOSS_HIST, sigma=2)
191
+ train_loss_line = gaussian_filter1d(TRAIN_LOSS_HIST, sigma=2)
192
+ val_loss_line = gaussian_filter1d(VAL_LOSS_HIST, sigma=2)
193
+ train_acc_line = gaussian_filter1d(TRAIN_ACC_HIST, sigma=2)
194
+ val_acc_line = gaussian_filter1d(VAL_ACC_HIST, sigma=2)
195
+ plt.plot(range(1, NUM_EPOCHS + 1), train_loss_line, label="Train Loss")
196
+ plt.plot(range(1, NUM_EPOCHS + 1), val_loss_line, label="Validation Loss")
197
+ plt.xlabel("Epochs")
198
+ plt.ylabel("Loss")
199
+ plt.legend()
200
+ plt.title("Train Loss and Validation Loss")
201
+ plt.savefig("loss_plot.png")
202
+ plt.clf()
203
+ plt.plot(range(1, NUM_EPOCHS + 1), avg_train_loss_line, label="Average Train Loss")
204
+ plt.plot(range(1, NUM_EPOCHS + 1), avg_val_loss_line, label="Average Validation Loss")
205
+ plt.xlabel("Epochs")
206
+ plt.ylabel("Loss")
207
+ plt.legend()
208
+ plt.title("Average Train Loss and Average Validation Loss")
209
+ plt.savefig("avg_loss_plot.png")
210
+ plt.clf()
211
+ plt.plot(range(1, NUM_EPOCHS + 1), train_acc_line, label="Train Accuracy")
212
+ plt.plot(range(1, NUM_EPOCHS + 1), val_acc_line, label="Validation Accuracy")
213
+ plt.xlabel("Epochs")
214
+ plt.ylabel("Accuracy")
215
+ plt.legend()
216
+ plt.title("Train Accuracy and Validation Accuracy")
217
+ plt.savefig("accuracy_plot.png")
models.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #######################################################
2
+ # This file stores all the models used in the project.#
3
+ #######################################################
4
+
5
+ import torch
6
+ from torchvision.models import resnet50
7
+ from torchvision.models import resnet18
8
+
9
+ # resnet50
10
+ class Bottleneck(torch.nn.Module):
11
+ expansion = 4
12
+
13
+ def __init__(self, in_channels, out_channels, i_downsample=None, stride=1):
14
+ super(Bottleneck, self).__init__()
15
+ # hmm,ex 1x1 convolution to reduce channels (intermediate channels)
16
+ self.conv1 = torch.nn.Conv2d(
17
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0
18
+ )
19
+ self.batch_norm1 = torch.nn.BatchNorm2d(out_channels)
20
+ # 3x3 convolution with specified stride
21
+ self.conv2 = torch.nn.Conv2d(
22
+ out_channels, out_channels, kernel_size=3, stride=stride, padding=1
23
+ )
24
+ self.batch_norm2 = torch.nn.BatchNorm2d(out_channels)
25
+ # and then leh,1x1 expand back
26
+ self.conv3 = torch.nn.Conv2d(
27
+ out_channels,
28
+ out_channels * self.expansion,
29
+ kernel_size=1,
30
+ stride=1,
31
+ padding=0,
32
+ )
33
+ self.batch_norm3 = torch.nn.BatchNorm2d(out_channels * self.expansion)
34
+
35
+ self.i_downsample = i_downsample
36
+ self.stride = stride
37
+ self.relu = torch.nn.ReLU()
38
+
39
+ ##forward the input x through the network,haiyaa
40
+ def forward(self, x):
41
+ identity = x.clone()
42
+ x = self.relu(self.batch_norm1(self.conv1(x)))
43
+
44
+ x = self.relu(self.batch_norm2(self.conv2(x)))
45
+
46
+ x = self.conv3(x)
47
+ x = self.batch_norm3(x)
48
+
49
+ # downsample if needed
50
+ if self.i_downsample is not None:
51
+ identity = self.i_downsample(identity)
52
+ # add identity
53
+ x += identity
54
+ x = self.relu(x)
55
+
56
+ return x
57
+
58
+
59
+ # we no use this first,but we can just copy this whole class and apply to resnet16 and etc
60
+ class Block(torch.nn.Module):
61
+ expansion = 1
62
+
63
+ def __init__(self, in_channels, out_channels, i_downsample=None, stride=1):
64
+ super(Block, self).__init__()
65
+
66
+ self.conv1 = torch.nn.Conv2d(
67
+ in_channels,
68
+ out_channels,
69
+ kernel_size=3,
70
+ padding=1,
71
+ stride=stride,
72
+ bias=False,
73
+ )
74
+ self.batch_norm1 = torch.nn.BatchNorm2d(out_channels)
75
+ self.conv2 = torch.nn.Conv2d(
76
+ out_channels,
77
+ out_channels,
78
+ kernel_size=3,
79
+ padding=1,
80
+ stride=stride,
81
+ bias=False,
82
+ )
83
+ self.batch_norm2 = torch.nn.BatchNorm2d(out_channels)
84
+
85
+ self.i_downsample = i_downsample
86
+ self.stride = stride
87
+ self.relu = torch.nn.ReLU()
88
+
89
+ def forward(self, x):
90
+ identity = x.clone()
91
+
92
+ x = self.relu(self.batch_norm2(self.conv1(x)))
93
+ x = self.batch_norm2(self.conv2(x))
94
+
95
+ if self.i_downsample is not None:
96
+ identity = self.i_downsample(identity)
97
+ print(x.shape)
98
+ print(identity.shape)
99
+ x += identity
100
+ x = self.relu(x)
101
+ return x
102
+
103
+
104
+ class ResNet(torch.nn.Module):
105
+ def __init__(self, ResBlock, layer_list, num_classes, num_channels=3):
106
+ super(ResNet, self).__init__()
107
+ self.in_channels = 64
108
+ # intial conv layaer
109
+ self.conv1 = torch.nn.Conv2d(
110
+ num_channels, 64, kernel_size=7, stride=2, padding=3, bias=False
111
+ )
112
+ self.batch_norm1 = torch.nn.BatchNorm2d(64)
113
+ self.relu = torch.nn.ReLU()
114
+ self.max_pool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
115
+ # residual block(layers),each block got three three layer,total 4 blocks
116
+ self.layer1 = self._make_layer(ResBlock, layer_list[0], planes=64)
117
+ self.layer2 = self._make_layer(ResBlock, layer_list[1], planes=128, stride=2)
118
+ self.layer3 = self._make_layer(ResBlock, layer_list[2], planes=256, stride=2)
119
+ self.layer4 = self._make_layer(ResBlock, layer_list[3], planes=512, stride=2)
120
+
121
+ self.avgpool = torch.nn.AdaptiveAvgPool2d((1, 1))
122
+ self.fc = torch.nn.Linear(512 * ResBlock.expansion, num_classes)
123
+
124
+ def forward(self, x):
125
+ x = self.relu(self.batch_norm1(self.conv1(x)))
126
+ x = self.max_pool(x)
127
+
128
+ x = self.layer1(x)
129
+ x = self.layer2(x)
130
+ x = self.layer3(x)
131
+ x = self.layer4(x)
132
+
133
+ x = self.avgpool(x)
134
+ x = x.reshape(x.shape[0], -1)
135
+ x = self.fc(x)
136
+
137
+ return x
138
+
139
+ def _make_layer(self, ResBlock, blocks, planes, stride=1):
140
+ # plane is the number of output channel
141
+ ii_downsample = None
142
+ layers = []
143
+
144
+ if stride != 1 or self.in_channels != planes * ResBlock.expansion:
145
+ ii_downsample = torch.nn.Sequential(
146
+ torch.nn.Conv2d(
147
+ self.in_channels,
148
+ planes * ResBlock.expansion,
149
+ kernel_size=1,
150
+ stride=stride,
151
+ ),
152
+ torch.nn.BatchNorm2d(planes * ResBlock.expansion),
153
+ )
154
+
155
+ layers.append(
156
+ ResBlock(
157
+ self.in_channels, planes, i_downsample=ii_downsample, stride=stride
158
+ )
159
+ )
160
+ self.in_channels = planes * ResBlock.expansion
161
+
162
+ for i in range(blocks - 1):
163
+ layers.append(ResBlock(self.in_channels, planes))
164
+
165
+ return torch.nn.Sequential(*layers)
166
+
167
+
168
+ ##list here leh is the number of residual block in each layer
169
+ def ResNet50(num_classes, channels=3):
170
+ return ResNet(Bottleneck, [3, 4, 6, 3], num_classes, channels)
171
+
172
+
173
+ # VGG16 model
174
+ class VGG16(torch.nn.Module):
175
+ def __init__(self, num_classes):
176
+ super().__init__()
177
+
178
+ self.block_1 = torch.nn.Sequential(
179
+ torch.nn.Conv2d(
180
+ in_channels=3,
181
+ out_channels=64,
182
+ kernel_size=(3, 3),
183
+ stride=(1, 1),
184
+ padding=1,
185
+ ),
186
+ torch.nn.ReLU(),
187
+ torch.nn.Conv2d(
188
+ in_channels=64,
189
+ out_channels=64,
190
+ kernel_size=(3, 3),
191
+ stride=(1, 1),
192
+ padding=1,
193
+ ),
194
+ torch.nn.ReLU(),
195
+ torch.nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2)),
196
+ )
197
+
198
+ self.block_2 = torch.nn.Sequential(
199
+ torch.nn.Conv2d(
200
+ in_channels=64,
201
+ out_channels=128,
202
+ kernel_size=(3, 3),
203
+ stride=(1, 1),
204
+ padding=1,
205
+ ),
206
+ torch.nn.ReLU(),
207
+ torch.nn.Conv2d(
208
+ in_channels=128,
209
+ out_channels=128,
210
+ kernel_size=(3, 3),
211
+ stride=(1, 1),
212
+ padding=1,
213
+ ),
214
+ torch.nn.ReLU(),
215
+ torch.nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2)),
216
+ )
217
+
218
+ self.block_3 = torch.nn.Sequential(
219
+ torch.nn.Conv2d(
220
+ in_channels=128,
221
+ out_channels=256,
222
+ kernel_size=(3, 3),
223
+ stride=(1, 1),
224
+ padding=1,
225
+ ),
226
+ torch.nn.ReLU(),
227
+ torch.nn.Conv2d(
228
+ in_channels=256,
229
+ out_channels=256,
230
+ kernel_size=(3, 3),
231
+ stride=(1, 1),
232
+ padding=1,
233
+ ),
234
+ torch.nn.ReLU(),
235
+ torch.nn.Conv2d(
236
+ in_channels=256,
237
+ out_channels=256,
238
+ kernel_size=(3, 3),
239
+ stride=(1, 1),
240
+ padding=1,
241
+ ),
242
+ torch.nn.ReLU(),
243
+ torch.nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2)),
244
+ )
245
+
246
+ self.block_4 = torch.nn.Sequential(
247
+ torch.nn.Conv2d(
248
+ in_channels=256,
249
+ out_channels=512,
250
+ kernel_size=(3, 3),
251
+ stride=(1, 1),
252
+ padding=1,
253
+ ),
254
+ torch.nn.ReLU(),
255
+ torch.nn.Conv2d(
256
+ in_channels=512,
257
+ out_channels=512,
258
+ kernel_size=(3, 3),
259
+ stride=(1, 1),
260
+ padding=1,
261
+ ),
262
+ torch.nn.ReLU(),
263
+ torch.nn.Conv2d(
264
+ in_channels=512,
265
+ out_channels=512,
266
+ kernel_size=(3, 3),
267
+ stride=(1, 1),
268
+ padding=1,
269
+ ),
270
+ torch.nn.ReLU(),
271
+ torch.nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2)),
272
+ )
273
+
274
+ self.block_5 = torch.nn.Sequential(
275
+ torch.nn.Conv2d(
276
+ in_channels=512,
277
+ out_channels=512,
278
+ kernel_size=(3, 3),
279
+ stride=(1, 1),
280
+ padding=1,
281
+ ),
282
+ torch.nn.ReLU(),
283
+ torch.nn.Conv2d(
284
+ in_channels=512,
285
+ out_channels=512,
286
+ kernel_size=(3, 3),
287
+ stride=(1, 1),
288
+ padding=1,
289
+ ),
290
+ torch.nn.ReLU(),
291
+ torch.nn.Conv2d(
292
+ in_channels=512,
293
+ out_channels=512,
294
+ kernel_size=(3, 3),
295
+ stride=(1, 1),
296
+ padding=1,
297
+ ),
298
+ torch.nn.ReLU(),
299
+ torch.nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2)),
300
+ )
301
+
302
+ height, width = 3, 3
303
+ self.classifier = torch.nn.Sequential(
304
+ torch.nn.Linear(512 * height * width, 4096),
305
+ torch.nn.ReLU(True),
306
+ torch.nn.Dropout(p=0.5),
307
+ torch.nn.Linear(4096, 4096),
308
+ torch.nn.ReLU(True),
309
+ torch.nn.Dropout(p=0.5),
310
+ torch.nn.Linear(4096, num_classes),
311
+ )
312
+
313
+ for m in self.modules():
314
+ if isinstance(m, torch.torch.nn.Conv2d) or isinstance(
315
+ m, torch.torch.nn.Linear
316
+ ):
317
+ torch.nn.init.kaiming_uniform_(
318
+ m.weight, mode="fan_in", nonlinearity="relu"
319
+ )
320
+ if m.bias is not None:
321
+ m.bias.detach().zero_()
322
+
323
+ self.avgpool = torch.nn.AdaptiveAvgPool2d((height, width))
324
+
325
+ def forward(self, x):
326
+ x = self.block_1(x)
327
+ x = self.block_2(x)
328
+ x = self.block_3(x)
329
+ x = self.block_4(x)
330
+ x = self.block_5(x)
331
+ x = self.avgpool(x)
332
+ x = x.view(x.size(0), -1) # flatten
333
+
334
+ logits = self.classifier(x)
335
+ # probas = F.softmax(logits, dim=1)
336
+
337
+ return logits
338
+
339
+
340
+ # ResNet18 model
predict.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ from torchvision import transforms
5
+ from PIL import Image
6
+ from models import * # Make sure you import your model correctly from the 'models' module
7
+ from torchmetrics import ConfusionMatrix
8
+ import matplotlib.pyplot as plt
9
+ import pathlib
10
+
11
+ # Define the path to your model checkpoint
12
+ model_checkpoint_path = "model.pth"
13
+
14
+ # Define the path to the image you want to classify
15
+ image_path = "data/test/Task 1/" # Use forward slashes for file paths
16
+
17
+ # Define images variable to recursively list all the data file in the image_path
18
+ images = list(pathlib.Path(image_path).rglob("*.png"))
19
+ classes = os.listdir(image_path)
20
+ print(images)
21
+
22
+ true_classs = []
23
+ predicted_labels = []
24
+
25
+ DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
26
+
27
+ NUM_CLASSES = 5 # Update with the correct number of classes
28
+
29
+ # Load your model (change this according to your model definition)
30
+ model = resnet18(pretrained=False, num_classes=NUM_CLASSES)
31
+ model.load_state_dict(torch.load(model_checkpoint_path, map_location=DEVICE)) # Load the model on the same device
32
+ model.eval()
33
+ model = model.to(DEVICE)
34
+
35
+ # Define transformation for preprocessing the input image
36
+ preprocess = transforms.Compose(
37
+ [
38
+ transforms.Resize((64, 64)), # Resize the image to match training input size
39
+ transforms.Grayscale(num_output_channels=3), # Convert the image to grayscale
40
+ transforms.ToTensor(),
41
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), # Normalize the image
42
+ ]
43
+ )
44
+
45
+ def predict_image(image_path, model, transform):
46
+ model.eval()
47
+ correct_predictions = 0
48
+ total_predictions = len(images)
49
+
50
+ with torch.no_grad():
51
+ for i in images:
52
+ print('---------------------------')
53
+ # Check the true label of the image by checking the sequence of the folder in Task 1
54
+ true_class = classes.index(i.parts[-2])
55
+ print("Image path:", i)
56
+ print("True class:", true_class)
57
+ image = Image.open(i)
58
+ image = transform(image).unsqueeze(0)
59
+ image = image.to(DEVICE)
60
+ output = model(image)
61
+
62
+ # softmax algorithm
63
+ probabilities = torch.softmax(output, dim=1)[0] * 100
64
+ predicted_class = torch.argmax(output, dim=1).item()
65
+
66
+ # Append true and predicted labels to their respective lists
67
+ true_classs.append(true_class)
68
+ predicted_labels.append(predicted_class)
69
+
70
+ # Check if the prediction is correct
71
+ if predicted_class == true_class:
72
+ correct_predictions += 1
73
+
74
+ # Report the prediction
75
+ print("Predicted class:", predicted_class)
76
+ print("Probability:", probabilities[predicted_class].item())
77
+ print("Predicted label:", classes[predicted_class])
78
+ print("Correct predictions:", correct_predictions)
79
+ print("Correct?", "Yes" if predicted_class == true_class else "No")
80
+ print("---------------------------")
81
+
82
+
83
+
84
+ # Calculate accuracy
85
+ accuracy = correct_predictions / total_predictions
86
+ print("Accuracy:", accuracy)
87
+
88
+ # Call the predict_image function
89
+ predict_image(image_path, model, preprocess)
90
+
91
+ # Convert the lists to tensors
92
+ predicted_labels_tensor = torch.tensor(predicted_labels)
93
+ true_classs_tensor = torch.tensor(true_classs)
94
+
95
+ # Create confusion matrix
96
+ conf_matrix = ConfusionMatrix(num_classes=NUM_CLASSES, task='multiclass')
97
+ conf_matrix.update(predicted_labels_tensor, true_classs_tensor)
98
+
99
+ # Plot confusion matrix
100
+ conf_matrix.plot()
101
+ plt.show()