Spaces:
Runtime error
Runtime error
Update
Browse files- __pycache__/models.cpython-310.pyc +0 -0
- __pycache__/models.cpython-311.pyc +0 -0
- main.py +217 -0
- models.py +340 -0
- predict.py +101 -0
__pycache__/models.cpython-310.pyc
ADDED
Binary file (5.83 kB). View file
|
|
__pycache__/models.cpython-311.pyc
ADDED
Binary file (15.1 kB). View file
|
|
main.py
ADDED
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.optim as optim
|
5 |
+
from torchvision.transforms import transforms
|
6 |
+
from torch.utils.data import DataLoader, random_split, Dataset
|
7 |
+
from torchvision.datasets import ImageFolder
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
from models import *
|
10 |
+
from scipy.ndimage import gaussian_filter1d
|
11 |
+
import numpy as np
|
12 |
+
|
13 |
+
# Constants
|
14 |
+
RANDOM_SEED = 123
|
15 |
+
BATCH_SIZE = 32
|
16 |
+
NUM_EPOCHS = 100
|
17 |
+
LEARNING_RATE = 0.0001
|
18 |
+
STEP_SIZE = 10
|
19 |
+
GAMMA = 0.5
|
20 |
+
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
21 |
+
NUM_PRINT = 100
|
22 |
+
NUM_CLASSES = 5
|
23 |
+
|
24 |
+
# Load and preprocess the data
|
25 |
+
data_dir = r"data/train/Task 1"
|
26 |
+
|
27 |
+
# Define transformation for preprocessing
|
28 |
+
preprocess = transforms.Compose(
|
29 |
+
[
|
30 |
+
transforms.Resize((64, 64)), # Resize images to 64x64
|
31 |
+
transforms.ToTensor(), # Convert to tensor
|
32 |
+
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), # Normalize
|
33 |
+
]
|
34 |
+
)
|
35 |
+
|
36 |
+
augmentation = transforms.Compose(
|
37 |
+
[
|
38 |
+
transforms.Resize((64, 64)), # Resize images to 64x64
|
39 |
+
transforms.RandomHorizontalFlip(p=0.5), # Random horizontal flip
|
40 |
+
transforms.RandomRotation(degrees=45), # Random rotation
|
41 |
+
transforms.RandomVerticalFlip(p=0.5), # Random vertical flip
|
42 |
+
transforms.RandomGrayscale(p=0.1), # Random grayscale
|
43 |
+
transforms.ColorJitter(
|
44 |
+
brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5
|
45 |
+
), # Random color jitter
|
46 |
+
transforms.ToTensor(), # Convert to tensor
|
47 |
+
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), # Normalize
|
48 |
+
]
|
49 |
+
)
|
50 |
+
|
51 |
+
# Load the dataset using ImageFolder
|
52 |
+
original_dataset = ImageFolder(root=data_dir, transform=preprocess)
|
53 |
+
augmented_dataset = ImageFolder(root=data_dir, transform=augmentation)
|
54 |
+
dataset = original_dataset + augmented_dataset
|
55 |
+
|
56 |
+
print("Length of dataset: ", len(dataset))
|
57 |
+
print("Classes: ", original_dataset.classes)
|
58 |
+
|
59 |
+
|
60 |
+
# Custom dataset class
|
61 |
+
class CustomDataset(Dataset):
|
62 |
+
def __init__(self, dataset):
|
63 |
+
self.data = dataset
|
64 |
+
|
65 |
+
def __len__(self):
|
66 |
+
return len(self.data)
|
67 |
+
|
68 |
+
def __getitem__(self, idx):
|
69 |
+
img, label = self.data[idx]
|
70 |
+
return img, label
|
71 |
+
|
72 |
+
|
73 |
+
# Split the dataset into train and validation sets
|
74 |
+
train_size = int(0.8 * len(dataset))
|
75 |
+
val_size = len(dataset) - train_size
|
76 |
+
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
|
77 |
+
|
78 |
+
# Create data loaders for the custom dataset
|
79 |
+
train_loader = DataLoader(
|
80 |
+
CustomDataset(train_dataset), batch_size=BATCH_SIZE, shuffle=True, num_workers=0
|
81 |
+
)
|
82 |
+
valid_loader = DataLoader(
|
83 |
+
CustomDataset(val_dataset), batch_size=BATCH_SIZE, num_workers=0
|
84 |
+
)
|
85 |
+
|
86 |
+
# Initialize model, criterion, optimizer, and scheduler
|
87 |
+
model = resnet18(pretrained=False, num_classes=NUM_CLASSES)
|
88 |
+
model = model.to(DEVICE)
|
89 |
+
criterion = nn.CrossEntropyLoss()
|
90 |
+
# Adam optimizer
|
91 |
+
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
|
92 |
+
# ReduceLROnPlateau scheduler
|
93 |
+
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
|
94 |
+
optimizer, mode="min", factor=0.1, patience=10, verbose=True
|
95 |
+
)
|
96 |
+
|
97 |
+
# Lists to store training and validation loss history
|
98 |
+
TRAIN_LOSS_HIST = []
|
99 |
+
VAL_LOSS_HIST = []
|
100 |
+
AVG_TRAIN_LOSS_HIST = []
|
101 |
+
AVG_VAL_LOSS_HIST = []
|
102 |
+
TRAIN_ACC_HIST = []
|
103 |
+
VAL_ACC_HIST = []
|
104 |
+
|
105 |
+
# Training loop
|
106 |
+
for epoch in range(NUM_EPOCHS):
|
107 |
+
model.train(True) # Set model to training mode
|
108 |
+
running_loss = 0.0
|
109 |
+
total_train = 0
|
110 |
+
correct_train = 0
|
111 |
+
|
112 |
+
for i, (inputs, labels) in enumerate(train_loader, 0):
|
113 |
+
inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
|
114 |
+
optimizer.zero_grad()
|
115 |
+
outputs = model(inputs)
|
116 |
+
loss = criterion(outputs, labels)
|
117 |
+
loss.backward()
|
118 |
+
optimizer.step()
|
119 |
+
running_loss += loss.item()
|
120 |
+
|
121 |
+
if (i + 1) % NUM_PRINT == 0:
|
122 |
+
print(
|
123 |
+
"[Epoch %d, Batch %d] Loss: %.6f"
|
124 |
+
% (epoch + 1, i + 1, running_loss / NUM_PRINT)
|
125 |
+
)
|
126 |
+
running_loss = 0.0
|
127 |
+
|
128 |
+
_, predicted = torch.max(outputs, 1)
|
129 |
+
total_train += labels.size(0)
|
130 |
+
correct_train += (predicted == labels).sum().item()
|
131 |
+
|
132 |
+
TRAIN_LOSS_HIST.append(loss.item())
|
133 |
+
|
134 |
+
# Calculate the average training loss for the epoch
|
135 |
+
avg_train_loss = running_loss / len(train_loader)
|
136 |
+
AVG_TRAIN_LOSS_HIST.append(avg_train_loss)
|
137 |
+
|
138 |
+
# Print average training loss for the epoch
|
139 |
+
print("[Epoch %d] Average Training Loss: %.6f" % (epoch + 1, avg_train_loss))
|
140 |
+
|
141 |
+
# Learning rate scheduling
|
142 |
+
lr_1 = optimizer.param_groups[0]["lr"]
|
143 |
+
print("Learning Rate: {:.15f}".format(lr_1))
|
144 |
+
scheduler.step(avg_val_loss)
|
145 |
+
|
146 |
+
# Validation loop
|
147 |
+
model.eval() # Set model to evaluation mode
|
148 |
+
val_loss = 0.0
|
149 |
+
correct_val = 0
|
150 |
+
total_val = 0
|
151 |
+
|
152 |
+
with torch.no_grad():
|
153 |
+
for inputs, labels in valid_loader:
|
154 |
+
inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
|
155 |
+
outputs = model(inputs)
|
156 |
+
loss = criterion(outputs, labels)
|
157 |
+
val_loss += loss.item()
|
158 |
+
# Calculate accuracy
|
159 |
+
_, predicted = torch.max(outputs, 1)
|
160 |
+
total_val += labels.size(0)
|
161 |
+
correct_val += (predicted == labels).sum().item()
|
162 |
+
|
163 |
+
VAL_LOSS_HIST.append(loss.item())
|
164 |
+
|
165 |
+
# Calculate the average validation loss for the epoch
|
166 |
+
avg_val_loss = val_loss / len(valid_loader)
|
167 |
+
AVG_VAL_LOSS_HIST.append(loss.item())
|
168 |
+
print("Average Validation Loss: %.6f" % (avg_val_loss))
|
169 |
+
|
170 |
+
# Calculate the accuracy of validation set
|
171 |
+
val_accuracy = correct_val / total_val
|
172 |
+
VAL_ACC_HIST.append(val_accuracy)
|
173 |
+
print("Validation Accuracy: %.6f" % (val_accuracy))
|
174 |
+
|
175 |
+
# End of training loop
|
176 |
+
|
177 |
+
# Save the model
|
178 |
+
model_save_path = "model.pth"
|
179 |
+
torch.save(model.state_dict(), model_save_path)
|
180 |
+
print("Model saved at", model_save_path)
|
181 |
+
|
182 |
+
print("Generating loss plot...")
|
183 |
+
# Make the plot smoother by interpolating the data
|
184 |
+
# https://stackoverflow.com/questions/5283649/plot-smooth-line-with-pyplot
|
185 |
+
# train_loss_line = gaussian_filter1d(TRAIN_LOSS_HIST, sigma=10)
|
186 |
+
# val_loss_line = gaussian_filter1d(VAL_LOSS_HIST, sigma=10)
|
187 |
+
# plt.plot(range(1, NUM_EPOCHS + 1), train_loss_line, label='Train Loss')
|
188 |
+
# plt.plot(range(1, NUM_EPOCHS + 1), val_loss_line, label='Validation Loss')
|
189 |
+
avg_train_loss_line = gaussian_filter1d(AVG_TRAIN_LOSS_HIST, sigma=2)
|
190 |
+
avg_val_loss_line = gaussian_filter1d(AVG_VAL_LOSS_HIST, sigma=2)
|
191 |
+
train_loss_line = gaussian_filter1d(TRAIN_LOSS_HIST, sigma=2)
|
192 |
+
val_loss_line = gaussian_filter1d(VAL_LOSS_HIST, sigma=2)
|
193 |
+
train_acc_line = gaussian_filter1d(TRAIN_ACC_HIST, sigma=2)
|
194 |
+
val_acc_line = gaussian_filter1d(VAL_ACC_HIST, sigma=2)
|
195 |
+
plt.plot(range(1, NUM_EPOCHS + 1), train_loss_line, label="Train Loss")
|
196 |
+
plt.plot(range(1, NUM_EPOCHS + 1), val_loss_line, label="Validation Loss")
|
197 |
+
plt.xlabel("Epochs")
|
198 |
+
plt.ylabel("Loss")
|
199 |
+
plt.legend()
|
200 |
+
plt.title("Train Loss and Validation Loss")
|
201 |
+
plt.savefig("loss_plot.png")
|
202 |
+
plt.clf()
|
203 |
+
plt.plot(range(1, NUM_EPOCHS + 1), avg_train_loss_line, label="Average Train Loss")
|
204 |
+
plt.plot(range(1, NUM_EPOCHS + 1), avg_val_loss_line, label="Average Validation Loss")
|
205 |
+
plt.xlabel("Epochs")
|
206 |
+
plt.ylabel("Loss")
|
207 |
+
plt.legend()
|
208 |
+
plt.title("Average Train Loss and Average Validation Loss")
|
209 |
+
plt.savefig("avg_loss_plot.png")
|
210 |
+
plt.clf()
|
211 |
+
plt.plot(range(1, NUM_EPOCHS + 1), train_acc_line, label="Train Accuracy")
|
212 |
+
plt.plot(range(1, NUM_EPOCHS + 1), val_acc_line, label="Validation Accuracy")
|
213 |
+
plt.xlabel("Epochs")
|
214 |
+
plt.ylabel("Accuracy")
|
215 |
+
plt.legend()
|
216 |
+
plt.title("Train Accuracy and Validation Accuracy")
|
217 |
+
plt.savefig("accuracy_plot.png")
|
models.py
ADDED
@@ -0,0 +1,340 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#######################################################
|
2 |
+
# This file stores all the models used in the project.#
|
3 |
+
#######################################################
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from torchvision.models import resnet50
|
7 |
+
from torchvision.models import resnet18
|
8 |
+
|
9 |
+
# resnet50
|
10 |
+
class Bottleneck(torch.nn.Module):
|
11 |
+
expansion = 4
|
12 |
+
|
13 |
+
def __init__(self, in_channels, out_channels, i_downsample=None, stride=1):
|
14 |
+
super(Bottleneck, self).__init__()
|
15 |
+
# hmm,ex 1x1 convolution to reduce channels (intermediate channels)
|
16 |
+
self.conv1 = torch.nn.Conv2d(
|
17 |
+
in_channels, out_channels, kernel_size=1, stride=1, padding=0
|
18 |
+
)
|
19 |
+
self.batch_norm1 = torch.nn.BatchNorm2d(out_channels)
|
20 |
+
# 3x3 convolution with specified stride
|
21 |
+
self.conv2 = torch.nn.Conv2d(
|
22 |
+
out_channels, out_channels, kernel_size=3, stride=stride, padding=1
|
23 |
+
)
|
24 |
+
self.batch_norm2 = torch.nn.BatchNorm2d(out_channels)
|
25 |
+
# and then leh,1x1 expand back
|
26 |
+
self.conv3 = torch.nn.Conv2d(
|
27 |
+
out_channels,
|
28 |
+
out_channels * self.expansion,
|
29 |
+
kernel_size=1,
|
30 |
+
stride=1,
|
31 |
+
padding=0,
|
32 |
+
)
|
33 |
+
self.batch_norm3 = torch.nn.BatchNorm2d(out_channels * self.expansion)
|
34 |
+
|
35 |
+
self.i_downsample = i_downsample
|
36 |
+
self.stride = stride
|
37 |
+
self.relu = torch.nn.ReLU()
|
38 |
+
|
39 |
+
##forward the input x through the network,haiyaa
|
40 |
+
def forward(self, x):
|
41 |
+
identity = x.clone()
|
42 |
+
x = self.relu(self.batch_norm1(self.conv1(x)))
|
43 |
+
|
44 |
+
x = self.relu(self.batch_norm2(self.conv2(x)))
|
45 |
+
|
46 |
+
x = self.conv3(x)
|
47 |
+
x = self.batch_norm3(x)
|
48 |
+
|
49 |
+
# downsample if needed
|
50 |
+
if self.i_downsample is not None:
|
51 |
+
identity = self.i_downsample(identity)
|
52 |
+
# add identity
|
53 |
+
x += identity
|
54 |
+
x = self.relu(x)
|
55 |
+
|
56 |
+
return x
|
57 |
+
|
58 |
+
|
59 |
+
# we no use this first,but we can just copy this whole class and apply to resnet16 and etc
|
60 |
+
class Block(torch.nn.Module):
|
61 |
+
expansion = 1
|
62 |
+
|
63 |
+
def __init__(self, in_channels, out_channels, i_downsample=None, stride=1):
|
64 |
+
super(Block, self).__init__()
|
65 |
+
|
66 |
+
self.conv1 = torch.nn.Conv2d(
|
67 |
+
in_channels,
|
68 |
+
out_channels,
|
69 |
+
kernel_size=3,
|
70 |
+
padding=1,
|
71 |
+
stride=stride,
|
72 |
+
bias=False,
|
73 |
+
)
|
74 |
+
self.batch_norm1 = torch.nn.BatchNorm2d(out_channels)
|
75 |
+
self.conv2 = torch.nn.Conv2d(
|
76 |
+
out_channels,
|
77 |
+
out_channels,
|
78 |
+
kernel_size=3,
|
79 |
+
padding=1,
|
80 |
+
stride=stride,
|
81 |
+
bias=False,
|
82 |
+
)
|
83 |
+
self.batch_norm2 = torch.nn.BatchNorm2d(out_channels)
|
84 |
+
|
85 |
+
self.i_downsample = i_downsample
|
86 |
+
self.stride = stride
|
87 |
+
self.relu = torch.nn.ReLU()
|
88 |
+
|
89 |
+
def forward(self, x):
|
90 |
+
identity = x.clone()
|
91 |
+
|
92 |
+
x = self.relu(self.batch_norm2(self.conv1(x)))
|
93 |
+
x = self.batch_norm2(self.conv2(x))
|
94 |
+
|
95 |
+
if self.i_downsample is not None:
|
96 |
+
identity = self.i_downsample(identity)
|
97 |
+
print(x.shape)
|
98 |
+
print(identity.shape)
|
99 |
+
x += identity
|
100 |
+
x = self.relu(x)
|
101 |
+
return x
|
102 |
+
|
103 |
+
|
104 |
+
class ResNet(torch.nn.Module):
|
105 |
+
def __init__(self, ResBlock, layer_list, num_classes, num_channels=3):
|
106 |
+
super(ResNet, self).__init__()
|
107 |
+
self.in_channels = 64
|
108 |
+
# intial conv layaer
|
109 |
+
self.conv1 = torch.nn.Conv2d(
|
110 |
+
num_channels, 64, kernel_size=7, stride=2, padding=3, bias=False
|
111 |
+
)
|
112 |
+
self.batch_norm1 = torch.nn.BatchNorm2d(64)
|
113 |
+
self.relu = torch.nn.ReLU()
|
114 |
+
self.max_pool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
115 |
+
# residual block(layers),each block got three three layer,total 4 blocks
|
116 |
+
self.layer1 = self._make_layer(ResBlock, layer_list[0], planes=64)
|
117 |
+
self.layer2 = self._make_layer(ResBlock, layer_list[1], planes=128, stride=2)
|
118 |
+
self.layer3 = self._make_layer(ResBlock, layer_list[2], planes=256, stride=2)
|
119 |
+
self.layer4 = self._make_layer(ResBlock, layer_list[3], planes=512, stride=2)
|
120 |
+
|
121 |
+
self.avgpool = torch.nn.AdaptiveAvgPool2d((1, 1))
|
122 |
+
self.fc = torch.nn.Linear(512 * ResBlock.expansion, num_classes)
|
123 |
+
|
124 |
+
def forward(self, x):
|
125 |
+
x = self.relu(self.batch_norm1(self.conv1(x)))
|
126 |
+
x = self.max_pool(x)
|
127 |
+
|
128 |
+
x = self.layer1(x)
|
129 |
+
x = self.layer2(x)
|
130 |
+
x = self.layer3(x)
|
131 |
+
x = self.layer4(x)
|
132 |
+
|
133 |
+
x = self.avgpool(x)
|
134 |
+
x = x.reshape(x.shape[0], -1)
|
135 |
+
x = self.fc(x)
|
136 |
+
|
137 |
+
return x
|
138 |
+
|
139 |
+
def _make_layer(self, ResBlock, blocks, planes, stride=1):
|
140 |
+
# plane is the number of output channel
|
141 |
+
ii_downsample = None
|
142 |
+
layers = []
|
143 |
+
|
144 |
+
if stride != 1 or self.in_channels != planes * ResBlock.expansion:
|
145 |
+
ii_downsample = torch.nn.Sequential(
|
146 |
+
torch.nn.Conv2d(
|
147 |
+
self.in_channels,
|
148 |
+
planes * ResBlock.expansion,
|
149 |
+
kernel_size=1,
|
150 |
+
stride=stride,
|
151 |
+
),
|
152 |
+
torch.nn.BatchNorm2d(planes * ResBlock.expansion),
|
153 |
+
)
|
154 |
+
|
155 |
+
layers.append(
|
156 |
+
ResBlock(
|
157 |
+
self.in_channels, planes, i_downsample=ii_downsample, stride=stride
|
158 |
+
)
|
159 |
+
)
|
160 |
+
self.in_channels = planes * ResBlock.expansion
|
161 |
+
|
162 |
+
for i in range(blocks - 1):
|
163 |
+
layers.append(ResBlock(self.in_channels, planes))
|
164 |
+
|
165 |
+
return torch.nn.Sequential(*layers)
|
166 |
+
|
167 |
+
|
168 |
+
##list here leh is the number of residual block in each layer
|
169 |
+
def ResNet50(num_classes, channels=3):
|
170 |
+
return ResNet(Bottleneck, [3, 4, 6, 3], num_classes, channels)
|
171 |
+
|
172 |
+
|
173 |
+
# VGG16 model
|
174 |
+
class VGG16(torch.nn.Module):
|
175 |
+
def __init__(self, num_classes):
|
176 |
+
super().__init__()
|
177 |
+
|
178 |
+
self.block_1 = torch.nn.Sequential(
|
179 |
+
torch.nn.Conv2d(
|
180 |
+
in_channels=3,
|
181 |
+
out_channels=64,
|
182 |
+
kernel_size=(3, 3),
|
183 |
+
stride=(1, 1),
|
184 |
+
padding=1,
|
185 |
+
),
|
186 |
+
torch.nn.ReLU(),
|
187 |
+
torch.nn.Conv2d(
|
188 |
+
in_channels=64,
|
189 |
+
out_channels=64,
|
190 |
+
kernel_size=(3, 3),
|
191 |
+
stride=(1, 1),
|
192 |
+
padding=1,
|
193 |
+
),
|
194 |
+
torch.nn.ReLU(),
|
195 |
+
torch.nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2)),
|
196 |
+
)
|
197 |
+
|
198 |
+
self.block_2 = torch.nn.Sequential(
|
199 |
+
torch.nn.Conv2d(
|
200 |
+
in_channels=64,
|
201 |
+
out_channels=128,
|
202 |
+
kernel_size=(3, 3),
|
203 |
+
stride=(1, 1),
|
204 |
+
padding=1,
|
205 |
+
),
|
206 |
+
torch.nn.ReLU(),
|
207 |
+
torch.nn.Conv2d(
|
208 |
+
in_channels=128,
|
209 |
+
out_channels=128,
|
210 |
+
kernel_size=(3, 3),
|
211 |
+
stride=(1, 1),
|
212 |
+
padding=1,
|
213 |
+
),
|
214 |
+
torch.nn.ReLU(),
|
215 |
+
torch.nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2)),
|
216 |
+
)
|
217 |
+
|
218 |
+
self.block_3 = torch.nn.Sequential(
|
219 |
+
torch.nn.Conv2d(
|
220 |
+
in_channels=128,
|
221 |
+
out_channels=256,
|
222 |
+
kernel_size=(3, 3),
|
223 |
+
stride=(1, 1),
|
224 |
+
padding=1,
|
225 |
+
),
|
226 |
+
torch.nn.ReLU(),
|
227 |
+
torch.nn.Conv2d(
|
228 |
+
in_channels=256,
|
229 |
+
out_channels=256,
|
230 |
+
kernel_size=(3, 3),
|
231 |
+
stride=(1, 1),
|
232 |
+
padding=1,
|
233 |
+
),
|
234 |
+
torch.nn.ReLU(),
|
235 |
+
torch.nn.Conv2d(
|
236 |
+
in_channels=256,
|
237 |
+
out_channels=256,
|
238 |
+
kernel_size=(3, 3),
|
239 |
+
stride=(1, 1),
|
240 |
+
padding=1,
|
241 |
+
),
|
242 |
+
torch.nn.ReLU(),
|
243 |
+
torch.nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2)),
|
244 |
+
)
|
245 |
+
|
246 |
+
self.block_4 = torch.nn.Sequential(
|
247 |
+
torch.nn.Conv2d(
|
248 |
+
in_channels=256,
|
249 |
+
out_channels=512,
|
250 |
+
kernel_size=(3, 3),
|
251 |
+
stride=(1, 1),
|
252 |
+
padding=1,
|
253 |
+
),
|
254 |
+
torch.nn.ReLU(),
|
255 |
+
torch.nn.Conv2d(
|
256 |
+
in_channels=512,
|
257 |
+
out_channels=512,
|
258 |
+
kernel_size=(3, 3),
|
259 |
+
stride=(1, 1),
|
260 |
+
padding=1,
|
261 |
+
),
|
262 |
+
torch.nn.ReLU(),
|
263 |
+
torch.nn.Conv2d(
|
264 |
+
in_channels=512,
|
265 |
+
out_channels=512,
|
266 |
+
kernel_size=(3, 3),
|
267 |
+
stride=(1, 1),
|
268 |
+
padding=1,
|
269 |
+
),
|
270 |
+
torch.nn.ReLU(),
|
271 |
+
torch.nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2)),
|
272 |
+
)
|
273 |
+
|
274 |
+
self.block_5 = torch.nn.Sequential(
|
275 |
+
torch.nn.Conv2d(
|
276 |
+
in_channels=512,
|
277 |
+
out_channels=512,
|
278 |
+
kernel_size=(3, 3),
|
279 |
+
stride=(1, 1),
|
280 |
+
padding=1,
|
281 |
+
),
|
282 |
+
torch.nn.ReLU(),
|
283 |
+
torch.nn.Conv2d(
|
284 |
+
in_channels=512,
|
285 |
+
out_channels=512,
|
286 |
+
kernel_size=(3, 3),
|
287 |
+
stride=(1, 1),
|
288 |
+
padding=1,
|
289 |
+
),
|
290 |
+
torch.nn.ReLU(),
|
291 |
+
torch.nn.Conv2d(
|
292 |
+
in_channels=512,
|
293 |
+
out_channels=512,
|
294 |
+
kernel_size=(3, 3),
|
295 |
+
stride=(1, 1),
|
296 |
+
padding=1,
|
297 |
+
),
|
298 |
+
torch.nn.ReLU(),
|
299 |
+
torch.nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2)),
|
300 |
+
)
|
301 |
+
|
302 |
+
height, width = 3, 3
|
303 |
+
self.classifier = torch.nn.Sequential(
|
304 |
+
torch.nn.Linear(512 * height * width, 4096),
|
305 |
+
torch.nn.ReLU(True),
|
306 |
+
torch.nn.Dropout(p=0.5),
|
307 |
+
torch.nn.Linear(4096, 4096),
|
308 |
+
torch.nn.ReLU(True),
|
309 |
+
torch.nn.Dropout(p=0.5),
|
310 |
+
torch.nn.Linear(4096, num_classes),
|
311 |
+
)
|
312 |
+
|
313 |
+
for m in self.modules():
|
314 |
+
if isinstance(m, torch.torch.nn.Conv2d) or isinstance(
|
315 |
+
m, torch.torch.nn.Linear
|
316 |
+
):
|
317 |
+
torch.nn.init.kaiming_uniform_(
|
318 |
+
m.weight, mode="fan_in", nonlinearity="relu"
|
319 |
+
)
|
320 |
+
if m.bias is not None:
|
321 |
+
m.bias.detach().zero_()
|
322 |
+
|
323 |
+
self.avgpool = torch.nn.AdaptiveAvgPool2d((height, width))
|
324 |
+
|
325 |
+
def forward(self, x):
|
326 |
+
x = self.block_1(x)
|
327 |
+
x = self.block_2(x)
|
328 |
+
x = self.block_3(x)
|
329 |
+
x = self.block_4(x)
|
330 |
+
x = self.block_5(x)
|
331 |
+
x = self.avgpool(x)
|
332 |
+
x = x.view(x.size(0), -1) # flatten
|
333 |
+
|
334 |
+
logits = self.classifier(x)
|
335 |
+
# probas = F.softmax(logits, dim=1)
|
336 |
+
|
337 |
+
return logits
|
338 |
+
|
339 |
+
|
340 |
+
# ResNet18 model
|
predict.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from torchvision import transforms
|
5 |
+
from PIL import Image
|
6 |
+
from models import * # Make sure you import your model correctly from the 'models' module
|
7 |
+
from torchmetrics import ConfusionMatrix
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
import pathlib
|
10 |
+
|
11 |
+
# Define the path to your model checkpoint
|
12 |
+
model_checkpoint_path = "model.pth"
|
13 |
+
|
14 |
+
# Define the path to the image you want to classify
|
15 |
+
image_path = "data/test/Task 1/" # Use forward slashes for file paths
|
16 |
+
|
17 |
+
# Define images variable to recursively list all the data file in the image_path
|
18 |
+
images = list(pathlib.Path(image_path).rglob("*.png"))
|
19 |
+
classes = os.listdir(image_path)
|
20 |
+
print(images)
|
21 |
+
|
22 |
+
true_classs = []
|
23 |
+
predicted_labels = []
|
24 |
+
|
25 |
+
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
26 |
+
|
27 |
+
NUM_CLASSES = 5 # Update with the correct number of classes
|
28 |
+
|
29 |
+
# Load your model (change this according to your model definition)
|
30 |
+
model = resnet18(pretrained=False, num_classes=NUM_CLASSES)
|
31 |
+
model.load_state_dict(torch.load(model_checkpoint_path, map_location=DEVICE)) # Load the model on the same device
|
32 |
+
model.eval()
|
33 |
+
model = model.to(DEVICE)
|
34 |
+
|
35 |
+
# Define transformation for preprocessing the input image
|
36 |
+
preprocess = transforms.Compose(
|
37 |
+
[
|
38 |
+
transforms.Resize((64, 64)), # Resize the image to match training input size
|
39 |
+
transforms.Grayscale(num_output_channels=3), # Convert the image to grayscale
|
40 |
+
transforms.ToTensor(),
|
41 |
+
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), # Normalize the image
|
42 |
+
]
|
43 |
+
)
|
44 |
+
|
45 |
+
def predict_image(image_path, model, transform):
|
46 |
+
model.eval()
|
47 |
+
correct_predictions = 0
|
48 |
+
total_predictions = len(images)
|
49 |
+
|
50 |
+
with torch.no_grad():
|
51 |
+
for i in images:
|
52 |
+
print('---------------------------')
|
53 |
+
# Check the true label of the image by checking the sequence of the folder in Task 1
|
54 |
+
true_class = classes.index(i.parts[-2])
|
55 |
+
print("Image path:", i)
|
56 |
+
print("True class:", true_class)
|
57 |
+
image = Image.open(i)
|
58 |
+
image = transform(image).unsqueeze(0)
|
59 |
+
image = image.to(DEVICE)
|
60 |
+
output = model(image)
|
61 |
+
|
62 |
+
# softmax algorithm
|
63 |
+
probabilities = torch.softmax(output, dim=1)[0] * 100
|
64 |
+
predicted_class = torch.argmax(output, dim=1).item()
|
65 |
+
|
66 |
+
# Append true and predicted labels to their respective lists
|
67 |
+
true_classs.append(true_class)
|
68 |
+
predicted_labels.append(predicted_class)
|
69 |
+
|
70 |
+
# Check if the prediction is correct
|
71 |
+
if predicted_class == true_class:
|
72 |
+
correct_predictions += 1
|
73 |
+
|
74 |
+
# Report the prediction
|
75 |
+
print("Predicted class:", predicted_class)
|
76 |
+
print("Probability:", probabilities[predicted_class].item())
|
77 |
+
print("Predicted label:", classes[predicted_class])
|
78 |
+
print("Correct predictions:", correct_predictions)
|
79 |
+
print("Correct?", "Yes" if predicted_class == true_class else "No")
|
80 |
+
print("---------------------------")
|
81 |
+
|
82 |
+
|
83 |
+
|
84 |
+
# Calculate accuracy
|
85 |
+
accuracy = correct_predictions / total_predictions
|
86 |
+
print("Accuracy:", accuracy)
|
87 |
+
|
88 |
+
# Call the predict_image function
|
89 |
+
predict_image(image_path, model, preprocess)
|
90 |
+
|
91 |
+
# Convert the lists to tensors
|
92 |
+
predicted_labels_tensor = torch.tensor(predicted_labels)
|
93 |
+
true_classs_tensor = torch.tensor(true_classs)
|
94 |
+
|
95 |
+
# Create confusion matrix
|
96 |
+
conf_matrix = ConfusionMatrix(num_classes=NUM_CLASSES, task='multiclass')
|
97 |
+
conf_matrix.update(predicted_labels_tensor, true_classs_tensor)
|
98 |
+
|
99 |
+
# Plot confusion matrix
|
100 |
+
conf_matrix.plot()
|
101 |
+
plt.show()
|