SpiralSense / ensemble.py
cycool29's picture
Update
97dcf92
raw
history blame
7.47 kB
import matplotlib.pyplot as plt
from torch.optim.lr_scheduler import CosineAnnealingLR
import torch
import torch.nn as nn
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from data_loader import load_data, load_test_data
from configs import *
import numpy as np
torch.cuda.empty_cache()
#
class MLP(nn.Module):
def __init__(self, num_classes, num_models):
super(MLP, self).__init__()
self.layers = nn.Sequential(
nn.Linear(num_classes * num_models, 1024),
nn.LayerNorm(1024),
nn.LeakyReLU(negative_slope=0.01, inplace=True),
nn.Dropout(0.8),
nn.Linear(1024, 2048),
nn.LeakyReLU(negative_slope=0.01, inplace=True),
nn.Dropout(0.5),
nn.Linear(2048, 2048),
nn.LeakyReLU(negative_slope=0.01, inplace=True),
nn.Dropout(0.5),
nn.Linear(2048, num_classes),
)
def forward(self, x):
x = x.view(x.size(0), -1)
x = self.layers(x)
return x
def mlp_meta(num_classes, num_models):
model = MLP(num_classes, num_models)
return model
# Hyperparameters
input_dim = 3 * 224 * 224 # Modify this based on your input size
hidden_dim = 256
output_dim = NUM_CLASSES
# Create the data loaders using your data_loader functions50
train_loader, val_loader = load_data(COMBINED_DATA_DIR + "1", preprocess, BATCH_SIZE)
test_loader = load_test_data("data/test/Task 1", preprocess, BATCH_SIZE)
model_paths = [
"output/checkpoints/bestsqueezenetSE3.pth",
"output/checkpoints/EfficientNetB3WithDropout.pth",
"output/checkpoints/MobileNetV2WithDropout2.pth",
]
# Define a function to load pretrained models
def load_pretrained_model(path, model):
model.load_state_dict(torch.load(path))
return model.to(DEVICE)
def rand_bbox(size, lam):
W = size[2]
H = size[3]
cut_rat = np.sqrt(1.0 - lam)
cut_w = np.int_(W * cut_rat)
cut_h = np.int_(H * cut_rat)
# uniform
cx = np.random.randint(W)
cy = np.random.randint(H)
bbx1 = np.clip(cx - cut_w // 2, 0, W)
bby1 = np.clip(cy - cut_h // 2, 0, H)
bbx2 = np.clip(cx + cut_w // 2, 0, W)
bby2 = np.clip(cy + cut_h // 2, 0, H)
return bbx1, bby1, bbx2, bby2
def cutmix_data(input, target, alpha=1.0):
if alpha > 0:
lam = np.random.beta(alpha, alpha)
else:
lam = 1
batch_size = input.size()[0]
index = torch.randperm(batch_size)
rand_index = torch.randperm(input.size()[0])
bbx1, bby1, bbx2, bby2 = rand_bbox(input.size(), lam)
input[:, :, bbx1:bbx2, bby1:bby2] = input[rand_index, :, bbx1:bbx2, bby1:bby2]
lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (input.size()[-1] * input.size()[-2]))
targets_a = target
targets_b = target[rand_index]
return input, targets_a, targets_b, lam
def cutmix_criterion(criterion, outputs, targets_a, targets_b, lam):
return lam * criterion(outputs, targets_a) + (1 - lam) * criterion(
outputs, targets_b
)
# Load pretrained models
model1 = load_pretrained_model(
model_paths[0], SqueezeNet1_0WithSE(num_classes=NUM_CLASSES)
).to(DEVICE)
model2 = load_pretrained_model(
model_paths[1], EfficientNetB3WithDropout(num_classes=NUM_CLASSES)
).to(DEVICE)
model3 = load_pretrained_model(
model_paths[2], MobileNetV2WithDropout(num_classes=NUM_CLASSES)
).to(DEVICE)
models = [model1, model2, model3]
# Create the meta learner
meta_learner_model = mlp_meta(NUM_CLASSES, len(models)).to(DEVICE)
meta_optimizer = torch.optim.Adam(meta_learner_model.parameters(), lr=0.001)
meta_loss_fn = torch.nn.CrossEntropyLoss()
# Define the Cosine Annealing Learning Rate Scheduler
scheduler = CosineAnnealingLR(
meta_optimizer, T_max=700
) # T_max is the number of epochs for the cosine annealing.
# Define loss function and optimizer for the meta learner
criterion = nn.CrossEntropyLoss().to(DEVICE)
# Record learning rate
lr_hist = []
# Training loop
num_epochs = 160
for epoch in range(num_epochs):
print("[Epoch: {}]".format(epoch + 1))
print("Total number of batches: {}".format(len(train_loader)))
for batch_idx, data in enumerate(train_loader, 0):
print("Batch: {}".format(batch_idx + 1))
inputs, labels = data
inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
inputs, targets_a, targets_b, lam = cutmix_data(inputs, labels, alpha=1)
# Forward pass through the three pretrained models
features1 = model1(inputs)
features2 = model2(inputs)
features3 = model3(inputs)
# Stack the features from the three models
stacked_features = torch.cat((features1, features2, features3), dim=1).to(
DEVICE
)
# Forward pass through the meta learner
meta_output = meta_learner_model(stacked_features)
# Compute the loss
loss = cutmix_criterion(criterion, meta_output, targets_a, targets_b, lam)
# Compute the accuracy
_, predicted = torch.max(meta_output, 1)
total = labels.size(0)
correct = (predicted == labels).sum().item()
# Backpropagation and optimization
meta_optimizer.zero_grad()
loss.backward()
meta_optimizer.step()
lr_hist.append(meta_optimizer.param_groups[0]["lr"])
scheduler.step()
print("Train Loss: {}".format(loss.item()))
print("Train Accuracy: {}%".format(100 * correct / total))
# Validation
meta_learner_model.eval()
correct = 0
total = 0
val_loss = 0
with torch.no_grad():
for data in val_loader:
inputs, labels = data
inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
features1 = model1(inputs)
features2 = model2(inputs)
features3 = model3(inputs)
stacked_features = torch.cat((features1, features2, features3), dim=1).to(
DEVICE
)
outputs = meta_learner_model(stacked_features)
loss = criterion(outputs, labels) # Use the validation loss
val_loss += loss.item()
_, predicted = torch.max(outputs, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print(
"Validation Loss: {}".format(val_loss / len(val_loader))
) # Calculate the average loss
print("Validation Accuracy: {}%".format(100 * correct / total))
print("Finished Training")
# Test the ensemble
print("Testing the ensemble")
meta_learner_model.eval()
correct = 0
total = 0
with torch.no_grad():
for data in test_loader:
inputs, labels = data
inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
features1 = model1(inputs)
features2 = model2(inputs)
features3 = model3(inputs)
stacked_features = torch.cat((features1, features2, features3), dim=1)
outputs = meta_learner_model(stacked_features)
_, predicted = torch.max(outputs, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print(
"Accuracy of the ensemble network on the test images: {}%".format(
100 * correct / total
)
)
# Plot the learning rate history
plt.plot(lr_hist)
plt.xlabel("Iterations")
plt.ylabel("Learning Rate")
plt.title("Learning Rate History")
plt.show()
# Save the model
torch.save(meta_learner_model.state_dict(), "output/checkpoints/ensemble.pth")